pir.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #include "pir.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::uint32_t d) {
  6. assert(d > 0);
  7. assert(num_of_plaintexts > 0);
  8. std::uint64_t root = max(static_cast<uint32_t>(2),static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0/d))));
  9. std::vector<std::uint64_t> dimensions(d, root);
  10. for(int i = 0; i < d; i++){
  11. if(accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>()) > num_of_plaintexts){
  12. break;
  13. }
  14. dimensions[i] += 1;
  15. }
  16. std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>());
  17. cout << "Total:" << num_of_plaintexts << endl << "Prod: "
  18. << prod << endl;
  19. assert(prod > num_of_plaintexts);
  20. return dimensions;
  21. }
  22. void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
  23. uint32_t d, EncryptionParameters &params,
  24. PirParams &pir_params) {
  25. // Determine the maximum size of each dimension
  26. // plain modulus = a power of 2 plus 1
  27. uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
  28. #ifdef DEBUG
  29. cout << "log(plain mod) before expand = " << logt << endl;
  30. cout << "number of FV plaintexts = " << plaintext_num << endl;
  31. #endif
  32. params.set_poly_modulus_degree(N);
  33. params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
  34. params.set_plain_modulus(PlainModulus::Batching(N, logt));
  35. logt = floor(log2(params.plain_modulus().value()));
  36. cout << "logt: " << logt << endl << "N: " << N << endl <<
  37. "ele_num: " << ele_num << endl << "ele_size: " << ele_size << endl;
  38. uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
  39. vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
  40. uint32_t expansion_ratio = 0;
  41. for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
  42. double logqi = log2(params.coeff_modulus()[i].value());
  43. cout << "PIR: logqi = " << logqi << endl;
  44. expansion_ratio += ceil(logqi / logt);
  45. }
  46. pir_params.d = d;
  47. pir_params.dbc = 6;
  48. pir_params.n = plaintext_num;
  49. pir_params.nvec = nvec;
  50. pir_params.expansion_ratio = expansion_ratio << 1; // because one ciphertext = two polys
  51. }
  52. uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d,
  53. uint64_t ele_num, uint64_t ele_size) {
  54. // Goal: find max logtp such that logtp + ceil(log(ceil(d_root(n)))) <= logt
  55. // where n = ceil(ele_num / floor(N*logtp / ele_size *8))
  56. for (uint32_t logtp = logt; logtp >= 2; logtp--) {
  57. uint64_t n = plaintexts_per_db(logtp, N, ele_num, ele_size);
  58. if (logtp == logt && n == 1) {
  59. return logtp - 1;
  60. }
  61. if ((double)logtp + ceil(log2(ceil(pow(n, 1.0/(double)d)))) <= logt) {
  62. return logtp;
  63. }
  64. }
  65. assert(0); // this should never happen
  66. return logt;
  67. }
  68. // Number of coefficients needed to represent a database element
  69. uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
  70. return ceil(8 * ele_size / (double)logtp);
  71. }
  72. // Number of database elements that can fit in a single FV plaintext
  73. uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
  74. uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
  75. uint64_t ele_per_ptxt = N / coeff_per_ele;
  76. assert(ele_per_ptxt > 0);
  77. return ele_per_ptxt;
  78. }
  79. // Number of FV plaintexts needed to represent the database
  80. uint64_t plaintexts_per_db(uint32_t logtp, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
  81. uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
  82. return ceil((double)ele_num / ele_per_ptxt);
  83. }
  84. vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
  85. uint64_t size_out = coefficients_per_element(limit, size);
  86. vector<uint64_t> output(size_out);
  87. uint32_t room = limit;
  88. uint64_t *target = &output[0];
  89. for (uint32_t i = 0; i < size; i++) {
  90. uint8_t src = bytes[i];
  91. uint32_t rest = 8;
  92. while (rest) {
  93. if (room == 0) {
  94. target++;
  95. room = limit;
  96. }
  97. uint32_t shift = rest;
  98. if (room < rest) {
  99. shift = room;
  100. }
  101. *target = *target << shift;
  102. *target = *target | (src >> (8 - shift));
  103. src = src << shift;
  104. room -= shift;
  105. rest -= shift;
  106. }
  107. }
  108. *target = *target << room;
  109. return output;
  110. }
  111. void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, uint32_t size_out) {
  112. uint32_t room = 8;
  113. uint32_t j = 0;
  114. uint8_t *target = output;
  115. for (uint32_t i = 0; i < coeffs.coeff_count(); i++) {
  116. uint64_t src = coeffs[i];
  117. uint32_t rest = limit;
  118. while (rest && j < size_out) {
  119. uint32_t shift = rest;
  120. if (room < rest) {
  121. shift = room;
  122. }
  123. target[j] = target[j] << shift;
  124. target[j] = target[j] | (src >> (limit - shift));
  125. src = src << shift;
  126. room -= shift;
  127. rest -= shift;
  128. if (room == 0) {
  129. j++;
  130. room = 8;
  131. }
  132. }
  133. }
  134. }
  135. void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
  136. uint32_t coeff_count = coeffs.size();
  137. plain.resize(coeff_count);
  138. util::set_uint(coeffs.data(), coeff_count, plain.data());
  139. }
  140. vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
  141. uint32_t num = Nvec.size();
  142. uint64_t product = 1;
  143. for (uint32_t i = 0; i < num; i++) {
  144. product *= Nvec[i];
  145. }
  146. uint64_t j = desiredIndex;
  147. vector<uint64_t> result;
  148. for (uint32_t i = 0; i < num; i++) {
  149. product /= Nvec[i];
  150. uint64_t ji = j / product;
  151. result.push_back(ji);
  152. j -= ji * product;
  153. }
  154. return result;
  155. }
  156. uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
  157. if (mod.uint64_count() > 1) {
  158. cout << "Mod too big to invert";
  159. }
  160. uint64_t inverse = 0;
  161. if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
  162. cout << "Could not invert value";
  163. }
  164. return inverse;
  165. }
  166. inline Ciphertext deserialize_ciphertext(string s, shared_ptr<SEALContext> context) {
  167. Ciphertext c;
  168. std::istringstream input(s);
  169. c.unsafe_load(*context, input);
  170. return c;
  171. }
  172. vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext,
  173. shared_ptr<SEALContext> context) {
  174. vector<Ciphertext> c;
  175. for (uint32_t i = 0; i < count; i++) {
  176. c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext), context));
  177. }
  178. return c;
  179. }
  180. PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
  181. shared_ptr<SEALContext> context) {
  182. vector<vector<Ciphertext>> c;
  183. for (uint32_t i = 0; i < d; i++) {
  184. c.push_back(deserialize_ciphertexts(
  185. count,
  186. s.substr(i * count * len_ciphertext, count * len_ciphertext),
  187. len_ciphertext, context)
  188. );
  189. }
  190. return c;
  191. }
  192. inline string serialize_ciphertext(Ciphertext c) {
  193. std::ostringstream output;
  194. c.save(output);
  195. return output.str();
  196. }
  197. string serialize_ciphertexts(vector<Ciphertext> c) {
  198. string s;
  199. for (uint32_t i = 0; i < c.size(); i++) {
  200. s.append(serialize_ciphertext(c[i]));
  201. }
  202. return s;
  203. }
  204. string serialize_query(vector<vector<Ciphertext>> c) {
  205. string s;
  206. for (uint32_t i = 0; i < c.size(); i++) {
  207. for (uint32_t j = 0; j < c[i].size(); j++) {
  208. s.append(serialize_ciphertext(c[i][j]));
  209. }
  210. }
  211. return s;
  212. }
  213. string serialize_galoiskeys(GaloisKeys g) {
  214. std::ostringstream output;
  215. g.save(output);
  216. return output.str();
  217. }
  218. GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
  219. GaloisKeys *g = new GaloisKeys();
  220. std::istringstream input(s);
  221. g->unsafe_load(*context, input);
  222. return g;
  223. }