pir_client.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #include "pir_client.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. PIRClient::PIRClient(const EncryptionParameters &params,
  6. const PirParams &pir_parms) :
  7. params_(params){
  8. newcontext_ = SEALContext::Create(params_);
  9. pir_params_ = pir_parms;
  10. keygen_ = make_unique<KeyGenerator>(newcontext_);
  11. encryptor_ = make_unique<Encryptor>(newcontext_, keygen_->public_key());
  12. SecretKey secret_key = keygen_->secret_key();
  13. decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
  14. evaluator_ = make_unique<Evaluator>(newcontext_);
  15. uint64_t t = params_.plain_modulus().value();
  16. uint64_t N = params_.poly_modulus_degree();
  17. //
  18. // int logt = floor(log2(params_.plain_modulus().value()));
  19. // for(int i = 0; i < pir_params_.nvec.size(); i++){
  20. // uint64_t inverse_scale;
  21. // //
  22. // // If the number of items are less than N, then
  23. // // we use logm.
  24. // int logm = ceil(log2(min(N, pir_params_.nvec[i]))); // if nvec > n what do we do?
  25. // int quo = logm / logt;
  26. // int mod = logm % logt;
  27. // inverse_scale = pow(2, logt - mod);
  28. // if ((quo +1) %2 != 0){
  29. // inverse_scale = params_.plain_modulus().value() - pow(2, logt - mod);
  30. // }
  31. // inverse_scales_.push_back(inverse_scale);
  32. // if ( (inverse_scale << logm) % t != 1){
  33. // throw logic_error("something wrong");
  34. // }
  35. // cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl;
  36. // }
  37. }
  38. // void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
  39. // const PirParams &pir_params)
  40. // {
  41. // // The only thing that can change is the plaintext modulus and pir_params
  42. // assert(expanded_params.poly_modulus_degree() == params_.poly_modulus_degree());
  43. // assert(expanded_params.coeff_modulus() == params_.coeff_modulus());
  44. // params_ = expanded_params;
  45. // pir_params_ = pir_params;
  46. // auto newcontext = SEALContext::Create(expanded_params);
  47. // SecretKey secret_key = keygen_->secret_key();
  48. // secret_key.parms_id() = expanded_params.parms_id();
  49. // decryptor_ = make_unique<Decryptor>(newcontext, secret_key);
  50. // evaluator_ = make_unique<Evaluator>(newcontext);
  51. // }
  52. PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
  53. indices_ = compute_indices(desiredIndex, pir_params_.nvec);
  54. compute_inverse_scales();
  55. vector<vector<Ciphertext> > result(pir_params_.d);
  56. int N = params_.poly_modulus_degree();
  57. Plaintext pt(params_.poly_modulus_degree());
  58. for (uint32_t i = 0; i < indices_.size(); i++) {
  59. uint32_t num_ptxts = ceil( (pir_params_.nvec[i] +0.0) / N);
  60. // initialize result.
  61. cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = " << indices_[i] << endl;
  62. cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
  63. for (uint32_t j =0; j < num_ptxts; j++){
  64. pt.set_zero();
  65. if (indices_[i] > N*(j+1) || indices_[i] < N*j){
  66. cout << "Client: coming here: so just encrypt zero." << endl;
  67. // just encrypt zero
  68. } else{
  69. cout << "Client: encrypting a real thing " << endl;
  70. uint64_t real_index = indices_[i] - N*j;
  71. pt[real_index] = 1;
  72. }
  73. Ciphertext dest;
  74. encryptor_->encrypt(pt, dest);
  75. dest.parms_id() = params_.parms_id();
  76. result[i].push_back(dest);
  77. }
  78. }
  79. return result;
  80. }
  81. uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
  82. auto N = params_.poly_modulus_degree();
  83. auto logt = floor(log2(params_.plain_modulus().value()));
  84. auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  85. return static_cast<uint64_t>(element_idx / ele_per_ptxt);
  86. }
  87. uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
  88. uint32_t N = params_.poly_modulus_degree();
  89. uint32_t logt = floor(log2(params_.plain_modulus().value()));
  90. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  91. return element_idx % ele_per_ptxt;
  92. }
  93. Plaintext PIRClient::decode_reply(PirReply reply) {
  94. uint32_t exp_ratio = pir_params_.expansion_ratio;
  95. uint32_t recursion_level = pir_params_.d;
  96. vector<Ciphertext> temp = reply;
  97. uint64_t t = params_.plain_modulus().value();
  98. for (uint32_t i = 0; i < recursion_level; i++) {
  99. cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl;
  100. vector<Ciphertext> newtemp;
  101. vector<Plaintext> tempplain;
  102. for (uint32_t j = 0; j < temp.size(); j++) {
  103. Plaintext ptxt;
  104. decryptor_->decrypt(temp[j], ptxt);
  105. cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl;
  106. // multiply by inverse_scale for every coefficient of ptxt
  107. for(int h = 0; h < ptxt.coeff_count(); h++){
  108. ptxt[h] *= inverse_scales_[i];
  109. ptxt[h] %= t;
  110. }
  111. //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
  112. tempplain.push_back(ptxt);
  113. #ifdef DEBUG
  114. cout << "recursion level : " << i << " noise budget : ";
  115. cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
  116. #endif
  117. if ((j + 1) % exp_ratio == 0 && j > 0) {
  118. // Combine into one ciphertext.
  119. Ciphertext combined = compose_to_ciphertext(tempplain);
  120. newtemp.push_back(combined);
  121. // cout << "Client: const term of ciphertext = " << combined[0] << endl;
  122. }
  123. }
  124. cout << "Client: done." << endl;
  125. cout << endl;
  126. if (i == recursion_level - 1) {
  127. assert(temp.size() == 1);
  128. return tempplain[0];
  129. } else {
  130. tempplain.clear();
  131. temp = newtemp;
  132. }
  133. }
  134. // This should never be called
  135. assert(0);
  136. Plaintext fail;
  137. return fail;
  138. }
  139. GaloisKeys PIRClient::generate_galois_keys() {
  140. // Generate the Galois keys needed for coeff_select.
  141. vector<uint64_t> galois_elts;
  142. int N = params_.poly_modulus_degree();
  143. int logN = get_power_of_two(N);
  144. //cout << "printing galois elements...";
  145. for (int i = 0; i < logN; i++) {
  146. galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
  147. //#ifdef DEBUG
  148. // cout << galois_elts.back() << ", ";
  149. //#endif
  150. }
  151. return keygen_->galois_keys(pir_params_.dbc, galois_elts);
  152. }
  153. Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
  154. size_t encrypted_count = 2;
  155. auto coeff_count = params_.poly_modulus_degree();
  156. auto coeff_mod_count = params_.coeff_modulus().size();
  157. uint64_t plainMod = params_.plain_modulus().value();
  158. int logt = floor(log2(plainMod));
  159. Ciphertext result(newcontext_);
  160. result.resize(encrypted_count);
  161. // A triple for loop. Going over polys, moduli, and decomposed index.
  162. for (int i = 0; i < encrypted_count; i++) {
  163. uint64_t *encrypted_pointer = result.data(i);
  164. for (int j = 0; j < coeff_mod_count; j++) {
  165. // populate one poly at a time.
  166. // create a polynomial to store the current decomposition value
  167. // which will be copied into the array to populate it at the current
  168. // index.
  169. double logqj = log2(params_.coeff_modulus()[j].value());
  170. int expansion_ratio = ceil(logqj / logt);
  171. uint64_t cur = 1;
  172. // cout << "Client: expansion_ratio = " << expansion_ratio << endl;
  173. for (int k = 0; k < expansion_ratio; k++) {
  174. // Compose here
  175. const uint64_t *plain_coeff =
  176. plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
  177. .data();
  178. for (int m = 0; m < coeff_count; m++) {
  179. if (k == 0) {
  180. *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
  181. } else {
  182. *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
  183. }
  184. }
  185. // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
  186. cur <<= logt;
  187. }
  188. // XXX: Reduction modulo qj. This is needed?
  189. /*
  190. for (int m = 0; m < coeff_count; m++) {
  191. *(encrypted_pointer + m + j * coeff_count) %=
  192. params_.coeff_modulus()[j].value();
  193. }
  194. */
  195. }
  196. }
  197. result.parms_id() = params_.parms_id();
  198. return result;
  199. }
  200. void PIRClient::compute_inverse_scales(){
  201. if (indices_.size() != pir_params_.nvec.size()){
  202. throw invalid_argument("size mismatch");
  203. }
  204. int logt = floor(log2(params_.plain_modulus().value()));
  205. uint64_t N = params_.poly_modulus_degree();
  206. uint64_t t = params_.plain_modulus().value();
  207. int logN = log2(N);
  208. int logm = logN;
  209. inverse_scales_.clear();
  210. for(int i = 0; i < pir_params_.nvec.size(); i++){
  211. uint64_t index_modN = indices_[i] % N;
  212. uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N); // number of query ciphertexts.
  213. uint64_t batchId = indices_[i] / N;
  214. if (batchId == numCtxt - 1) {
  215. cout << "Client: adjusting the logm value..." << endl;
  216. logm = ceil(log2((pir_params_.nvec[i] % N)));
  217. }
  218. uint64_t inverse_scale;
  219. int quo = logm / logt;
  220. int mod = logm % logt;
  221. inverse_scale = pow(2, logt - mod);
  222. if ((quo +1) %2 != 0){
  223. inverse_scale = params_.plain_modulus().value() - pow(2, logt - mod);
  224. }
  225. inverse_scales_.push_back(inverse_scale);
  226. if ( (inverse_scale << logm) % t != 1){
  227. throw logic_error("something wrong");
  228. }
  229. cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl;
  230. }
  231. }