소스 검색

fixed mult by power of X problem (coeff count -1).
Now noise is okay but result is wrong. I think it's due to encoding problem.
Let me check.

hao chen 5 년 전
부모
커밋
8fb6ff7ad2
5개의 변경된 파일95개의 추가작업 그리고 14개의 파일을 삭제
  1. 9 6
      main.cpp
  2. 33 2
      pir_client.cpp
  3. 7 0
      pir_client.hpp
  4. 43 4
      pir_server.cpp
  5. 3 2
      pir_server.hpp

+ 9 - 6
main.cpp

@@ -15,15 +15,15 @@ using namespace seal;
 int main(int argc, char *argv[]) {
 
     // uint64_t number_of_items = 1 << 13;
-    // uint64_t number_of_items = 4096;
-    uint64_t number_of_items = 1 << 16;
+    uint64_t number_of_items = 128;
+    //uint64_t number_of_items = 1 << 16;
 
     uint64_t size_per_item = 288; // in bytes
     // uint64_t size_per_item = 1 << 10; // 1 KB.
     // uint64_t size_per_item = 10 << 10; // 10 KB.
 
     uint32_t N = 2048;
-    uint32_t logt = 20;
+    uint32_t logt = 15;
     uint32_t d = 1;
 
     EncryptionParameters params(scheme_type::BFV);
@@ -63,15 +63,17 @@ int main(int argc, char *argv[]) {
 
     // The following can be used to update parameters rather than creating new instances
     // (here it doesn't do anything).
-    cout << "Updating database size to: " << number_of_items << " elements" << endl;
+    // cout << "Updating database size to: " << number_of_items << " elements" << endl;
     // update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params);
 
+    cout << "done" << endl;
 
 
     // Measure database setup
     auto time_pre_s = high_resolution_clock::now();
     server.set_database(move(db), number_of_items, size_per_item);
     server.preprocess_database();
+    cout << "database pre processed " << endl;
     auto time_pre_e = high_resolution_clock::now();
     auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
 
@@ -85,11 +87,12 @@ int main(int argc, char *argv[]) {
     PirQuery query = client.generate_query(index);
     auto time_query_e = high_resolution_clock::now();
     auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
+    cout << "query generated" << endl;
 
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
-    PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
-    PirReply reply = server.generate_reply(query_ser, 0);
+    //PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
+    PirReply reply = server.generate_reply(query, 0, client);
     auto time_server_e = high_resolution_clock::now();
     auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
 

+ 33 - 2
pir_client.cpp

@@ -19,6 +19,27 @@ PIRClient::PIRClient(const EncryptionParameters &params,
 
     decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
     evaluator_ = make_unique<Evaluator>(newcontext_);
+
+    uint64_t t = params_.plain_modulus().value(); 
+    // 
+    int logt = floor(log2(params_.plain_modulus().value())); 
+    for(int i = 0; i < pir_params_.nvec.size(); i++){
+        uint64_t inverse_scale; 
+        int logm = ceil(log2(pir_params_.nvec[i]));  
+
+        int quo = logm / logt; 
+        int mod = logm % logt; 
+        inverse_scale = pow(2, logt - mod); 
+        if ((quo +1) %2 != 0){
+            inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
+        }
+        inverse_scales_.push_back(inverse_scale); 
+        if ( (inverse_scale << logm)  % t != 1){
+            throw logic_error("something wrong"); 
+        }
+        cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
+    }
+
 }
 
 // void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
@@ -80,6 +101,8 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 
     vector<Ciphertext> temp = reply;
 
+    uint64_t t = params_.plain_modulus().value();
+
     for (uint32_t i = 0; i < recursion_level; i++) {
 
         vector<Ciphertext> newtemp;
@@ -88,6 +111,13 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
         for (uint32_t j = 0; j < temp.size(); j++) {
             Plaintext ptxt;
             decryptor_->decrypt(temp[j], ptxt);
+            cout << " reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
+            // multiply by inverse_scale for every coefficient of ptxt
+            for(int h = 0; h < ptxt.coeff_count(); h++){
+                ptxt[h] *= inverse_scales_[i]; 
+                ptxt[h] %= t; 
+            }
+            cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
             tempplain.push_back(ptxt);
 
 #ifdef DEBUG
@@ -123,11 +153,12 @@ GaloisKeys PIRClient::generate_galois_keys() {
     int N = params_.poly_modulus_degree();
     int logN = get_power_of_two(N);
 
+    cout << "printing galois elements...";
     for (int i = 0; i < logN; i++) {
         galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-#ifdef DEBUG
+//#ifdef DEBUG
         cout << galois_elts.back() << ", ";
-#endif
+//#endif
     }
 
     return keygen_->galois_keys(pir_params_.dbc, galois_elts);

+ 7 - 0
pir_client.hpp

@@ -2,6 +2,9 @@
 
 #include "pir.hpp"
 #include <memory>
+#include <vector>
+
+using namespace std; 
 
 class PIRClient {
   public:
@@ -31,5 +34,9 @@ class PIRClient {
     std::unique_ptr<seal::KeyGenerator> keygen_;
     std::shared_ptr<seal::SEALContext> newcontext_;
 
+    vector<uint64_t> inverse_scales_; 
+
     seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);
+
+    friend class PIRServer;
 };

+ 43 - 4
pir_server.cpp

@@ -1,4 +1,5 @@
 #include "pir_server.hpp"
+#include "pir_client.hpp"
 
 using namespace std;
 using namespace seal;
@@ -111,6 +112,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 
         Plaintext plain;
         vector_to_plaintext(coefficients, plain);
+        cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
         result->push_back(move(plain));
     }
 
@@ -141,7 +143,7 @@ void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey)
     galoisKeys_[client_id] = galkey;
 }
 
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient &client) {
 
     vector<uint64_t> nvec = pir_params_.nvec;
     uint64_t product = 1;
@@ -159,7 +161,15 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 
     for (uint32_t i = 0; i < nvec.size(); i++) {
         uint64_t n_i = nvec[i];
-        vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id);
+        vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id, client);
+        cout << "Checking expanded query "; 
+        Plaintext tempPt; 
+        for (int h = 0 ; h < expanded_query.size(); h++){
+            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
+            client.decryptor_->decrypt(expanded_query[h], tempPt); 
+            cout << tempPt.to_string()  << endl; 
+        }
+        cout << endl;
 
         // Transform expanded query to NTT, and ...
         for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
@@ -224,7 +234,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 }
 
 inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
-                                           uint32_t client_id) {
+                                           uint32_t client_id, PIRClient &client) {
 
 #ifdef DEBUG
     uint64_t plainMod = params_.plain_modulus().value();
@@ -239,10 +249,16 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 
     vector<int> galois_elts;
     auto n = params_.poly_modulus_degree();
+    if (logm > ceil(log2(n))){
+        throw logic_error("m > n is not allowed."); 
+    }
 
+    cout << "galois elts at server: "; 
     for (uint32_t i = 0; i < logm; i++) {
         galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
+        cout << galois_elts.back() << ", "; 
     }
+    cout << endl;
 
     vector<Ciphertext> temp;
     temp.push_back(encrypted);
@@ -257,17 +273,36 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
         // some scaling....
         int index_raw = (n << 1) - (1 << i);
         int index = (index_raw * galois_elts[i]) % (n << 1);
+        cout << i << "-th expansion round, noise budget = " << endl; 
 
         for (uint32_t a = 0; a < temp.size(); a++) {
 
             evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
+
+            cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
+
+
             evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
             multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+
+            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
+
+
             multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+
+            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
+
+
             // Enc(2^i x^j) if j = 0 (mod 2**i).
             evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
         }
         temp = newtemp;
+
+        cout << "end: "; 
+        for (int h = 0; h < temp.size();h++){
+            cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", "; 
+        }
+        cout << endl; 
     }
 
     // Last step of the loop
