pir.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. #include "pir.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts,
  6. std::uint32_t d) {
  7. assert(d > 0);
  8. assert(num_of_plaintexts > 0);
  9. std::uint64_t root =
  10. max(static_cast<uint32_t>(2),
  11. static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0 / d))));
  12. std::vector<std::uint64_t> dimensions(d, root);
  13. for (int i = 0; i < d; i++) {
  14. if (accumulate(dimensions.begin(), dimensions.end(), 1,
  15. multiplies<uint64_t>()) > num_of_plaintexts) {
  16. break;
  17. }
  18. dimensions[i] += 1;
  19. }
  20. std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1,
  21. multiplies<uint64_t>());
  22. assert(prod >= num_of_plaintexts);
  23. return dimensions;
  24. }
  25. void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
  26. seal::EncryptionParameters &enc_params) {
  27. enc_params.set_poly_modulus_degree(N);
  28. enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
  29. enc_params.set_plain_modulus(PlainModulus::Batching(N, logt + 1));
  30. // the +1 above ensures we get logt bits for each plaintext coefficient.
  31. // Otherwise the coefficient modulus t will be logt bits, but only floor(t) =
  32. // logt-1 (whp) will be usable (since we need to ensure that all data in the
  33. // coefficient is < t).
  34. }
  35. void verify_encryption_params(const seal::EncryptionParameters &enc_params) {
  36. SEALContext context(enc_params, true);
  37. if (!context.parameters_set()) {
  38. throw invalid_argument("SEAL parameters not valid.");
  39. }
  40. if (!context.using_keyswitching()) {
  41. throw invalid_argument("SEAL parameters do not support key switching.");
  42. }
  43. if (!context.first_context_data()->qualifiers().using_batching) {
  44. throw invalid_argument("SEAL parameters do not support batching.");
  45. }
  46. BatchEncoder batch_encoder(context);
  47. size_t slot_count = batch_encoder.slot_count();
  48. if (slot_count != enc_params.poly_modulus_degree()) {
  49. throw invalid_argument("Slot count not equal to poly modulus degree - this "
  50. "will cause issues downstream.");
  51. }
  52. return;
  53. }
  54. void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
  55. const EncryptionParameters &enc_params,
  56. PirParams &pir_params, bool enable_symmetric,
  57. bool enable_batching, bool enable_mswitching) {
  58. std::uint32_t N = enc_params.poly_modulus_degree();
  59. Modulus t = enc_params.plain_modulus();
  60. std::uint32_t logt = floor(log2(t.value())); // # of usable bits
  61. std::uint64_t elements_per_plaintext;
  62. std::uint64_t num_of_plaintexts;
  63. if (enable_batching) {
  64. elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
  65. num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
  66. } else {
  67. elements_per_plaintext = 1;
  68. num_of_plaintexts = ele_num;
  69. }
  70. vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
  71. uint32_t expansion_ratio = 0;
  72. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  73. double logqi = log2(enc_params.coeff_modulus()[i].value());
  74. expansion_ratio += ceil(logqi / logt);
  75. }
  76. pir_params.enable_symmetric = enable_symmetric;
  77. pir_params.enable_batching = enable_batching;
  78. pir_params.enable_mswitching = enable_mswitching;
  79. pir_params.ele_num = ele_num;
  80. pir_params.ele_size = ele_size;
  81. pir_params.elements_per_plaintext = elements_per_plaintext;
  82. pir_params.num_of_plaintexts = num_of_plaintexts;
  83. pir_params.d = d;
  84. pir_params.expansion_ratio = expansion_ratio << 1;
  85. pir_params.nvec = nvec;
  86. pir_params.slot_count = N;
  87. }
  88. void print_pir_params(const PirParams &pir_params) {
  89. std::uint32_t prod =
  90. accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1,
  91. multiplies<uint64_t>());
  92. cout << "PIR Parameters" << endl;
  93. cout << "number of elements: " << pir_params.ele_num << endl;
  94. cout << "element size: " << pir_params.ele_size << endl;
  95. cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext
  96. << endl;
  97. cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d
  98. << endl;
  99. cout << "number of BFV plaintexts (before padding): "
  100. << pir_params.num_of_plaintexts << endl;
  101. cout << "Number of BFV plaintexts after padding (to fill d-dimensional "
  102. "hyperrectangle): "
  103. << prod << endl;
  104. cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
  105. cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
  106. cout << "Using recursive mod switching: " << pir_params.enable_mswitching
  107. << endl;
  108. cout << "slot count: " << pir_params.slot_count << endl;
  109. cout << "==============================" << endl;
  110. }
  111. void print_seal_params(const EncryptionParameters &enc_params) {
  112. std::uint32_t N = enc_params.poly_modulus_degree();
  113. Modulus t = enc_params.plain_modulus();
  114. std::uint32_t logt = floor(log2(t.value()));
  115. cout << "SEAL encryption parameters" << endl;
  116. cout << "Degree of polynomial modulus (N): " << N << endl;
  117. cout << "Size of plaintext modulus (log t):" << logt << endl;
  118. cout << "There are " << enc_params.coeff_modulus().size()
  119. << " coefficient modulus:" << endl;
  120. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  121. double logqi = log2(enc_params.coeff_modulus()[i].value());
  122. cout << "Size of coefficient modulus " << i << " (log q_" << i
  123. << "): " << logqi << endl;
  124. }
  125. cout << "==============================" << endl;
  126. }
  127. // Number of coefficients needed to represent a database element
  128. uint64_t coefficients_per_element(uint32_t logt, uint64_t ele_size) {
  129. return ceil(8 * ele_size / (double)logt);
  130. }
  131. // Number of database elements that can fit in a single FV plaintext
  132. uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
  133. uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
  134. uint64_t ele_per_ptxt = N / coeff_per_ele;
  135. assert(ele_per_ptxt > 0);
  136. return ele_per_ptxt;
  137. }
  138. // Number of FV plaintexts needed to represent the database
  139. uint64_t plaintexts_per_db(uint32_t logt, uint64_t N, uint64_t ele_num,
  140. uint64_t ele_size) {
  141. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  142. return ceil((double)ele_num / ele_per_ptxt);
  143. }
  144. vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes,
  145. uint64_t size) {
  146. uint64_t size_out = coefficients_per_element(limit, size);
  147. vector<uint64_t> output(size_out);
  148. uint32_t room = limit;
  149. uint64_t *target = &output[0];
  150. for (uint32_t i = 0; i < size; i++) {
  151. uint8_t src = bytes[i];
  152. uint32_t rest = 8;
  153. while (rest) {
  154. if (room == 0) {
  155. target++;
  156. room = limit;
  157. }
  158. uint32_t shift = rest;
  159. if (room < rest) {
  160. shift = room;
  161. }
  162. *target = *target << shift;
  163. *target = *target | (src >> (8 - shift));
  164. src = src << shift;
  165. room -= shift;
  166. rest -= shift;
  167. }
  168. }
  169. *target = *target << room;
  170. return output;
  171. }
  172. void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs,
  173. uint8_t *output, uint32_t size_out, uint32_t ele_size) {
  174. uint32_t room = 8;
  175. uint32_t j = 0;
  176. uint8_t *target = output;
  177. uint32_t bits_left = ele_size * 8;
  178. for (uint32_t i = 0; i < coeffs.size(); i++) {
  179. if (bits_left == 0) {
  180. bits_left = ele_size * 8;
  181. }
  182. uint64_t src = coeffs[i];
  183. uint32_t rest = min(limit, bits_left);
  184. while (rest && j < size_out) {
  185. uint32_t shift = rest;
  186. if (room < rest) {
  187. shift = room;
  188. }
  189. target[j] = target[j] << shift;
  190. target[j] = target[j] | (src >> (limit - shift));
  191. src = src << shift;
  192. room -= shift;
  193. rest -= shift;
  194. bits_left -= shift;
  195. if (room == 0) {
  196. j++;
  197. room = 8;
  198. }
  199. }
  200. }
  201. }
  202. void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
  203. uint32_t coeff_count = coeffs.size();
  204. plain.resize(coeff_count);
  205. util::set_uint(coeffs.data(), coeff_count, plain.data());
  206. }
  207. vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
  208. uint32_t num = Nvec.size();
  209. uint64_t product = 1;
  210. for (uint32_t i = 0; i < num; i++) {
  211. product *= Nvec[i];
  212. }
  213. uint64_t j = desiredIndex;
  214. vector<uint64_t> result;
  215. for (uint32_t i = 0; i < num; i++) {
  216. product /= Nvec[i];
  217. uint64_t ji = j / product;
  218. result.push_back(ji);
  219. j -= ji * product;
  220. }
  221. return result;
  222. }
  223. uint64_t invert_mod(uint64_t m, const seal::Modulus &mod) {
  224. if (mod.uint64_count() > 1) {
  225. cout << "Mod too big to invert";
  226. }
  227. uint64_t inverse = 0;
  228. if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
  229. cout << "Could not invert value";
  230. }
  231. return inverse;
  232. }
  233. uint32_t compute_expansion_ratio(EncryptionParameters params) {
  234. uint32_t expansion_ratio = 0;
  235. uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  236. for (size_t i = 0; i < params.coeff_modulus().size(); ++i) {
  237. double coeff_bit_size = log2(params.coeff_modulus()[i].value());
  238. expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff);
  239. }
  240. return expansion_ratio;
  241. }
  242. vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params,
  243. const Ciphertext &ct) {
  244. const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  245. const auto coeff_count = params.poly_modulus_degree();
  246. const auto coeff_mod_count = params.coeff_modulus().size();
  247. const uint64_t pt_bitmask = (1 << pt_bits_per_coeff) - 1;
  248. vector<Plaintext> result(compute_expansion_ratio(params) * ct.size());
  249. auto pt_iter = result.begin();
  250. for (size_t poly_index = 0; poly_index < ct.size(); ++poly_index) {
  251. for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
  252. ++coeff_mod_index) {
  253. const double coeff_bit_size =
  254. log2(params.coeff_modulus()[coeff_mod_index].value());
  255. const size_t local_expansion_ratio =
  256. ceil(coeff_bit_size / pt_bits_per_coeff);
  257. size_t shift = 0;
  258. for (size_t i = 0; i < local_expansion_ratio; ++i) {
  259. pt_iter->resize(coeff_count);
  260. for (size_t c = 0; c < coeff_count; ++c) {
  261. (*pt_iter)[c] =
  262. (ct.data(poly_index)[coeff_mod_index * coeff_count + c] >>
  263. shift) &
  264. pt_bitmask;
  265. }
  266. ++pt_iter;
  267. shift += pt_bits_per_coeff;
  268. }
  269. }
  270. }
  271. return result;
  272. }
  273. void compose_to_ciphertext(EncryptionParameters params,
  274. vector<Plaintext>::const_iterator pt_iter,
  275. const size_t ct_poly_count, Ciphertext &ct) {
  276. const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  277. const auto coeff_count = params.poly_modulus_degree();
  278. const auto coeff_mod_count = params.coeff_modulus().size();
  279. ct.resize(ct_poly_count);
  280. for (size_t poly_index = 0; poly_index < ct_poly_count; ++poly_index) {
  281. for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
  282. ++coeff_mod_index) {
  283. const double coeff_bit_size =
  284. log2(params.coeff_modulus()[coeff_mod_index].value());
  285. const size_t local_expansion_ratio =
  286. ceil(coeff_bit_size / pt_bits_per_coeff);
  287. size_t shift = 0;
  288. for (size_t i = 0; i < local_expansion_ratio; ++i) {
  289. for (size_t c = 0; c < pt_iter->coeff_count(); ++c) {
  290. if (shift == 0) {
  291. ct.data(poly_index)[coeff_mod_index * coeff_count + c] =
  292. (*pt_iter)[c];
  293. } else {
  294. ct.data(poly_index)[coeff_mod_index * coeff_count + c] +=
  295. ((*pt_iter)[c] << shift);
  296. }
  297. }
  298. ++pt_iter;
  299. shift += pt_bits_per_coeff;
  300. }
  301. }
  302. }
  303. }
  304. void compose_to_ciphertext(EncryptionParameters params,
  305. const vector<Plaintext> &pts, Ciphertext &ct) {
  306. return compose_to_ciphertext(
  307. params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
  308. }
  309. PirQuery deserialize_query(uint32_t d, uint32_t count, string s,
  310. uint32_t len_ciphertext,
  311. shared_ptr<SEALContext> context) {
  312. vector<vector<Ciphertext>> q;
  313. std::istringstream input(s);
  314. for (uint32_t i = 0; i < d; i++) {
  315. vector<Ciphertext> cs;
  316. for (uint32_t i = 0; i < count; i++) {
  317. Ciphertext c;
  318. c.load(*context, input);
  319. cs.push_back(c);
  320. }
  321. q.push_back(cs);
  322. }
  323. return q;
  324. }
  325. string serialize_galoiskeys(Serializable<GaloisKeys> g) {
  326. std::ostringstream output;
  327. g.save(output);
  328. return output.str();
  329. }
  330. GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
  331. GaloisKeys *g = new GaloisKeys();
  332. std::istringstream input(s);
  333. g->load(*context, input);
  334. return g;
  335. }