Browse Source

cleaning example and pointing README to the right SEAL branch

Sebastian Angel 4 years ago
parent
commit
0e62d25777
8 changed files with 39 additions and 189 deletions
  1. 9 4
      README.md
  2. 27 51
      main.cpp
  3. 0 38
      pir.cpp
  4. 0 5
      pir.hpp
  5. 1 53
      pir_client.cpp
  6. 0 4
      pir_client.hpp
  7. 1 28
      pir_server.cpp
  8. 1 6
      pir_server.hpp

+ 9 - 4
README.md

@@ -1,11 +1,15 @@
 # SealPIR: A computational PIR library that achieves low communication costs and high performance.
 
-SealPIR is a (research) library and should not be used in production systems. SealPIR allows a client to download an element from a database stored by a server without revealing which element was downloaded. SealPIR was introduced in our [paper](https://eprint.iacr.org/2017/1142.pdf).
-
+SealPIR is a research library and should not be used in production systems. 
+SealPIR allows a client to download an element from a database stored by a server without
+revealing which element was downloaded. SealPIR was introduced at 
+the Symposium on Security and Privacy (Oakland) in 2018. You can find
+a copy of the paper [here](https://eprint.iacr.org/2017/1142.pdf).
 
 # Compiling SEAL
 
-SealPIR depends on Microsoft SEAL version 3.2.0 ([link](https://www.microsoft.com/en-us/research/project/microsoft-seal)). Download Microsoft SEAL from [GitHub](https://GitHub.com/Microsoft/SEAL), and follow the instructions in README.md to install it system-wide.
+SealPIR depends on [Microsoft SEAL version 3.2.0](https://github.com/microsoft/SEAL/tree/3.2.0).
+Install SEAL before compiling SealPIR.
 
 # Compiling SealPIR
 
@@ -19,7 +23,8 @@ This should produce a binary file ``bin/sealpir``.
 # Using SealPIR
 
 Take a look at the example in main.cpp for how to use SealPIR. 
-Note: the parameter "d" stands for recursion levels, and for the current configuration, the server-to-client reply has size (pow(10, d-1) * 32) KB. Therefore we recommend using d <= 3.  
+Note: the parameter "d" stands for recursion levels, and for the current configuration, the 
+server-to-client reply has size (pow(10, d-1) * 32) KB. Therefore we recommend using d <= 3.  
 
 # Contributing
 

+ 27 - 51
main.cpp

@@ -14,90 +14,78 @@ using namespace seal;
 
 int main(int argc, char *argv[]) {
 
-    //uint64_t number_of_items = 1 << 11;
-    //uint64_t number_of_items = 2048;
     uint64_t number_of_items = 1 << 12;
-
     uint64_t size_per_item = 288; // in bytes
-    // uint64_t size_per_item = 1 << 10; // 1 KB.
-    // uint64_t size_per_item = 10 << 10; // 10 KB.
-
     uint32_t N = 2048;
+
     // Recommended values: (logt, d) = (12, 2) or (8, 1). 
     uint32_t logt = 12; 
-    uint32_t d = 5;
+    uint32_t d = 2;
 
     EncryptionParameters params(scheme_type::BFV);
     PirParams pir_params;
 
     // Generates all parameters
-    cout << "Generating all parameters" << endl;
+    cout << "Main: Generating all parameters" << endl;
     gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
 
-    cout << "Initializing the database (this may take some time) ..." << endl;
+    cout << "Main: Initializing the database (this may take some time) ..." << endl;
 
     // Create test database
     auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
 
-    // For testing purposes only
-    auto check_db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+    // Copy of the database. We use this at the end to make sure we retrieved
+    // the correct element.
+    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
 
     random_device rd;
     for (uint64_t i = 0; i < number_of_items; i++) {
         for (uint64_t j = 0; j < size_per_item; j++) {
             auto val = rd() % 256;
             db.get()[(i * size_per_item) + j] = val;
-            check_db.get()[(i * size_per_item) + j] = val;
+            db_copy.get()[(i * size_per_item) + j] = val;
         }
     }
 
     // Initialize PIR Server
-    cout << "Initializing server and client" << endl;
+    cout << "Main: Initializing server and client" << endl;
     PIRServer server(params, pir_params);
 
     // Initialize PIR client....
     PIRClient client(params, pir_params);
     GaloisKeys galois_keys = client.generate_galois_keys();
 
-    // Set galois key
+    // Set galois key for client with id 0
     cout << "Main: Setting Galois keys...";
     server.set_galois_key(0, galois_keys);
 
-
-    // The following can be used to update parameters rather than creating new instances
-    // (here it doesn't do anything).
-    // cout << "Updating database size to: " << number_of_items << " elements" << endl;
-    // update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params);
-
-    cout << "done" << endl;
-
-
     // Measure database setup
     auto time_pre_s = high_resolution_clock::now();
     server.set_database(move(db), number_of_items, size_per_item);
     server.preprocess_database();
-    cout << "database pre processed " << endl;
+    cout << "Main: database pre processed " << endl;
     auto time_pre_e = high_resolution_clock::now();
     auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
 
     // Choose an index of an element in the DB
     uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    //uint64_t ele_index = 35; 
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
     uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
     uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
-    // Measure query generation
+    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
     cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
 
+    // Measure query generation
     auto time_query_s = high_resolution_clock::now();
     PirQuery query = client.generate_query(index);
     auto time_query_e = high_resolution_clock::now();
     auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
     cout << "Main: query generated" << endl;
 
+    //To marshall query to send over the network, you can use serialize/deserialize:
+    //PirQuery query = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
+
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
-    //PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
     PirReply reply = server.generate_reply(query, 0, client);
     auto time_server_e = high_resolution_clock::now();
     auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
@@ -108,40 +96,28 @@ int main(int argc, char *argv[]) {
     auto time_decode_e = chrono::high_resolution_clock::now();
     auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
 
-    // Convert to elements
+    // Convert from FV plaintext (polynomial) to database element at the client
     vector<uint8_t> elems(N * logt / 8);
     coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
-    // cout << "printing the bytes...of the supposed item: "; 
-    // for (int i = 0; i < size_per_item; i++){
-    //     cout << (int) elems[offset*size_per_item + i] << ", "; 
-    // }
-    // cout << endl; 
-
-    // // cout << "offset = " << offset << endl; 
-
-    // cout << "printing the bytes of real item: "; 
-    // for (int i = 0; i < size_per_item; i++){
-    //     cout << (int) check_db.get()[ele_index *size_per_item + i] << ", "; 
-    // }
 
     // Check that we retrieved the correct element
     for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[(offset * size_per_item) + i] != check_db.get()[(ele_index * size_per_item) + i]) {
-            cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
-                 << (int) check_db.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "PIR result wrong!" << endl;
+        if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+            cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
+                 << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
+            cout << "Main: PIR result wrong!" << endl;
             return -1;
         }
     }
 
     // Output results
-    cout << "PIR reseult correct!" << endl;
-    cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "PIRServer reply generation time: " << time_server_us / 1000 << " ms"
+    cout << "Main: PIR result correct!" << endl;
+    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
+    cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
+    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms"
          << endl;
-    cout << "PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
-    cout << "PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
-    cout << "Reply num ciphertexts: " << reply.size() << endl;
+    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
+    cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
 
     return 0;
 }

+ 0 - 38
pir.cpp

@@ -76,44 +76,6 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
     pir_params.expansion_ratio = expansion_ratio << 1; // because one ciphertext = two polys
 }
 
-void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d, 
-                   const EncryptionParameters &old_params, EncryptionParameters &expanded_params, 
-                   PirParams &pir_params) {
-
-    uint32_t logt = ceil(log2(old_params.plain_modulus().value()));
-    uint32_t N = old_params.poly_modulus_degree();
-
-    // Determine the maximum size of each dimension
-    uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
-
-    uint64_t expanded_plain_mod = static_cast<uint64_t>(1) << logtp;
-    uint64_t plaintext_num = plaintexts_per_db(logtp, N, ele_num, ele_size);
-
-#ifdef DEBUG
-    cout << "log(plain mod) before expand = " << logt << endl;
-    cout << "log(plain mod) after expand = " << logtp << endl;
-    cout << "number of FV plaintexts = " << plaintext_num << endl;
-#endif
-
-    expanded_params.set_poly_modulus_degree(old_params.poly_modulus_degree());
-    expanded_params.set_coeff_modulus(old_params.coeff_modulus());
-    expanded_params.set_plain_modulus(expanded_plain_mod);
-
-    // Assumes dimension of database is 2
-    vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
-
-    uint32_t expansion_ratio = 0;
-    for (uint32_t i = 0; i < old_params.coeff_modulus().size(); ++i) {
-        double logqi = log2(old_params.coeff_modulus()[i].value());
-        expansion_ratio += ceil(logqi / logtp);
-    }
-
-    pir_params.d = d;
-    pir_params.dbc = 6;
-    pir_params.n = plaintext_num;
-    pir_params.nvec = nvec;
-    pir_params.expansion_ratio = expansion_ratio << 1;
-}
 
 uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d, 
         uint64_t ele_num, uint64_t ele_size) {

+ 0 - 5
pir.hpp

@@ -29,11 +29,6 @@ void gen_params(std::uint64_t ele_num,  // number of elements (not FV plaintexts
                 seal::EncryptionParameters &params,
                 PirParams &pir_params);
 
-void update_params(std::uint64_t ele_num, 
-                   std::uint64_t ele_size,
-                   std::uint32_t d,
-                   const seal::EncryptionParameters &old_params, PirParams &pir_params);
-
 // returns the plaintext modulus after expansion
 std::uint32_t plainmod_after_expansion(std::uint32_t logt, std::uint32_t N, 
                                        std::uint32_t d, std::uint64_t ele_num,

+ 1 - 53
pir_client.cpp

@@ -19,54 +19,8 @@ PIRClient::PIRClient(const EncryptionParameters &params,
 
     decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
     evaluator_ = make_unique<Evaluator>(newcontext_);
-
-    uint64_t t = params_.plain_modulus().value(); 
-
-    uint64_t N = params_.poly_modulus_degree(); 
-
-
-    // 
-    // int logt = floor(log2(params_.plain_modulus().value())); 
-
-    // for(int i = 0; i < pir_params_.nvec.size(); i++){
-    //     uint64_t inverse_scale; 
-    //     // 
-    //     // If the number of items are less than N, then 
-    //     // we use logm.
-    //     int logm = ceil(log2(min(N, pir_params_.nvec[i])));  // if nvec > n what do we do?
-
-    //     int quo = logm / logt; 
-    //     int mod = logm % logt; 
-    //     inverse_scale = pow(2, logt - mod); 
-    //     if ((quo +1) %2 != 0){
-    //         inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
-    //     }
-    //     inverse_scales_.push_back(inverse_scale); 
-    //     if ( (inverse_scale << logm)  % t != 1){
-    //         throw logic_error("something wrong"); 
-    //     }
-    //     cout << "logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
-    // }
 }
 
-// void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
-//                                   const PirParams &pir_params) 
-// {
-
-//     // The only thing that can change is the plaintext modulus and pir_params
-//     assert(expanded_params.poly_modulus_degree() == params_.poly_modulus_degree());
-//     assert(expanded_params.coeff_modulus() == params_.coeff_modulus());
-
-//     params_ = expanded_params;
-//     pir_params_ = pir_params;
-//     auto newcontext = SEALContext::Create(expanded_params);
-
-//     SecretKey secret_key = keygen_->secret_key();
-//     secret_key.parms_id() = expanded_params.parms_id();
-
-//     decryptor_ = make_unique<Decryptor>(newcontext, secret_key);
-//     evaluator_ = make_unique<Evaluator>(newcontext);
-// }
 
 PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
@@ -79,7 +33,7 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
     Plaintext pt(params_.poly_modulus_degree());
     for (uint32_t i = 0; i < indices_.size(); i++) {
-        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] +0.0) / N);
+        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
         // initialize result. 
         cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
         cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
@@ -102,8 +56,6 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
             dest.parms_id() = params_.parms_id();
             result[i].push_back(dest);
         }   
