|
@@ -4,34 +4,31 @@ using namespace std;
|
|
|
using namespace seal;
|
|
|
using namespace seal::util;
|
|
|
|
|
|
-PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) {
|
|
|
- expanded_params_ = expanded_params;
|
|
|
- pir_params_ = pir_params;
|
|
|
- SEALContext context(expanded_params);
|
|
|
- evaluator_.reset(new Evaluator(context));
|
|
|
- is_db_preprocessed_ = false;
|
|
|
-}
|
|
|
-
|
|
|
-PIRServer::~PIRServer() {
|
|
|
- delete db_;
|
|
|
+PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) :
|
|
|
+ expanded_params_(expanded_params),
|
|
|
+ pir_params_(pir_params),
|
|
|
+ is_db_preprocessed_(false)
|
|
|
+{
|
|
|
+ auto context = SEALContext::Create(expanded_params, false);
|
|
|
+ evaluator_ = make_unique<Evaluator>(context);
|
|
|
}
|
|
|
|
|
|
void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
|
|
|
const PirParams &pir_params) {
|
|
|
|
|
|
// The only thing that can change is the plaintext modulus and pir_params
|
|
|
- assert(expanded_params.poly_modulus() == expanded_params_.poly_modulus());
|
|
|
+ assert(expanded_params.poly_modulus_degree() == expanded_params_.poly_modulus_degree());
|
|
|
assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
|
|
|
|
|
|
expanded_params_ = expanded_params;
|
|
|
pir_params_ = pir_params;
|
|
|
- SEALContext context(expanded_params);
|
|
|
- evaluator_.reset(new Evaluator(context));
|
|
|
+ auto context = SEALContext::Create(expanded_params);
|
|
|
+ evaluator_ = make_unique<Evaluator>(context);
|
|
|
is_db_preprocessed_ = false;
|
|
|
|
|
|
// Update all the galois keys
|
|
|
for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
|
|
|
- key.second.hash_block() = expanded_params_.hash_block();
|
|
|
+ key.second.parms_id() = expanded_params_.parms_id();
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -39,7 +36,8 @@ void PIRServer::preprocess_database() {
|
|
|
if (!is_db_preprocessed_) {
|
|
|
|
|
|
for (uint32_t i = 0; i < db_->size(); i++) {
|
|
|
- evaluator_->transform_to_ntt(db_->operator[](i));
|
|
|
+ evaluator_->transform_to_ntt_inplace(
|
|
|
+ db_->operator[](i), expanded_params_.parms_id());
|
|
|
}
|
|
|
|
|
|
is_db_preprocessed_ = true;
|
|
@@ -47,19 +45,20 @@ void PIRServer::preprocess_database() {
|
|
|
}
|
|
|
|
|
|
// Server takes over ownership of db and will free it when it exits
|
|
|
-void PIRServer::set_database(vector<Plaintext> *db) {
|
|
|
- if (db == nullptr) {
|
|
|
+void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
|
|
|
+ if (!db) {
|
|
|
throw invalid_argument("db cannot be null");
|
|
|
}
|
|
|
|
|
|
- db_ = db;
|
|
|
+ db_ = move(db);
|
|
|
is_db_preprocessed_ = false;
|
|
|
}
|
|
|
|
|
|
-void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t ele_size) {
|
|
|
+void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
+ uint64_t ele_num, uint64_t ele_size) {
|
|
|
|
|
|
uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
|
|
|
- uint32_t N = expanded_params_.poly_modulus().coeff_count() - 1;
|
|
|
+ uint32_t N = expanded_params_.poly_modulus_degree();
|
|
|
|
|
|
// number of FV plaintexts needed to represent all elements
|
|
|
uint64_t total = plaintexts_per_db(logtp, N, ele_num, ele_size);
|
|
@@ -72,7 +71,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
|
|
|
uint64_t matrix_plaintexts = prod;
|
|
|
assert(total <= matrix_plaintexts);
|
|
|
|
|
|
- vector<Plaintext> *result = new vector<Plaintext>();
|
|
|
+ auto result = make_unique<vector<Plaintext>>();
|
|
|
result->reserve(matrix_plaintexts);
|
|
|
|
|
|
uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
|
|
@@ -98,7 +97,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
|
|
|
}
|
|
|
|
|
|
// Get the coefficients of the elements that will be packed in plaintext i
|
|
|
- vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes + offset, process_bytes);
|
|
|
+ vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes.get() + offset, process_bytes);
|
|
|
offset += process_bytes;
|
|
|
|
|
|
uint64_t used = coefficients.size();
|
|
@@ -112,7 +111,7 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
|
|
|
|
|
|
Plaintext plain;
|
|
|
vector_to_plaintext(coefficients, plain);
|
|
|
- result->push_back(plain);
|
|
|
+ result->push_back(move(plain));
|
|
|
}
|
|
|
|
|
|
// Add padding to make database a matrix
|
|
@@ -134,11 +133,11 @@ void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t el
|
|
|
result->push_back(plain);
|
|
|
}
|
|
|
|
|
|
- set_database(result);
|
|
|
+ set_database(move(result));
|
|
|
}
|
|
|
|
|
|
void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
|
|
|
- galkey.hash_block() = expanded_params_.hash_block();
|
|
|
+ galkey.parms_id() = expanded_params_.parms_id();
|
|
|
galoisKeys_[client_id] = galkey;
|
|
|
}
|
|
|
|
|
@@ -151,12 +150,12 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
product *= nvec[i];
|
|
|
}
|
|
|
|
|
|
- int coeff_count = expanded_params_.poly_modulus().coeff_count();
|
|
|
+ auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
|
|
|
- vector<Plaintext> *cur = db_;
|
|
|
+ vector<Plaintext> *cur = db_.get();
|
|
|
vector<Plaintext> intermediate_plain; // decompose....
|
|
|
|
|
|
- auto my_pool = MemoryPoolHandle::New();
|
|
|
+ auto pool = MemoryManager::GetPool();
|
|
|
|
|
|
for (uint32_t i = 0; i < nvec.size(); i++) {
|
|
|
uint64_t n_i = nvec[i];
|
|
@@ -164,13 +163,13 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
|
|
|
// Transform expanded query to NTT, and ...
|
|
|
for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
|
|
|
- evaluator_->transform_to_ntt(expanded_query[jj]);
|
|
|
+ evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
|
|
|
}
|
|
|
|
|
|
// Transform plaintext to NTT. If database is pre-processed, can skip
|
|
|
if ((!is_db_preprocessed_) || i > 0) {
|
|
|
for (uint32_t jj = 0; jj < cur->size(); jj++) {
|
|
|
- evaluator_->transform_to_ntt((*cur)[jj]);
|
|
|
+ evaluator_->transform_to_ntt_inplace((*cur)[jj], expanded_params_.parms_id());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -180,17 +179,16 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
Ciphertext temp;
|
|
|
|
|
|
for (uint64_t k = 0; k < product; k++) {
|
|
|
- evaluator_->multiply_plain_ntt(expanded_query[0], (*cur)[k], intermediate[k]);
|
|
|
+ evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediate[k]);
|
|
|
|
|
|
for (uint64_t j = 1; j < n_i; j++) {
|
|
|
- evaluator_->multiply_plain_ntt(expanded_query[j], (*cur)[k + j * product], temp);
|
|
|
- evaluator_->add(intermediate[k],
|
|
|
- temp); // Adds to first component.
|
|
|
+ evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
|
|
|
+ evaluator_->add_inplace(intermediate[k], temp); // Adds to first component.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
|
|
|
- evaluator_->transform_from_ntt(intermediate[jj]);
|
|
|
+ evaluator_->transform_from_ntt_inplace(intermediate[jj]);
|
|
|
}
|
|
|
|
|
|
if (i == nvec.size() - 1) {
|
|
@@ -200,18 +198,18 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
intermediate_plain.reserve(pir_params_.expansion_ratio * product);
|
|
|
cur = &intermediate_plain;
|
|
|
|
|
|
- util::Pointer tempplain_ptr(allocate_zero_poly(
|
|
|
- pir_params_.expansion_ratio * product, coeff_count, my_pool));
|
|
|
+ auto tempplain = util::allocate<Plaintext>(
|
|
|
+ pir_params_.expansion_ratio * product,
|
|
|
+ pool, coeff_count);
|
|
|
|
|
|
for (uint64_t rr = 0; rr < product; rr++) {
|
|
|
|
|
|
decompose_to_plaintexts_ptr(intermediate[rr],
|
|
|
- tempplain_ptr.get() +
|
|
|
- rr * pir_params_.expansion_ratio * coeff_count);
|
|
|
+ tempplain.get() + rr * pir_params_.expansion_ratio);
|
|
|
|
|
|
for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
|
|
|
- int offset = rr * pir_params_.expansion_ratio * coeff_count + jj * coeff_count;
|
|
|
- intermediate_plain.emplace_back(coeff_count, tempplain_ptr.get() + offset);
|
|
|
+ auto offset = rr * pir_params_.expansion_ratio + jj;
|
|
|
+ intermediate_plain.emplace_back(tempplain[offset]);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -240,7 +238,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
Plaintext two("2");
|
|
|
|
|
|
vector<int> galois_elts;
|
|
|
- int n = expanded_params_.poly_modulus().coeff_count() - 1;
|
|
|
+ auto n = expanded_params_.poly_modulus_degree();
|
|
|
|
|
|
for (uint32_t i = 0; i < logm; i++) {
|
|
|
galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
|
|
@@ -297,9 +295,9 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
|
|
|
uint32_t index) {
|
|
|
|
|
|
- int coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
- int coeff_count = expanded_params_.poly_modulus().coeff_count();
|
|
|
- int encrypted_count = encrypted.size();
|
|
|
+ auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
+ auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
+ auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// First copy over.
|
|
|
destination = encrypted;
|
|
@@ -316,12 +314,12 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, uint64_t *plain_ptr) {
|
|
|
+inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
|
|
|
|
|
|
vector<Plaintext> result;
|
|
|
- int coeff_count = expanded_params_.poly_modulus().coeff_count();
|
|
|
- int coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
- int encrypted_count = encrypted.size();
|
|
|
+ auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
+ auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
+ auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// Generate powers of t.
|
|
|
uint64_t plainModMinusOne = expanded_params_.plain_modulus().value() - 1;
|
|
@@ -344,9 +342,9 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
|
|
|
for (int k = 0; k < expansion_ratio; k++) {
|
|
|
// Decompose here
|
|
|
for (int m = 0; m < coeff_count; m++) {
|
|
|
- *plain_ptr =
|
|
|
+ plain_ptr[i * coeff_mod_count * expansion_ratio
|
|
|
+ + j * expansion_ratio + k][m] =
|
|
|
(*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
|
|
|
- plain_ptr++;
|
|
|
}
|
|
|
curexp += exp;
|
|
|
}
|
|
@@ -356,10 +354,10 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
|
|
|
|
|
|
vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
|
|
|
vector<Plaintext> result;
|
|
|
- int coeff_count = expanded_params_.poly_modulus().coeff_count();
|
|
|
- int coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
- int plain_bit_count = expanded_params_.plain_modulus().bit_count();
|
|
|
- int encrypted_count = encrypted.size();
|
|
|
+ auto coeff_count = expanded_params_.poly_modulus_degree();
|
|
|
+ auto coeff_mod_count = expanded_params_.coeff_modulus().size();
|
|
|
+ auto plain_bit_count = expanded_params_.plain_modulus().bit_count();
|
|
|
+ auto encrypted_count = encrypted.size();
|
|
|
|
|
|
// Generate powers of t.
|
|
|
uint64_t plainMod = expanded_params_.plain_modulus().value();
|
|
@@ -379,16 +377,14 @@ vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted
|
|
|
uint64_t cur = 1;
|
|
|
for (int k = 0; k < expansion_ratio; k++) {
|
|
|
// Decompose here
|
|
|
- BigPoly temp;
|
|
|
- temp.resize(coeff_count, plain_bit_count);
|
|
|
- temp.set_zero();
|
|
|
- uint64_t *plain_coeff = temp.data();
|
|
|
- for (int m = 0; m < coeff_count; m++) {
|
|
|
- *(plain_coeff + m) =
|
|
|
- (*(encrypted_pointer + m + (j * coeff_count)) / cur) % plainMod;
|
|
|
- }
|
|
|
-
|
|
|
- result.push_back(Plaintext(temp));
|
|
|
+ Plaintext temp(coeff_count);
|
|
|
+ transform(encrypted_pointer + (j * coeff_count),
|
|
|
+ encrypted_pointer + ((j + 1) * coeff_count),
|
|
|
+ temp.data(),
|
|
|
+ [cur, &plainMod](auto &in) { return (in / cur) % plainMod; }
|
|
|
+ );
|
|
|
+
|
|
|
+ result.emplace_back(move(temp));
|
|
|
cur *= plainMod;
|
|
|
}
|
|
|
}
|