Browse Source

Renamed some things

Andrew Beams 3 years ago
parent
commit
144490f9c3
2 changed files with 27 additions and 37 deletions
  1. 23 33
      pir_client.cpp
  2. 4 4
      pir_client.hpp

+ 23 - 33
pir_client.cpp

@@ -4,24 +4,23 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-PIRClient::PIRClient(const EncryptionParameters &params,
-                     const PirParams &pir_parms) :
-    params_(params){
+PIRClient::PIRClient(const EncryptionParameters &enc_params,
+                     const PirParams &pir_params) :
+    enc_params_(enc_params),
+    pir_params_(pir_params){
 
-    newcontext_ = make_shared<SEALContext>(params, true);
+    context_ = make_shared<SEALContext>(enc_params, true);
 
-    pir_params_ = pir_parms;
-
-    keygen_ = make_unique<KeyGenerator>(*newcontext_);
+    keygen_ = make_unique<KeyGenerator>(*context_);
     
     PublicKey public_key;
     keygen_->create_public_key(public_key);
-    encryptor_ = make_unique<Encryptor>(*newcontext_, public_key);
+    encryptor_ = make_unique<Encryptor>(*context_, public_key);
 
     SecretKey secret_key = keygen_->secret_key();
-    decryptor_ = make_unique<Decryptor>(*newcontext_, secret_key);
+    decryptor_ = make_unique<Decryptor>(*context_, secret_key);
 
-    evaluator_ = make_unique<Evaluator>(*newcontext_);
+    evaluator_ = make_unique<Evaluator>(*context_);
 }
 
 
@@ -30,9 +29,9 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
     indices_ = compute_indices(desiredIndex, pir_params_.nvec);
 
     vector<vector<Ciphertext> > result(pir_params_.d);
-    int N = params_.poly_modulus_degree(); 
+    int N = enc_params_.poly_modulus_degree(); 
 
-    Plaintext pt(params_.poly_modulus_degree());
+    Plaintext pt(enc_params_.poly_modulus_degree());
     for (uint32_t i = 0; i < indices_.size(); i++) {
         uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
         // initialize result. 
@@ -59,7 +58,7 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
                 uint64_t log_total = ceil(log2(total));
 
                 cout << "Client: Inverting " << pow(2, log_total) << endl;
-                pt[real_index] = invert_mod(pow(2, log_total), params_.plain_modulus());
+                pt[real_index] = invert_mod(pow(2, log_total), enc_params_.plain_modulus());
             }
             Ciphertext dest;
             encryptor_->encrypt(pt, dest);
@@ -71,16 +70,16 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 }
 
 uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
-    auto N = params_.poly_modulus_degree();
-    auto logt = floor(log2(params_.plain_modulus().value()));
+    auto N = enc_params_.poly_modulus_degree();
+    auto logt = floor(log2(enc_params_.plain_modulus().value()));
 
     auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
     return static_cast<uint64_t>(element_idx / ele_per_ptxt);
 }
 
 uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
-    uint32_t N = params_.poly_modulus_degree();
-    uint32_t logt = floor(log2(params_.plain_modulus().value()));
+    uint32_t N = enc_params_.poly_modulus_degree();
+    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
 
     uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
     return element_idx % ele_per_ptxt;
@@ -92,7 +91,7 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 
     vector<Ciphertext> temp = reply;
 
-    uint64_t t = params_.plain_modulus().value();
+    uint64_t t = enc_params_.plain_modulus().value();
 
     for (uint32_t i = 0; i < recursion_level; i++) {
         cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl; 
@@ -142,7 +141,7 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 GaloisKeys PIRClient::generate_galois_keys() {
     // Generate the Galois keys needed for coeff_select.
     vector<uint32_t> galois_elts;
-    int N = params_.poly_modulus_degree();
+    int N = enc_params_.poly_modulus_degree();
     int logN = get_power_of_two(N);
 
     //cout << "printing galois elements...";
@@ -159,12 +158,12 @@ GaloisKeys PIRClient::generate_galois_keys() {
 
 Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
     size_t encrypted_count = 2;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    uint64_t plainMod = params_.plain_modulus().value();
+    auto coeff_count = enc_params_.poly_modulus_degree();
+    auto coeff_mod_count = enc_params_.coeff_modulus().size();
+    uint64_t plainMod = enc_params_.plain_modulus().value();
     int logt = floor(log2(plainMod)); 
 
-    Ciphertext result(*newcontext_);
+    Ciphertext result(*context_);
     result.resize(encrypted_count);
 
     // A triple for loop. Going over polys, moduli, and decomposed index.
@@ -176,7 +175,7 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
             // create a polynomial to store the current decomposition value
             // which will be copied into the array to populate it at the current
             // index.
-            double logqj = log2(params_.coeff_modulus()[j].value());
+            double logqj = log2(enc_params_.coeff_modulus()[j].value());
             int expansion_ratio = ceil(logqj / logt);
             uint64_t cur = 1;
             // cout << "Client: expansion_ratio = " << expansion_ratio << endl; 
@@ -194,17 +193,8 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
                         *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
                     }
                 }
-                // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
                 cur <<= logt;
             }
-
-            // XXX: Reduction modulo qj. This is needed?
-            /*
-            for (int m = 0; m < coeff_count; m++) {
-                *(encrypted_pointer + m + j * coeff_count) %=
-                    params_.coeff_modulus()[j].value();
-            }
-            */
         }
     }
 

+ 4 - 4
pir_client.hpp

@@ -8,8 +8,8 @@ using namespace std;
 
 class PIRClient {
   public:
-    PIRClient(const seal::EncryptionParameters &parms,
-               const PirParams &pirparms);
+    PIRClient(const seal::EncryptionParameters &encparms,
+               const PirParams &pirparams);
 
     PirQuery generate_query(std::uint64_t desiredIndex);
     seal::Plaintext decode_reply(PirReply reply);
@@ -22,14 +22,14 @@ class PIRClient {
 
 
   private:
-    seal::EncryptionParameters params_;
+    seal::EncryptionParameters enc_params_;
     PirParams pir_params_;
 
     std::unique_ptr<seal::Encryptor> encryptor_;
     std::unique_ptr<seal::Decryptor> decryptor_;
     std::unique_ptr<seal::Evaluator> evaluator_;
     std::unique_ptr<seal::KeyGenerator> keygen_;
-    std::shared_ptr<seal::SEALContext> newcontext_;
+    std::shared_ptr<seal::SEALContext> context_;
 
     vector<uint64_t> indices_; // the indices for retrieval. 
     vector<uint64_t> inverse_scales_;