Browse Source

Now works with Seal 3.6

Andrew Beams 3 years ago
parent
commit
187bb2b134
9 changed files with 259 additions and 195 deletions
  1. 6 9
      CMakeLists.txt
  2. 107 0
      encoding_test.cpp
  3. 30 7
      main.cpp
  4. 31 94
      pir.cpp
  5. 1 10
      pir.hpp
  6. 36 58
      pir_client.cpp
  7. 3 1
      pir_client.hpp
  8. 40 14
      pir_server.cpp
  9. 5 2
      pir_server.hpp

+ 6 - 9
CMakeLists.txt

@@ -6,16 +6,13 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
 project(SealPIR VERSION 2.1 LANGUAGES CXX)
 set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/bin)
 
-add_executable(main 
-	main.cpp
+add_executable(main main.cpp pir.cpp pir_client.cpp pir_server.cpp
 )
 
-add_library(sealpir STATIC
-  pir.cpp
-  pir_client.cpp
-  pir_server.cpp
-)
+add_executable(encoding_test encoding_test.cpp pir.cpp)
+
+find_package(SEAL 3.6 REQUIRED)
 
-find_package(SEAL 3.2.0 EXACT REQUIRED)
+target_link_libraries(main SEAL::seal)
 
-target_link_libraries(main sealpir SEAL::seal)
+target_link_libraries(encoding_test SEAL::seal)

+ 107 - 0
encoding_test.cpp

@@ -0,0 +1,107 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <seal/seal.h>
+#include <memory>
+#include <random>
+#include <cstdint>
+#include <cstddef>
+
+using namespace std;
+using namespace seal;
+
+int main(int argc, char *argv[]) {
+    uint64_t number_of_items = 1 << 12;
+    uint64_t size_per_item = 288; // in bytes
+    uint32_t N = 4096;
+
+    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
+    uint32_t logt = 20; 
+    uint32_t d = 1;
+
+    EncryptionParameters params(scheme_type::bfv);
+    PirParams pir_params;
+
+    // Generates all parameters
+    cout << "Main: Generating all parameters" << endl;
+    gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
+
+    logt = floor(log2(params.plain_modulus().value()));
+
+    cout << "Main: Initializing the database (this may take some time) ..." << endl;
+
+    // Create test database
+    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+    // Copy of the database. We use this at the end to make sure we retrieved
+    // the correct element.
+    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+    random_device rd;
+    for (uint64_t i = 0; i < number_of_items; i++) {
+        for (uint64_t j = 0; j < size_per_item; j++) {
+            uint8_t val = rd() % 256;
+            db.get()[(i * size_per_item) + j] = val;
+            db_copy.get()[(i * size_per_item) + j] = val;
+        }
+    }
+
+    shared_ptr<SEALContext> context = make_shared<SEALContext>(params, true);
+    unique_ptr<KeyGenerator> keygen = make_unique<KeyGenerator>(*context);
+    
+    PublicKey public_key;
+    keygen->create_public_key(public_key);
+    
+    unique_ptr<Encryptor> encryptor = make_unique<Encryptor>(*context, public_key);
+
+    SecretKey secret_key = keygen->secret_key();
+    unique_ptr<Decryptor> decryptor = make_unique<Decryptor>(*context, secret_key);
+
+    unique_ptr<Evaluator> evaluator = make_unique<Evaluator>(*context);
+
+
+    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, size_per_item);
+    uint64_t bytes_per_ptxt = ele_per_ptxt * size_per_item;
+
+    uint64_t db_size = number_of_items * size_per_item;
+
+    uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, size_per_item);
+    assert(coeff_per_ptxt <= N);
+
+
+    vector<uint64_t> coefficients = bytes_to_coeffs(logt, db.get(), size_per_item);
+    uint64_t used = coefficients.size();
+
+    assert(used <= coeff_per_ptxt);
+
+    // Pad the rest with 1s
+    for (uint64_t j = 0; j < (N - used); j++) {
+        coefficients.push_back(1);
+    }
+
+    Plaintext plain;
+    vector_to_plaintext(coefficients, plain);
+
+    //cout << "Plaintext: " << plain.to_string() << endl;
+
+    vector<uint8_t> elems(N * logt / 8);
+    coeffs_to_bytes(logt, plain, elems.data(), (N * logt) / 8);
+
+    bool failed = false;
+    // Check that we retrieved the correct element
+    for (uint32_t i = 0; i < size_per_item; i++) {
+        if (elems[i] != db_copy.get()[i]) {
+            cout << "Main: elems " << (int)elems[i] << ", db "
+                 << (int) db_copy.get()[i] << endl;
+            cout << "Main: PIR result wrong at " << i <<  endl;
+            failed = true;
+        }
+    }
+    if(failed){
+        return -1;
+    }
+    else{
+        cout << "succeeded" << endl;
+    }
+
+}

