Browse Source

Merge branch 'new_protocol' of vs-ssh.visualstudio.com:v3/haoche/SEALPIR2019/SEALPIR2019 into new_protocol

Kim Laine 5 years ago
parent
commit
c63ccdc55f
7 changed files with 233 additions and 96 deletions
  1. 25 9
      main.cpp
  2. 4 3
      pir.cpp
  3. 3 3
      pir.hpp
  4. 115 44
      pir_client.cpp
  5. 3 0
      pir_client.hpp
  6. 82 36
      pir_server.cpp
  7. 1 1
      pir_server.hpp

+ 25 - 9
main.cpp

@@ -14,17 +14,18 @@ using namespace seal;
 
 int main(int argc, char *argv[]) {
 
-    // uint64_t number_of_items = 1 << 13;
-    uint64_t number_of_items = 128;
-    //uint64_t number_of_items = 1 << 16;
+    //uint64_t number_of_items = 1 << 11;
+    //uint64_t number_of_items = 2048;
+    uint64_t number_of_items = 1 << 20;
 
     uint64_t size_per_item = 288; // in bytes
     // uint64_t size_per_item = 1 << 10; // 1 KB.
     // uint64_t size_per_item = 10 << 10; // 10 KB.
 
     uint32_t N = 2048;
-    uint32_t logt = 15;
-    uint32_t d = 1;
+    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
+    uint32_t logt = 12; 
+    uint32_t d = 2;
 
     EncryptionParameters params(scheme_type::BFV);
     PirParams pir_params;
@@ -57,7 +58,7 @@ int main(int argc, char *argv[]) {
     GaloisKeys galois_keys = client.generate_galois_keys();
 
     // Set galois key
-    cout << "Setting Galois keys" << endl;
+    cout << "Main: Setting Galois keys...";
     server.set_galois_key(0, galois_keys);
 
 
@@ -79,15 +80,18 @@ int main(int argc, char *argv[]) {
 
     // Choose an index of an element in the DB
     uint64_t ele_index = rd() % number_of_items; // element in DB at random position
+    //uint64_t ele_index = 35; 
+    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
     uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
     uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
-
     // Measure query generation
+    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
+
     auto time_query_s = high_resolution_clock::now();
     PirQuery query = client.generate_query(index);
     auto time_query_e = high_resolution_clock::now();
     auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "query generated" << endl;
+    cout << "Main: query generated" << endl;
 
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
@@ -105,6 +109,18 @@ int main(int argc, char *argv[]) {
     // Convert to elements
     vector<uint8_t> elems(N * logt / 8);
     coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
+    // cout << "printing the bytes...of the supposed item: "; 
+    // for (int i = 0; i < size_per_item; i++){
+    //     cout << (int) elems[offset*size_per_item + i] << ", "; 
+    // }
+    // cout << endl; 
+
+    // // cout << "offset = " << offset << endl; 
+
+    // cout << "printing the bytes of real item: "; 
+    // for (int i = 0; i < size_per_item; i++){
+    //     cout << (int) check_db.get()[ele_index *size_per_item + i] << ", "; 
+    // }
 
     // Check that we retrieved the correct element
     for (uint32_t i = 0; i < size_per_item; i++) {
@@ -118,7 +134,7 @@ int main(int argc, char *argv[]) {
 
     // Output results
     cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "PIRServer query processing generation time: " << time_server_us / 1000 << " ms"
+    cout << "PIRServer reply generation time: " << time_server_us / 1000 << " ms"
          << endl;
     cout << "PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
     cout << "PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;

+ 4 - 3
pir.cpp

@@ -65,6 +65,7 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
     uint32_t expansion_ratio = 0;
     for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
         double logqi = log2(params.coeff_modulus()[i].value());
+        cout << "PIR: logqi = " << logqi << endl; 
         expansion_ratio += ceil(logqi / logt);
     }
 
@@ -72,7 +73,7 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
     pir_params.dbc = 6;
     pir_params.n = plaintext_num;
     pir_params.nvec = nvec;
-    pir_params.expansion_ratio = expansion_ratio << 1;
+    pir_params.expansion_ratio = expansion_ratio << 1; // because one ciphertext = two polys
 }
 
 void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d, 
@@ -142,8 +143,8 @@ uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
 }
 
 // Number of database elements that can fit in a single FV plaintext
-uint64_t elements_per_ptxt(uint32_t logtp, uint64_t N, uint64_t ele_size) {
-    uint64_t coeff_per_ele = coefficients_per_element(logtp, ele_size);
+uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
+    uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
     uint64_t ele_per_ptxt = N / coeff_per_ele;
     assert(ele_per_ptxt > 0);
     return ele_per_ptxt;

+ 3 - 3
pir.hpp

@@ -10,13 +10,13 @@
 #define CIPHER_SIZE 32841
 
 typedef std::vector<seal::Plaintext> Database;
-typedef std::vector<seal::Ciphertext> PirQuery;
+typedef std::vector< std::vector< seal::Ciphertext >> PirQuery;
 typedef std::vector<seal::Ciphertext> PirReply;
 
 struct PirParams {
     std::uint64_t n;                 // number of plaintexts in database
-    std::uint32_t d;                 // number of dimensions for the database (usually 2)
-    std::uint32_t expansion_ratio;   // ratio of plaintext to ciphertext
+    std::uint32_t d;                 // number of dimensions for the database (1 or 2)
+    std::uint32_t expansion_ratio;   // ratio of ciphertext to plaintext
     std::uint32_t dbc;               // decomposition bit count (used by relinearization)
     std::vector<std::uint64_t> nvec; // size of each of the d dimensions
 };

+ 115 - 44
pir_client.cpp

@@ -21,25 +21,32 @@ PIRClient::PIRClient(const EncryptionParameters &params,
     evaluator_ = make_unique<Evaluator>(newcontext_);
 
     uint64_t t = params_.plain_modulus().value(); 
-    // 
-    int logt = floor(log2(params_.plain_modulus().value())); 
-    for(int i = 0; i < pir_params_.nvec.size(); i++){
-        uint64_t inverse_scale; 
-        int logm = ceil(log2(pir_params_.nvec[i]));  
 
-        int quo = logm / logt; 
-        int mod = logm % logt; 
-        inverse_scale = pow(2, logt - mod); 
-        if ((quo +1) %2 != 0){
-            inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
-        }
-        inverse_scales_.push_back(inverse_scale); 
-        if ( (inverse_scale << logm)  % t != 1){
-            throw logic_error("something wrong"); 
-        }
-        cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
-    }
+    uint64_t N = params_.poly_modulus_degree(); 
+
 
+    // 
+    // int logt = floor(log2(params_.plain_modulus().value())); 
+
+    // for(int i = 0; i < pir_params_.nvec.size(); i++){
+    //     uint64_t inverse_scale; 
+    //     // 
+    //     // If the number of items are less than N, then 
+    //     // we use logm.
+    //     int logm = ceil(log2(min(N, pir_params_.nvec[i])));  // if nvec > n what do we do?
+
+    //     int quo = logm / logt; 
+    //     int mod = logm % logt; 
+    //     inverse_scale = pow(2, logt - mod); 
+    //     if ((quo +1) %2 != 0){
+    //         inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
+    //     }
+    //     inverse_scales_.push_back(inverse_scale); 
+    //     if ( (inverse_scale << logm)  % t != 1){
+    //         throw logic_error("something wrong"); 
+    //     }
+    //     cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
+    // }
 }
 
 // void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
@@ -63,17 +70,36 @@ PIRClient::PIRClient(const EncryptionParameters &params,
 
 PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
-    vector<uint64_t> indices = compute_indices(desiredIndex, pir_params_.nvec);
-    vector<Ciphertext> result;
+    indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+
+    compute_inverse_scales(); 
+
+    vector<vector<Ciphertext> > result(pir_params_.d);
+    int N = params_.poly_modulus_degree(); 
 
     Plaintext pt(params_.poly_modulus_degree());
-    for (uint32_t i = 0; i < indices.size(); i++) {
-        pt.set_zero();
-        pt[indices[i]] = 1;
-        Ciphertext dest;
-        encryptor_->encrypt(pt, dest);
-        dest.parms_id() = params_.parms_id();
-        result.push_back(dest);
+    for (uint32_t i = 0; i < indices_.size(); i++) {
+        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] +0.0) / N);
+        // initialize result. 
+        cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
+        cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+        for (uint32_t j =0; j < num_ptxts; j++){
+            pt.set_zero();
+            if (indices_[i] > N*(j+1) || indices_[i] < N*j){
+                cout << "Client: coming here: so just encrypt zero." << endl; 
+                // just encrypt zero
+            } else{
+                cout << "Client: encrypting a real thing " << endl; 
+                uint64_t real_index = indices_[i] - N*j; 
+                pt[real_index] = 1;
+            }
+            Ciphertext dest;
+            encryptor_->encrypt(pt, dest);
+            dest.parms_id() = params_.parms_id();
+            result[i].push_back(dest);
+        }   
+        
+
     }
 
     return result;
@@ -81,17 +107,17 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
 uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
     auto N = params_.poly_modulus_degree();
-    auto logtp = ceil(log2(params_.plain_modulus().value()));
+    auto logt = floor(log2(params_.plain_modulus().value()));
 
-    auto ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
     return static_cast<uint64_t>(element_idx / ele_per_ptxt);
 }
 
 uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
     uint32_t N = params_.poly_modulus_degree();
-    uint32_t logtp = ceil(log2(params_.plain_modulus().value()));
+    uint32_t logt = floor(log2(params_.plain_modulus().value()));
 
-    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
     return element_idx % ele_per_ptxt;
 }
 
@@ -104,20 +130,20 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
     uint64_t t = params_.plain_modulus().value();
 
     for (uint32_t i = 0; i < recursion_level; i++) {
-
+        cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl; 
         vector<Ciphertext> newtemp;
         vector<Plaintext> tempplain;
 
         for (uint32_t j = 0; j < temp.size(); j++) {
             Plaintext ptxt;
             decryptor_->decrypt(temp[j], ptxt);
-            cout << " reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
+            cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
             // multiply by inverse_scale for every coefficient of ptxt
             for(int h = 0; h < ptxt.coeff_count(); h++){
                 ptxt[h] *= inverse_scales_[i]; 
                 ptxt[h] %= t; 
             }
-            cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
+            //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
             tempplain.push_back(ptxt);
 
 #ifdef DEBUG
@@ -129,9 +155,11 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
                 // Combine into one ciphertext.
                 Ciphertext combined = compose_to_ciphertext(tempplain);
                 newtemp.push_back(combined);
+                // cout << "Client: const term of ciphertext = " << combined[0] << endl; 
             }
         }
-
+        cout << "Client: done." << endl; 
+        cout << endl; 
         if (i == recursion_level - 1) {
             assert(temp.size() == 1);
             return tempplain[0];
@@ -153,11 +181,11 @@ GaloisKeys PIRClient::generate_galois_keys() {
     int N = params_.poly_modulus_degree();
     int logN = get_power_of_two(N);
 
-    cout << "printing galois elements...";
+    //cout << "printing galois elements...";
     for (int i = 0; i < logN; i++) {
         galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
 //#ifdef DEBUG
-        cout << galois_elts.back() << ", ";
+        // cout << galois_elts.back() << ", ";
 //#endif
     }
 
@@ -169,6 +197,7 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
     auto coeff_count = params_.poly_modulus_degree();
     auto coeff_mod_count = params_.coeff_modulus().size();
     uint64_t plainMod = params_.plain_modulus().value();
+    int logt = floor(log2(plainMod)); 
 
     Ciphertext result(newcontext_);
     result.resize(encrypted_count);
@@ -183,28 +212,25 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
             // which will be copied into the array to populate it at the current
             // index.
             double logqj = log2(params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj / log2(plainMod));
-
-            // cout << "expansion ratio = " << expansion_ratio << endl;
+            int expansion_ratio = ceil(logqj / logt);
             uint64_t cur = 1;
+            // cout << "Client: expansion_ratio = " << expansion_ratio << endl; 
 
             for (int k = 0; k < expansion_ratio; k++) {
-
                 // Compose here
                 const uint64_t *plain_coeff =
                     plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
                         .data();
 
-                for (int m = 0; m < coeff_count - 1; m++) {
+                for (int m = 0; m < coeff_count; m++) {
                     if (k == 0) {
                         *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
                     } else {
                         *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
                     }
                 }
-
-                *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
-                cur *= plainMod;
+                // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
+                cur <<= logt;
             }
 
             // XXX: Reduction modulo qj. This is needed?
@@ -220,3 +246,48 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
     result.parms_id() = params_.parms_id();
     return result;
 }
+
+
+void PIRClient::compute_inverse_scales(){
+    if (indices_.size() != pir_params_.nvec.size()){
+        throw invalid_argument("size mismatch"); 
+    }
+    int logt = floor(log2(params_.plain_modulus().value())); 
+
+    uint64_t N = params_.poly_modulus_degree(); 
+    uint64_t t = params_.plain_modulus().value();
+    int logN = log2(N);
+    int logm = logN;
+
+    inverse_scales_.clear(); 
+
+    for(int i = 0; i < pir_params_.nvec.size(); i++){
+        uint64_t index_modN = indices_[i] % N; 
+        uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N);  // number of query ciphertexts. 
+        uint64_t batchId = indices_[i] / N;  
+        if (batchId == numCtxt - 1) {
+            cout << "Client: adjusting the logm value..." << endl; 
+            logm = ceil(log2((pir_params_.nvec[i] % N)));
+        }
+
+        uint64_t inverse_scale; 
+ 
+
+        int quo = logm / logt; 
+        int mod = logm % logt; 
+        inverse_scale = pow(2, logt - mod); 
+        if ((quo +1) %2 != 0){
+            inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
+        }
+        inverse_scales_.push_back(inverse_scale); 
+        if ( (inverse_scale << logm)  % t != 1){
+            throw logic_error("something wrong"); 
+        }
+        cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
+    }
+
+
+
+
+}
+

+ 3 - 0
pir_client.hpp

@@ -23,6 +23,8 @@ class PIRClient {
     uint64_t get_fv_index(uint64_t element_idx, uint64_t ele_size);
     uint64_t get_fv_offset(uint64_t element_idx, uint64_t ele_size);
 
+    void compute_inverse_scales(); 
+
   private:
     // Should we store a decryptor and an encryptor?
     seal::EncryptionParameters params_;
@@ -34,6 +36,7 @@ class PIRClient {
     std::unique_ptr<seal::KeyGenerator> keygen_;
     std::shared_ptr<seal::SEALContext> newcontext_;
 
+    vector<uint64_t> indices_; // the indices for retrieval. 
     vector<uint64_t> inverse_scales_; 
 
     seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);

+ 82 - 36
pir_server.cpp

@@ -83,6 +83,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
     assert(coeff_per_ptxt <= N);
 
+    cout << "Server: total number of FV plaintext = " << total << endl;
+
+    cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl; 
+
+
+
     uint32_t offset = 0;
 
     for (uint64_t i = 0; i < total; i++) {
@@ -96,7 +102,6 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
         } else {
             process_bytes = bytes_per_ptxt;
         }
-
         // Get the coefficients of the elements that will be packed in plaintext i
         vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
         offset += process_bytes;
@@ -112,7 +117,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 
         Plaintext plain;
         vector_to_plaintext(coefficients, plain);
-        cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
+        // cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
         result->push_back(move(plain));
     }
 
@@ -159,10 +164,39 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 
     auto pool = MemoryManager::GetPool();
 
+
+    int N = params_.poly_modulus_degree();
+
+    int logt = floor(log2(params_.plain_modulus().value()));
+
+    cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
     for (uint32_t i = 0; i < nvec.size(); i++) {
+        cout << "Server: " << i + 1 << "-th recursion level started " << endl; 
+
+
+        vector<Ciphertext> expanded_query; 
+
         uint64_t n_i = nvec[i];
-        vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id, client);
-        cout << "Checking expanded query "; 
+        cout << "Server: n_i = " << n_i << endl; 
+        cout << "Server: expanding " << query[i].size() << "query ctxts" << endl;
+        for (uint32_t j = 0; j < query[i].size(); j++){
+            uint64_t total = N; 
+            if (j == query[i].size() - 1){
+                total = n_i % N; 
+            }
+            cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
+            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
+            expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
+                    std::make_move_iterator(expanded_query_part.end()));
+            expanded_query_part.clear(); 
+        }
+        cout << "Server: expansion done " << endl; 
+        if (expanded_query.size() != n_i) {
+            cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
+        }    
+
+        /*
+        cout << "Checking expanded query " << endl; 
         Plaintext tempPt; 
         for (int h = 0 ; h < expanded_query.size(); h++){
             cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
@@ -170,6 +204,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
             cout << tempPt.to_string()  << endl; 
         }
         cout << endl;
+        */
 
         // Transform expanded query to NTT, and ...
         for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
@@ -183,26 +218,40 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
             }
         }
 
