Browse Source

Now all database elements are encoded into plaintexts using the batch encoder.

Andrew Beams 3 years ago
parent
commit
cf6d3c1682
7 changed files with 38 additions and 13 deletions
  1. 4 8
      main.cpp
  2. 11 2
      pir.cpp
  3. 2 1
      pir.hpp
  4. 15 0
      pir_client.cpp
  5. 2 0
      pir_client.hpp
  6. 3 2
      pir_server.cpp
  7. 1 0
      pir_server.hpp

+ 4 - 8
main.cpp

@@ -107,21 +107,17 @@ int main(int argc, char *argv[]) {
 
     // Measure response extraction
     auto time_decode_s = chrono::high_resolution_clock::now();
-    Plaintext result = client.decode_reply(reply);
+    vector<uint8_t> elems = client.decode_reply(reply, offset);
     auto time_decode_e = chrono::high_resolution_clock::now();
     auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
 
-    logt = floor(log2(enc_params.plain_modulus().value()));
-
-    // Convert from FV plaintext (polynomial) to database element at the client
-    vector<uint8_t> elems(N * logt / 8);
-    coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
+    assert(elems.size() == size_per_item);
 
     bool failed = false;
     // Check that we retrieved the correct element
     for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
-            cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
+        if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+            cout << "Main: elems " << (int)elems[i] << ", db "
                 << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
             cout << "Main: PIR result wrong at " << i <<  endl;
             failed = true;

+ 11 - 2
pir.cpp

@@ -46,6 +46,13 @@ void verify_encryption_params(const seal::EncryptionParameters &enc_params){
     if(!context.first_context_data()->qualifiers().using_batching){
         throw invalid_argument("SEAL parameters do not support batching.");
     }
+
+    BatchEncoder batch_encoder(context);
+    size_t slot_count = batch_encoder.slot_count();
+    if(slot_count != enc_params.poly_modulus_degree()){
+        throw invalid_argument("Slot count not equal to poly modulus degree - this will cause issues downstream.");
+    }
+
     return;
 }
 
@@ -90,6 +97,7 @@ void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
     pir_params.expansion_ratio = expansion_ratio << 1;           
     pir_params.nvec = nvec;
     pir_params.n = num_of_plaintexts;
+    pir_params.slot_count = N;
 }
 
 
@@ -102,6 +110,7 @@ void print_pir_params(const PirParams &pir_params){
     cout << "dimension: " << pir_params.d << endl;
     cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
     cout << "n: " << pir_params.n << endl;
+    cout << "slot count: " << pir_params.slot_count << endl;
 }
 
 uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d, 
@@ -177,12 +186,12 @@ vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t
     return output;
 }
 
-void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, uint32_t size_out) {
+void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs, uint8_t *output, uint32_t size_out) {
     uint32_t room = 8;
     uint32_t j = 0;
     uint8_t *target = output;
 
-    for (uint32_t i = 0; i < coeffs.coeff_count(); i++) {
+    for (uint32_t i = 0; i < coeffs.size(); i++) {
         uint64_t src = coeffs[i];
         uint32_t rest = limit;
         while (rest && j < size_out) {

+ 2 - 1
pir.hpp

@@ -22,6 +22,7 @@ struct PirParams {
     std::uint32_t expansion_ratio;           // ratio of ciphertext to plaintext
     std::vector<std::uint64_t> nvec;         // size of each of the d dimensions
     std::uint32_t n;
+    std::uint32_t slot_count;
 };
 
 void gen_encryption_params(std::uint32_t N,        // degree of polynomial
@@ -69,7 +70,7 @@ std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit, const std::uint8
                                            std::uint64_t size);
 
 // Converts an array of coefficients into an array of bytes
-void coeffs_to_bytes(std::uint32_t logtp, const seal::Plaintext &coeffs, std::uint8_t *output,
+void coeffs_to_bytes(std::uint32_t logtp, const std::vector<std::uint64_t> &coeffs, std::uint8_t *output,
                      std::uint32_t size_out);
 
 // Takes a vector of coefficients and returns the corresponding FV plaintext

+ 15 - 0
pir_client.cpp

@@ -26,6 +26,7 @@ PIRClient::PIRClient(const EncryptionParameters &enc_params,
     
     decryptor_ = make_unique<Decryptor>(*context_, secret_key);
     evaluator_ = make_unique<Evaluator>(*context_);
+    encoder_ = make_unique<BatchEncoder>(*context_);
 }
 
 
@@ -79,6 +80,20 @@ uint64_t PIRClient::get_fv_offset(uint64_t element_index) {
     return element_index % pir_params_.elements_per_plaintext;
 }
 
+vector<uint8_t> PIRClient::decode_reply(PirReply reply, uint64_t offset){
+    Plaintext result = decode_reply(reply);
+    
+    uint32_t N = enc_params_.poly_modulus_degree(); 
+    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+
+    // Convert from FV plaintext (polynomial) to database element at the client
+    vector<uint8_t> elems(N * logt / 8);
+    vector<uint64_t> coeffs;
+    encoder_->decode(result, coeffs);
+    coeffs_to_bytes(logt, coeffs, elems.data(), (N * logt) / 8);
+    return std::vector<uint8_t>(elems.begin() + offset * pir_params_.ele_size, elems.begin() + (offset + 1) * pir_params_.ele_size);
+}
+
 Plaintext PIRClient::decode_reply(PirReply reply) {
     uint32_t exp_ratio = pir_params_.expansion_ratio;
     uint32_t recursion_level = pir_params_.d;

+ 2 - 0
pir_client.hpp

@@ -13,6 +13,7 @@ class PIRClient {
 
     PirQuery generate_query(std::uint64_t desiredIndex);
     seal::Plaintext decode_reply(PirReply reply);
+    std::vector<uint8_t> decode_reply(PirReply reply, uint64_t offset);
 
     seal::GaloisKeys generate_galois_keys();
 
@@ -29,6 +30,7 @@ class PIRClient {
     std::unique_ptr<seal::Decryptor> decryptor_;
     std::unique_ptr<seal::Evaluator> evaluator_;
     std::unique_ptr<seal::KeyGenerator> keygen_;
+    std::unique_ptr<seal::BatchEncoder> encoder_;
     std::shared_ptr<seal::SEALContext> context_;
 
     vector<uint64_t> indices_; // the indices for retrieval. 

+ 3 - 2
pir_server.cpp

@@ -12,6 +12,7 @@ PIRServer::PIRServer(const EncryptionParameters &enc_params, const PirParams &pi
 {
     context_ = make_shared<SEALContext>(enc_params, true);
     evaluator_ = make_unique<Evaluator>(*context_);
+    encoder_ = make_unique<BatchEncoder>(*context_);
 }
 
 void PIRServer::preprocess_database() {
@@ -94,12 +95,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
         assert(used <= coeff_per_ptxt);
 
         // Pad the rest with 1s
-        for (uint64_t j = 0; j < (N - used); j++) {
+        for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
             coefficients.push_back(1);
         }
 
         Plaintext plain;
-        vector_to_plaintext(coefficients, plain);
+        encoder_->encode(coefficients, plain);
         // cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
         result->push_back(move(plain));
     }

+ 1 - 0
pir_server.hpp

@@ -30,6 +30,7 @@ class PIRServer {
     bool is_db_preprocessed_;
     std::map<int, seal::GaloisKeys> galoisKeys_;
     std::unique_ptr<seal::Evaluator> evaluator_;
+    std::unique_ptr<seal::BatchEncoder> encoder_;
     std::shared_ptr<seal::SEALContext> context_;
 
     void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr, int logt);