pir_client.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #include "pir_client.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. PIRClient::PIRClient(const EncryptionParameters &params,
  6. const PirParams &pir_parms) :
  7. params_(params){
  8. newcontext_ = SEALContext::Create(params_);
  9. pir_params_ = pir_parms;
  10. keygen_ = make_unique<KeyGenerator>(newcontext_);
  11. encryptor_ = make_unique<Encryptor>(newcontext_, keygen_->public_key());
  12. SecretKey secret_key = keygen_->secret_key();
  13. decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
  14. evaluator_ = make_unique<Evaluator>(newcontext_);
  15. }
  16. PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
  17. indices_ = compute_indices(desiredIndex, pir_params_.nvec);
  18. compute_inverse_scales();
  19. vector<vector<Ciphertext> > result(pir_params_.d);
  20. int N = params_.poly_modulus_degree();
  21. Plaintext pt(params_.poly_modulus_degree());
  22. for (uint32_t i = 0; i < indices_.size(); i++) {
  23. uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
  24. // initialize result.
  25. cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = " << indices_[i] << endl;
  26. cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
  27. for (uint32_t j =0; j < num_ptxts; j++){
  28. pt.set_zero();
  29. if (indices_[i] > N*(j+1) || indices_[i] < N*j){
  30. #ifdef DEBUG
  31. cout << "Client: coming here: so just encrypt zero." << endl;
  32. #endif
  33. // just encrypt zero
  34. } else{
  35. #ifdef DEBUG
  36. cout << "Client: encrypting a real thing " << endl;
  37. #endif
  38. uint64_t real_index = indices_[i] - N*j;
  39. pt[real_index] = 1;
  40. }
  41. Ciphertext dest;
  42. encryptor_->encrypt(pt, dest);
  43. dest.parms_id() = params_.parms_id();
  44. result[i].push_back(dest);
  45. }
  46. }
  47. return result;
  48. }
  49. uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
  50. auto N = params_.poly_modulus_degree();
  51. auto logt = floor(log2(params_.plain_modulus().value()));
  52. auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  53. return static_cast<uint64_t>(element_idx / ele_per_ptxt);
  54. }
  55. uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
  56. uint32_t N = params_.poly_modulus_degree();
  57. uint32_t logt = floor(log2(params_.plain_modulus().value()));
  58. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  59. return element_idx % ele_per_ptxt;
  60. }
  61. Plaintext PIRClient::decode_reply(PirReply reply) {
  62. uint32_t exp_ratio = pir_params_.expansion_ratio;
  63. uint32_t recursion_level = pir_params_.d;
  64. vector<Ciphertext> temp = reply;
  65. uint64_t t = params_.plain_modulus().value();
  66. for (uint32_t i = 0; i < recursion_level; i++) {
  67. cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl;
  68. vector<Ciphertext> newtemp;
  69. vector<Plaintext> tempplain;
  70. for (uint32_t j = 0; j < temp.size(); j++) {
  71. Plaintext ptxt;
  72. decryptor_->decrypt(temp[j], ptxt);
  73. #ifdef DEBUG
  74. cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl;
  75. #endif
  76. // multiply by inverse_scale for every coefficient of ptxt
  77. for(int h = 0; h < ptxt.coeff_count(); h++){
  78. ptxt[h] *= inverse_scales_[recursion_level - 1 - i];
  79. ptxt[h] %= t;
  80. }
  81. //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
  82. tempplain.push_back(ptxt);
  83. #ifdef DEBUG
  84. cout << "recursion level : " << i << " noise budget : ";
  85. cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
  86. #endif
  87. if ((j + 1) % exp_ratio == 0 && j > 0) {
  88. // Combine into one ciphertext.
  89. Ciphertext combined = compose_to_ciphertext(tempplain);
  90. newtemp.push_back(combined);
  91. tempplain.clear();
  92. // cout << "Client: const term of ciphertext = " << combined[0] << endl;
  93. }
  94. }
  95. cout << "Client: done." << endl;
  96. cout << endl;
  97. if (i == recursion_level - 1) {
  98. assert(temp.size() == 1);
  99. return tempplain[0];
  100. } else {
  101. tempplain.clear();
  102. temp = newtemp;
  103. }
  104. }
  105. // This should never be called
  106. assert(0);
  107. Plaintext fail;
  108. return fail;
  109. }
  110. GaloisKeys PIRClient::generate_galois_keys() {
  111. // Generate the Galois keys needed for coeff_select.
  112. vector<uint64_t> galois_elts;
  113. int N = params_.poly_modulus_degree();
  114. int logN = get_power_of_two(N);
  115. //cout << "printing galois elements...";
  116. for (int i = 0; i < logN; i++) {
  117. galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
  118. //#ifdef DEBUG
  119. // cout << galois_elts.back() << ", ";
  120. //#endif
  121. }
  122. return keygen_->galois_keys(pir_params_.dbc, galois_elts);
  123. }
  124. Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
  125. size_t encrypted_count = 2;
  126. auto coeff_count = params_.poly_modulus_degree();
  127. auto coeff_mod_count = params_.coeff_modulus().size();
  128. uint64_t plainMod = params_.plain_modulus().value();
  129. int logt = floor(log2(plainMod));
  130. Ciphertext result(newcontext_);
  131. result.resize(encrypted_count);
  132. // A triple for loop. Going over polys, moduli, and decomposed index.
  133. for (int i = 0; i < encrypted_count; i++) {
  134. uint64_t *encrypted_pointer = result.data(i);
  135. for (int j = 0; j < coeff_mod_count; j++) {
  136. // populate one poly at a time.
  137. // create a polynomial to store the current decomposition value
  138. // which will be copied into the array to populate it at the current
  139. // index.
  140. double logqj = log2(params_.coeff_modulus()[j].value());
  141. int expansion_ratio = ceil(logqj / logt);
  142. uint64_t cur = 1;
  143. // cout << "Client: expansion_ratio = " << expansion_ratio << endl;
  144. for (int k = 0; k < expansion_ratio; k++) {
  145. // Compose here
  146. const uint64_t *plain_coeff =
  147. plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
  148. .data();
  149. for (int m = 0; m < coeff_count; m++) {
  150. if (k == 0) {
  151. *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
  152. } else {
  153. *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
  154. }
  155. }
  156. // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
  157. cur <<= logt;
  158. }
  159. // XXX: Reduction modulo qj. This is needed?
  160. /*
  161. for (int m = 0; m < coeff_count; m++) {
  162. *(encrypted_pointer + m + j * coeff_count) %=
  163. params_.coeff_modulus()[j].value();
  164. }
  165. */
  166. }
  167. }
  168. result.parms_id() = params_.parms_id();
  169. return result;
  170. }
  171. void PIRClient::compute_inverse_scales(){
  172. if (indices_.size() != pir_params_.nvec.size()){
  173. throw invalid_argument("size mismatch");
  174. }
  175. int logt = floor(log2(params_.plain_modulus().value()));
  176. uint64_t N = params_.poly_modulus_degree();
  177. uint64_t t = params_.plain_modulus().value();
  178. int logN = log2(N);
  179. int logm = logN;
  180. inverse_scales_.clear();
  181. for(int i = 0; i < pir_params_.nvec.size(); i++){
  182. uint64_t index_modN = indices_[i] % N;
  183. uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N); // number of query ciphertexts.
  184. uint64_t batchId = indices_[i] / N;
  185. if (batchId == numCtxt - 1) {
  186. cout << "Client: adjusting the logm value..." << endl;
  187. logm = ceil(log2((pir_params_.nvec[i] % N)));
  188. }
  189. uint64_t inverse_scale;
  190. int quo = logm / logt;
  191. int mod = logm % logt;
  192. inverse_scale = pow(2, logt - mod);
  193. if ((quo +1) %2 != 0){
  194. inverse_scale = params_.plain_modulus().value() - pow(2, logt - mod);
  195. }
  196. inverse_scales_.push_back(inverse_scale);
  197. if ( (inverse_scale << logm) % t != 1){
  198. throw logic_error("something wrong");
  199. }
  200. cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl;
  201. }
  202. }