pir_client.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. #include "pir_client.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. PIRClient::PIRClient(const EncryptionParameters &params,
  6. const PirParams &pir_parms) :
  7. params_(params){
  8. newcontext_ = SEALContext::Create(params_);
  9. pir_params_ = pir_parms;
  10. keygen_ = make_unique<KeyGenerator>(newcontext_);
  11. encryptor_ = make_unique<Encryptor>(newcontext_, keygen_->public_key());
  12. SecretKey secret_key = keygen_->secret_key();
  13. decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
  14. evaluator_ = make_unique<Evaluator>(newcontext_);
  15. uint64_t t = params_.plain_modulus().value();
  16. uint64_t N = params_.poly_modulus_degree();
  17. //
  18. // int logt = floor(log2(params_.plain_modulus().value()));
  19. // for(int i = 0; i < pir_params_.nvec.size(); i++){
  20. // uint64_t inverse_scale;
  21. // //
  22. // // If the number of items are less than N, then
  23. // // we use logm.
  24. // int logm = ceil(log2(min(N, pir_params_.nvec[i]))); // if nvec > n what do we do?
  25. // int quo = logm / logt;
  26. // int mod = logm % logt;
  27. // inverse_scale = pow(2, logt - mod);
  28. // if ((quo +1) %2 != 0){
  29. // inverse_scale = params_.plain_modulus().value() - pow(2, logt - mod);
  30. // }
  31. // inverse_scales_.push_back(inverse_scale);
  32. // if ( (inverse_scale << logm) % t != 1){
  33. // throw logic_error("something wrong");
  34. // }
  35. // cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl;
  36. // }
  37. }
  38. // void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
  39. // const PirParams &pir_params)
  40. // {
  41. // // The only thing that can change is the plaintext modulus and pir_params
  42. // assert(expanded_params.poly_modulus_degree() == params_.poly_modulus_degree());
  43. // assert(expanded_params.coeff_modulus() == params_.coeff_modulus());
  44. // params_ = expanded_params;
  45. // pir_params_ = pir_params;
  46. // auto newcontext = SEALContext::Create(expanded_params);
  47. // SecretKey secret_key = keygen_->secret_key();
  48. // secret_key.parms_id() = expanded_params.parms_id();
  49. // decryptor_ = make_unique<Decryptor>(newcontext, secret_key);
  50. // evaluator_ = make_unique<Evaluator>(newcontext);
  51. // }
  52. PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
  53. indices_ = compute_indices(desiredIndex, pir_params_.nvec);
  54. compute_inverse_scales();
  55. vector<vector<Ciphertext> > result(pir_params_.d);
  56. int N = params_.poly_modulus_degree();
  57. Plaintext pt(params_.poly_modulus_degree());
  58. for (uint32_t i = 0; i < indices_.size(); i++) {
  59. uint32_t num_ptxts = ceil( (pir_params_.nvec[i] +0.0) / N);
  60. // initialize result.
  61. cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = " << indices_[i] << endl;
  62. cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
  63. for (uint32_t j =0; j < num_ptxts; j++){
  64. pt.set_zero();
  65. if (indices_[i] > N*(j+1) || indices_[i] < N*j){
  66. #ifdef DEBUG
  67. cout << "Client: coming here: so just encrypt zero." << endl;
  68. #endif
  69. // just encrypt zero
  70. } else{
  71. #ifdef DEBUG
  72. cout << "Client: encrypting a real thing " << endl;
  73. #endif
  74. uint64_t real_index = indices_[i] - N*j;
  75. pt[real_index] = 1;
  76. }
  77. Ciphertext dest;
  78. encryptor_->encrypt(pt, dest);
  79. dest.parms_id() = params_.parms_id();
  80. result[i].push_back(dest);
  81. }
  82. }
  83. return result;
  84. }
  85. uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
  86. auto N = params_.poly_modulus_degree();
  87. auto logt = floor(log2(params_.plain_modulus().value()));
  88. auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  89. return static_cast<uint64_t>(element_idx / ele_per_ptxt);
  90. }
  91. uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
  92. uint32_t N = params_.poly_modulus_degree();
  93. uint32_t logt = floor(log2(params_.plain_modulus().value()));
  94. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
  95. return element_idx % ele_per_ptxt;
  96. }
  97. Plaintext PIRClient::decode_reply(PirReply reply) {
  98. uint32_t exp_ratio = pir_params_.expansion_ratio;
  99. uint32_t recursion_level = pir_params_.d;
  100. vector<Ciphertext> temp = reply;
  101. uint64_t t = params_.plain_modulus().value();
  102. for (uint32_t i = 0; i < recursion_level; i++) {
  103. cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl;
  104. vector<Ciphertext> newtemp;
  105. vector<Plaintext> tempplain;
  106. for (uint32_t j = 0; j < temp.size(); j++) {
  107. Plaintext ptxt;
  108. decryptor_->decrypt(temp[j], ptxt);
  109. #ifdef DEBUG
  110. cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl;
  111. #endif
  112. // multiply by inverse_scale for every coefficient of ptxt
  113. for(int h = 0; h < ptxt.coeff_count(); h++){
  114. ptxt[h] *= inverse_scales_[recursion_level - 1 - i];
  115. ptxt[h] %= t;
  116. }
  117. //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
  118. tempplain.push_back(ptxt);
  119. #ifdef DEBUG
  120. cout << "recursion level : " << i << " noise budget : ";
  121. cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
  122. #endif
  123. if ((j + 1) % exp_ratio == 0 && j > 0) {
  124. // Combine into one ciphertext.
  125. Ciphertext combined = compose_to_ciphertext(tempplain);
  126. newtemp.push_back(combined);
  127. tempplain.clear();
  128. // cout << "Client: const term of ciphertext = " << combined[0] << endl;
  129. }
  130. }
  131. cout << "Client: done." << endl;
  132. cout << endl;
  133. if (i == recursion_level - 1) {
  134. assert(temp.size() == 1);
  135. return tempplain[0];
  136. } else {
  137. tempplain.clear();
  138. temp = newtemp;
  139. }
  140. }
  141. // This should never be called
  142. assert(0);
  143. Plaintext fail;
  144. return fail;
  145. }
  146. GaloisKeys PIRClient::generate_galois_keys() {
  147. // Generate the Galois keys needed for coeff_select.
  148. vector<uint64_t> galois_elts;
  149. int N = params_.poly_modulus_degree();
  150. int logN = get_power_of_two(N);
  151. //cout << "printing galois elements...";
  152. for (int i = 0; i < logN; i++) {
  153. galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
  154. //#ifdef DEBUG
  155. // cout << galois_elts.back() << ", ";
  156. //#endif
  157. }
  158. return keygen_->galois_keys(pir_params_.dbc, galois_elts);
  159. }
  160. Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
  161. size_t encrypted_count = 2;
  162. auto coeff_count = params_.poly_modulus_degree();
  163. auto coeff_mod_count = params_.coeff_modulus().size();
  164. uint64_t plainMod = params_.plain_modulus().value();
  165. int logt = floor(log2(plainMod));
  166. Ciphertext result(newcontext_);
  167. result.resize(encrypted_count);
  168. // A triple for loop. Going over polys, moduli, and decomposed index.
  169. for (int i = 0; i < encrypted_count; i++) {
  170. uint64_t *encrypted_pointer = result.data(i);
  171. for (int j = 0; j < coeff_mod_count; j++) {
  172. // populate one poly at a time.
  173. // create a polynomial to store the current decomposition value
  174. // which will be copied into the array to populate it at the current
  175. // index.
  176. double logqj = log2(params_.coeff_modulus()[j].value());
  177. int expansion_ratio = ceil(logqj / logt);
  178. uint64_t cur = 1;
  179. // cout << "Client: expansion_ratio = " << expansion_ratio << endl;
  180. for (int k = 0; k < expansion_ratio; k++) {
  181. // Compose here
  182. const uint64_t *plain_coeff =
  183. plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
  184. .data();
  185. for (int m = 0; m < coeff_count; m++) {
  186. if (k == 0) {
  187. *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
  188. } else {
  189. *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
  190. }
  191. }
  192. // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
  193. cur <<= logt;
  194. }
  195. // XXX: Reduction modulo qj. This is needed?
  196. /*
  197. for (int m = 0; m < coeff_count; m++) {
  198. *(encrypted_pointer + m + j * coeff_count) %=
  199. params_.coeff_modulus()[j].value();
  200. }
  201. */
  202. }
  203. }
  204. result.parms_id() = params_.parms_id();
  205. return result;
  206. }
  207. void PIRClient::compute_inverse_scales(){
  208. if (indices_.size() != pir_params_.nvec.size()){
  209. throw invalid_argument("size mismatch");
  210. }
  211. int logt = floor(log2(params_.plain_modulus().value()));
  212. uint64_t N = params_.poly_modulus_degree();
  213. uint64_t t = params_.plain_modulus().value();
  214. int logN = log2(N);
  215. int logm = logN;
  216. inverse_scales_.clear();
  217. for(int i = 0; i < pir_params_.nvec.size(); i++){
  218. uint64_t index_modN = indices_[i] % N;
  219. uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N); // number of query ciphertexts.
  220. uint64_t batchId = indices_[i] / N;
  221. if (batchId == numCtxt - 1) {
  222. cout << "Client: adjusting the logm value..." << endl;
  223. logm = ceil(log2((pir_params_.nvec[i] % N)));
  224. }
  225. uint64_t inverse_scale;
  226. int quo = logm / logt;
  227. int mod = logm % logt;
  228. inverse_scale = pow(2, logt - mod);
  229. if ((quo +1) %2 != 0){
  230. inverse_scale = params_.plain_modulus().value() - pow(2, logt - mod);
  231. }
  232. inverse_scales_.push_back(inverse_scale);
  233. if ( (inverse_scale << logm) % t != 1){
  234. throw logic_error("something wrong");
  235. }
  236. cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl;
  237. }
  238. }