Browse Source

Added modulus context switching - greatly reduces response size.

Andrew Beams 2 years ago
parent
commit
eea9b2bcf8
6 changed files with 47 additions and 16 deletions
  1. 3 2
      src/main.cpp
  2. 3 1
      src/pir.cpp
  3. 3 1
      src/pir.hpp
  4. 13 3
      src/pir_client.cpp
  5. 16 7
      src/pir_server.cpp
  6. 9 2
      test/decomposition_test.cpp

+ 3 - 2
src/main.cpp

@@ -23,6 +23,7 @@ int main(int argc, char *argv[]) {
     uint32_t d = 2;
     bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
     bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
+    bool use_recursive_mod_switching = true;
 
     EncryptionParameters enc_params(scheme_type::bfv);
     PirParams pir_params;
@@ -37,7 +38,7 @@ int main(int argc, char *argv[]) {
     cout << "Main: SEAL parameters are good" << endl;
 
     cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching);
+    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching, use_recursive_mod_switching);
     
     
     print_seal_params(enc_params); 
@@ -79,7 +80,7 @@ int main(int argc, char *argv[]) {
     }
 
     // Measure database setup
-    auto time_pre_s = high_resolution_clock::now();
+    auto time_pre_s = high_resolution_clock::now(); 
     server.set_database(move(db), number_of_items, size_per_item);
     server.preprocess_database();
     cout << "Main: database pre processed " << endl;

+ 3 - 1
src/pir.cpp

