pir.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. #include "pir.hpp"
  2. using namespace std;
  3. #include <vector>
  4. using namespace seal;
  5. using namespace seal::util;
  6. PIRClient::PIRClient(const seal::EncryptionParameters &parms, pirParams & pirparms) {
  7. parms_ = parms;
  8. SEALContext context(parms);
  9. keygen_.reset(new KeyGenerator(context));
  10. encryptor_.reset(new Encryptor(context, keygen_->public_key()));
  11. uint64_t plainMod = parms.plain_modulus().value();
  12. int N = pirparms.Nvec[0];
  13. int logN = ceil(log(N) / log(2));
  14. EncryptionParameters newparms = parms;
  15. newparms.set_plain_modulus(plainMod >> logN);
  16. newparms_ = newparms;
  17. SEALContext newcontext(newparms);
  18. SecretKey secret_key = keygen_->secret_key();
  19. secret_key.mutable_hash_block() = newparms.hash_block();
  20. decryptor_.reset(new Decryptor(newcontext, secret_key));
  21. evaluator_.reset(new Evaluator(newcontext));
  22. int expansion_ratio = 0;
  23. for (int i = 0; i < parms.coeff_modulus().size(); ++i)
  24. {
  25. double logqi = log(parms.coeff_modulus()[i].value());
  26. expansion_ratio += ceil(logqi / log(newparms.plain_modulus().value()));
  27. }
  28. pirparms.expansion_ratio_ = expansion_ratio << 1;
  29. pirparms_ = pirparms;
  30. }
  31. pirQuery PIRClient::generate_query(int desiredIndex) {
  32. vector<int> indices = compute_indices(desiredIndex, pirparms_.Nvec);
  33. vector<Ciphertext> result;
  34. for (int i = 0; i < indices.size(); i++) {
  35. Ciphertext dest;
  36. encryptor_->encrypt(Plaintext("1x^" + std::to_string(indices[i])), dest);
  37. result.push_back(dest);
  38. }
  39. return result;
  40. }
  41. Plaintext PIRClient::decode_reply(pirReply reply) {
  42. int exp_ratio = pirparms_.expansion_ratio_;
  43. vector<Ciphertext> temp = reply;
  44. int recursion_level = pirparms_.d;
  45. for (int i = 0; i < recursion_level; i++) {
  46. vector<Ciphertext> newtemp;
  47. vector<Plaintext> tempplain;
  48. for (int j = 0; j < temp.size(); j++) {
  49. Plaintext ptxt;
  50. decryptor_->decrypt(temp[j], ptxt);
  51. tempplain.push_back(ptxt);
  52. if ( (j + 1) % exp_ratio == 0 && j > 0) {
  53. // Combine into one ciphertext.
  54. Ciphertext combined = compose_to_ciphertext(tempplain);
  55. newtemp.push_back(combined);
  56. }
  57. }
  58. if (i == recursion_level - 1) {
  59. if (temp.size() != 1) throw;
  60. return tempplain[0];
  61. }
  62. else {
  63. tempplain.clear();
  64. temp = newtemp;
  65. }
  66. }
  67. }
  68. GaloisKeys PIRClient::generate_galois_keys() {
  69. vector<uint64_t> galois_elts;
  70. int n = parms_.poly_modulus().coeff_count() - 1;
  71. int logn = get_power_of_two(n);
  72. for (int i = 0; i < logn; i++)
  73. {
  74. galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
  75. }
  76. GaloisKeys galois_keys;
  77. keygen_->generate_galois_keys(pirparms_.dbc, galois_elts, galois_keys);
  78. return galois_keys;
  79. }
  80. void PIRClient::print_info(Ciphertext & encrypted)
  81. {
  82. Plaintext ptxt;
  83. decryptor_->decrypt(encrypted, ptxt);
  84. }
  85. // Given a vector N1, ..., Nd and a number desired index j between 0 and prod(N_i).
  86. // Return j indices j1, ..., jd such that j = j1 (N/N1) + j2 (N/N1N2) + .....
  87. vector<int> compute_indices(int desiredIndex, vector<int> Nvec) {
  88. int d = Nvec.size();
  89. int product = 1;
  90. for (int i = 0; i < Nvec.size(); i++) {
  91. product *= Nvec[i];
  92. }
  93. int j = desiredIndex;
  94. vector<int> result;
  95. for (int i = 0; i < d; i++) {
  96. product /= Nvec[i];
  97. int ji = j / product;
  98. result.push_back(ji);
  99. j -= ji*product;
  100. }
  101. return result;
  102. }
  103. PIRServer::PIRServer(const seal::EncryptionParameters & parms, const pirParams &pirparams) {
  104. parms_ = parms;
  105. pirparams_ = pirparams;
  106. SEALContext context(parms);
  107. evaluator_.reset(new Evaluator(context));
  108. is_db_preprocessed_ = false;
  109. }
  110. void PIRServer::preprocess_database() {
  111. if (!is_db_preprocessed_) {
  112. for (int i = 0; i < dataBase_->size(); i++) {
  113. evaluator_->transform_to_ntt(dataBase_->operator[](i));
  114. }
  115. is_db_preprocessed_ = true;
  116. }
  117. }
  118. void PIRServer::set_database(vector<Plaintext> *db) {
  119. if (db == nullptr) {
  120. throw invalid_argument("db cannot be null");
  121. }
  122. dataBase_ = db;
  123. }
  124. pirReply PIRServer::generate_reply(pirQuery query, int client_id) {
  125. vector<int> Nvec = pirparams_.Nvec;
  126. uint64_t product = 1;
  127. for (int i = 0; i < Nvec.size(); i++) {
  128. product *= Nvec[i];
  129. }
  130. int coeff_count = parms_.poly_modulus().coeff_count();
  131. vector<Plaintext> *cur = dataBase_;
  132. vector<Plaintext> intermediate_plain; // decompose....
  133. auto my_pool = MemoryPoolHandle::New();
  134. for (int i = 0; i < Nvec.size(); i++) {
  135. int Ni = Nvec[i];
  136. vector<Ciphertext> expanded_query = expand_query(query[i], Ni, galoisKeys_[client_id]);
  137. #ifdef DEBUG
  138. cout << "query ciphertext check: " << endl;
  139. for (int tt = 0; tt < expanded_query.size(); tt++) {
  140. client.print_info(expanded_query[tt]);
  141. }
  142. #endif
  143. // Transform expanded query to NTT, and ...
  144. for (int jj = 0; jj < expanded_query.size(); jj++) {
  145. evaluator_->transform_to_ntt(expanded_query[jj]);
  146. }
  147. // Transform plaintext to NTT. If database is pre-processed, can skip
  148. if ((!is_db_preprocessed_) || i > 0) {
  149. for (int jj = 0; jj < cur->size(); jj++) {
  150. evaluator_->transform_to_ntt((*cur)[jj]);
  151. }
  152. }
  153. product /= Ni;
  154. vector<Ciphertext> intermediate(product);
  155. Ciphertext temp1;
  156. for (int k = 0; k < product; k++) {
  157. evaluator_->multiply_plain_ntt(expanded_query[0], (*cur)[k], intermediate[k]);
  158. for (int j = 1; j < Ni; j++) {
  159. evaluator_->multiply_plain_ntt(expanded_query[j], (*cur)[k + j*product], temp1);
  160. evaluator_->add(intermediate[k], temp1); // Adds to the first component.
  161. }
  162. }
  163. for (int jj = 0; jj < intermediate.size(); jj++) {
  164. evaluator_->transform_from_ntt(intermediate[jj]);
  165. }
  166. #ifdef DEBUG
  167. cout << "intermediate ciphertext check: " << endl;
  168. for (int tt = 0; tt < intermediate.size(); tt++) {
  169. cout << tt + 1 << " / " << intermediate.size() << " ";
  170. client.print_info(intermediate[tt]);
  171. }
  172. #endif
  173. if (i == Nvec.size() - 1) {
  174. return intermediate;
  175. } else {
  176. intermediate_plain.clear();
  177. intermediate_plain.reserve(pirparams_.expansion_ratio_ * product);
  178. cur = &intermediate_plain;
  179. util::Pointer tempplain_ptr(allocate_zero_poly(pirparams_.expansion_ratio_ * product, coeff_count, my_pool));
  180. for (int rr = 0; rr < product; rr++) {
  181. decompose_to_plaintexts_ptr(intermediate[rr], tempplain_ptr.get() + rr * pirparams_.expansion_ratio_* coeff_count);
  182. #ifdef DEBUG
  183. cout << "compose decompose check: " << endl;
  184. client.print_info(evaluator_->compose_to_ciphertext(tempplain));
  185. #endif
  186. for (int jj = 0; jj < pirparams_.expansion_ratio_; jj++){
  187. int offset = rr * pirparams_.expansion_ratio_* coeff_count + jj * coeff_count;
  188. intermediate_plain.emplace_back(coeff_count, tempplain_ptr.get() + offset);
  189. }
  190. }
  191. product *= pirparams_.expansion_ratio_; // multiply by expansion rate.
  192. }
  193. }
  194. }
  195. vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, int d, const GaloisKeys &galkey) {
  196. uint64_t plainMod = parms_.plain_modulus().value();
  197. #ifdef DEBUG
  198. cout << "PIRServer side plain modulus = " << plainMod << endl;
  199. #endif
  200. // Assume that d is a power of 2. If not, round it to the next power of 2.
  201. int logd = ceil(log(d) / log(2));
  202. Plaintext two("2");
  203. vector<int> galois_elts;
  204. int n = parms_.poly_modulus().coeff_count() - 1;
  205. for (int i = 0; i < logd; i++) {
  206. galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
  207. }
  208. vector<Ciphertext> temp;
  209. temp.push_back(encrypted);
  210. Ciphertext tempctxt;
  211. Ciphertext tempctxt_rotated;
  212. Ciphertext tempctxt_shifted;
  213. Ciphertext tempctxt_rotatedshifted;
  214. int shift = 1;
  215. for (int i = 0; i < logd -1; i++) {
  216. vector<Ciphertext> newtemp(temp.size() << 1);
  217. int index_raw = (n << 1) - (1 << i);
  218. int index = (index_raw * galois_elts[i]) % (n << 1);
  219. for (int a = 0; a < temp.size(); a++) {
  220. evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated); // Can be done in-place
  221. evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
  222. multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
  223. multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
  224. evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a+temp.size()]); // Enc(2^i x^j) if j = 0 (mod 2**i).
  225. }
  226. temp = newtemp;
  227. }
  228. // Last iteration of the loop
  229. vector<Ciphertext> newtemp(temp.size() << 1);
  230. int index_raw = (n << 1) - (1 << (logd - 1));
  231. int index = (index_raw * galois_elts[logd - 1]) % (n << 1);
  232. for (int a = 0; a < temp.size(); a++) {
  233. if(a >= (d - (1 << (logd - 1)))) { // corner case.
  234. evaluator_->multiply_plain(temp[a], two, newtemp[a]);// plain multiplication by 2.
  235. }
  236. else {
  237. evaluator_->apply_galois(temp[a], galois_elts[logd-1], galkey, tempctxt_rotated); // Can be done in-place
  238. evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
  239. multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
  240. multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
  241. evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]); // Enc(2^i x^j) if j = 0 (mod 2**i).
  242. }
  243. }
  244. vector<Ciphertext>::const_iterator first = newtemp.begin();
  245. vector<Ciphertext>::const_iterator last = newtemp.begin() + d;
  246. vector<Ciphertext> newVec(first, last);
  247. return newVec;
  248. }
  249. void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext & destination, int index)
  250. {
  251. // Extract parameter
  252. int coeff_mod_count = parms_.coeff_modulus().size();
  253. int coeff_count = parms_.poly_modulus().coeff_count();
  254. int coeff_bit_count = coeff_mod_count * bits_per_uint64;
  255. int encrypted_ptr_increment = coeff_count * coeff_mod_count;
  256. int encrypted_count = encrypted.size();
  257. // First copy over.
  258. destination = encrypted;
  259. // Prepare for destination
  260. // Multiply X^index for each ciphertext polynomial
  261. for (int i = 0; i < encrypted_count; i++)
  262. {
  263. for (int j = 0; j < coeff_mod_count; j++)
  264. {
  265. negacyclic_shift_poly_coeffmod(encrypted.pointer(i) + (j * coeff_count), coeff_count - 1, index, parms_.coeff_modulus()[j], destination.mutable_pointer(i) + (j * coeff_count));
  266. }
  267. }
  268. }
  269. Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
  270. Ciphertext result;
  271. int encrypted_count = 2;
  272. int coeff_count = newparms_.poly_modulus().coeff_count();
  273. int coeff_mod_count = newparms_.coeff_modulus().size();
  274. int array_poly_uint64_count = coeff_count * coeff_mod_count;
  275. result.reserve(newparms_, encrypted_count);
  276. int plain_bit_count = newparms_.plain_modulus().bit_count();
  277. uint64_t plainMod = newparms_.plain_modulus().value();
  278. // A triple for loop. Going over polys, moduli, and decomposed index.
  279. for (int i = 0; i < encrypted_count; i++) {
  280. uint64_t *encrypted_pointer = result.mutable_pointer(i);
  281. for (int j = 0; j < coeff_mod_count; j++)
  282. {
  283. // populate one poly at a time.
  284. // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
  285. double logqj = log(newparms_.coeff_modulus()[j].value());
  286. int expansion_ratio = ceil(logqj / log(plainMod));
  287. uint64_t cur = 1;
  288. for (int k = 0; k < expansion_ratio; k++)
  289. {
  290. // Compose here
  291. const uint64_t *plain_coeff = plains[k + j*(expansion_ratio)+i*(coeff_mod_count*expansion_ratio)].pointer();
  292. for (int m = 0; m < coeff_count - 1; m++)
  293. {
  294. if (k == 0) {
  295. *(encrypted_pointer + m + j*coeff_count) = *(plain_coeff + m) * cur;
  296. }
  297. else {
  298. *(encrypted_pointer + m + j*coeff_count) += *(plain_coeff + m) * cur;
  299. }
  300. }
  301. *(encrypted_pointer + coeff_count - 1 + j*coeff_count) = 0;
  302. cur *= plainMod;
  303. }
  304. // Reduction modulo qj. This is needed?
  305. for (int m = 0; m < coeff_count; m++)
  306. {
  307. *(encrypted_pointer + m + j*coeff_count) %= newparms_.coeff_modulus()[j].value();
  308. }
  309. }
  310. }
  311. result.mutable_hash_block() = newparms_.hash_block();
  312. return result;
  313. }
  314. void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, uint64_t* plain_ptr) {
  315. vector<Plaintext> result;
  316. int coeff_count = parms_.poly_modulus().coeff_count();
  317. int coeff_mod_count = parms_.coeff_modulus().size();
  318. int array_poly_uint64_count = coeff_count * coeff_mod_count;
  319. int plain_bit_count = parms_.plain_modulus().bit_count();
  320. int encrypted_count = encrypted.size();
  321. // Generate powers of t.
  322. uint64_t plainModMinusOne = parms_.plain_modulus().value() -1;
  323. int exp = ceil(log2(plainModMinusOne + 1));
  324. for (int i = 0; i < encrypted_count; i++) {
  325. const uint64_t * encrypted_pointer = encrypted.pointer(i);
  326. for (int j = 0; j < coeff_mod_count; j++)
  327. {
  328. // populate one poly at a time.
  329. // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
  330. int shift = 0;
  331. int logqj = log2(parms_.coeff_modulus()[j].value());
  332. int expansion_ratio = (logqj + exp -1) / exp;
  333. uint64_t curexp = 0;
  334. for (int k = 0; k < expansion_ratio; k++)
  335. {
  336. // Decompose here
  337. for (int m = 0; m < coeff_count; m++)
  338. {
  339. *plain_ptr = (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
  340. plain_ptr++;
  341. }
  342. curexp += exp;
  343. }
  344. }
  345. }
  346. return;
  347. }
  348. std::vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
  349. vector<Plaintext> result;
  350. int coeff_count = parms_.poly_modulus().coeff_count();
  351. int coeff_mod_count = parms_.coeff_modulus().size();
  352. int array_poly_uint64_count = coeff_count * coeff_mod_count;
  353. int plain_bit_count = parms_.plain_modulus().bit_count();
  354. int encrypted_count = encrypted.size();
  355. // Generate powers of t.
  356. uint64_t plainMod = parms_.plain_modulus().value();
  357. for (int i = 0; i < encrypted_count; i++) {
  358. const uint64_t * encrypted_pointer = encrypted.pointer(i);
  359. for (int j = 0; j < coeff_mod_count; j++)
  360. {
  361. // populate one poly at a time.
  362. // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
  363. int shift = 0;
  364. int logqj = log(parms_.coeff_modulus()[j].value());
  365. int expansion_ratio = ceil(logqj / log(plainMod));
  366. uint64_t cur = 1;
  367. for (int k = 0; k < expansion_ratio; k++)
  368. {
  369. // Decompose here
  370. BigPoly temp;
  371. temp.resize(coeff_count, plain_bit_count);
  372. temp.set_zero();
  373. uint64_t *plain_coeff = temp.pointer();
  374. for (int m = 0; m < coeff_count; m++)
  375. {
  376. *(plain_coeff + m) = (*(encrypted_pointer + m + (j * coeff_count)) / cur) % plainMod;
  377. }
  378. result.push_back(Plaintext(temp));
  379. cur *= plainMod;
  380. }
  381. }
  382. }
  383. return result;
  384. }
  385. void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, Plaintext &plain)
  386. {
  387. int coeff_count = coeffs.size();
  388. plain.resize(coeff_count);
  389. util:set_uint_uint(coeffs.data(), coeff_count, plain.pointer());
  390. }
  391. string serialize_ciphertext(Ciphertext c) {
  392. std::stringstream output(std::ios::binary|std::ios::out);
  393. c.save(output);
  394. return output.str();
  395. }
  396. string serialize_ciphertexts(vector<Ciphertext> c) {
  397. string s;
  398. for(int i=0; i<c.size(); i++) {
  399. s.append(serialize_ciphertext(c[i]));
  400. }
  401. return s;
  402. }
  403. Ciphertext* deserialize_ciphertext(string s) {
  404. Ciphertext *c = new Ciphertext();
  405. std::stringstream input(std::ios::binary|std::ios::in);
  406. input.str(s);
  407. c->load(input);
  408. return c;
  409. }
  410. vector<Ciphertext> deserialize_ciphertexts(int count, string s, int len_ciphertext) {
  411. vector<Ciphertext> c;
  412. for(int i=0; i<count; i++) {
  413. c.push_back(*(deserialize_ciphertext(s.substr(i*len_ciphertext, len_ciphertext))));
  414. }
  415. return c;
  416. }
  417. string serialize_plaintext(Plaintext p) {
  418. std::stringstream output(std::ios::binary|std::ios::out);
  419. p.save(output);
  420. return output.str();
  421. }
  422. string serialize_plaintexts(vector<Plaintext> p) {
  423. string s;
  424. for(int i=0; i<p.size(); i++) {
  425. s.append(serialize_plaintext(p[i]));
  426. }
  427. return s;
  428. }
  429. Plaintext* deserialize_plaintext(string s) {
  430. Plaintext *c = new Plaintext();
  431. std::stringstream input(std::ios::binary|std::ios::in);
  432. input.str(s);
  433. c->load(input);
  434. return c;
  435. }
  436. vector<Plaintext> deserialize_plaintexts(int count, string s, int len_plaintext) {
  437. vector<Plaintext> p;
  438. for(int i=0; i<count; i++) {
  439. p.push_back(*(deserialize_plaintext(s.substr(i*len_plaintext, len_plaintext))));
  440. }
  441. return p;
  442. }
  443. string serialize_galoiskeys(GaloisKeys g) {
  444. std::stringstream output(std::ios::binary|std::ios::out);
  445. g.save(output);
  446. return output.str();
  447. }
  448. GaloisKeys* deserialize_galoiskeys(string s) {
  449. GaloisKeys *g = new GaloisKeys();
  450. std::stringstream input(std::ios::binary|std::ios::in);
  451. input.str(s);
  452. g->load(input);
  453. return g;
  454. }
  455. void
  456. cpp_buffer_free(char *buf) {
  457. free(buf);
  458. }
  459. void*
  460. cpp_client_setup(uint64_t len_total_bytes, uint64_t num_db_entries) {
  461. uint64_t number_of_items = num_db_entries;
  462. uint64_t size_per_item = (len_total_bytes/num_db_entries) << 3;
  463. int n = 2048;
  464. int logt = 22;
  465. uint64_t plainMod = static_cast<uint64_t> (1) << logt;
  466. int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logt );
  467. EncryptionParameters parms;
  468. parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
  469. vector<SmallModulus> coeff_mod_array;
  470. int logq = 0;
  471. for (int i = 0; i < 1; ++i)
  472. {
  473. coeff_mod_array.emplace_back(SmallModulus());
  474. coeff_mod_array[i] = small_mods_60bit(i);
  475. logq += coeff_mod_array[i].bit_count();
  476. }
  477. parms.set_coeff_modulus(coeff_mod_array);
  478. parms.set_plain_modulus(plainMod);
  479. pirParams pirparms;
  480. int item_per_plaintext = floor((double)get_power_of_two(plainMod) *n / size_per_item);
  481. pirparms.d = 2;
  482. pirparms.alpha = 1;
  483. pirparms.dbc = 8;
  484. pirparms.N = number_of_plaintexts;
  485. int sqrt_items = ceil(sqrt(number_of_plaintexts));
  486. int bound1 = number_of_plaintexts / sqrt_items;
  487. int bound2 = sqrt_items;
  488. vector<int> Nvec = { bound1, bound2 };
  489. pirparms.Nvec = Nvec;
  490. PIRClient *client = new PIRClient(parms, pirparms);
  491. return (void*) client;
  492. }
  493. char*
  494. cpp_client_generate_query(void* pir, uint64_t chosen_idx, uint64_t* rlen_total_bytes, uint64_t* rnum_logical_entries) {
  495. pirQuery query = ((PIRClient*) pir)->generate_query(chosen_idx);
  496. string s = serialize_ciphertexts(query);
  497. *rlen_total_bytes = s.length();
  498. *rnum_logical_entries = query.size();
  499. char *outptr, *result;
  500. result = (char*)calloc(*rlen_total_bytes, sizeof(char));
  501. memcpy(result, s.c_str(), s.length());
  502. return result;
  503. }
  504. char*
  505. cpp_client_generate_galois_keys(void *pir, uint64_t *rlen_total_bytes) {
  506. GaloisKeys g = ((PIRClient*) pir)->generate_galois_keys();
  507. string s = serialize_galoiskeys(g); //.c_str();
  508. char *outptr, *result;
  509. result = (char*)calloc(s.length(), sizeof(char));
  510. memcpy(result, s.c_str(), s.length());
  511. *rlen_total_bytes = s.length();
  512. return result;
  513. }
  514. char*
  515. cpp_client_process_reply(void* pir, char* r, uint64_t len_total_bytes, uint64_t num_logical_entries, uint64_t* rlen_total_bytes)
  516. {
  517. string s(r);
  518. vector<Ciphertext> reply = deserialize_ciphertexts(num_logical_entries, s, 32828);
  519. Plaintext p = ((PIRClient*) pir)->decode_reply(reply);
  520. string resp = serialize_plaintext(p);
  521. *rlen_total_bytes = resp.length();
  522. char *result = (char*)calloc(*rlen_total_bytes, sizeof(char));
  523. memcpy(result, resp.c_str(), resp.length());
  524. return result;
  525. }
  526. void
  527. cpp_client_free(void *pir)
  528. {
  529. delete (PIRClient*) pir;
  530. }
  531. void*
  532. cpp_server_setup(uint64_t len_total_bytes, char *db, uint64_t num_logical_entries)
  533. {
  534. uint64_t max_entry_size_bytes = len_total_bytes/num_logical_entries;
  535. uint64_t number_of_items = num_logical_entries;
  536. uint64_t size_per_item = max_entry_size_bytes << 3; // 288 B.
  537. int n = 2048;
  538. int logt = 22;
  539. uint64_t plainMod = static_cast<uint64_t> (1) << logt;
  540. int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logt );
  541. EncryptionParameters parms;
  542. parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
  543. vector<SmallModulus> coeff_mod_array;
  544. int logq = 0;
  545. for (int i = 0; i < 1; ++i)
  546. {
  547. coeff_mod_array.emplace_back(SmallModulus());
  548. coeff_mod_array[i] = small_mods_60bit(i);
  549. logq += coeff_mod_array[i].bit_count();
  550. }
  551. parms.set_coeff_modulus(coeff_mod_array);
  552. parms.set_plain_modulus(plainMod);
  553. pirParams pirparms;
  554. int item_per_plaintext = floor((double)get_power_of_two(plainMod) *n / size_per_item);
  555. pirparms.d = 2;
  556. pirparms.alpha = 1;
  557. pirparms.dbc = 8;
  558. pirparms.N = number_of_plaintexts;
  559. int sqrt_items = ceil(sqrt(number_of_plaintexts));
  560. int bound1 = number_of_plaintexts / sqrt_items;
  561. int bound2 = sqrt_items;
  562. vector<int> Nvec = { bound1, bound2 };
  563. pirparms.Nvec = Nvec;
  564. PIRServer *server = new PIRServer(parms, pirparms);
  565. string d(db);
  566. vector<Plaintext> items = deserialize_plaintexts(num_logical_entries, d, max_entry_size_bytes);
  567. server->set_database(&items);
  568. server->preprocess_database();
  569. return (void*) server;
  570. }
  571. void
  572. cpp_server_set_galois_keys(void *pir, char *q, uint64_t len_total_bytes, int client_id)
  573. {
  574. string s(q);
  575. GaloisKeys *g = deserialize_galoiskeys(s);
  576. ((PIRServer*)pir)->set_galois_key(client_id, *g);
  577. }
  578. char*
  579. cpp_server_process_query(void* pir, char* q, uint64_t len_total_bytes, uint64_t num_logical_entries, uint64_t* rlen_total_bytes, uint64_t* rnum_logical_entries, int client_id)
  580. {
  581. string str(q);
  582. pirQuery query = deserialize_ciphertexts(num_logical_entries, str, len_total_bytes/num_logical_entries);
  583. pirReply reply = ((PIRServer*) pir)->generate_reply(query, client_id);
  584. string s = serialize_ciphertexts(reply);
  585. *rlen_total_bytes = s.length();
  586. *rnum_logical_entries = reply.size();
  587. char *outptr, *result;
  588. result = (char*)calloc(*rlen_total_bytes, sizeof(char));
  589. memcpy(result, s.c_str(), s.length());
  590. return result;
  591. }
  592. void
  593. cpp_server_free(void *pir)
  594. {
  595. delete (PIRServer*) pir;
  596. }