Browse Source

serialization with a bug

Sebastian Angel 2 years ago
parent
commit
8b2d6a2f95
7 changed files with 104 additions and 69 deletions
  1. 21 6
      src/main.cpp
  2. 13 50
      src/pir.cpp
  3. 2 10
      src/pir.hpp
  4. 41 1
      src/pir_client.cpp
  5. 2 0
      src/pir_client.hpp
  6. 24 2
      src/pir_server.cpp
  7. 1 0
      src/pir_server.hpp

+ 21 - 6
src/main.cpp

@@ -21,7 +21,7 @@ int main(int argc, char *argv[]) {
     // Recommended values: (logt, d) = (20, 2).
     uint32_t logt = 20;
     uint32_t d = 2;
-    bool use_symmetric = true; // use symmetric encryption instead of public key (recommended)
+    bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
     bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
 
     EncryptionParameters enc_params(scheme_type::bfv);
@@ -86,7 +86,6 @@ int main(int argc, char *argv[]) {
     auto time_pre_e = high_resolution_clock::now();
     auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
 
-
     // Choose an index of an element in the DB
     uint64_t ele_index = rd() % number_of_items; // element in DB at random position
     uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
@@ -101,13 +100,27 @@ int main(int argc, char *argv[]) {
     auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
     cout << "Main: query generated" << endl;
 
-    //To marshall query to send over the network, you can use serialize/deserialize:
-    //std::string query_ser = serialize_query(query);
-    //PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
+    // Measure serialized query generation (useful for sending over the network)
+    stringstream stream;
+    auto time_s_query_s = high_resolution_clock::now();
+    int query_size = client.generate_serialized_query(index, stream);
+    auto time_s_query_e = high_resolution_clock::now();
+    auto time_s_query_us = duration_cast<microseconds>(time_s_query_e - time_s_query_s).count();
+    cout << "Main: serialized query generated" << endl;
+
+    // Measure query deserialization (useful for receiving over the network)
+    auto time_deserial_s = high_resolution_clock::now();
+    PirQuery query2 = server.deserialize_query(stream);
+    auto time_deserial_e = high_resolution_clock::now();
+    auto time_deserial_us = duration_cast<microseconds>(time_deserial_e - time_deserial_s).count();
+    cout << "Main: query deserialized" << endl;
+
+    //XXX: deserialization is not working correctly at the moment. There is likely a bug in either
+    //serialize or deserialize.
 
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
-    // Answer PIR query form client 0. If there are multiple clients, 
+    // Answer PIR query from client 0. If there are multiple clients, 
     // enter the id of the client (to use the associated galois key).
     PirReply reply = server.generate_reply(query, 0); 
     auto time_server_e = high_resolution_clock::now();
@@ -141,9 +154,11 @@ int main(int argc, char *argv[]) {
     cout << "Main: PIR result correct!" << endl;
     cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
     cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
+    cout << "Main: PIRClient serialized query generation time: " << time_s_query_us / 1000 << " ms" << endl;
     cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms" << endl;
     cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
     cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+    cout << "Main: Query size: " << query_size << " bytes" << endl;
 
     return 0;
 }

+ 13 - 50
src/pir.cpp

@@ -248,62 +248,25 @@ uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
   return inverse;
 }
 
-inline Ciphertext deserialize_ciphertext(string s, shared_ptr<SEALContext> context) {
-    Ciphertext c;
-    std::istringstream input(s);
-    c.unsafe_load(*context, input);
-    return c;
-}
-
-
-vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext, 
-shared_ptr<SEALContext> context) {
-    vector<Ciphertext> c;
-    for (uint32_t i = 0; i < count; i++) {
-        c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext), context));
-    }
-    return c;
-}
 
 PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
 shared_ptr<SEALContext> context) {
-    vector<vector<Ciphertext>> c;
-    for (uint32_t i = 0; i < d; i++) {
-        c.push_back(deserialize_ciphertexts(
-              count, 
-              s.substr(i * count * len_ciphertext, count * len_ciphertext),
-              len_ciphertext, context)
-        );
-    }
-    return c;
-}
-
-
-inline string serialize_ciphertext(Ciphertext c) {
-    std::ostringstream output;
-    c.save(output);
-    return output.str();
-}
-
-string serialize_ciphertexts(vector<Ciphertext> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-        s.append(serialize_ciphertext(c[i]));
-    }
-    return s;
-}
+    vector<vector<Ciphertext>> q;
+    std::istringstream input(s);
 
-string serialize_query(vector<vector<Ciphertext>> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-      for (uint32_t j = 0; j < c[i].size(); j++) {
-        s.append(serialize_ciphertext(c[i][j]));
-      }
+    for (uint32_t i = 0; i < d; i++) {
+        vector<Ciphertext> cs;
+        for (uint32_t i = 0; i < count; i++) {
+          Ciphertext c;
+          c.load(*context, input);
+          cs.push_back(c);
+        }
+        q.push_back(cs);
     }
-    return s;
+    return q;
 }
 