-        
-
     }
 
     return result;
@@ -292,9 +244,5 @@ void PIRClient::compute_inverse_scales(){
         }
         cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
     }
-
-
-
-
 }
 

+ 0 - 4
pir_client.hpp

@@ -11,9 +11,6 @@ class PIRClient {
     PIRClient(const seal::EncryptionParameters &parms,
                const PirParams &pirparms);
 
-    //void update_parameters(const seal::EncryptionParameters &expandedParams,
-    //                       const PirParams &pirparms);
-
     PirQuery generate_query(std::uint64_t desiredIndex);
     seal::Plaintext decode_reply(PirReply reply);
 
@@ -26,7 +23,6 @@ class PIRClient {
     void compute_inverse_scales(); 
 
   private:
-    // Should we store a decryptor and an encryptor?
     seal::EncryptionParameters params_;
     PirParams pir_params_;
 

+ 1 - 28
pir_server.cpp

@@ -14,25 +14,6 @@ PIRServer::PIRServer(const EncryptionParameters &params, const PirParams &pir_pa
     evaluator_ = make_unique<Evaluator>(context);
 }
 
-// void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
-//                                   const PirParams &pir_params) {
-
-//     // The only thing that can change is the plaintext modulus and pir_params
-//     assert(expanded_params.poly_modulus_degree() == params_.poly_modulus_degree());
-//     assert(expanded_params.coeff_modulus() == params_.coeff_modulus());
-
-//     params_ = expanded_params;
-//     pir_params_ = pir_params;
-//     auto context = SEALContext::Create(expanded_params);
-//     evaluator_ = make_unique<Evaluator>(context);
-//     is_db_preprocessed_ = false;
-
-//     // Update all the galois keys
-//     for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
-//         key.second.parms_id() = params_.parms_id();
-//     }
-// }
-
 void PIRServer::preprocess_database() {
     if (!is_db_preprocessed_) {
 
@@ -84,11 +65,8 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     assert(coeff_per_ptxt <= N);
 
     cout << "Server: total number of FV plaintext = " << total << endl;
-
     cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl; 
 
-
-
     uint32_t offset = 0;
 
     for (uint64_t i = 0; i < total; i++) {
@@ -102,6 +80,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
         } else {
             process_bytes = bytes_per_ptxt;
         }
+
         // Get the coefficients of the elements that will be packed in plaintext i
         vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
         offset += process_bytes;
@@ -229,9 +208,6 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
         vector<Ciphertext> intermediateCtxts(product);
         Ciphertext temp;
 
-
-
-
         for (uint64_t k = 0; k < product; k++) {
 
             evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
@@ -248,8 +224,6 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
             //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
         }
 
-
-
         if (i == nvec.size() - 1) {
             return intermediateCtxts;
         } else {
@@ -319,7 +293,6 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
         // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
         // some scaling....
         int index_raw = (n << 1) - (1 << i);
-        // TODO: galois elements. 
         int index = (index_raw * galois_elts[i]) % (n << 1);
 
         for (uint32_t a = 0; a < temp.size(); a++) {

+ 1 - 6
pir_server.hpp

@@ -10,9 +10,6 @@ class PIRServer {
   public:
     PIRServer(const seal::EncryptionParameters &params, const PirParams &pir_params);
 
-    //void update_parameters(const seal::EncryptionParameters &expanded_params,
-    //                       const PirParams &pir_params);
-
     // NOTE: server takes over ownership of db and frees it when it exits.
     // Caller cannot free db
     void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
@@ -28,9 +25,7 @@ class PIRServer {
 
   private:
     seal::EncryptionParameters params_; // SEAL parameters
-
-    //seal::EncryptionParameters expanded_params_; // SEAL parameters
-    PirParams pir_params_;                       // PIR parameters
+    PirParams pir_params_;              // PIR parameters
     std::unique_ptr<Database> db_;
     bool is_db_preprocessed_;
     std::map<int, seal::GaloisKeys> galoisKeys_;