+ 30 - 7
main.cpp

@@ -16,13 +16,13 @@ int main(int argc, char *argv[]) {
 
     uint64_t number_of_items = 1 << 12;
     uint64_t size_per_item = 288; // in bytes
-    uint32_t N = 2048;
+    uint32_t N = 4096;
 
     // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = 12; 
+    uint32_t logt = 20; 
     uint32_t d = 2;
 
-    EncryptionParameters params(scheme_type::BFV);
+    EncryptionParameters params(scheme_type::bfv);
     PirParams pir_params;
 
     // Generates all parameters
@@ -41,7 +41,7 @@ int main(int argc, char *argv[]) {
     random_device rd;
     for (uint64_t i = 0; i < number_of_items; i++) {
         for (uint64_t j = 0; j < size_per_item; j++) {
-            auto val = rd() % 256;
+            uint8_t val = rd() % 256;
             db.get()[(i * size_per_item) + j] = val;
             db_copy.get()[(i * size_per_item) + j] = val;
         }
@@ -87,7 +87,7 @@ int main(int argc, char *argv[]) {
 
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
-    PirReply reply = server.generate_reply(query, 0);
+    PirReply reply = server.generate_reply(query, 0, client);
     auto time_server_e = high_resolution_clock::now();
     auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
 
@@ -97,19 +97,42 @@ int main(int argc, char *argv[]) {
     auto time_decode_e = chrono::high_resolution_clock::now();
     auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
 
+    Ciphertext one_ct = client.get_encrypted_one();
+    Ciphertext reply2 = server.generate_public_reply(one_ct, index);
+    Plaintext result2 = client.decrypt(reply2);
+
+    logt = floor(log2(params.plain_modulus().value()));
+
     // Convert from FV plaintext (polynomial) to database element at the client
     vector<uint8_t> elems(N * logt / 8);
     coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
 
+    vector<uint8_t> elems2(N * logt / 8);
+    coeffs_to_bytes(logt, result2, elems2.data(), (N * logt) / 8);
+
+    // Check that we retrieved the correct element
+    for (uint32_t i = 0; i < size_per_item; i++) {
+        if (elems[(offset * size_per_item) + i] != elems2[(offset * size_per_item) + i]) {
+            cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", elems2 "
+                 << (int)elems[(offset * size_per_item) + i] << endl;
+            cout << "Main: PIR results inconsistent at" << i << endl;
+            return -1;
+        }
+    }
+
+    bool failed = false;
     // Check that we retrieved the correct element
     for (uint32_t i = 0; i < size_per_item; i++) {
         if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
             cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
                  << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong!" << endl;
-            return -1;
+            cout << "Main: PIR result wrong at " << i <<  endl;
+            failed = true;
         }
     }
+    if(failed){
+        return -1;
+    }
 
     // Output results
     cout << "Main: PIR result correct!" << endl;

+ 31 - 94
pir.cpp

@@ -4,31 +4,26 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-vector<uint64_t> get_dimensions(uint64_t plaintext_num, uint32_t d) {
+std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::uint32_t d) {
 
     assert(d > 0);
-    assert(plaintext_num > 0);
+    assert(num_of_plaintexts > 0);
 
-    vector<uint64_t> dimensions(d);
+    std::uint64_t root = max(static_cast<uint32_t>(2),static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0/d))));
 
-    for (uint32_t i = 0; i < d; i++) {
-        dimensions[i] = std::max((uint32_t) 2, (uint32_t) floor(pow(plaintext_num, 1.0/d)));
-    }
-
-    uint32_t product = 1;
-    uint32_t j = 0;
+    std::vector<std::uint64_t> dimensions(d, root);
 
-    // if plaintext_num is not a d-power
-    if ((double) dimensions[0] != pow(plaintext_num, 1.0 / d)) {
-        while  (product < plaintext_num && j < d) {
-            product = 1;
-            dimensions[j++]++;
-            for (uint32_t i = 0; i < d; i++) {
-                product *= dimensions[i];
-            }
-        }
+    for(int i = 0; i < d; i++){
+        if(accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>()) > num_of_plaintexts){
+            break;
+        } 
+        dimensions[i] += 1;
     }
 
+    std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>());
+    cout << "Total:" << num_of_plaintexts << endl << "Prod: "
+     << prod << endl;
+    assert(prod > num_of_plaintexts);
     return dimensions;
 }
 
