Browse Source

Renamed params to enc_params on server too

Andrew Beams 3 years ago
parent
commit
e4f31fe3e0
2 changed files with 22 additions and 27 deletions
  1. 20 25
      pir_server.cpp
  2. 2 2
      pir_server.hpp

+ 20 - 25
pir_server.cpp

@@ -5,12 +5,12 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-PIRServer::PIRServer(const EncryptionParameters &params, const PirParams &pir_params) :
-    params_(params), 
+PIRServer::PIRServer(const EncryptionParameters &enc_params, const PirParams &pir_params) :
+    enc_params_(enc_params), 
     pir_params_(pir_params),
     is_db_preprocessed_(false)
 {
-    context_ = make_shared<SEALContext>(params, true);
+    context_ = make_shared<SEALContext>(enc_params, true);
     evaluator_ = make_unique<Evaluator>(*context_);
 }
 
@@ -39,8 +39,8 @@ void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
 void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, 
     uint64_t ele_num, uint64_t ele_size) {
 
-    uint32_t logt = floor(log2(params_.plain_modulus().value()));
-    uint32_t N = params_.poly_modulus_degree();
+    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+    uint32_t N = enc_params_.poly_modulus_degree();
 
     // number of FV plaintexts needed to represent all elements
     uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
@@ -139,7 +139,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
         product *= nvec[i];
     }
 
-    auto coeff_count = params_.poly_modulus_degree();
+    auto coeff_count = enc_params_.poly_modulus_degree();
 
     vector<Plaintext> *cur = db_.get();
     vector<Plaintext> intermediate_plain; // decompose....
@@ -147,9 +147,9 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
     auto pool = MemoryManager::GetPool();
 
 
-    int N = params_.poly_modulus_degree();
+    int N = enc_params_.poly_modulus_degree();
 
-    int logt = floor(log2(params_.plain_modulus().value()));
+    int logt = floor(log2(enc_params_.plain_modulus().value()));
 
     cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
     for (uint32_t i = 0; i < nvec.size(); i++) {
@@ -252,11 +252,6 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
                                            uint32_t client_id) {
 
-#ifdef DEBUG
-    uint64_t plainMod = params_.plain_modulus().value();
-    cout << "PIRServer side plain modulus = " << plainMod << endl;
-#endif
-
     GaloisKeys &galkey = galoisKeys_[client_id];
 
     // Assume that m is a power of 2. If not, round it to the next power of 2.
@@ -264,7 +259,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     Plaintext two("2");
 
     vector<int> galois_elts;
-    auto n = params_.poly_modulus_degree();
+    auto n = enc_params_.poly_modulus_degree();
     if (logm > ceil(log2(n))){
         throw logic_error("m > n is not allowed."); 
     }
@@ -343,8 +338,8 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
                                     uint32_t index) {
 
-    auto coeff_mod_count = params_.coeff_modulus().size() - 1;
-    auto coeff_count = params_.poly_modulus_degree();
+    auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
+    auto coeff_count = enc_params_.poly_modulus_degree();
     auto encrypted_count = encrypted.size();
 
     //cout << "coeff mod count for power of X = " << coeff_mod_count << endl; 
@@ -359,7 +354,7 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
         for (int j = 0; j < coeff_mod_count; j++) {
             negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
                                            coeff_count, index,
-                                           params_.coeff_modulus()[j],
+                                           enc_params_.coeff_modulus()[j],
                                            destination.data(i) + (j * coeff_count));
         }
     }
@@ -368,8 +363,8 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
 inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
 
     vector<Plaintext> result;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
+    auto coeff_count = enc_params_.poly_modulus_degree();
+    auto coeff_mod_count = enc_params_.coeff_modulus().size();
     auto encrypted_count = encrypted.size();
 
     uint64_t t1 = 1 << logt;  //  t1 <= t. 
@@ -384,7 +379,7 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
             // create a polynomial to store the current decomposition value
             // which will be copied into the array to populate it at the current
             // index.
-            double logqj = log2(params_.coeff_modulus()[j].value());
+            double logqj = log2(enc_params_.coeff_modulus()[j].value());
             //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
             int expansion_ratio =  ceil(logqj / logt); 
             // cout << "local expansion ratio = " << expansion_ratio << endl;
@@ -404,13 +399,13 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
 
 vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
     vector<Plaintext> result;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    auto plain_bit_count = params_.plain_modulus().bit_count();
+    auto coeff_count = enc_params_.poly_modulus_degree();
+    auto coeff_mod_count = enc_params_.coeff_modulus().size();
+    auto plain_bit_count = enc_params_.plain_modulus().bit_count();
     auto encrypted_count = encrypted.size();
 
     // Generate powers of t.
-    uint64_t plainMod = params_.plain_modulus().value();
+    uint64_t plainMod = enc_params_.plain_modulus().value();
 
     // A triple for loop. Going over polys, moduli, and decomposed index.
     for (int i = 0; i < encrypted_count; i++) {
@@ -420,7 +415,7 @@ vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted
             // create a polynomial to store the current decomposition value
             // which will be copied into the array to populate it at the current
             // index.
-            int logqj = log2(params_.coeff_modulus()[j].value());
+            int logqj = log2(enc_params_.coeff_modulus()[j].value());
             int expansion_ratio = ceil(logqj / log2(plainMod));
 
             // cout << "expansion ratio = " << expansion_ratio << endl;

+ 2 - 2
pir_server.hpp

@@ -8,7 +8,7 @@
 
 class PIRServer {
   public:
-    PIRServer(const seal::EncryptionParameters &params, const PirParams &pir_params);
+    PIRServer(const seal::EncryptionParameters &enc_params, const PirParams &pir_params);
 
     // NOTE: server takes over ownership of db and frees it when it exits.
     // Caller cannot free db
@@ -24,7 +24,7 @@ class PIRServer {
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
 
   private:
-    seal::EncryptionParameters params_; // SEAL parameters
+    seal::EncryptionParameters enc_params_; // SEAL parameters
     PirParams pir_params_;              // PIR parameters
     std::unique_ptr<Database> db_;
     bool is_db_preprocessed_;