Browse Source

applied Kim's patch for 3.1

hao chen 5 years ago
parent
commit
7cdd24807e
8 changed files with 136 additions and 131 deletions
  1. 1 2
      CMakeLists.txt
  2. 18 8
      main.cpp
  3. 10 12
      pir.cpp
  4. 3 3
      pir.hpp
  5. 36 34
      pir_client.cpp
  6. 1 0
      pir_client.hpp
  7. 61 65
      pir_server.cpp
  8. 6 7
      pir_server.hpp

+ 1 - 2
CMakeLists.txt

@@ -11,10 +11,9 @@ add_executable(sealpir
 	pir_server.cpp
 )
 
-find_package(SEAL 2.3.1 EXACT REQUIRED)
+find_package(SEAL 3.1.0 EXACT REQUIRED)
 find_package(Threads REQUIRED)
 
 target_link_libraries(sealpir 
 	SEAL::seal
-	Threads::Threads
 )

+ 18 - 8
main.cpp

@@ -1,10 +1,15 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
+#include <seal/seal.h>
 #include <chrono>
+#include <memory>
 #include <random>
+#include <cstdint>
+#include <cstddef>
 
-using namespace chrono;
+using namespace std::chrono;
+using namespace std;
 using namespace seal;
 
 int main(int argc, char *argv[]) {
@@ -21,8 +26,8 @@ int main(int argc, char *argv[]) {
     uint32_t logt = 20;
     uint32_t d = 2;
 
-    EncryptionParameters params;
-    EncryptionParameters expanded_params;
+    EncryptionParameters params(scheme_type::BFV);
+    EncryptionParameters expanded_params(scheme_type::BFV);
     PirParams pir_params;
 
     // Generates all parameters
@@ -30,12 +35,17 @@ int main(int argc, char *argv[]) {
     gen_params(number_of_items, size_per_item, N, logt, d, params, expanded_params, pir_params);
 
     // Create test database
-    uint8_t *db = (uint8_t *)malloc(number_of_items * size_per_item);
+    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+    // For testing purposes only
+    auto check_db(make_unique<uint8_t[]>(number_of_items * size_per_item));
 
     random_device rd;
     for (uint64_t i = 0; i < number_of_items; i++) {
         for (uint64_t j = 0; j < size_per_item; j++) {
-            *(db + (i * size_per_item) + j) = rd() % 256;
+            auto val = rd() % 256;
+            db.get()[(i * size_per_item) + j] = val;
+            check_db.get()[(i * size_per_item) + j] = val;
         }
     }
 
@@ -66,7 +76,7 @@ int main(int argc, char *argv[]) {
 
     // Measure database setup
     auto time_pre_s = high_resolution_clock::now();
-    server.set_database(db, number_of_items, size_per_item);
+    server.set_database(move(db), number_of_items, size_per_item);
     server.preprocess_database();
     auto time_pre_e = high_resolution_clock::now();
     auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
@@ -101,9 +111,9 @@ int main(int argc, char *argv[]) {
 
     // Check that we retrieved the correct element
     for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[(offset * size_per_item) + i] != db[(ele_index * size_per_item) + i]) {
+        if (elems[(offset * size_per_item) + i] != check_db.get()[(ele_index * size_per_item) + i]) {
             cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
-                 << (int)db[(ele_index * size_per_item) + i] << endl;
+                 << check_db.get()[(ele_index * size_per_item) + i] << endl;
             cout << "PIR result wrong!" << endl;
             return -1;
         }

+ 10 - 12
pir.cpp

@@ -58,11 +58,11 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
         logq += coeff_mod_array[i].bit_count();
     }
 
-    params.set_poly_modulus("1x^" + to_string(N) + " + 1");
+    params.set_poly_modulus_degree(N);
     params.set_coeff_modulus(coeff_mod_array);
     params.set_plain_modulus(plain_mod);
 
-    expanded_params.set_poly_modulus("1x^" + to_string(N) + " + 1");
+    expanded_params.set_poly_modulus_degree(N);
     expanded_params.set_coeff_modulus(coeff_mod_array);
     expanded_params.set_plain_modulus(expanded_plain_mod);
 
@@ -86,7 +86,7 @@ void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
                    PirParams &pir_params) {
 
     uint32_t logt = ceil(log2(old_params.plain_modulus().value()));
-    uint32_t N = old_params.poly_modulus().coeff_count() - 1;
+    uint32_t N = old_params.poly_modulus_degree();
 
     // Determine the maximum size of each dimension
     uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
@@ -100,7 +100,7 @@ void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
     cout << "number of FV plaintexts = " << plaintext_num << endl;
 #endif
 
-    expanded_params.set_poly_modulus(old_params.poly_modulus());
+    expanded_params.set_poly_modulus_degree(old_params.poly_modulus_degree());
     expanded_params.set_coeff_modulus(old_params.coeff_modulus());
     expanded_params.set_plain_modulus(expanded_plain_mod);
 
@@ -250,9 +250,8 @@ vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
 
 inline Ciphertext deserialize_ciphertext(string s) {
     Ciphertext c;
-    std::stringstream input(std::ios::binary | std::ios::in);
-    input.str(s);
-    c.load(input);
+    std::istringstream input(s);
+    c.unsafe_load(input);
     return c;
 }
 
@@ -265,7 +264,7 @@ vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t le
 }
 
 inline string serialize_ciphertext(Ciphertext c) {
-    std::stringstream output(std::ios::binary | std::ios::out);
+    std::ostringstream output;
     c.save(output);
     return output.str();
 }
@@ -279,15 +278,14 @@ string serialize_ciphertexts(vector<Ciphertext> c) {
 }
 
 string serialize_galoiskeys(GaloisKeys g) {
-    std::stringstream output(std::ios::binary | std::ios::out);
+    std::ostringstream output;
     g.save(output);
     return output.str();
 }
 
 GaloisKeys *deserialize_galoiskeys(string s) {
     GaloisKeys *g = new GaloisKeys();
-    std::stringstream input(std::ios::binary | std::ios::in);
-    input.str(s);
-    g->load(input);
+    std::istringstream input(s);
+    g->unsafe_load(input);
     return g;
 }

+ 3 - 3
pir.hpp

@@ -7,7 +7,7 @@
 #include <string>
 #include <vector>
 
-#define CIPHER_SIZE 32828
+#define CIPHER_SIZE 32841
 
 typedef std::vector<seal::Plaintext> Database;
 typedef std::vector<seal::Ciphertext> PirQuery;
@@ -60,7 +60,7 @@ void coeffs_to_bytes(std::uint32_t logtp, const seal::Plaintext &coeffs, std::ui
                      std::uint32_t size_out);
 
 // Takes a vector of coefficients and returns the corresponding FV plaintext
-void vector_to_plaintext(const vector<std::uint64_t> &coeffs, seal::Plaintext &plain);
+void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, seal::Plaintext &plain);
 
 // Since the database has d dimensions, and an item is a particular cell
 // in the d-dimensional hypercube, this function computes the corresponding
@@ -71,7 +71,7 @@ std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
 // Serialize and deserialize ciphertexts to send them over the network
 std::vector<seal::Ciphertext> deserialize_ciphertexts(std::uint32_t count, std::string s,
                                                       std::uint32_t len_ciphertext);
-std::string serialize_ciphertexts(vector<seal::Ciphertext> c);
+std::string serialize_ciphertexts(std::vector<seal::Ciphertext> c);
 
 // Serialize and deserialize galois keys to send them over the network
 std::string serialize_galoiskeys(seal::GaloisKeys g);

+ 36 - 34
pir_client.cpp

@@ -5,42 +5,43 @@ using namespace seal;
 using namespace seal::util;
 
 PIRClient::PIRClient(const EncryptionParameters &params,
-                     const EncryptionParameters &expanded_params, const PirParams &pir_parms) {
+                     const EncryptionParameters &expanded_params, 
+                     const PirParams &pir_parms) :
+    params_(params),
+    expanded_params_(expanded_params) {
 
-    params_ = params;
-    SEALContext context(params);
-
-    expanded_params_ = expanded_params;
-    SEALContext newcontext(expanded_params);
+    auto context = SEALContext::Create(params_);
+    newcontext_ = SEALContext::Create(expanded_params_);
 
     pir_params_ = pir_parms;
 
-    keygen_.reset(new KeyGenerator(context));
-    encryptor_.reset(new Encryptor(context, keygen_->public_key()));
+    keygen_ = make_unique<KeyGenerator>(context);
+    encryptor_ = make_unique<Encryptor>(context, keygen_->public_key());
 
     SecretKey secret_key = keygen_->secret_key();
-    secret_key.hash_block() = expanded_params.hash_block();
+    secret_key.parms_id() = expanded_params.parms_id();
 
-    decryptor_.reset(new Decryptor(newcontext, secret_key));
-    evaluator_.reset(new Evaluator(newcontext));
+    decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
+    evaluator_ = make_unique<Evaluator>(newcontext_);
 }
 
 void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
-                                  const PirParams &pir_params) {
+                                  const PirParams &pir_params) 
+{
 
     // The only thing that can change is the plaintext modulus and pir_params
-    assert(expanded_params.poly_modulus() == expanded_params_.poly_modulus());
+    assert(expanded_params.poly_modulus_degree() == expanded_params_.poly_modulus_degree());
     assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
 
     expanded_params_ = expanded_params;
     pir_params_ = pir_params;
-    SEALContext newcontext(expanded_params);
+    auto newcontext = SEALContext::Create(expanded_params);
 
     SecretKey secret_key = keygen_->secret_key();
-    secret_key.hash_block() = expanded_params.hash_block();
+    secret_key.parms_id() = expanded_params.parms_id();
 
-    decryptor_.reset(new Decryptor(newcontext, secret_key));
-    evaluator_.reset(new Evaluator(newcontext));
+    decryptor_ = make_unique<Decryptor>(newcontext, secret_key);
+    evaluator_ = make_unique<Evaluator>(newcontext);
 }
 
 PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
@@ -48,10 +49,13 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
     vector<uint64_t> indices = compute_indices(desiredIndex, pir_params_.nvec);
     vector<Ciphertext> result;
 
+    Plaintext pt(expanded_params_.poly_modulus_degree());
     for (uint32_t i = 0; i < indices.size(); i++) {
+        pt.set_zero();
+        pt[indices[i]] = 1;
         Ciphertext dest;
-        encryptor_->encrypt(Plaintext("1x^" + std::to_string(indices[i])), dest);
-        dest.hash_block() = expanded_params_.hash_block();
+        encryptor_->encrypt(pt, dest);
+        dest.parms_id() = expanded_params_.parms_id();
         result.push_back(dest);
     }
 
@@ -59,15 +63,15 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 }
 
 uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