@@ -40,25 +35,22 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
 
     // plain modulus = a power of 2 plus 1
     uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
-    uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
 
 #ifdef DEBUG
     cout << "log(plain mod) before expand = " << logt << endl;
     cout << "number of FV plaintexts = " << plaintext_num << endl;
 #endif
 
-    vector<SmallModulus> coeff_mod_array;
-    uint32_t logq = 0;
+    params.set_poly_modulus_degree(N);
+    params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
+    params.set_plain_modulus(PlainModulus::Batching(N, logt));
 
-    for (uint32_t i = 0; i < 1; i++) {
-        coeff_mod_array.emplace_back(SmallModulus());
-        coeff_mod_array[i] = DefaultParams::small_mods_60bit(i);
-        logq += coeff_mod_array[i].bit_count();
-    }
+    logt = floor(log2(params.plain_modulus().value()));
 
-    params.set_poly_modulus_degree(N);
-    params.set_coeff_modulus(coeff_mod_array);
-    params.set_plain_modulus(plain_mod);
+    cout << "logt: " << logt << endl << "N: " << N << endl <<
+    "ele_num: " << ele_num << endl << "ele_size: " << ele_size << endl;
+
+    uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
 
     vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
 
@@ -179,7 +171,7 @@ void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, u
 void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
     uint32_t coeff_count = coeffs.size();
     plain.resize(coeff_count);
-    util::set_uint_uint(coeffs.data(), coeff_count, plain.data());
+    util::set_uint(coeffs.data(), coeff_count, plain.data());
 }
 
 vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
@@ -205,68 +197,13 @@ vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
     return result;
 }
 
-inline Ciphertext deserialize_ciphertext(string s) {
-    Ciphertext c;
-    std::istringstream input(s);
-    c.unsafe_load(input);
-    return c;
-}
-
-
-vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext) {
-    vector<Ciphertext> c;
-    for (uint32_t i = 0; i < count; i++) {
-        c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext)));
-    }
-    return c;
-}
-
-PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext) {
-    vector<vector<Ciphertext>> c;
-    for (uint32_t i = 0; i < d; i++) {
-        c.push_back(deserialize_ciphertexts(
-              count, 
-              s.substr(i * count * len_ciphertext, count * len_ciphertext),
-              len_ciphertext)
-        );
-    }
-    return c;
-}
-
-
-inline string serialize_ciphertext(Ciphertext c) {
-    std::ostringstream output;
-    c.save(output);
-    return output.str();
-}
-
-string serialize_ciphertexts(vector<Ciphertext> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-        s.append(serialize_ciphertext(c[i]));
-    }
-    return s;
-}
-
-string serialize_query(vector<vector<Ciphertext>> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-      for (uint32_t j = 0; j < c[i].size(); j++) {
-        s.append(serialize_ciphertext(c[i][j]));
-      }
-    }
-    return s;
-}
-
-string serialize_galoiskeys(GaloisKeys g) {
-    std::ostringstream output;
-    g.save(output);
-    return output.str();
-}
-
-GaloisKeys *deserialize_galoiskeys(string s) {
-    GaloisKeys *g = new GaloisKeys();
-    std::istringstream input(s);
-    g->unsafe_load(input);
-    return g;
+uint64_t InvertMod(uint64_t m, const seal::Modulus& mod) {
+  if (mod.uint64_count() > 1) {
+    cout << "Mod too big to invert";
+  }
+  uint64_t inverse = 0;
+  if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
+    cout << "Could not invert value";
+  }
+  return inverse;
 }

+ 1 - 10
pir.hpp

@@ -62,13 +62,4 @@ void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, seal::Plainte
 std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
                                            std::vector<std::uint64_t> nvec);
 
