|
@@ -83,6 +83,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
|
|
|
assert(coeff_per_ptxt <= N);
|
|
|
|
|
|
+ cout << "Server: total number of FV plaintext = " << total << endl;
|
|
|
+
|
|
|
+ cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl;
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
uint32_t offset = 0;
|
|
|
|
|
|
for (uint64_t i = 0; i < total; i++) {
|
|
@@ -96,7 +102,6 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
} else {
|
|
|
process_bytes = bytes_per_ptxt;
|
|
|
}
|
|
|
-
|
|
|
// Get the coefficients of the elements that will be packed in plaintext i
|
|
|
vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
|
|
|
offset += process_bytes;
|
|
@@ -112,7 +117,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
|
|
|
|
|
|
Plaintext plain;
|
|
|
vector_to_plaintext(coefficients, plain);
|
|
|
- cout << i << "-th encoded plaintext = " << plain.to_string() << endl;
|
|
|
+ // cout << i << "-th encoded plaintext = " << plain.to_string() << endl;
|
|
|
result->push_back(move(plain));
|
|
|
}
|
|
|
|
|
@@ -159,10 +164,39 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
|
|
|
|
|
|
auto pool = MemoryManager::GetPool();
|
|
|
|
|
|
+
|
|
|
+ int N = params_.poly_modulus_degree();
|
|
|
+
|
|
|
+ int logt = floor(log2(params_.plain_modulus().value()));
|
|
|
+
|
|
|
+ cout << "expansion ratio = " << pir_params_.expansion_ratio << endl;
|
|
|
for (uint32_t i = 0; i < nvec.size(); i++) {
|
|
|
+ cout << "Server: " << i + 1 << "-th recursion level started " << endl;
|
|
|
+
|
|
|
+
|
|
|
+ vector<Ciphertext> expanded_query;
|
|
|
+
|
|
|
uint64_t n_i = nvec[i];
|
|
|
- vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id, client);
|
|
|
- cout << "Checking expanded query ";
|
|
|
+ cout << "Server: n_i = " << n_i << endl;
|
|
|
+ cout << "Server: expanding " << query[i].size() << "query ctxts" << endl;
|
|
|
+ for (uint32_t j = 0; j < query[i].size(); j++){
|
|
|
+ uint64_t total = N;
|
|
|
+ if (j == query[i].size() - 1){
|
|
|
+ total = n_i % N;
|
|
|
+ }
|
|
|
+ cout << "-- expanding one query ctxt into " << total << " ctxts "<< endl;
|
|
|
+ vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
|
|
|
+ expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()),
|
|
|
+ std::make_move_iterator(expanded_query_part.end()));
|
|
|
+ expanded_query_part.clear();
|
|
|
+ }
|
|
|
+ cout << "Server: expansion done " << endl;
|
|
|
+ if (expanded_query.size() != n_i) {
|
|
|
+ cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl;
|
|
|
+ }
|
|
|
+
|
|
|
+ /*
|
|
|
+ cout << "Checking expanded query " << endl;
|
|
|
Plaintext tempPt;
|
|
|
for (int h = 0 ; h < expanded_query.size(); h++){
|
|
|
cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", ";
|
|
@@ -170,6 +204,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
|
|
|
cout << tempPt.to_string() << endl;
|
|
|
}
|
|
|
cout << endl;
|
|
|
+ */
|
|
|
|
|
|
// Transform expanded query to NTT, and ...
|
|
|
for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
|
|
@@ -183,26 +218,40 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ for (uint64_t k = 0; k < product; k++) {
|
|
|
+ if ((*cur)[k].is_zero()){
|
|
|
+ cout << k + 1 << "/ " << product << "-th ptxt = 0 " << endl;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
product /= n_i;
|
|
|
|
|
|
- vector<Ciphertext> intermediate(product);
|
|
|
+ vector<Ciphertext> intermediateCtxts(product);
|
|
|
Ciphertext temp;
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
for (uint64_t k = 0; k < product; k++) {
|
|
|
- evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediate[k]);
|
|
|
+
|
|
|
+ evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
|
|
|
|
|
|
for (uint64_t j = 1; j < n_i; j++) {
|
|
|
evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
|
|
|
- evaluator_->add_inplace(intermediate[k], temp); // Adds to first component.
|
|
|
+ evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
|
|
|
- evaluator_->transform_from_ntt_inplace(intermediate[jj]);
|
|
|
+ for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
|
|
|
+ evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
|
|
|
+ // print intermediate ctxts?
|
|
|
+ //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl;
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+
|
|
|
if (i == nvec.size() - 1) {
|
|
|
- return intermediate;
|
|
|
+ return intermediateCtxts;
|
|
|
} else {
|
|
|
intermediate_plain.clear();
|
|
|
intermediate_plain.reserve(pir_params_.expansion_ratio * product);
|
|
@@ -214,19 +263,20 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
|
|
|
|
|
|
for (uint64_t rr = 0; rr < product; rr++) {
|
|
|
|
|
|
- decompose_to_plaintexts_ptr(intermediate[rr],
|
|
|
- tempplain.get() + rr * pir_params_.expansion_ratio);
|
|
|
+ decompose_to_plaintexts_ptr(intermediateCtxts[rr],
|
|
|
+ tempplain.get() + rr * pir_params_.expansion_ratio, logt);
|
|
|
|
|
|
for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
|
|
|
auto offset = rr * pir_params_.expansion_ratio + jj;
|
|
|
intermediate_plain.emplace_back(tempplain[offset]);
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
product *= pir_params_.expansion_ratio; // multiply by expansion rate.
|
|
|
}
|
|
|
+ cout << "Server: " << i + 1 << "-th recursion level finished " << endl;
|
|
|
+ cout << endl;
|
|
|
}
|
|
|
-
|
|
|
+ cout << "reply generated! " << endl;
|
|
|
// This should never get here
|
|
|
assert(0);
|
|
|
vector<Ciphertext> fail(1);
|
|
@@ -252,13 +302,9 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
if (logm > ceil(log2(n))){
|
|
|
throw logic_error("m > n is not allowed.");
|
|
|
}
|
|
|
-
|
|
|
- cout << "galois elts at server: ";
|
|
|
- for (uint32_t i = 0; i < logm; i++) {
|
|
|
+ for (int i = 0; i < ceil(log2(n)); i++) {
|
|
|
galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
|
|
|
- cout << galois_elts.back() << ", ";
|
|
|
}
|
|
|
- cout << endl;
|
|
|
|
|
|
vector<Ciphertext> temp;
|
|
|
temp.push_back(encrypted);
|
|
@@ -267,44 +313,45 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
Ciphertext tempctxt_shifted;
|
|
|
Ciphertext tempctxt_rotatedshifted;
|
|
|
|
|
|
+
|
|
|
for (uint32_t i = 0; i < logm - 1; i++) {
|
|
|
vector<Ciphertext> newtemp(temp.size() << 1);
|
|
|
// temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0). With
|
|
|
// some scaling....
|
|
|
int index_raw = (n << 1) - (1 << i);
|
|
|
+ // TODO: galois elements.
|
|
|
int index = (index_raw * galois_elts[i]) % (n << 1);
|
|
|
- cout << i << "-th expansion round, noise budget = " << endl;
|
|
|
|
|
|
for (uint32_t a = 0; a < temp.size(); a++) {
|
|
|
|
|
|
evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
|
|
|
|
|
|
- cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", ";
|
|
|
+ //cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", ";
|
|
|
|
|
|
|
|
|
evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
|
|
|
multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
|
|
|
|
|
|
- cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", ";
|
|
|
+ //cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", ";
|
|
|
|
|
|
|
|
|
multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
|
|
|
|
|
|
- cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", ";
|
|
|
+ // cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", ";
|
|
|
|
|
|
|
|
|
// Enc(2^i x^j) if j = 0 (mod 2**i).
|
|
|
evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
|
|
|
}
|
|
|
temp = newtemp;
|
|
|
-
|
|
|
+ /*
|
|
|
cout << "end: ";
|
|
|
for (int h = 0; h < temp.size();h++){
|
|
|
cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", ";
|
|
|
}
|
|
|
cout << endl;
|
|
|
+ */
|
|
|
}
|
|
|
-
|
|
|
// Last step of the loop
|
|
|
vector<Ciphertext> newtemp(temp.size() << 1);
|
|
|
int index_raw = (n << 1) - (1 << (logm - 1));
|
|
@@ -312,7 +359,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
|
|
|
for (uint32_t a = 0; a < temp.size(); a++) {
|
|
|
if (a >= (m - (1 << (logm - 1)))) { // corner case.
|
|
|
evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
|
|
|
- cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", ";
|
|
|
+ // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", ";
|
|
|
} else {
|
|
|
evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
|
|
|
evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
|
|
@@ -353,17 +400,16 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
|
|
|
+inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
|
|
|
|
|
|
vector<Plaintext> result;
|
|
|
auto coeff_count = params_.poly_modulus_degree();
|
|
|
auto coeff_mod_count = params_.coeff_modulus().size();
|
|
|
auto encrypted_count = encrypted.size();
|
|
|
|
|
|
- // Generate powers of t.
|
|
|
- uint64_t plainModMinusOne = params_.plain_modulus().value() - 1;
|
|
|
- int exp = ceil(log2(plainModMinusOne + 1));
|
|
|
+ uint64_t t1 = 1 << logt; // t1 <= t.
|
|
|
|
|
|
+ uint64_t t1minusone = t1 -1;
|
|
|
// A triple for loop. Going over polys, moduli, and decomposed index.
|
|
|
|
|
|
for (int i = 0; i < encrypted_count; i++) {
|
|
@@ -373,19 +419,19 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
|
|
|
// create a polynomial to store the current decomposition value
|
|
|
// which will be copied into the array to populate it at the current
|
|
|
// index.
|
|
|
- int logqj = log2(params_.coeff_modulus()[j].value());
|
|
|
- int expansion_ratio = ceil(logqj + exp - 1) / exp;
|
|
|
-
|
|
|
- // cout << "expansion ratio = " << expansion_ratio << endl;
|
|
|
+ double logqj = log2(params_.coeff_modulus()[j].value());
|
|
|
+ //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
|
|
|
+ int expansion_ratio = ceil(logqj / logt);
|
|
|
+ // cout << "local expansion ratio = " << expansion_ratio << endl;
|
|
|
uint64_t curexp = 0;
|
|
|
for (int k = 0; k < expansion_ratio; k++) {
|
|
|
// Decompose here
|
|
|
for (int m = 0; m < coeff_count; m++) {
|
|
|
plain_ptr[i * coeff_mod_count * expansion_ratio
|
|
|
+ j * expansion_ratio + k][m] =
|
|
|
- (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
|
|
|
+ (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & t1minusone;
|
|
|
}
|
|
|
- curexp += exp;
|
|
|
+ curexp += logt;
|
|
|
}
|
|
|
}
|
|
|
}
|