pir.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #include "pir.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::uint32_t d) {
  6. assert(d > 0);
  7. assert(num_of_plaintexts > 0);
  8. std::uint64_t root = max(static_cast<uint32_t>(2),static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0/d))));
  9. std::vector<std::uint64_t> dimensions(d, root);
  10. for(int i = 0; i < d; i++){
  11. if(accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>()) > num_of_plaintexts){
  12. break;
  13. }
  14. dimensions[i] += 1;
  15. }
  16. std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>());
  17. assert(prod >= num_of_plaintexts);
  18. return dimensions;
  19. }
  20. void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
  21. seal::EncryptionParameters &enc_params){
  22. enc_params.set_poly_modulus_degree(N);
  23. enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
  24. enc_params.set_plain_modulus(PlainModulus::Batching(N, logt+1));
  25. // the +1 above ensures we get logt bits for each plaintext coefficient. Otherwise
  26. // the coefficient modulus t will be logt bits, but only floor(t) = logt-1 (whp)
  27. // will be usable (since we need to ensure that all data in the coefficient is < t).
  28. }
  29. void verify_encryption_params(const seal::EncryptionParameters &enc_params){
  30. SEALContext context(enc_params, true);
  31. if(!context.parameters_set()){
  32. throw invalid_argument("SEAL parameters not valid.");
  33. }
  34. if(!context.using_keyswitching()){
  35. throw invalid_argument("SEAL parameters do not support key switching.");
  36. }
  37. if(!context.first_context_data()->qualifiers().using_batching){
  38. throw invalid_argument("SEAL parameters do not support batching.");
  39. }
  40. BatchEncoder batch_encoder(context);
  41. size_t slot_count = batch_encoder.slot_count();
  42. if(slot_count != enc_params.poly_modulus_degree()){
  43. throw invalid_argument("Slot count not equal to poly modulus degree - this will cause issues downstream.");
  44. }
  45. return;
  46. }
  47. void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
  48. const EncryptionParameters &enc_params, PirParams &pir_params,
  49. bool enable_symmetric, bool enable_batching, bool enable_mswitching){
  50. std::uint32_t N = enc_params.poly_modulus_degree();
  51. Modulus t = enc_params.plain_modulus();
  52. std::uint32_t logt = floor(log2(t.value())); // # of usable bits
  53. std::uint64_t elements_per_plaintext;
  54. std::uint64_t num_of_plaintexts;
  55. if(enable_batching){
  56. elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
  57. num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
  58. }
  59. else{
  60. elements_per_plaintext = 1;
  61. num_of_plaintexts = ele_num;
  62. }
  63. vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
  64. uint32_t expansion_ratio = 0;
  65. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  66. double logqi = log2(enc_params.coeff_modulus()[i].value());
  67. expansion_ratio += ceil(logqi / logt);
  68. }
  69. pir_params.enable_symmetric = enable_symmetric;
  70. pir_params.enable_batching = enable_batching;
  71. pir_params.enable_mswitching = enable_mswitching;
  72. pir_params.ele_num = ele_num;
  73. pir_params.ele_size = ele_size;
  74. pir_params.elements_per_plaintext = elements_per_plaintext;
  75. pir_params.num_of_plaintexts = num_of_plaintexts;
  76. pir_params.d = d;
  77. pir_params.expansion_ratio = expansion_ratio << 1;
  78. pir_params.nvec = nvec;
  79. pir_params.slot_count = N;
  80. }
  81. void print_pir_params(const PirParams &pir_params){
  82. std::uint32_t prod = accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1, multiplies<uint64_t>());
  83. cout << "PIR Parameters" << endl;
  84. cout << "number of elements: " << pir_params.ele_num << endl;
  85. cout << "element size: " << pir_params.ele_size << endl;
  86. cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext << endl;
  87. cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d << endl;
  88. cout << "number of BFV plaintexts (before padding): " << pir_params.num_of_plaintexts << endl;
  89. cout << "Number of BFV plaintexts after padding (to fill d-dimensional hyperrectangle): " << prod << endl;
  90. cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
  91. cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
  92. cout << "Using recursive mod switching: " << pir_params.enable_mswitching << endl;
  93. cout << "slot count: " << pir_params.slot_count << endl;
  94. cout << "=============================="<< endl;
  95. }
  96. void print_seal_params(const EncryptionParameters &enc_params){
  97. std::uint32_t N = enc_params.poly_modulus_degree();
  98. Modulus t = enc_params.plain_modulus();
  99. std::uint32_t logt = floor(log2(t.value()));
  100. cout << "SEAL encryption parameters" << endl;
  101. cout << "Degree of polynomial modulus (N): " << N << endl;
  102. cout << "Size of plaintext modulus (log t):" << logt << endl;
  103. cout << "There are " << enc_params.coeff_modulus().size() << " coefficient modulus:" << endl;
  104. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  105. double logqi = log2(enc_params.coeff_modulus()[i].value());
  106. cout << "Size of coefficient modulus " << i << " (log q_" << i << "): " << logqi << endl;
  107. }
  108. cout << "=============================="<< endl;
  109. }
  110. // Number of coefficients needed to represent a database element
  111. uint64_t coefficients_per_element(uint32_t logt, uint64_t ele_size) {
  112. return ceil(8 * ele_size / (double)logt);
  113. }
  114. // Number of database elements that can fit in a single FV plaintext
  115. uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
  116. uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
  117. uint64_t ele_per_ptxt = N / coeff_per_ele;
  118. assert(ele_per_ptxt > 0);
  119. return ele_per_ptxt;
  120. }
  121. // Number of FV plaintexts needed to represent the database
  122. uint64_t plaintexts_per_db(uint32_t logt, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
  123. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  124. return ceil((double)ele_num / ele_per_ptxt);
  125. }
  126. vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
  127. uint64_t size_out = coefficients_per_element(limit, size);
  128. vector<uint64_t> output(size_out);
  129. uint32_t room = limit;
  130. uint64_t *target = &output[0];
  131. for (uint32_t i = 0; i < size; i++) {
  132. uint8_t src = bytes[i];
  133. uint32_t rest = 8;
  134. while (rest) {
  135. if (room == 0) {
  136. target++;
  137. room = limit;
  138. }
  139. uint32_t shift = rest;
  140. if (room < rest) {
  141. shift = room;
  142. }
  143. *target = *target << shift;
  144. *target = *target | (src >> (8 - shift));
  145. src = src << shift;
  146. room -= shift;
  147. rest -= shift;
  148. }
  149. }
  150. *target = *target << room;
  151. return output;
  152. }
  153. void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs, uint8_t *output, uint32_t size_out, uint32_t ele_size){
  154. uint32_t room = 8;
  155. uint32_t j = 0;
  156. uint8_t *target = output;
  157. uint32_t bits_left = ele_size * 8;
  158. for (uint32_t i = 0; i < coeffs.size(); i++) {
  159. if(bits_left == 0){
  160. bits_left = ele_size * 8;
  161. }
  162. uint64_t src = coeffs[i];
  163. uint32_t rest = min(limit, bits_left);
  164. while (rest && j < size_out) {
  165. uint32_t shift = rest;
  166. if (room < rest) {
  167. shift = room;
  168. }
  169. target[j] = target[j] << shift;
  170. target[j] = target[j] | (src >> (limit - shift));
  171. src = src << shift;
  172. room -= shift;
  173. rest -= shift;
  174. bits_left -= shift;
  175. if (room == 0) {
  176. j++;
  177. room = 8;
  178. }
  179. }
  180. }
  181. }
  182. void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
  183. uint32_t coeff_count = coeffs.size();
  184. plain.resize(coeff_count);
  185. util::set_uint(coeffs.data(), coeff_count, plain.data());
  186. }
  187. vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
  188. uint32_t num = Nvec.size();
  189. uint64_t product = 1;
  190. for (uint32_t i = 0; i < num; i++) {
  191. product *= Nvec[i];
  192. }
  193. uint64_t j = desiredIndex;
  194. vector<uint64_t> result;
  195. for (uint32_t i = 0; i < num; i++) {
  196. product /= Nvec[i];
  197. uint64_t ji = j / product;
  198. result.push_back(ji);
  199. j -= ji * product;
  200. }
  201. return result;
  202. }
  203. uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
  204. if (mod.uint64_count() > 1) {
  205. cout << "Mod too big to invert";
  206. }
  207. uint64_t inverse = 0;
  208. if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
  209. cout << "Could not invert value";
  210. }
  211. return inverse;
  212. }
  213. uint32_t compute_expansion_ratio(EncryptionParameters params) {
  214. uint32_t expansion_ratio = 0;
  215. uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  216. for (size_t i = 0; i < params.coeff_modulus().size(); ++i) {
  217. double coeff_bit_size = log2(params.coeff_modulus()[i].value());
  218. expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff);
  219. }
  220. return expansion_ratio;
  221. }
  222. vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params, const Ciphertext& ct) {
  223. const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  224. const auto coeff_count = params.poly_modulus_degree();
  225. const auto coeff_mod_count = params.coeff_modulus().size();
  226. const uint64_t pt_bitmask = (1 << pt_bits_per_coeff) - 1;
  227. vector<Plaintext> result(compute_expansion_ratio(params) * ct.size());
  228. auto pt_iter = result.begin();
  229. for (size_t poly_index = 0; poly_index < ct.size(); ++poly_index) {
  230. for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
  231. ++coeff_mod_index) {
  232. const double coeff_bit_size =
  233. log2(params.coeff_modulus()[coeff_mod_index].value());
  234. const size_t local_expansion_ratio =
  235. ceil(coeff_bit_size / pt_bits_per_coeff);
  236. size_t shift = 0;
  237. for (size_t i = 0; i < local_expansion_ratio; ++i) {
  238. pt_iter->resize(coeff_count);
  239. for (size_t c = 0; c < coeff_count; ++c) {
  240. (*pt_iter)[c] =
  241. (ct.data(poly_index)[coeff_mod_index * coeff_count + c] >>
  242. shift) &
  243. pt_bitmask;
  244. }
  245. ++pt_iter;
  246. shift += pt_bits_per_coeff;
  247. }
  248. }
  249. }
  250. return result;
  251. }
  252. void compose_to_ciphertext(EncryptionParameters params, vector<Plaintext>::const_iterator pt_iter,
  253. const size_t ct_poly_count, Ciphertext& ct) {
  254. const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
  255. const auto coeff_count = params.poly_modulus_degree();
  256. const auto coeff_mod_count = params.coeff_modulus().size();
  257. ct.resize(ct_poly_count);
  258. for (size_t poly_index = 0; poly_index < ct_poly_count; ++poly_index) {
  259. for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
  260. ++coeff_mod_index) {
  261. const double coeff_bit_size =
  262. log2(params.coeff_modulus()[coeff_mod_index].value());
  263. const size_t local_expansion_ratio =
  264. ceil(coeff_bit_size / pt_bits_per_coeff);
  265. size_t shift = 0;
  266. for (size_t i = 0; i < local_expansion_ratio; ++i) {
  267. for (size_t c = 0; c < pt_iter->coeff_count(); ++c) {
  268. if (shift == 0) {
  269. ct.data(poly_index)[coeff_mod_index * coeff_count + c] =
  270. (*pt_iter)[c];
  271. } else {
  272. ct.data(poly_index)[coeff_mod_index * coeff_count + c] +=
  273. ((*pt_iter)[c] << shift);
  274. }
  275. }
  276. ++pt_iter;
  277. shift += pt_bits_per_coeff;
  278. }
  279. }
  280. }
  281. }
  282. void compose_to_ciphertext(EncryptionParameters params, const vector<Plaintext>& pts, Ciphertext& ct) {
  283. return compose_to_ciphertext(params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
  284. }
  285. PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
  286. shared_ptr<SEALContext> context) {
  287. vector<vector<Ciphertext>> q;
  288. std::istringstream input(s);
  289. for (uint32_t i = 0; i < d; i++) {
  290. vector<Ciphertext> cs;
  291. for (uint32_t i = 0; i < count; i++) {
  292. Ciphertext c;
  293. c.load(*context, input);
  294. cs.push_back(c);
  295. }
  296. q.push_back(cs);
  297. }
  298. return q;
  299. }
  300. string serialize_galoiskeys(Serializable<GaloisKeys> g) {
  301. std::ostringstream output;
  302. g.save(output);
  303. return output.str();
  304. }
  305. GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
  306. GaloisKeys *g = new GaloisKeys();
  307. std::istringstream input(s);
  308. g->load(*context, input);
  309. return g;
  310. }