-    uint32_t N = params_.poly_modulus().coeff_count() - 1;
-    uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
+    auto N = params_.poly_modulus_degree();
+    auto logtp = ceil(log2(expanded_params_.plain_modulus().value()));
 
-    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
-    return element_idx / ele_per_ptxt;
+    auto ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    return static_cast<uint64_t>(element_idx / ele_per_ptxt);
 }
 
 uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
-    uint32_t N = params_.poly_modulus().coeff_count() - 1;
+    uint32_t N = params_.poly_modulus_degree();
     uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
 
     uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
@@ -120,7 +124,7 @@ Plaintext PIRClient::decode_reply(PirReply reply) {
 GaloisKeys PIRClient::generate_galois_keys() {
     // Generate the Galois keys needed for coeff_select.
     vector<uint64_t> galois_elts;
-    int N = params_.poly_modulus().coeff_count() - 1;
+    int N = params_.poly_modulus_degree();
     int logN = get_power_of_two(N);
 
     for (int i = 0; i < logN; i++) {
@@ -130,19 +134,17 @@ GaloisKeys PIRClient::generate_galois_keys() {
 #endif
     }
 
-    GaloisKeys galois_keys;
-    keygen_->generate_galois_keys(pir_params_.dbc, galois_elts, galois_keys);
-    return galois_keys;
+    return keygen_->galois_keys(pir_params_.dbc, galois_elts);
 }
 
 Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
-    int encrypted_count = 2;
-    int coeff_count = expanded_params_.poly_modulus().coeff_count();
-    int coeff_mod_count = expanded_params_.coeff_modulus().size();
+    size_t encrypted_count = 2;
+    auto coeff_count = expanded_params_.poly_modulus_degree();
+    auto coeff_mod_count = expanded_params_.coeff_modulus().size();
     uint64_t plainMod = expanded_params_.plain_modulus().value();
 
-    Ciphertext result;
-    result.reserve(expanded_params_, encrypted_count);
+    Ciphertext result(newcontext_);
+    result.resize(encrypted_count);
 
     // A triple for loop. Going over polys, moduli, and decomposed index.
     for (int i = 0; i < encrypted_count; i++) {
@@ -188,6 +190,6 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
         }
     }
 
-    result.hash_block() = expanded_params_.hash_block();
+    result.parms_id() = expanded_params_.parms_id();
     return result;
 }

