Browse Source

Fixed two bugs - one where encode_db would read slightly beyond the end (not sure of ramifications) and one in encoding ciphertexts as plaintexts (so d >= 3 works now)

Andrew Beams 2 years ago
parent
commit
0df8e32bcc
9 changed files with 197 additions and 144 deletions
  1. 80 0
      src/pir.cpp
  2. 11 0
      src/pir.hpp
  3. 5 48
      src/pir_client.cpp
  4. 1 2
      src/pir_client.hpp
  5. 11 92
      src/pir_server.cpp
  6. 3 2
      src/pir_server.hpp
  7. 4 0
      test/CMakeLists.txt
  8. 73 0
      test/decomposition_test.cpp
  9. 9 0
      test/query_test.cpp

+ 80 - 0
src/pir.cpp

@@ -254,6 +254,86 @@ uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
 }
 
 
+uint32_t compute_expansion_ratio(EncryptionParameters params) {
+  uint32_t expansion_ratio = 0;
+  uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  for (size_t i = 0; i < params.coeff_modulus().size(); ++i) {
+    double coeff_bit_size = log2(params.coeff_modulus()[i].value());
+    expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff);
+  }
+  return expansion_ratio;
+}
+
+vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params, const Ciphertext& ct) {
+  const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  const auto coeff_count = params.poly_modulus_degree();
+  const auto coeff_mod_count = params.coeff_modulus().size();
+  const uint64_t pt_bitmask = (1 << pt_bits_per_coeff) - 1;
+
+  vector<Plaintext> result(compute_expansion_ratio(params) * ct.size());
+  auto pt_iter = result.begin();
+  for (size_t poly_index = 0; poly_index < ct.size(); ++poly_index) {
+    for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
+         ++coeff_mod_index) {
+      const double coeff_bit_size =
+          log2(params.coeff_modulus()[coeff_mod_index].value());
+      const size_t local_expansion_ratio =
+          ceil(coeff_bit_size / pt_bits_per_coeff);
+      size_t shift = 0;
+      for (size_t i = 0; i < local_expansion_ratio; ++i) {
+        pt_iter->resize(coeff_count);
+        for (size_t c = 0; c < coeff_count; ++c) {
+          (*pt_iter)[c] =
+              (ct.data(poly_index)[coeff_mod_index * coeff_count + c] >>
+               shift) &
+              pt_bitmask;
+        }
+        ++pt_iter;
+        shift += pt_bits_per_coeff;
+      }
+    }
+  }
+  return result;
+}
+
+void compose_to_ciphertext(EncryptionParameters params, vector<Plaintext>::const_iterator pt_iter,
+  const size_t ct_poly_count, Ciphertext& ct) {
+  const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  const auto coeff_count = params.poly_modulus_degree();
+  const auto coeff_mod_count = params.coeff_modulus().size();
+
+  ct.resize(ct_poly_count);
+  for (size_t poly_index = 0; poly_index < ct_poly_count; ++poly_index) {
+    for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
+         ++coeff_mod_index) {
+      const double coeff_bit_size =
+          log2(params.coeff_modulus()[coeff_mod_index].value());
+      const size_t local_expansion_ratio =
+          ceil(coeff_bit_size / pt_bits_per_coeff);
+      size_t shift = 0;
+      for (size_t i = 0; i < local_expansion_ratio; ++i) {
+        for (size_t c = 0; c < pt_iter->coeff_count(); ++c) {
+          if (shift == 0) {
+            ct.data(poly_index)[coeff_mod_index * coeff_count + c] =
+                (*pt_iter)[c];
+          } else {
+            ct.data(poly_index)[coeff_mod_index * coeff_count + c] +=
+                ((*pt_iter)[c] << shift);
+          }
+        }
+        ++pt_iter;
+        shift += pt_bits_per_coeff;
+      }
+    }
+  }
+}
+
+void compose_to_ciphertext(EncryptionParameters params, const vector<Plaintext>& pts, Ciphertext& ct) {
+  return compose_to_ciphertext(params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
+}
+
+
+
 PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
 shared_ptr<SEALContext> context) {
     vector<vector<Ciphertext>> q;

+ 11 - 0
src/pir.hpp

@@ -79,6 +79,17 @@ std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
 
 uint64_t invert_mod(uint64_t m, const seal::Modulus& mod);
 
+uint32_t compute_expansion_ratio(seal::EncryptionParameters params);
+std::vector<seal::Plaintext> decompose_to_plaintexts(seal::EncryptionParameters params,
+    const seal::Ciphertext& ct);
+
+//We need the returned ciphertext to be initialized by Context so the caller will pass it in
+void compose_to_ciphertext(seal::EncryptionParameters params, 
+    const std::vector<seal::Plaintext>& pts, seal::Ciphertext& ct);
+void compose_to_ciphertext(seal::EncryptionParameters params, 
+    std::vector<seal::Plaintext>::const_iterator pt_iter, seal::Ciphertext& ct);
+
+
 // Serialize and deserialize galois keys to send them over the network
 std::string serialize_galoiskeys(seal::Serializable<seal::GaloisKeys> g);
 seal::GaloisKeys *deserialize_galoiskeys(std::string s, std::shared_ptr<seal::SEALContext> context);

+ 5 - 48
src/pir_client.cpp

@@ -162,10 +162,11 @@ std::vector<uint8_t> PIRClient::extract_bytes(seal::Plaintext pt, uint64_t offse
 }
 
 Plaintext PIRClient::decode_reply(PirReply &reply) {
-    uint32_t exp_ratio = pir_params_.expansion_ratio;
+    uint32_t exp_ratio = compute_expansion_ratio(context_->first_context_data()->parms());
     uint32_t recursion_level = pir_params_.d;
 
     vector<Ciphertext> temp = reply;
+    uint32_t ciphertext_size = temp[0].size();
 
     uint64_t t = enc_params_.plain_modulus().value();
 
@@ -189,9 +190,10 @@ Plaintext PIRClient::decode_reply(PirReply &reply) {
             cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
 #endif
 
-            if ((j + 1) % exp_ratio == 0 && j > 0) {
+            if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
                 // Combine into one ciphertext.
-                Ciphertext combined = compose_to_ciphertext(tempplain);
+                Ciphertext combined(*context_); 
+                compose_to_ciphertext(context_->first_context_data()->parms(), tempplain, combined);
                 newtemp.push_back(combined);
                 tempplain.clear();
                 // cout << "Client: const term of ciphertext = " << combined[0] << endl; 
@@ -232,51 +234,6 @@ GaloisKeys PIRClient::generate_galois_keys() {
     return gal_keys;
 }
 
-Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
-    size_t encrypted_count = 2;
-    auto coeff_count = enc_params_.poly_modulus_degree();
-    auto coeff_mod_count = enc_params_.coeff_modulus().size();
-    uint64_t plainMod = enc_params_.plain_modulus().value();
-    int logt = floor(log2(plainMod)); 
-
-    Ciphertext result(*context_);
-    result.resize(encrypted_count);
-
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-    for (int i = 0; i < encrypted_count; i++) {
-        uint64_t *encrypted_pointer = result.data(i);
-
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            double logqj = log2(enc_params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj / logt);
-            uint64_t cur = 1;
-            // cout << "Client: expansion_ratio = " << expansion_ratio << endl; 
-
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Compose here
-                const uint64_t *plain_coeff =
-                    plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
-                        .data();
-
-                for (int m = 0; m < coeff_count; m++) {
-                    if (k == 0) {
-                        *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
-                    } else {
-                        *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
-                    }
-                }
-                cur <<= logt;
-            }
-        }
-    }
-
-    return result;
-}
-
 Plaintext PIRClient::replace_element(Plaintext pt, vector<uint64_t> new_element, uint64_t offset){
     vector<uint64_t> coeffs = extract_coeffs(pt);
     

+ 1 - 2
src/pir_client.hpp

@@ -35,7 +35,7 @@ class PIRClient {
     seal::Ciphertext get_one();
 
     seal::Plaintext replace_element(seal::Plaintext pt, std::vector<std::uint64_t> new_element, std::uint64_t offset);
-
+   
 
   private:
     seal::EncryptionParameters enc_params_;
@@ -51,7 +51,6 @@ class PIRClient {
     vector<uint64_t> indices_; // the indices for retrieval. 
     vector<uint64_t> inverse_scales_; 
 
-    seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);
 
     friend class PIRServer;
 };

+ 11 - 92
src/pir_server.cpp

@@ -83,10 +83,12 @@ void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes,
         } else {
             process_bytes = bytes_per_ptxt;
         }
+        assert(process_bytes % ele_size == 0);
+        uint64_t ele_in_chunk = process_bytes / ele_size;
 
         // Get the coefficients of the elements that will be packed in plaintext i
         vector<uint64_t> coefficients(coeff_per_ptxt);
-        for(uint64_t ele = 0; ele < ele_per_ptxt; ele++){
+        for(uint64_t ele = 0; ele < ele_in_chunk; ele++){
             vector<uint64_t> element_coeffs = bytes_to_coeffs(logt, bytes.get() + offset + (ele_size*ele), ele_size);
             std::copy(element_coeffs.begin(), element_coeffs.end(), coefficients.begin() + (coefficients_per_element(logt, ele_size) * ele));
         }
@@ -189,7 +191,7 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
         cout << "Server: " << i + 1 << "-th recursion level started " << endl; 
 
 
-        vector<Ciphertext> expanded_query; 
+        vector<Ciphertext> expanded_query;
 
         uint64_t n_i = nvec[i];
         cout << "Server: n_i = " << n_i << endl; 
@@ -197,7 +199,7 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
         for (uint32_t j = 0; j < query[i].size(); j++){
             uint64_t total = N; 
             if (j == query[i].size() - 1){
-                total = n_i % N;
+                total = n_i % N; 
             }
             cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
             vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
@@ -224,7 +226,7 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
 
         for (uint64_t k = 0; k < product; k++) {
             if ((*cur)[k].is_zero()){
-                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl; 
+                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl;
             }
         }
 
@@ -256,21 +258,16 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
             intermediate_plain.reserve(pir_params_.expansion_ratio * product);
             cur = &intermediate_plain;
 
-            auto tempplain = util::allocate<Plaintext>(
-                pir_params_.expansion_ratio * product,
-                pool, coeff_count);
-
             for (uint64_t rr = 0; rr < product; rr++) {
 
-                decompose_to_plaintexts_ptr(intermediateCtxts[rr],
-                    tempplain.get() + rr * pir_params_.expansion_ratio, logt);
+                vector<Plaintext> plains = decompose_to_plaintexts(context_->first_context_data()->parms(),
+                    intermediateCtxts[rr]);
 
-                for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
-                    auto offset = rr * pir_params_.expansion_ratio + jj;
-                    intermediate_plain.emplace_back(tempplain[offset]);
+                for (uint32_t jj = 0; jj < plains.size(); jj++) {
+                    intermediate_plain.emplace_back(plains[jj]);
                 }
             }
-            product *= pir_params_.expansion_ratio; // multiply by expansion rate.
+            product = intermediate_plain.size(); // multiply by expansion rate.
         }
         cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
         cout << endl;
@@ -393,84 +390,6 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
     }
 }
 
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
-
-    vector<Plaintext> result;
-    auto coeff_count = enc_params_.poly_modulus_degree();
-    auto coeff_mod_count = enc_params_.coeff_modulus().size();
-    auto encrypted_count = encrypted.size();
-
-    uint64_t t1 = 1 << logt;  //  t1 <= t. 
-
-    uint64_t t1minusone =  t1 -1; 
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-
-    for (int i = 0; i < encrypted_count; i++) {
-        const uint64_t *encrypted_pointer = encrypted.data(i);
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            double logqj = log2(enc_params_.coeff_modulus()[j].value());
-            //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
-            int expansion_ratio =  ceil(logqj / logt); 
-            // cout << "local expansion ratio = " << expansion_ratio << endl;
-            uint64_t curexp = 0;
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Decompose here
-                for (int m = 0; m < coeff_count; m++) {
-                    plain_ptr[i * coeff_mod_count * expansion_ratio
-                        + j * expansion_ratio + k][m] =
-                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & t1minusone;
-                }
-                curexp += logt;
-            }
-        }
-    }
-}
-
-vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
-    vector<Plaintext> result;
-    auto coeff_count = enc_params_.poly_modulus_degree();
-    auto coeff_mod_count = enc_params_.coeff_modulus().size();
-    auto plain_bit_count = enc_params_.plain_modulus().bit_count();
-    auto encrypted_count = encrypted.size();
-
-    // Generate powers of t.
-    uint64_t plainMod = enc_params_.plain_modulus().value();
-
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-    for (int i = 0; i < encrypted_count; i++) {
-        const uint64_t *encrypted_pointer = encrypted.data(i);
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            int logqj = log2(enc_params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj / log2(plainMod));
-
-            // cout << "expansion ratio = " << expansion_ratio << endl;
-            uint64_t cur = 1;
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Decompose here
-                Plaintext temp(coeff_count);
-                transform(encrypted_pointer + (j * coeff_count), 
-                        encrypted_pointer + ((j + 1) * coeff_count), 
-                        temp.data(),
-                        [cur, &plainMod](auto &in) { return (in / cur) % plainMod; }
-                );
-
-                result.emplace_back(move(temp));
-                cur *= plainMod;
-            }
-        }
-    }
-
-    return result;
-}
-
 void PIRServer::simple_set(uint64_t index, Plaintext pt){
     if(is_db_preprocessed_){
         evaluator_->transform_to_ntt_inplace(

+ 3 - 2
src/pir_server.hpp

@@ -27,9 +27,12 @@ class PIRServer {
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
 
     void simple_set(std::uint64_t index, seal::Plaintext pt);
+    // This is used for querying an element of the database WITHOUT PIR.
     seal::Ciphertext simple_query(std::uint64_t index);
     //This is only used for simple_query
     void set_one_ct(seal::Ciphertext one);
+   
+
 
   private:
     seal::EncryptionParameters enc_params_; // SEAL parameters
@@ -44,8 +47,6 @@ class PIRServer {
     //This is only uesd for simple_query
     seal::Ciphertext one_;
 
-    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr, int logt);
-    std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);
     void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
                              std::uint32_t index);
 };

+ 4 - 0
test/CMakeLists.txt

@@ -19,3 +19,7 @@ add_test(NAME simple_query_test COMMAND simple_query_test)
 add_executable(replace_test replace_test.cpp)
 target_link_libraries(replace_test sealpir)
 add_test(NAME replace_test COMMAND replace_test)
+
+add_executable(decomposition_test decomposition_test.cpp)
+target_link_libraries(decomposition_test sealpir)
+add_test(NAME decomposition_test COMMAND decomposition_test)

+ 73 - 0
test/decomposition_test.cpp

@@ -0,0 +1,73 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <seal/seal.h>
+#include <chrono>
+#include <memory>
+#include <random>
+#include <cstdint>
+#include <cstddef>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+
+int main(int argc, char *argv[]) {
+
+    uint64_t number_of_items = 2048;
+    uint64_t size_per_item = 288; // in bytes
+    uint32_t N = 8192;
+
+    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
+    uint32_t logt = 20; 
+
+    EncryptionParameters enc_params(scheme_type::bfv);
+
+    // Generates all parameters
+    
+    cout << "Main: Generating SEAL parameters" << endl;
+    gen_encryption_params(N, logt, enc_params);
+    
+    cout << "Main: Verifying SEAL parameters" << endl;
+    verify_encryption_params(enc_params);
+    cout << "Main: SEAL parameters are good" << endl;
+    
+    SEALContext context(enc_params, true);
+    KeyGenerator keygen(context);
+    
+    SecretKey secret_key = keygen.secret_key();
+    Encryptor encryptor(context, secret_key);
+    Decryptor decryptor(context, secret_key);
+    BatchEncoder encoder(context);
+    logt = floor(log2(enc_params.plain_modulus().value()));
+
+    uint32_t plain_modulus = enc_params.plain_modulus().value();
+
+    size_t slot_count = encoder.slot_count();
+
+    vector<uint64_t> coefficients(slot_count, 0ULL);
+    for(uint32_t i = 0; i < coefficients.size(); i++){
+        coefficients[i] = rand() % plain_modulus;
+    }
+    Plaintext pt; 
+    encoder.encode(coefficients, pt);
+    Ciphertext ct;
+    encryptor.encrypt_symmetric(pt, ct);
+    std::cout << "Encrypting" << std::endl;
+    EncryptionParameters params = context.first_context_data()->parms();
+    std::cout << "Encoding" << std::endl;
+    vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
+    std::cout << "Decoding" << std::endl;
+    Ciphertext decoded(context);
+    compose_to_ciphertext(params, encoded, decoded);
+    std::cout << "Checking" <<std::endl;
+    Plaintext pt2;
+    decryptor.decrypt(decoded, pt2);
+
+    assert(pt == pt2);
+
+    std::cout << "Worked" << std::endl;
+
+    return 0;
+}

+ 9 - 0
test/query_test.cpp

@@ -18,11 +18,20 @@ int main(int argc, char *argv[]) {
     // Quick check
     assert(query_test(1 << 10, 288, 4096, 20, 1) == 0);
 
+    assert(query_test(1 << 10, 288, 4096, 20, 2) == 0);
+
+    assert(query_test(1 << 10, 288, 4096, 20, 3) == 0);
+
+    assert(query_test(1 << 10, 288, 8192, 20, 2) == 0);
+
+
     // Forces ciphertext expansion to be the same as the degree
     assert(query_test(1 << 20, 288, 4096, 20, 1) == 0);
 
     
     assert(query_test(1 << 20, 288, 4096, 20, 2) == 0);
+
+    
 }