+        for (uint64_t k = 0; k < product; k++) {
+            if ((*cur)[k].is_zero()){
+                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl; 
+            }
+        }
+
         product /= n_i;
 
-        vector<Ciphertext> intermediate(product);
+        vector<Ciphertext> intermediateCtxts(product);
         Ciphertext temp;
 
+
+
+
         for (uint64_t k = 0; k < product; k++) {
-            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediate[k]);
+
+            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
 
             for (uint64_t j = 1; j < n_i; j++) {
                 evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
-                evaluator_->add_inplace(intermediate[k], temp); // Adds to first component.
+                evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
             }
         }
 
-        for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
-            evaluator_->transform_from_ntt_inplace(intermediate[jj]);
+        for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
+            evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
+            // print intermediate ctxts? 
+            //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
         }
 
+
+
         if (i == nvec.size() - 1) {
-            return intermediate;
+            return intermediateCtxts;
         } else {
             intermediate_plain.clear();
             intermediate_plain.reserve(pir_params_.expansion_ratio * product);
@@ -214,19 +263,20 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 
             for (uint64_t rr = 0; rr < product; rr++) {
 
-                decompose_to_plaintexts_ptr(intermediate[rr],
-                    tempplain.get() + rr * pir_params_.expansion_ratio);
+                decompose_to_plaintexts_ptr(intermediateCtxts[rr],
+                    tempplain.get() + rr * pir_params_.expansion_ratio, logt);
 
                 for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
                     auto offset = rr * pir_params_.expansion_ratio + jj;
                     intermediate_plain.emplace_back(tempplain[offset]);
                 }
             }