@@ -59,7 +59,7 @@ void verify_encryption_params(const seal::EncryptionParameters &enc_params){
 
 void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
                     const EncryptionParameters &enc_params, PirParams &pir_params,
-                    bool enable_symmetric, bool enable_batching){
+                    bool enable_symmetric, bool enable_batching, bool enable_mswitching){
     std::uint32_t N = enc_params.poly_modulus_degree();
     Modulus t = enc_params.plain_modulus();
     std::uint32_t logt = floor(log2(t.value())); // # of usable bits
@@ -85,6 +85,7 @@ void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
 
     pir_params.enable_symmetric = enable_symmetric;
     pir_params.enable_batching = enable_batching;
+    pir_params.enable_mswitching = enable_mswitching;
     pir_params.ele_num = ele_num;
     pir_params.ele_size = ele_size;
     pir_params.elements_per_plaintext = elements_per_plaintext;
@@ -108,6 +109,7 @@ void print_pir_params(const PirParams &pir_params){
     cout << "Number of BFV plaintexts after padding (to fill d-dimensional hyperrectangle): " << prod << endl;
     cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
     cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
+    cout << "Using recursive mod switching: " << pir_params.enable_mswitching << endl;
     cout << "slot count: " << pir_params.slot_count << endl;
     cout << "=============================="<< endl;
 }

+ 3 - 1
src/pir.hpp

@@ -14,6 +14,7 @@ typedef std::vector<seal::Ciphertext> PirReply;
 struct PirParams {
     bool enable_symmetric;
     bool enable_batching;
+    bool enable_mswitching;
     std::uint64_t ele_num;
     std::uint64_t ele_size;
     std::uint64_t elements_per_plaintext;
@@ -34,7 +35,8 @@ void gen_pir_params(uint64_t ele_num,
                     const seal::EncryptionParameters &enc_params,
                     PirParams &pir_params,
                     bool enable_symmetric = false,
-                    bool enable_batching = true);
+                    bool enable_batching = true,
+                    bool enable_mswitching = true);
 
 void gen_params(uint64_t ele_num,
                 uint64_t ele_size,

+ 13 - 3
src/pir_client.cpp

@@ -162,7 +162,17 @@ std::vector<uint8_t> PIRClient::extract_bytes(seal::Plaintext pt, uint64_t offse
 }
 
 Plaintext PIRClient::decode_reply(PirReply &reply) {
-    uint32_t exp_ratio = compute_expansion_ratio(context_->first_context_data()->parms());
+    EncryptionParameters parms;
+    parms_id_type parms_id;
+    if(pir_params_.enable_mswitching){
+        parms = context_->last_context_data()->parms();
+        parms_id = context_->last_parms_id();
+    }
+    else{
+        parms = context_->first_context_data()->parms();
+        parms_id = context_->first_parms_id();
+    }
+    uint32_t exp_ratio = compute_expansion_ratio(parms);
     uint32_t recursion_level = pir_params_.d;
 
     vector<Ciphertext> temp = reply;
@@ -192,8 +202,8 @@ Plaintext PIRClient::decode_reply(PirReply &reply) {
 
             if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
                 // Combine into one ciphertext.
-                Ciphertext combined(*context_); 
-                compose_to_ciphertext(context_->first_context_data()->parms(), tempplain, combined);
+                Ciphertext combined(*context_, parms_id); 
+                compose_to_ciphertext(parms, tempplain, combined);
                 newtemp.push_back(combined);
                 tempplain.clear();
                 // cout << "Client: const term of ciphertext = " << combined[0] << endl; 

+ 16 - 7
src/pir_server.cpp

@@ -159,11 +159,12 @@ PirQuery PIRServer::deserialize_query(stringstream &stream) {
 }
 
 int PIRServer::serialize_reply(PirReply &reply, stringstream &stream) {
-  int output_size = 0;
-  for (int i = 0; i < reply.size(); i++){
-    output_size += reply[i].save(stream);
-  }
-  return output_size;
+    int output_size = 0;
+    for (int i = 0; i < reply.size(); i++){
+        evaluator_->mod_switch_to_inplace(reply[i], context_->last_parms_id());
+        output_size += reply[i].save(stream);
+    }
+    return output_size;
 }
 
 PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
@@ -199,7 +200,7 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
         for (uint32_t j = 0; j < query[i].size(); j++){
             uint64_t total = N; 
             if (j == query[i].size() - 1){
-                total = n_i % N; 
+                total = n_i % N;
             }
             cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
             vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
@@ -259,8 +260,16 @@ PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
             cur = &intermediate_plain;
 
             for (uint64_t rr = 0; rr < product; rr++) {
+                EncryptionParameters parms;
+                if(pir_params_.enable_mswitching){
+                    evaluator_->mod_switch_to_inplace(intermediateCtxts[rr], context_->last_parms_id());
+                    parms = context_->last_context_data()->parms();
+                }
+                else{
+                    parms = context_->first_context_data()->parms();
+                }
 
-                vector<Plaintext> plains = decompose_to_plaintexts(context_->first_context_data()->parms(),
+                vector<Plaintext> plains = decompose_to_plaintexts(parms,
                     intermediateCtxts[rr]);
 
                 for (uint32_t jj = 0; jj < plains.size(); jj++) {

+ 9 - 2
test/decomposition_test.cpp

@@ -39,6 +39,7 @@ int main(int argc, char *argv[]) {
     SecretKey secret_key = keygen.secret_key();
     Encryptor encryptor(context, secret_key);
     Decryptor decryptor(context, secret_key);
+    Evaluator evaluator(context);
     BatchEncoder encoder(context);
     logt = floor(log2(enc_params.plain_modulus().value()));
 
@@ -55,11 +56,17 @@ int main(int argc, char *argv[]) {
     Ciphertext ct;
     encryptor.encrypt_symmetric(pt, ct);
     std::cout << "Encrypting" << std::endl;
-    EncryptionParameters params = context.first_context_data()->parms();
+    auto context_data = context.last_context_data();
+    auto parms_id = context.last_parms_id();
+
+    evaluator.mod_switch_to_inplace(ct, parms_id);
+
+    EncryptionParameters params = context_data->parms();
     std::cout << "Encoding" << std::endl;
     vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
+    std::cout << "Expansion Factor: " << encoded.size() << std::endl;
     std::cout << "Decoding" << std::endl;
-    Ciphertext decoded(context);
+    Ciphertext decoded(context, parms_id);
     compose_to_ciphertext(params, encoded, decoded);
     std::cout << "Checking" <<std::endl;
     Plaintext pt2;