+ 1 - 0
pir_client.hpp

@@ -30,6 +30,7 @@ class PIRClient {
     std::unique_ptr<seal::Decryptor> decryptor_;
     std::unique_ptr<seal::Evaluator> evaluator_;
     std::unique_ptr<seal::KeyGenerator> keygen_;
+    std::shared_ptr<seal::SEALContext> newcontext_;
 
     seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);
 };

+ 61 - 65
pir_server.cpp

@@ -4,34 +4,31 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) {
-    expanded_params_ = expanded_params;
-    pir_params_ = pir_params;
-    SEALContext context(expanded_params);
-    evaluator_.reset(new Evaluator(context));
-    is_db_preprocessed_ = false;
-}
-
-PIRServer::~PIRServer() {
-    delete db_;
+PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) :
+    expanded_params_(expanded_params), 
+    pir_params_(pir_params),
+    is_db_preprocessed_(false)
+{
+    auto context = SEALContext::Create(expanded_params, false);
+    evaluator_ = make_unique<Evaluator>(context);
 }
 
 void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
                                   const PirParams &pir_params) {
 
     // The only thing that can change is the plaintext modulus and pir_params
-    assert(expanded_params.poly_modulus() == expanded_params_.poly_modulus());
+    assert(expanded_params.poly_modulus_degree() == expanded_params_.poly_modulus_degree());
     assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
 
     expanded_params_ = expanded_params;
     pir_params_ = pir_params;
-    SEALContext context(expanded_params);
-    evaluator_.reset(new Evaluator(context));
+    auto context = SEALContext::Create(expanded_params);
+    evaluator_ = make_unique<Evaluator>(context);
     is_db_preprocessed_ = false;
 
     // Update all the galois keys
     for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
-        key.second.hash_block() = expanded_params_.hash_block();
+        key.second.parms_id() = expanded_params_.parms_id();
     }
 }
 
