浏览代码

Fixed the problems so that it works for recursion level d > 2

hao chen 5 年之前
父节点
当前提交
2eaafc25f4
共有 3 个文件被更改,包括 13 次插入5 次删除
  1. 4 3
      main.cpp
  2. 8 1
      pir_client.cpp
  3. 1 1
      pir_server.cpp

+ 4 - 3
main.cpp

@@ -16,7 +16,7 @@ int main(int argc, char *argv[]) {
 
     //uint64_t number_of_items = 1 << 11;
     //uint64_t number_of_items = 2048;
-    uint64_t number_of_items = 1 << 20;
+    uint64_t number_of_items = 1 << 12;
 
     uint64_t size_per_item = 288; // in bytes
     // uint64_t size_per_item = 1 << 10; // 1 KB.
@@ -25,7 +25,7 @@ int main(int argc, char *argv[]) {
     uint32_t N = 2048;
     // Recommended values: (logt, d) = (12, 2) or (8, 1). 
     uint32_t logt = 12; 
-    uint32_t d = 2;
+    uint32_t d = 5;
 
     EncryptionParameters params(scheme_type::BFV);
     PirParams pir_params;
@@ -34,7 +34,7 @@ int main(int argc, char *argv[]) {
     cout << "Generating all parameters" << endl;
     gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
 
-    cout << "This may take some time ..." << endl;
+    cout << "Initializing the database (this may take some time) ..." << endl;
 
     // Create test database
     auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
@@ -135,6 +135,7 @@ int main(int argc, char *argv[]) {
     }
 
     // Output results
+    cout << "PIR reseult correct!" << endl;
     cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
     cout << "PIRServer reply generation time: " << time_server_us / 1000 << " ms"
          << endl;

+ 8 - 1
pir_client.cpp

@@ -86,10 +86,14 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
         for (uint32_t j =0; j < num_ptxts; j++){
             pt.set_zero();
             if (indices_[i] > N*(j+1) || indices_[i] < N*j){
+#ifdef DEBUG
                 cout << "Client: coming here: so just encrypt zero." << endl; 
+#endif 
                 // just encrypt zero
             } else{
+#ifdef DEBUG
                 cout << "Client: encrypting a real thing " << endl; 
+#endif 
                 uint64_t real_index = indices_[i] - N*j; 
                 pt[real_index] = 1;
             }
@@ -137,10 +141,12 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
         for (uint32_t j = 0; j < temp.size(); j++) {
             Plaintext ptxt;
             decryptor_->decrypt(temp[j], ptxt);
+#ifdef DEBUG
             cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
+#endif
             // multiply by inverse_scale for every coefficient of ptxt
             for(int h = 0; h < ptxt.coeff_count(); h++){
-                ptxt[h] *= inverse_scales_[i]; 
+                ptxt[h] *= inverse_scales_[recursion_level -  1 - i]; 
                 ptxt[h] %= t; 
             }
             //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
@@ -155,6 +161,7 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
                 // Combine into one ciphertext.
                 Ciphertext combined = compose_to_ciphertext(tempplain);
                 newtemp.push_back(combined);
+                tempplain.clear();
                 // cout << "Client: const term of ciphertext = " << combined[0] << endl; 
             }
         }

+ 1 - 1
pir_server.cpp

@@ -178,7 +178,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 
         uint64_t n_i = nvec[i];
         cout << "Server: n_i = " << n_i << endl; 
-        cout << "Server: expanding " << query[i].size() << "query ctxts" << endl;
+        cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
         for (uint32_t j = 0; j < query[i].size(); j++){
             uint64_t total = N; 
             if (j == query[i].size() - 1){