-
             product *= pir_params_.expansion_ratio; // multiply by expansion rate.
         }
+        cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
+        cout << endl;
     }
-
+    cout << "reply generated!  " << endl;
     // This should never get here
     assert(0);
     vector<Ciphertext> fail(1);
@@ -252,13 +302,9 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     if (logm > ceil(log2(n))){
         throw logic_error("m > n is not allowed."); 
     }
-
-    cout << "galois elts at server: "; 
-    for (uint32_t i = 0; i < logm; i++) {
+    for (int i = 0; i < ceil(log2(n)); i++) {
         galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-        cout << galois_elts.back() << ", "; 
     }
-    cout << endl;
 
     vector<Ciphertext> temp;
     temp.push_back(encrypted);
@@ -267,44 +313,45 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     Ciphertext tempctxt_shifted;
     Ciphertext tempctxt_rotatedshifted;
 
+
     for (uint32_t i = 0; i < logm - 1; i++) {
         vector<Ciphertext> newtemp(temp.size() << 1);
         // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
         // some scaling....
         int index_raw = (n << 1) - (1 << i);
+        // TODO: galois elements. 
         int index = (index_raw * galois_elts[i]) % (n << 1);
-        cout << i << "-th expansion round, noise budget = " << endl; 
 
         for (uint32_t a = 0; a < temp.size(); a++) {
 
             evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
 
-            cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
+            //cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
 
 
             evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
             multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
 
-            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
+            //cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
 
 
             multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
 
-            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
+            // cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
 
 
             // Enc(2^i x^j) if j = 0 (mod 2**i).
             evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
         }
         temp = newtemp;
-
+        /*
         cout << "end: "; 
         for (int h = 0; h < temp.size();h++){
             cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", "; 
         }
         cout << endl; 
+        */
     }
-
     // Last step of the loop
     vector<Ciphertext> newtemp(temp.size() << 1);
     int index_raw = (n << 1) - (1 << (logm - 1));
@@ -312,7 +359,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     for (uint32_t a = 0; a < temp.size(); a++) {
         if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
             evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
-            cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
+            // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
         } else {
             evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
             evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
@@ -353,17 +400,16 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
     }
 }
 
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
+inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
 
     vector<Plaintext> result;
     auto coeff_count = params_.poly_modulus_degree();
     auto coeff_mod_count = params_.coeff_modulus().size();
     auto encrypted_count = encrypted.size();
 