@@ -39,7 +36,8 @@ void PIRServer::preprocess_database() {
     if (!is_db_preprocessed_) {
 
         for (uint32_t i = 0; i < db_->size(); i++) {
-            evaluator_->transform_to_ntt(db_->operator[](i));
+            evaluator_->transform_to_ntt_inplace(
+                db_->operator[](i), expanded_params_.parms_id());
         }
 
         is_db_preprocessed_ = true;
@@ -47,19 +45,20 @@ void PIRServer::preprocess_database() {
 }
 
 // Server takes over ownership of db and will free it when it exits
-void PIRServer::set_database(vector<Plaintext> *db) {
-    if (db == nullptr) {
+void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
+    if (!db) {
         throw invalid_argument("db cannot be null");
     }
 
-    db_ = db;
+    db_ = move(db);
     is_db_preprocessed_ = false;
 }
 
-void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t ele_size) {
+void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, 
+    uint64_t ele_num, uint64_t ele_size) {
 
     uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
-    uint32_t N = expanded_params_.poly_modulus().coeff_count() - 1;
+    uint32_t N = expanded_params_.poly_modulus_degree();
 
     // number of FV plaintexts needed to represent all elements
     uint64_t total = plaintexts_per_db(logtp, N, ele_num, ele_size);
@@ -72,7 +71,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
     uint64_t matrix_plaintexts = prod;
     assert(total <= matrix_plaintexts);
 
-    vector<Plaintext> *result = new vector<Plaintext>();
+    auto result = make_unique<vector<Plaintext>>();
     result->reserve(matrix_plaintexts);
 
     uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
@@ -98,7 +97,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
         }
 
         // Get the coefficients of the elements that will be packed in plaintext i
-        vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes + offset, process_bytes);
+        vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes.get() + offset, process_bytes);
         offset += process_bytes;
 
         uint64_t used = coefficients.size();
@@ -112,7 +111,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
 
         Plaintext plain;
         vector_to_plaintext(coefficients, plain);
-        result->push_back(plain);
+        result->push_back(move(plain));
     }
 
     // Add padding to make database a matrix
@@ -134,11 +133,11 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
         result->push_back(plain);
     }
 
