pir.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. #include "pir.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. vector<uint64_t> get_dimensions(uint64_t plaintext_num, uint32_t d) {
  6. assert(d > 0);
  7. assert(plaintext_num > 0);
  8. vector<uint64_t> dimensions(d);
  9. for (uint32_t i = 0; i < d; i++) {
  10. dimensions[i] = std::max((uint32_t) 2, (uint32_t) floor(pow(plaintext_num, 1.0/d)));
  11. }
  12. uint32_t product = 1;
  13. uint32_t j = 0;
  14. // if plaintext_num is not a d-power
  15. if ((double) dimensions[0] != pow(plaintext_num, 1.0 / d)) {
  16. while (product < plaintext_num && j < d) {
  17. product = 1;
  18. dimensions[j++]++;
  19. for (uint32_t i = 0; i < d; i++) {
  20. product *= dimensions[i];
  21. }
  22. }
  23. }
  24. return dimensions;
  25. }
  26. void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
  27. uint32_t d, EncryptionParameters &params,
  28. PirParams &pir_params) {
  29. // Determine the maximum size of each dimension
  30. // plain modulus = a power of 2 plus 1
  31. uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
  32. uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
  33. #ifdef DEBUG
  34. cout << "log(plain mod) before expand = " << logt << endl;
  35. cout << "number of FV plaintexts = " << plaintext_num << endl;
  36. #endif
  37. vector<SmallModulus> coeff_mod_array;
  38. uint32_t logq = 0;
  39. for (uint32_t i = 0; i < 1; i++) {
  40. coeff_mod_array.emplace_back(SmallModulus());
  41. coeff_mod_array[i] = DefaultParams::small_mods_60bit(i);
  42. logq += coeff_mod_array[i].bit_count();
  43. }
  44. params.set_poly_modulus_degree(N);
  45. params.set_coeff_modulus(coeff_mod_array);
  46. params.set_plain_modulus(plain_mod);
  47. vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
  48. uint32_t expansion_ratio = 0;
  49. for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
  50. double logqi = log2(params.coeff_modulus()[i].value());
  51. cout << "PIR: logqi = " << logqi << endl;
  52. expansion_ratio += ceil(logqi / logt);
  53. }
  54. pir_params.d = d;
  55. pir_params.dbc = 6;
  56. pir_params.n = plaintext_num;
  57. pir_params.nvec = nvec;
  58. pir_params.expansion_ratio = expansion_ratio << 1; // because one ciphertext = two polys
  59. }
  60. uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d,
  61. uint64_t ele_num, uint64_t ele_size) {
  62. // Goal: find max logtp such that logtp + ceil(log(ceil(d_root(n)))) <= logt
  63. // where n = ceil(ele_num / floor(N*logtp / ele_size *8))
  64. for (uint32_t logtp = logt; logtp >= 2; logtp--) {
  65. uint64_t n = plaintexts_per_db(logtp, N, ele_num, ele_size);
  66. if (logtp == logt && n == 1) {
  67. return logtp - 1;
  68. }
  69. if ((double)logtp + ceil(log2(ceil(pow(n, 1.0/(double)d)))) <= logt) {
  70. return logtp;
  71. }
  72. }
  73. assert(0); // this should never happen
  74. return logt;
  75. }
  76. // Number of coefficients needed to represent a database element
  77. uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
  78. return ceil(8 * ele_size / (double)logtp);
  79. }
  80. // Number of database elements that can fit in a single FV plaintext
  81. uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
  82. uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
  83. uint64_t ele_per_ptxt = N / coeff_per_ele;
  84. assert(ele_per_ptxt > 0);
  85. return ele_per_ptxt;
  86. }
  87. // Number of FV plaintexts needed to represent the database
  88. uint64_t plaintexts_per_db(uint32_t logtp, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
  89. uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
  90. return ceil((double)ele_num / ele_per_ptxt);
  91. }
  92. vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
  93. uint64_t size_out = coefficients_per_element(limit, size);
  94. vector<uint64_t> output(size_out);
  95. uint32_t room = limit;
  96. uint64_t *target = &output[0];
  97. for (uint32_t i = 0; i < size; i++) {
  98. uint8_t src = bytes[i];
  99. uint32_t rest = 8;
  100. while (rest) {
  101. if (room == 0) {
  102. target++;
  103. room = limit;
  104. }
  105. uint32_t shift = rest;
  106. if (room < rest) {
  107. shift = room;
  108. }
  109. *target = *target << shift;
  110. *target = *target | (src >> (8 - shift));
  111. src = src << shift;
  112. room -= shift;
  113. rest -= shift;
  114. }
  115. }
  116. *target = *target << room;
  117. return output;
  118. }
  119. void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, uint32_t size_out) {
  120. uint32_t room = 8;
  121. uint32_t j = 0;
  122. uint8_t *target = output;
  123. for (uint32_t i = 0; i < coeffs.coeff_count(); i++) {
  124. uint64_t src = coeffs[i];
  125. uint32_t rest = limit;
  126. while (rest && j < size_out) {
  127. uint32_t shift = rest;
  128. if (room < rest) {
  129. shift = room;
  130. }
  131. target[j] = target[j] << shift;
  132. target[j] = target[j] | (src >> (limit - shift));
  133. src = src << shift;
  134. room -= shift;
  135. rest -= shift;
  136. if (room == 0) {
  137. j++;
  138. room = 8;
  139. }
  140. }
  141. }
  142. }
  143. void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
  144. uint32_t coeff_count = coeffs.size();
  145. plain.resize(coeff_count);
  146. util::set_uint_uint(coeffs.data(), coeff_count, plain.data());
  147. }
  148. vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
  149. uint32_t num = Nvec.size();
  150. uint64_t product = 1;
  151. for (uint32_t i = 0; i < num; i++) {
  152. product *= Nvec[i];
  153. }
  154. uint64_t j = desiredIndex;
  155. vector<uint64_t> result;
  156. for (uint32_t i = 0; i < num; i++) {
  157. product /= Nvec[i];
  158. uint64_t ji = j / product;
  159. result.push_back(ji);
  160. j -= ji * product;
  161. }
  162. return result;
  163. }
  164. inline Ciphertext deserialize_ciphertext(string s) {
  165. Ciphertext c;
  166. std::istringstream input(s);
  167. c.unsafe_load(input);
  168. return c;
  169. }
  170. vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext) {
  171. vector<Ciphertext> c;
  172. for (uint32_t i = 0; i < count; i++) {
  173. c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext)));
  174. }
  175. return c;
  176. }
  177. PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext) {
  178. vector<vector<Ciphertext>> c;
  179. for (uint32_t i = 0; i < d; i++) {
  180. c.push_back(deserialize_ciphertexts(
  181. count,
  182. s.substr(i * count * len_ciphertext, count * len_ciphertext),
  183. len_ciphertext)
  184. );
  185. }
  186. return c;
  187. }
  188. inline string serialize_ciphertext(Ciphertext c) {
  189. std::ostringstream output;
  190. c.save(output);
  191. return output.str();
  192. }
  193. string serialize_ciphertexts(vector<Ciphertext> c) {
  194. string s;
  195. for (uint32_t i = 0; i < c.size(); i++) {
  196. s.append(serialize_ciphertext(c[i]));
  197. }
  198. return s;
  199. }
  200. string serialize_query(vector<vector<Ciphertext>> c) {
  201. string s;
  202. for (uint32_t i = 0; i < c.size(); i++) {
  203. for (uint32_t j = 0; j < c[i].size(); j++) {
  204. s.append(serialize_ciphertext(c[i][j]));
  205. }
  206. }
  207. return s;
  208. }
  209. string serialize_galoiskeys(GaloisKeys g) {
  210. std::ostringstream output;
  211. g.save(output);
  212. return output.str();
  213. }
  214. GaloisKeys *deserialize_galoiskeys(string s) {
  215. GaloisKeys *g = new GaloisKeys();
  216. std::istringstream input(s);
  217. g->unsafe_load(input);
  218. return g;
  219. }