-// Serialize and deserialize ciphertexts to send them over the network
-PirQuery deserialize_query(std::uint32_t d, uint32_t count, std::string s, std::uint32_t len_ciphertext);
-std::vector<seal::Ciphertext> deserialize_ciphertexts(std::uint32_t count, std::string s,
-                                                      std::uint32_t len_ciphertext);
-std::string serialize_ciphertexts(std::vector<seal::Ciphertext> c);
-std::string serialize_query(std::vector<std::vector<seal::Ciphertext>> c);
-
-// Serialize and deserialize galois keys to send them over the network
-std::string serialize_galoiskeys(seal::GaloisKeys g);
-seal::GaloisKeys *deserialize_galoiskeys(std::string s);
+uint64_t InvertMod(uint64_t m, const seal::Modulus& mod);

+ 36 - 58
pir_client.cpp

@@ -8,17 +8,20 @@ PIRClient::PIRClient(const EncryptionParameters &params,
                      const PirParams &pir_parms) :
     params_(params){
 
-    newcontext_ = SEALContext::Create(params_);
+    newcontext_ = make_shared<SEALContext>(params, true);
 
     pir_params_ = pir_parms;
 
-    keygen_ = make_unique<KeyGenerator>(newcontext_);
-    encryptor_ = make_unique<Encryptor>(newcontext_, keygen_->public_key());
+    keygen_ = make_unique<KeyGenerator>(*newcontext_);
+    
+    PublicKey public_key;
+    keygen_->create_public_key(public_key);
+    encryptor_ = make_unique<Encryptor>(*newcontext_, public_key);
 
     SecretKey secret_key = keygen_->secret_key();
+    decryptor_ = make_unique<Decryptor>(*newcontext_, secret_key);
 
-    decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
-    evaluator_ = make_unique<Evaluator>(newcontext_);
+    evaluator_ = make_unique<Evaluator>(*newcontext_);
 }
 
 
@@ -26,8 +29,6 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
     indices_ = compute_indices(desiredIndex, pir_params_.nvec);
 
-    compute_inverse_scales(); 
-
     vector<vector<Ciphertext> > result(pir_params_.d);
     int N = params_.poly_modulus_degree(); 
 
@@ -37,6 +38,7 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
         // initialize result. 
         cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
         cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+        
         for (uint32_t j =0; j < num_ptxts; j++){
             pt.set_zero();
             if (indices_[i] > N*(j+1) || indices_[i] < N*j){
@@ -49,11 +51,18 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
                 cout << "Client: encrypting a real thing " << endl; 
 #endif 
                 uint64_t real_index = indices_[i] - N*j; 
-                pt[real_index] = 1;
+                uint64_t n_i = pir_params_.nvec[i];
+                uint64_t total = N; 
+                if (j == num_ptxts - 1){
+                    total = n_i % N; 
+                }
+                uint64_t log_total = ceil(log2(total));
+
+                cout << "Client: Inverting " << pow(2, log_total) << endl;
+                pt[real_index] = InvertMod(pow(2, log_total), params_.plain_modulus());
             }
             Ciphertext dest;
             encryptor_->encrypt(pt, dest);
-            dest.parms_id() = params_.parms_id();
             result[i].push_back(dest);
         }   
     }
@@ -96,11 +105,7 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 #ifdef DEBUG
             cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
 #endif
-            // multiply by inverse_scale for every coefficient of ptxt
-            for(int h = 0; h < ptxt.coeff_count(); h++){
-                ptxt[h] *= inverse_scales_[recursion_level -  1 - i]; 
-                ptxt[h] %= t; 
-            }
+            
             //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
             tempplain.push_back(ptxt);
 
@@ -136,19 +141,20 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 
 GaloisKeys PIRClient::generate_galois_keys() {
     // Generate the Galois keys needed for coeff_select.
-    vector<uint64_t> galois_elts;
+    vector<uint32_t> galois_elts;
     int N = params_.poly_modulus_degree();
     int logN = get_power_of_two(N);
 
     //cout << "printing galois elements...";
     for (int i = 0; i < logN; i++) {
-        galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
+        galois_elts.push_back((N + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
 //#ifdef DEBUG
         // cout << galois_elts.back() << ", ";
 //#endif
     }
-
-    return keygen_->galois_keys(pir_params_.dbc, galois_elts);
+    GaloisKeys gal_keys;
+    keygen_->create_galois_keys(galois_elts, gal_keys);
+    return gal_keys;
 }
 
 Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
@@ -158,7 +164,7 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
     uint64_t plainMod = params_.plain_modulus().value();
     int logt = floor(log2(plainMod)); 
 
-    Ciphertext result(newcontext_);
+    Ciphertext result(*newcontext_);
     result.resize(encrypted_count);
 
     // A triple for loop. Going over polys, moduli, and decomposed index.
@@ -202,47 +208,19 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
         }
     }
 