-    set_database(result);
+    set_database(move(result));
 }
 
 void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
-    galkey.hash_block() = expanded_params_.hash_block();
+    galkey.parms_id() = expanded_params_.parms_id();
     galoisKeys_[client_id] = galkey;
 }
 
@@ -151,12 +150,12 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
         product *= nvec[i];
     }
 
-    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+    auto coeff_count = expanded_params_.poly_modulus_degree();
 
-    vector<Plaintext> *cur = db_;
+    vector<Plaintext> *cur = db_.get();
     vector<Plaintext> intermediate_plain; // decompose....
 
-    auto my_pool = MemoryPoolHandle::New();
+    auto pool = MemoryManager::GetPool();
 
     for (uint32_t i = 0; i < nvec.size(); i++) {
         uint64_t n_i = nvec[i];
@@ -164,13 +163,13 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 
         // Transform expanded query to NTT, and ...
         for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
-            evaluator_->transform_to_ntt(expanded_query[jj]);
+            evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
         }
 
         // Transform plaintext to NTT. If database is pre-processed, can skip
         if ((!is_db_preprocessed_) || i > 0) {
             for (uint32_t jj = 0; jj < cur->size(); jj++) {
-                evaluator_->transform_to_ntt((*cur)[jj]);
+                evaluator_->transform_to_ntt_inplace((*cur)[jj], expanded_params_.parms_id());
             }
         }
 
@@ -180,17 +179,16 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
         Ciphertext temp;
 
         for (uint64_t k = 0; k < product; k++) {
-            evaluator_->multiply_plain_ntt(expanded_query[0], (*cur)[k], intermediate[k]);
+            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediate[k]);
 
             for (uint64_t j = 1; j < n_i; j++) {
-                evaluator_->multiply_plain_ntt(expanded_query[j], (*cur)[k + j * product], temp);
-                evaluator_->add(intermediate[k],
-                                temp); // Adds to first component.
+                evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
+                evaluator_->add_inplace(intermediate[k], temp); // Adds to first component.
             }
         }
 
         for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
-            evaluator_->transform_from_ntt(intermediate[jj]);
+            evaluator_->transform_from_ntt_inplace(intermediate[jj]);
         }
 
         if (i == nvec.size() - 1) {
@@ -200,18 +198,18 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
             intermediate_plain.reserve(pir_params_.expansion_ratio * product);
             cur = &intermediate_plain;
 
-            util::Pointer tempplain_ptr(allocate_zero_poly(
-                pir_params_.expansion_ratio * product, coeff_count, my_pool));
+            auto tempplain = util::allocate<Plaintext>(
+                pir_params_.expansion_ratio * product,
+                pool, coeff_count);
 
             for (uint64_t rr = 0; rr < product; rr++) {
 
                 decompose_to_plaintexts_ptr(intermediate[rr],
-                                            tempplain_ptr.get() +
-                                                rr * pir_params_.expansion_ratio * coeff_count);
+                    tempplain.get() + rr * pir_params_.expansion_ratio);
 
                 for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
-                    int offset = rr * pir_params_.expansion_ratio * coeff_count + jj * coeff_count;
-                    intermediate_plain.emplace_back(coeff_count, tempplain_ptr.get() + offset);
+                    auto offset = rr * pir_params_.expansion_ratio + jj;
+                    intermediate_plain.emplace_back(tempplain[offset]);
                 }
             }
 
@@ -240,7 +238,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
     Plaintext two("2");
 
     vector<int> galois_elts;