-string serialize_galoiskeys(GaloisKeys g) {
+string serialize_galoiskeys(Serializable<GaloisKeys> g) {
     std::ostringstream output;
     g.save(output);
     return output.str();
@@ -312,6 +275,6 @@ string serialize_galoiskeys(GaloisKeys g) {
 GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
     GaloisKeys *g = new GaloisKeys();
     std::istringstream input(s);
-    g->unsafe_load(*context, input);
+    g->load(*context, input);
     return g;
 }

+ 2 - 10
src/pir.hpp

@@ -18,7 +18,7 @@ struct PirParams {
     std::uint64_t ele_size;
     std::uint64_t elements_per_plaintext;
     std::uint64_t num_of_plaintexts;         // number of plaintexts in database
-    std::uint32_t d;                 // number of dimensions for the database
+    std::uint32_t d;                         // number of dimensions for the database
     std::uint32_t expansion_ratio;           // ratio of ciphertext to plaintext
     std::vector<std::uint64_t> nvec;         // size of each of the d dimensions
     std::uint32_t slot_count;
@@ -79,14 +79,6 @@ std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
 
 uint64_t invert_mod(uint64_t m, const seal::Modulus& mod);
 
-// Serialize and deserialize ciphertexts to send them over the network
-PirQuery deserialize_query(std::uint32_t d, uint32_t count, std::string s, std::uint32_t len_ciphertext,
-                                                        std::shared_ptr<seal::SEALContext> context);
-std::vector<seal::Ciphertext> deserialize_ciphertexts(std::uint32_t count, std::string s,
-                                                      std::uint32_t len_ciphertext, std::shared_ptr<seal::SEALContext> context);
-std::string serialize_ciphertexts(std::vector<seal::Ciphertext> c);
-std::string serialize_query(std::vector<std::vector<seal::Ciphertext>> c);
-
 // Serialize and deserialize galois keys to send them over the network
-std::string serialize_galoiskeys(seal::GaloisKeys g);
+std::string serialize_galoiskeys(seal::Serializable<seal::GaloisKeys> g);
 seal::GaloisKeys *deserialize_galoiskeys(std::string s, std::shared_ptr<seal::SEALContext> context);

+ 41 - 1
src/pir_client.cpp

@@ -29,12 +29,52 @@ PIRClient::PIRClient(const EncryptionParameters &enc_params,
     encoder_ = make_unique<BatchEncoder>(*context_);
 }
 
+int PIRClient::generate_serialized_query(uint64_t desiredIndex, std::stringstream &stream) {
+
+    int N = enc_params_.poly_modulus_degree(); 
+    int output_size = 0;
+    indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+    Plaintext pt(enc_params_.poly_modulus_degree());
+
+    for (uint32_t i = 0; i < indices_.size(); i++) {
+        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
+        // initialize result. 
+        cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
+        cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+        
+        for (uint32_t j =0; j < num_ptxts; j++){
+            pt.set_zero();
+            if (indices_[i] >= N*j && indices_[i] <= N*(j+1)){
+                uint64_t real_index = indices_[i] - N*j; 
+                uint64_t n_i = pir_params_.nvec[i];
+                uint64_t total = N; 
+                if (j == num_ptxts - 1){
+                    total = n_i % N; 
+                }
+                uint64_t log_total = ceil(log2(total));
+
+                cout << "Client: Inverting " << pow(2, log_total) << endl;
+                pt[real_index] = invert_mod(pow(2, log_total), enc_params_.plain_modulus());
+            }
+
+            if(pir_params_.enable_symmetric){
+                output_size += encryptor_->encrypt_symmetric(pt).save(stream);
+            }
+            else{
+                output_size += encryptor_->encrypt(pt).save(stream);
+            }
+        }   
+    }
+
+    return output_size;
+}
+
 
 PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
     indices_ = compute_indices(desiredIndex, pir_params_.nvec);
 
-    vector<vector<Ciphertext> > result(pir_params_.d);
+    PirQuery result(pir_params_.d);
     int N = enc_params_.poly_modulus_degree(); 
 
     Plaintext pt(enc_params_.poly_modulus_degree());

+ 2 - 0
src/pir_client.hpp

@@ -12,6 +12,8 @@ class PIRClient {
                const PirParams &pirparams);
 
     PirQuery generate_query(std::uint64_t desiredIndex);
+    // Serializes the query into the provided stream and returns number of bytes written
+    int generate_serialized_query(std::uint64_t desiredIndex, std::stringstream &stream);
     seal::Plaintext decode_reply(PirReply reply);
     std::vector<uint8_t> decode_reply(PirReply reply, uint64_t offset);
 

+ 24 - 2
src/pir_server.cpp

@@ -37,7 +37,7 @@ void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
     is_db_preprocessed_ = false;
 }
 
-void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, 
+void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes, 
     uint64_t ele_num, uint64_t ele_size) {
 
     uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
@@ -121,10 +121,32 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     set_database(move(result));
 }
 
-void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
+void PIRServer::set_galois_key(uint32_t client_id, seal::GaloisKeys galkey) {
     galoisKeys_[client_id] = galkey;
 }
 
+PirQuery PIRServer::deserialize_query(stringstream &stream) {
+    PirQuery q;
+
+    for (uint32_t i; i < pir_params_.d; i++) {
+        // number of ciphertexts needed to encode the index for dimension i
+        // keeping into account that each ciphertext can encode up to poly_modulus_degree indexes
+        // In most cases this is usually 1.
+        uint32_t ctx_per_dimension = ceil((pir_params_.nvec[i] + 0.0) / enc_params_.poly_modulus_degree());
+
+        vector<Ciphertext> cs;
+        for (uint32_t j = 0; j < ctx_per_dimension; j++) {
+          Ciphertext c;
+          c.load(*context_, stream);
+          cs.push_back(c);
+        }
+
+        q.push_back(cs);
+    }
+
+    return q;
+}
+
 PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 
     vector<uint64_t> nvec = pir_params_.nvec;

+ 1 - 0
src/pir_server.hpp

@@ -19,6 +19,7 @@ class PIRServer {
     std::vector<seal::Ciphertext> expand_query(
             const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
 
+    PirQuery deserialize_query(std::stringstream &stream);
     PirReply generate_reply(PirQuery query, std::uint32_t client_id);
 
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);