-    // Generate powers of t.
-    uint64_t plainModMinusOne = params_.plain_modulus().value() - 1;
-    int exp = ceil(log2(plainModMinusOne + 1));
+    uint64_t t1 = 1 << logt;  //  t1 <= t. 
 
+    uint64_t t1minusone =  t1 -1; 
     // A triple for loop. Going over polys, moduli, and decomposed index.
 
     for (int i = 0; i < encrypted_count; i++) {
@@ -373,19 +419,19 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
             // create a polynomial to store the current decomposition value
             // which will be copied into the array to populate it at the current
             // index.
-            int logqj = log2(params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj + exp - 1) / exp;
-
-            // cout << "expansion ratio = " << expansion_ratio << endl;
+            double logqj = log2(params_.coeff_modulus()[j].value());
+            //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
+            int expansion_ratio =  ceil(logqj / logt); 
+            // cout << "local expansion ratio = " << expansion_ratio << endl;
             uint64_t curexp = 0;
             for (int k = 0; k < expansion_ratio; k++) {
                 // Decompose here
                 for (int m = 0; m < coeff_count; m++) {
                     plain_ptr[i * coeff_mod_count * expansion_ratio
                         + j * expansion_ratio + k][m] =
-                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
+                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & t1minusone;
                 }
-                curexp += exp;
+                curexp += logt;
             }
         }
     }

+ 1 - 1
pir_server.hpp

@@ -36,7 +36,7 @@ class PIRServer {
     std::map<int, seal::GaloisKeys> galoisKeys_;
     std::unique_ptr<seal::Evaluator> evaluator_;
 
-    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr);
+    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr, int logt);
     std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);
     void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
                              std::uint32_t index);