|
@@ -4,40 +4,40 @@ using namespace std;
|
|
|
using namespace seal;
|
|
|
using namespace seal::util;
|
|
|
|
|
|
-PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) :
|
|
|
- expanded_params_(expanded_params),
|
|
|
+PIRServer::PIRServer(const EncryptionParameters ¶ms, const PirParams &pir_params) :
|
|
|
+ params_(params),
|
|
|
pir_params_(pir_params),
|
|
|
is_db_preprocessed_(false)
|
|
|
{
|
|
|
- auto context = SEALContext::Create(expanded_params, false);
|
|
|
+ auto context = SEALContext::Create(params, false);
|
|
|
evaluator_ = make_unique<Evaluator>(context);
|
|
|
}
|
|
|
|
|
|
-void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
|
|
|
- const PirParams &pir_params) {
|
|
|
+// void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
|
|
|
+// const PirParams &pir_params) {
|
|
|
|
|
|
- // The only thing that can change is the plaintext modulus and pir_params
|
|
|
- assert(expanded_params.poly_modulus_degree() == expanded_params_.poly_modulus_degree());
|
|
|
- assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
|
|
|
+// // The only thing that can change is the plaintext modulus and pir_params
|
|
|
+// assert(expanded_params.poly_modulus_degree() == params_.poly_modulus_degree());
|
|
|
+// assert(expanded_params.coeff_modulus() == params_.coeff_modulus());
|
|
|
|
|
|
- expanded_params_ = expanded_params;
|
|
|
- pir_params_ = pir_params;
|
|
|
- auto context = SEALContext::Create(expanded_params);
|
|
|
- evaluator_ = make_unique<Evaluator>(context);
|
|
|
- is_db_preprocessed_ = false;
|
|
|
+// params_ = expanded_params;
|
|
|
+// pir_params_ = pir_params;
|
|
|
+// auto context = SEALContext::Create(expanded_params);
|
|
|
+// evaluator_ = make_unique<Evaluator>(context);
|
|
|
+// is_db_preprocessed_ = false;
|
|
|
|
|
|
- // Update all the galois keys
|
|
|
- for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
|
|
|
- key.second.parms_id() = expanded_params_.parms_id();
|
|
|
- }
|
|
|
-}
|
|
|
+// // Update all the galois keys
|
|
|
+// for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
|
|
|
+// key.second.parms_id() = params_.parms_id();
|
|
|
+// }
|
|
|
+// }
|
|
|
|
|
|
void PIRServer::preprocess_database() {
|
|
|
if (!is_db_preprocessed_) {
|
|
|
|
|
|
for (uint32_t i = 0; i < db_->size(); i++) {
|
|
|
evaluator_->transform_to_ntt_inplace(
|
|
|
- db_->operator[](i), expanded_params_.parms_id());
|
|
|
+ db_->operator[](i), params_.parms_id());
|
|
|
}
|
|
|
|
|
|
is_db_preprocessed_ = true;
|
|
@@ -57,11 +57,11 @@ void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
|
|
|
void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
uint64_t ele_num, uint64_t ele_size) {
|
|
|
|
|
|
- uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
|
|
|
- uint32_t N = expanded_params_.poly_modulus_degree();
|
|
|
+ uint32_t logt = floor(log2(params_.plain_modulus().value()));
|
|
|
+ uint32_t N = params_.poly_modulus_degree();
|
|
|
|
|
|
// number of FV plaintexts needed to represent all elements
|
|
|
- uint64_t total = plaintexts_per_db(logtp, N, ele_num, ele_size);
|
|
|
+ uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
|
|
|
|
|
|
// number of FV plaintexts needed to create the d-dimensional matrix
|
|
|
uint64_t prod = 1;
|
|
@@ -74,12 +74,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
auto result = make_unique<vector<Plaintext>>();
|
|
|
result->reserve(matrix_plaintexts);
|
|
|
|
|
|
- uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
|
|
|
+ uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
|
|
|
uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
|
|
|
|
|
|
uint64_t db_size = ele_num * ele_size;
|
|
|
|
|
|
- uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logtp, ele_size);
|
|
|
+ uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
|
|
|
assert(coeff_per_ptxt <= N);
|
|
|
|
|
|
uint32_t offset = 0;
|
|
@@ -97,7 +97,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
}
|
|
|
|
|
|
// Get the coefficients of the elements that will be packed in plaintext i
|
|
|
- vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes.get() + offset, process_bytes);
|
|
|
+ vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
|
|
|
offset += process_bytes;
|
|
|
|
|
|
uint64_t used = coefficients.size();
|
|
@@ -137,7 +137,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
}
|
|
|
|
|
|
void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
|
|
|
- galkey.parms_id() = expanded_params_.parms_id();
|
|
|
+ galkey.parms_id() = params_.parms_id();
|
|
|
galoisKeys_[client_id] = galkey;
|
|
|
}
|
|
|
|
|
@@ -150,7 +150,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
product *= nvec[i];
|
|
|
}
|
|
|
|
|
|
- auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
+ auto coeff_count = params_.poly_modulus_degree();
|
|
|
|
|
|
vector<Plaintext> *cur = db_.get();
|
|
|
vector<Plaintext> intermediate_plain; // decompose....
|
|
@@ -169,7 +169,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
// Transform plaintext to NTT. If database is pre-processed, can skip
|
|
|
if ((!is_db_preprocessed_) || i > 0) {
|
|
|
for (uint32_t jj = 0; jj < cur->size(); jj++) {
|
|
|
- evaluator_->transform_to_ntt_inplace((*cur)[jj], expanded_params_.parms_id());
|
|
|
+ evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -227,7 +227,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
uint32_t client_id) {
|
|
|
|
|
|
#ifdef DEBUG
|
|
|
- uint64_t plainMod = expanded_params_.plain_modulus().value();
|
|
|
+ uint64_t plainMod = params_.plain_modulus().value();
|
|
|
cout << "PIRServer side plain modulus = " << plainMod << endl;
|
|
|
#endif
|
|
|
|
|
@@ -238,7 +238,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
Plaintext two("2");
|
|
|
|
|
|
vector<int> galois_elts;
|
|
|
- auto n = expanded_params_.poly_modulus_degree();
|
|
|
+ auto n = params_.poly_modulus_degree();
|
|
|
|
|
|
for (uint32_t i = 0; i < logm; i++) {
|
|
|
galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
|
|
@@ -295,8 +295,8 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
|
|
|
uint32_t index) {
|
|
|
|
|
|
- auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
- auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
+ auto coeff_mod_count = params_.coeff_modulus().size();
|
|
|
+ auto coeff_count = params_.poly_modulus_degree();
|
|
|
auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// First copy over.
|
|
@@ -308,7 +308,7 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
|
|
|
for (int j = 0; j < coeff_mod_count; j++) {
|
|
|
negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
|
|
|
coeff_count - 1, index,
|
|
|
- expanded_params_.coeff_modulus()[j],
|
|
|
+ params_.coeff_modulus()[j],
|
|
|
destination.data(i) + (j * coeff_count));
|
|
|
}
|
|
|
}
|
|
@@ -317,12 +317,12 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
|
|
|
inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
|
|
|
|
|
|
vector<Plaintext> result;
|
|
|
- auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
- auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
+ auto coeff_count = params_.poly_modulus_degree();
|
|
|
+ auto coeff_mod_count = params_.coeff_modulus().size();
|
|
|
auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// Generate powers of t.
|
|
|
- uint64_t plainModMinusOne = expanded_params_.plain_modulus().value() - 1;
|
|
|
+ uint64_t plainModMinusOne = params_.plain_modulus().value() - 1;
|
|
|
int exp = ceil(log2(plainModMinusOne + 1));
|
|
|
|
|
|
// A triple for loop. Going over polys, moduli, and decomposed index.
|
|
@@ -334,7 +334,7 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
|
|
|
// create a polynomial to store the current decomposition value
|
|
|
// which will be copied into the array to populate it at the current
|
|
|
// index.
|
|
|
- int logqj = log2(expanded_params_.coeff_modulus()[j].value());
|
|
|
+ int logqj = log2(params_.coeff_modulus()[j].value());
|
|
|
int expansion_ratio = ceil(logqj + exp - 1) / exp;
|
|
|
|
|
|
// cout << "expansion ratio = " << expansion_ratio << endl;
|
|
@@ -354,13 +354,13 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
|
|
|
|
|
|
vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
|
|
|
vector<Plaintext> result;
|
|
|
- auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
- auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
- auto plain_bit_count = expanded_params_.plain_modulus().bit_count();
|
|
|
+ auto coeff_count = params_.poly_modulus_degree();
|
|
|
+ auto coeff_mod_count = params_.coeff_modulus().size();
|
|
|
+ auto plain_bit_count = params_.plain_modulus().bit_count();
|
|
|
auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// Generate powers of t.
|
|
|
- uint64_t plainMod = expanded_params_.plain_modulus().value();
|
|
|
+ uint64_t plainMod = params_.plain_modulus().value();
|
|
|
|
|
|
// A triple for loop. Going over polys, moduli, and decomposed index.
|
|
|
for (int i = 0; i < encrypted_count; i++) {
|
|
@@ -370,7 +370,7 @@ vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted
|
|
|
// create a polynomial to store the current decomposition value
|
|
|
// which will be copied into the array to populate it at the current
|
|
|
// index.
|
|
|
- int logqj = log2(expanded_params_.coeff_modulus()[j].value());
|
|
|
+ int logqj = log2(params_.coeff_modulus()[j].value());
|
|
|
int expansion_ratio = ceil(logqj / log2(plainMod));
|
|
|
|
|
|
// cout << "expansion ratio = " << expansion_ratio << endl;
|