pir_client.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. #include "pir_client.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. PIRClient::PIRClient(const EncryptionParameters &params,
  6. const PirParams &pir_parms) :
  7. params_(params){
  8. newcontext_ = make_shared<SEALContext>(params, true);
  9. pir_params_ = pir_parms;
  10. keygen_ = make_unique<KeyGenerator>(*newcontext_);
  11. PublicKey public_key;
  12. keygen_->create_public_key(public_key);
  13. encryptor_ = make_unique<Encryptor>(*newcontext_, public_key);
  14. SecretKey secret_key = keygen_->secret_key();
  15. decryptor_ = make_unique<Decryptor>(*newcontext_, secret_key);
  16. evaluator_ = make_unique<Evaluator>(*newcontext_);
  17. }
  18. PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
  19. indices_ = compute_indices(desiredIndex, pir_params_.nvec);
  20. vector<vector<Ciphertext> > result(pir_params_.d);
  21. int N = params_.poly_modulus_degree();
  22. Plaintext pt(params_.poly_modulus_degree());
  23. for (uint32_t i = 0; i < indices_.size(); i++) {
  24. uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
  25. // initialize result.
  26. cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = " << indices_[i] << endl;
  27. cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
  28. for (uint32_t j =0; j < num_ptxts; j++){
  29. pt.set_zero();
  30. if (indices_[i] > N*(j+1) || indices_[i] < N*j){
  31. #ifdef DEBUG
  32. cout << "Client: coming here: so just encrypt zero." << endl;
  33. #endif
  34. // just encrypt zero
  35. } else{
  36. #ifdef DEBUG
  37. cout << "Client: encrypting a real thing " << endl;
  38. #endif
  39. uint64_t real_index = indices_[i] - N*j;
  40. uint64_t n_i = pir_params_.nvec[i];
  41. uint64_t total = N;
  42. if (j == num_ptxts - 1){
  43. total = n_i % N;
  44. }
  45. uint64_t log_total = ceil(log2(total));
  46. cout << "Client: Inverting " << pow(2, log_total) << endl;
  47. pt[real_index] = invert_mod(pow(2, log_total), params_.plain_modulus());
  48. }
  49. Ciphertext dest;
  50. encryptor_->encrypt(pt, dest);
  51. result[i].push_back(dest);
  52. }
  53. }
  54. return result;
  55. }
  56. uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
  57. auto N = params_.poly_modulus_degree();
  58. auto logt = floor(log2(params_.plain_modulus().value()));
  59. auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  60. return static_cast<uint64_t>(element_idx / ele_per_ptxt);
  61. }
  62. uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
  63. uint32_t N = params_.poly_modulus_degree();
  64. uint32_t logt = floor(log2(params_.plain_modulus().value()));
  65. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  66. return element_idx % ele_per_ptxt;
  67. }
  68. Plaintext PIRClient::decode_reply(PirReply reply) {
  69. uint32_t exp_ratio = pir_params_.expansion_ratio;
  70. uint32_t recursion_level = pir_params_.d;
  71. vector<Ciphertext> temp = reply;
  72. uint64_t t = params_.plain_modulus().value();
  73. for (uint32_t i = 0; i < recursion_level; i++) {
  74. cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl;
  75. vector<Ciphertext> newtemp;
  76. vector<Plaintext> tempplain;
  77. for (uint32_t j = 0; j < temp.size(); j++) {
  78. Plaintext ptxt;
  79. decryptor_->decrypt(temp[j], ptxt);
  80. #ifdef DEBUG
  81. cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl;
  82. #endif
  83. //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
  84. tempplain.push_back(ptxt);
  85. #ifdef DEBUG
  86. cout << "recursion level : " << i << " noise budget : ";
  87. cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
  88. #endif
  89. if ((j + 1) % exp_ratio == 0 && j > 0) {
  90. // Combine into one ciphertext.
  91. Ciphertext combined = compose_to_ciphertext(tempplain);
  92. newtemp.push_back(combined);
  93. tempplain.clear();
  94. // cout << "Client: const term of ciphertext = " << combined[0] << endl;
  95. }
  96. }
  97. cout << "Client: done." << endl;
  98. cout << endl;
  99. if (i == recursion_level - 1) {
  100. assert(temp.size() == 1);
  101. return tempplain[0];
  102. } else {
  103. tempplain.clear();
  104. temp = newtemp;
  105. }
  106. }
  107. // This should never be called
  108. assert(0);
  109. Plaintext fail;
  110. return fail;
  111. }
  112. GaloisKeys PIRClient::generate_galois_keys() {
  113. // Generate the Galois keys needed for coeff_select.
  114. vector<uint32_t> galois_elts;
  115. int N = params_.poly_modulus_degree();
  116. int logN = get_power_of_two(N);
  117. //cout << "printing galois elements...";
  118. for (int i = 0; i < logN; i++) {
  119. galois_elts.push_back((N + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
  120. //#ifdef DEBUG
  121. // cout << galois_elts.back() << ", ";
  122. //#endif
  123. }
  124. GaloisKeys gal_keys;
  125. keygen_->create_galois_keys(galois_elts, gal_keys);
  126. return gal_keys;
  127. }
  128. Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
  129. size_t encrypted_count = 2;
  130. auto coeff_count = params_.poly_modulus_degree();
  131. auto coeff_mod_count = params_.coeff_modulus().size();
  132. uint64_t plainMod = params_.plain_modulus().value();
  133. int logt = floor(log2(plainMod));
  134. Ciphertext result(*newcontext_);
  135. result.resize(encrypted_count);
  136. // A triple for loop. Going over polys, moduli, and decomposed index.
  137. for (int i = 0; i < encrypted_count; i++) {
  138. uint64_t *encrypted_pointer = result.data(i);
  139. for (int j = 0; j < coeff_mod_count; j++) {
  140. // populate one poly at a time.
  141. // create a polynomial to store the current decomposition value
  142. // which will be copied into the array to populate it at the current
  143. // index.
  144. double logqj = log2(params_.coeff_modulus()[j].value());
  145. int expansion_ratio = ceil(logqj / logt);
  146. uint64_t cur = 1;
  147. // cout << "Client: expansion_ratio = " << expansion_ratio << endl;
  148. for (int k = 0; k < expansion_ratio; k++) {
  149. // Compose here
  150. const uint64_t *plain_coeff =
  151. plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
  152. .data();
  153. for (int m = 0; m < coeff_count; m++) {
  154. if (k == 0) {
  155. *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
  156. } else {
  157. *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
  158. }
  159. }
  160. // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
  161. cur <<= logt;
  162. }
  163. // XXX: Reduction modulo qj. This is needed?
  164. /*
  165. for (int m = 0; m < coeff_count; m++) {
  166. *(encrypted_pointer + m + j * coeff_count) %=
  167. params_.coeff_modulus()[j].value();
  168. }
  169. */
  170. }
  171. }
  172. return result;
  173. }