-    result.parms_id() = params_.parms_id();
     return result;
 }
 
-
-void PIRClient::compute_inverse_scales(){
-    if (indices_.size() != pir_params_.nvec.size()){
-        throw invalid_argument("size mismatch"); 
-    }
-    int logt = floor(log2(params_.plain_modulus().value())); 
-
-    uint64_t N = params_.poly_modulus_degree(); 
-    uint64_t t = params_.plain_modulus().value();
-    int logN = log2(N);
-    int logm = logN;
-
-    inverse_scales_.clear(); 
-
-    for(int i = 0; i < pir_params_.nvec.size(); i++){
-        uint64_t index_modN = indices_[i] % N; 
-        uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N);  // number of query ciphertexts. 
-        uint64_t batchId = indices_[i] / N;  
-        if (batchId == numCtxt - 1) {
-            cout << "Client: adjusting the logm value..." << endl; 
-            logm = ceil(log2((pir_params_.nvec[i] % N)));
-        }
-
-        uint64_t inverse_scale; 
- 
-
-        int quo = logm / logt; 
-        int mod = logm % logt; 
-        inverse_scale = pow(2, logt - mod); 
-        if ((quo +1) %2 != 0){
-            inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
-        }
-        inverse_scales_.push_back(inverse_scale); 
-        if ( (inverse_scale << logm)  % t != 1){
-            throw logic_error("something wrong"); 
-        }
-        cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
-    }
+Ciphertext PIRClient::get_encrypted_one(){
+    Ciphertext one_ct;
+    Plaintext one("1");
+    encryptor_->encrypt(one, one_ct);
+    return one_ct;
 }
 
+
+Plaintext PIRClient::decrypt(seal::Ciphertext ct){
+    Plaintext result;
+    decryptor_->decrypt(ct, result);
+    return result;
+}

+ 3 - 1
pir_client.hpp

@@ -20,7 +20,9 @@ class PIRClient {
     uint64_t get_fv_index(uint64_t element_idx, uint64_t ele_size);
     uint64_t get_fv_offset(uint64_t element_idx, uint64_t ele_size);
 
-    void compute_inverse_scales(); 
+    seal::Ciphertext get_encrypted_one();
+    seal::Plaintext decrypt(seal::Ciphertext ct);
+
 
   private:
     seal::EncryptionParameters params_;

+ 40 - 14
pir_server.cpp

@@ -10,8 +10,8 @@ PIRServer::PIRServer(const EncryptionParameters &params, const PirParams &pir_pa
     pir_params_(pir_params),
     is_db_preprocessed_(false)
 {
-    auto context = SEALContext::Create(params, false);
-    evaluator_ = make_unique<Evaluator>(context);
+    context_ = make_shared<SEALContext>(params, true);
+    evaluator_ = make_unique<Evaluator>(*context_);
 }
 
 void PIRServer::preprocess_database() {
@@ -19,7 +19,7 @@ void PIRServer::preprocess_database() {
 
         for (uint32_t i = 0; i < db_->size(); i++) {
             evaluator_->transform_to_ntt_inplace(
-                db_->operator[](i), params_.parms_id());
+                db_->operator[](i), context_->first_parms_id());
         }
 
         is_db_preprocessed_ = true;
@@ -42,6 +42,9 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     uint32_t logt = floor(log2(params_.plain_modulus().value()));
     uint32_t N = params_.poly_modulus_degree();
 
+    cout << "logt: " << logt << endl << "N: " << N << endl <<
+    "ele_num: " << ele_num << endl << "ele_size: " << ele_size << endl;
+
     // number of FV plaintexts needed to represent all elements
     uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
 
@@ -51,6 +54,9 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
         prod *= pir_params_.nvec[i];
     }
     uint64_t matrix_plaintexts = prod;
+    cout << "Total:" << total << endl << "Prod: "
+     << matrix_plaintexts << endl;
+
     assert(total <= matrix_plaintexts);
 
     auto result = make_unique<vector<Plaintext>>();
@@ -123,11 +129,10 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 }
 
 void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
-    galkey.parms_id() = params_.parms_id();
     galoisKeys_[client_id] = galkey;
 }
 
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIRClient& client) {
 
     vector<uint64_t> nvec = pir_params_.nvec;
     uint64_t product = 1;
@@ -161,10 +166,10 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
         for (uint32_t j = 0; j < query[i].size(); j++){
             uint64_t total = N; 
             if (j == query[i].size() - 1){
-                total = n_i % N; 
+                total = ((n_i - 1) % N) + 1; 
             }
             cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
-            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
+            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
             expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
                     std::make_move_iterator(expanded_query_part.end()));
             expanded_query_part.clear(); 
@@ -174,16 +179,20 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
             cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
         }    
 