-    int n = expanded_params_.poly_modulus().coeff_count() - 1;
+    auto n = expanded_params_.poly_modulus_degree();
 
     for (uint32_t i = 0; i < logm; i++) {
         galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
@@ -297,9 +295,9 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
                                     uint32_t index) {
 
-    int coeff_mod_count = expanded_params_.coeff_modulus().size();
-    int coeff_count = expanded_params_.poly_modulus().coeff_count();
-    int encrypted_count = encrypted.size();
+    auto coeff_mod_count = expanded_params_.coeff_modulus().size();
+    auto coeff_count = expanded_params_.poly_modulus_degree();
+    auto encrypted_count = encrypted.size();
 
     // First copy over.
     destination = encrypted;
@@ -316,12 +314,12 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
     }
 }
 
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, uint64_t *plain_ptr) {
+inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
 
     vector<Plaintext> result;
-    int coeff_count = expanded_params_.poly_modulus().coeff_count();
-    int coeff_mod_count = expanded_params_.coeff_modulus().size();
-    int encrypted_count = encrypted.size();
+    auto coeff_count = expanded_params_.poly_modulus_degree();
+    auto coeff_mod_count = expanded_params_.coeff_modulus().size();
+    auto encrypted_count = encrypted.size();
 
     // Generate powers of t.
     uint64_t plainModMinusOne = expanded_params_.plain_modulus().value() - 1;
@@ -344,9 +342,9 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
             for (int k = 0; k < expansion_ratio; k++) {
                 // Decompose here
                 for (int m = 0; m < coeff_count; m++) {
-                    *plain_ptr =
+                    plain_ptr[i * coeff_mod_count * expansion_ratio
+                        + j * expansion_ratio + k][m] =
                         (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
-                    plain_ptr++;
                 }
                 curexp += exp;
             }
@@ -356,10 +354,10 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
 
 vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
     vector<Plaintext> result;
-    int coeff_count = expanded_params_.poly_modulus().coeff_count();
-    int coeff_mod_count = expanded_params_.coeff_modulus().size();
-    int plain_bit_count = expanded_params_.plain_modulus().bit_count();
-    int encrypted_count = encrypted.size();
+    auto coeff_count = expanded_params_.poly_modulus_degree();
+    auto coeff_mod_count = expanded_params_.coeff_modulus().size();
+    auto plain_bit_count = expanded_params_.plain_modulus().bit_count();
+    auto encrypted_count = encrypted.size();
 
     // Generate powers of t.
     uint64_t plainMod = expanded_params_.plain_modulus().value();
@@ -379,16 +377,14 @@ vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted
             uint64_t cur = 1;
             for (int k = 0; k < expansion_ratio; k++) {
                 // Decompose here
-                BigPoly temp;
-                temp.resize(coeff_count, plain_bit_count);
-                temp.set_zero();
-                uint64_t *plain_coeff = temp.data();
-                for (int m = 0; m < coeff_count; m++) {
-                    *(plain_coeff + m) =
-                        (*(encrypted_pointer + m + (j * coeff_count)) / cur) % plainMod;
-                }
-
-                result.push_back(Plaintext(temp));
+                Plaintext temp(coeff_count);
+                transform(encrypted_pointer + (j * coeff_count), 
+                        encrypted_pointer + ((j + 1) * coeff_count), 
+                        temp.data(),
+                        [cur, &plainMod](auto &in) { return (in / cur) % plainMod; }
+                );
+
+                result.emplace_back(move(temp));
                 cur *= plainMod;
             }
         }

+ 6 - 7
pir_server.hpp

@@ -8,19 +8,18 @@
 class PIRServer {
   public:
     PIRServer(const seal::EncryptionParameters &expanded_params, const PirParams &pir_params);
-    ~PIRServer();
 
     void update_parameters(const seal::EncryptionParameters &expanded_params,
                            const PirParams &pir_params);
 
     // NOTE: server takes over ownership of db and frees it when it exits.
     // Caller cannot free db
-    void set_database(std::vector<seal::Plaintext> *db);
-    void set_database(const std::uint8_t *bytes, std::uint64_t ele_num, std::uint64_t ele_size);
+    void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
+    void set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, std::uint64_t ele_num, std::uint64_t ele_size);
     void preprocess_database();
 
-    std::vector<seal::Ciphertext> expand_query(const seal::Ciphertext &encrypted, std::uint32_t m,
-                                               uint32_t client_id);
+    std::vector<seal::Ciphertext> expand_query(
+            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
 
     PirReply generate_reply(PirQuery query, std::uint32_t client_id);
 
@@ -29,12 +28,12 @@ class PIRServer {
   private:
     seal::EncryptionParameters expanded_params_; // SEAL parameters
     PirParams pir_params_;                       // PIR parameters
-    Database *db_ = nullptr;
+    std::unique_ptr<Database> db_;
     bool is_db_preprocessed_;
     std::map<int, seal::GaloisKeys> galoisKeys_;
     std::unique_ptr<seal::Evaluator> evaluator_;
 
-    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, std::uint64_t *plain_ptr);
+    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr);
     std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);
     void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
                              std::uint32_t index);