@@ -277,6 +312,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     for (uint32_t a = 0; a < temp.size(); a++) {
         if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
             evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
+            cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
         } else {
             evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
             evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
@@ -299,6 +335,9 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
     auto coeff_count = params_.poly_modulus_degree();
     auto encrypted_count = encrypted.size();
 
+    //cout << "coeff mod count for power of X = " << coeff_mod_count << endl; 
+    //cout << "coeff count for power of X = " << coeff_count << endl; 
+
     // First copy over.
     destination = encrypted;
 
@@ -307,7 +346,7 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
     for (int i = 0; i < encrypted_count; i++) {
         for (int j = 0; j < coeff_mod_count; j++) {
             negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
-                                           coeff_count - 1, index,
+                                           coeff_count, index,
                                            params_.coeff_modulus()[j],
                                            destination.data(i) + (j * coeff_count));
         }

+ 3 - 2
pir_server.hpp

@@ -4,6 +4,7 @@
 #include <map>
 #include <memory>
 #include <vector>
+#include "pir_client.hpp"
 
 class PIRServer {
   public:
@@ -19,9 +20,9 @@ class PIRServer {
     void preprocess_database();
 
     std::vector<seal::Ciphertext> expand_query(
-            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
+            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id, PIRClient &client);
 
-    PirReply generate_reply(PirQuery query, std::uint32_t client_id);
+    PirReply generate_reply(PirQuery query, std::uint32_t client_id, PIRClient &client);
 
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);