-        /*
+        
         cout << "Checking expanded query " << endl; 
         Plaintext tempPt; 
         for (int h = 0 ; h < expanded_query.size(); h++){
-            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
             client.decryptor_->decrypt(expanded_query[h], tempPt); 
+            if(tempPt.is_zero()){
+                continue;
+            }
+            cout << "index: " << h << ", ";
+            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
             cout << tempPt.to_string()  << endl; 
         }
         cout << endl;
-        */
+        
 
         // Transform expanded query to NTT, and ...
         for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
@@ -193,7 +202,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
         // Transform plaintext to NTT. If database is pre-processed, can skip
         if ((!is_db_preprocessed_) || i > 0) {
             for (uint32_t jj = 0; jj < cur->size(); jj++) {
-                evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
+                evaluator_->transform_to_ntt_inplace((*cur)[jj], context_->first_parms_id());
             }
         }
 
@@ -257,8 +266,19 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
     return fail;
 }
 
+Ciphertext PIRServer::generate_public_reply(Ciphertext one_ct, std::uint64_t desiredIndex){
+    vector<Plaintext> *cur = db_.get();
+    Ciphertext result;
+    evaluator_->transform_to_ntt_inplace(one_ct);
+    cout << "transformed" << endl;
+    evaluator_->multiply_plain(one_ct, (*cur)[desiredIndex], result);
+    cout << "reply generated" << endl;
+    evaluator_->transform_from_ntt_inplace(result);
+    return result;
+}
+
 inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
-                                           uint32_t client_id) {
+                                           uint32_t client_id, const PIRClient& client) {
 
 #ifdef DEBUG
     uint64_t plainMod = params_.plain_modulus().value();
@@ -277,7 +297,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
         throw logic_error("m > n is not allowed."); 
     }
     for (int i = 0; i < ceil(log2(n)); i++) {
-        galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
+        galois_elts.push_back((n + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
     }
 
     vector<Ciphertext> temp;
@@ -344,13 +364,19 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     vector<Ciphertext>::const_iterator first = newtemp.begin();
     vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
     vector<Ciphertext> newVec(first, last);
+
+    for(Ciphertext c: newVec){
+
+    }
+
+
     return newVec;
 }
 
 inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
                                     uint32_t index) {
 
-    auto coeff_mod_count = params_.coeff_modulus().size();
+    auto coeff_mod_count = params_.coeff_modulus().size() - 1;
     auto coeff_count = params_.poly_modulus_degree();
     auto encrypted_count = encrypted.size();
 

+ 5 - 2
pir_server.hpp

@@ -17,9 +17,11 @@ class PIRServer {
     void preprocess_database();
 
     std::vector<seal::Ciphertext> expand_query(
-            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
+            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id, const PIRClient& client);
 
-    PirReply generate_reply(PirQuery query, std::uint32_t client_id);
+    PirReply generate_reply(PirQuery query, std::uint32_t client_id, const PIRClient& client);
+    
+    seal::Ciphertext generate_public_reply(seal::Ciphertext one_ct, std::uint64_t desiredIndex);
 
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
 
@@ -30,6 +32,7 @@ class PIRServer {
     bool is_db_preprocessed_;
     std::map<int, seal::GaloisKeys> galoisKeys_;
     std::unique_ptr<seal::Evaluator> evaluator_;
+    std::shared_ptr<seal::SEALContext> context_;
 
     void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr, int logt);
     std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);