|
@@ -10,8 +10,8 @@ PIRServer::PIRServer(const EncryptionParameters ¶ms, const PirParams &pir_pa
|
|
|
pir_params_(pir_params),
|
|
|
is_db_preprocessed_(false)
|
|
|
{
|
|
|
- auto context = SEALContext::Create(params, false);
|
|
|
- evaluator_ = make_unique<Evaluator>(context);
|
|
|
+ context_ = make_shared<SEALContext>(params, true);
|
|
|
+ evaluator_ = make_unique<Evaluator>(*context_);
|
|
|
}
|
|
|
|
|
|
void PIRServer::preprocess_database() {
|
|
@@ -19,7 +19,7 @@ void PIRServer::preprocess_database() {
|
|
|
|
|
|
for (uint32_t i = 0; i < db_->size(); i++) {
|
|
|
evaluator_->transform_to_ntt_inplace(
|
|
|
- db_->operator[](i), params_.parms_id());
|
|
|
+ db_->operator[](i), context_->first_parms_id());
|
|
|
}
|
|
|
|
|
|
is_db_preprocessed_ = true;
|
|
@@ -42,6 +42,9 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
uint32_t logt = floor(log2(params_.plain_modulus().value()));
|
|
|
uint32_t N = params_.poly_modulus_degree();
|
|
|
|
|
|
+ cout << "logt: " << logt << endl << "N: " << N << endl <<
|
|
|
+ "ele_num: " << ele_num << endl << "ele_size: " << ele_size << endl;
|
|
|
+
|
|
|
// number of FV plaintexts needed to represent all elements
|
|
|
uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
|
|
|
|
|
@@ -51,6 +54,9 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
prod *= pir_params_.nvec[i];
|
|
|
}
|
|
|
uint64_t matrix_plaintexts = prod;
|
|
|
+ cout << "Total:" << total << endl << "Prod: "
|
|
|
+ << matrix_plaintexts << endl;
|
|
|
+
|
|
|
assert(total <= matrix_plaintexts);
|
|
|
|
|
|
auto result = make_unique<vector<Plaintext>>();
|
|
@@ -123,11 +129,10 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
}
|
|
|
|
|
|
void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
|
|
|
- galkey.parms_id() = params_.parms_id();
|
|
|
galoisKeys_[client_id] = galkey;
|
|
|
}
|
|
|
|
|
|
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIRClient& client) {
|
|
|
|
|
|
vector<uint64_t> nvec = pir_params_.nvec;
|
|
|
uint64_t product = 1;
|
|
@@ -161,10 +166,10 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
for (uint32_t j = 0; j < query[i].size(); j++){
|
|
|
uint64_t total = N;
|
|
|
if (j == query[i].size() - 1){
|
|
|
- total = n_i % N;
|
|
|
+ total = ((n_i - 1) % N) + 1;
|
|
|
}
|
|
|
cout << "-- expanding one query ctxt into " << total << " ctxts "<< endl;
|
|
|
- vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
|
|
|
+ vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
|
|
|
expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()),
|
|
|
std::make_move_iterator(expanded_query_part.end()));
|
|
|
expanded_query_part.clear();
|
|
@@ -174,16 +179,20 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl;
|
|
|
}
|
|
|
|
|
|
- /*
|
|
|
+
|
|
|
cout << "Checking expanded query " << endl;
|
|
|
Plaintext tempPt;
|
|
|
for (int h = 0 ; h < expanded_query.size(); h++){
|
|
|
- cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", ";
|
|
|
client.decryptor_->decrypt(expanded_query[h], tempPt);
|
|
|
+ if(tempPt.is_zero()){
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ cout << "index: " << h << ", ";
|
|
|
+ cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", ";
|
|
|
cout << tempPt.to_string() << endl;
|
|
|
}
|
|
|
cout << endl;
|
|
|
- */
|
|
|
+
|
|
|
|
|
|
// Transform expanded query to NTT, and ...
|
|
|
for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
|
|
@@ -193,7 +202,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
// Transform plaintext to NTT. If database is pre-processed, can skip
|
|
|
if ((!is_db_preprocessed_) || i > 0) {
|
|
|
for (uint32_t jj = 0; jj < cur->size(); jj++) {
|
|
|
- evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
|
|
|
+ evaluator_->transform_to_ntt_inplace((*cur)[jj], context_->first_parms_id());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -257,8 +266,19 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
|
|
|
return fail;
|
|
|
}
|
|
|
|
|
|
+Ciphertext PIRServer::generate_public_reply(Ciphertext one_ct, std::uint64_t desiredIndex){
|
|
|
+ vector<Plaintext> *cur = db_.get();
|
|
|
+ Ciphertext result;
|
|
|
+ evaluator_->transform_to_ntt_inplace(one_ct);
|
|
|
+ cout << "transformed" << endl;
|
|
|
+ evaluator_->multiply_plain(one_ct, (*cur)[desiredIndex], result);
|
|
|
+ cout << "reply generated" << endl;
|
|
|
+ evaluator_->transform_from_ntt_inplace(result);
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
|
|
|
- uint32_t client_id) {
|
|
|
+ uint32_t client_id, const PIRClient& client) {
|
|
|
|
|
|
#ifdef DEBUG
|
|
|
uint64_t plainMod = params_.plain_modulus().value();
|
|
@@ -277,7 +297,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
throw logic_error("m > n is not allowed.");
|
|
|
}
|
|
|
for (int i = 0; i < ceil(log2(n)); i++) {
|
|
|
- galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
|
|
|
+ galois_elts.push_back((n + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
|
|
|
}
|
|
|
|
|
|
vector<Ciphertext> temp;
|
|
@@ -344,13 +364,19 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
vector<Ciphertext>::const_iterator first = newtemp.begin();
|
|
|
vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
|
|
|
vector<Ciphertext> newVec(first, last);
|
|
|
+
|
|
|
+ for(Ciphertext c: newVec){
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
return newVec;
|
|
|
}
|
|
|
|
|
|
inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
|
|
|
uint32_t index) {
|
|
|
|
|
|
- auto coeff_mod_count = params_.coeff_modulus().size();
|
|
|
+ auto coeff_mod_count = params_.coeff_modulus().size() - 1;
|
|
|
auto coeff_count = params_.poly_modulus_degree();
|
|
|
auto encrypted_count = encrypted.size();
|
|
|
|