Sfoglia il codice sorgente

Patch for our microbenchmarker scripts

sshsshy 2 anni fa
parent
commit
2f4991e9bf
4 ha cambiato i file con 146 aggiunte e 71 eliminazioni
  1. 126 66
      main.cpp
  2. 5 1
      pir.cpp
  3. 13 3
      pir_server.cpp
  4. 2 1
      pir_server.hpp

+ 126 - 66
main.cpp

@@ -12,20 +12,52 @@ using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
+//TODO: Accept the following as CMD parameters
+// 1) number_of_items
+// 2) size_per_item
+// 3) d?
+// 4) num_requests 
+
+uint8_t no_of_expected_parameters = 4;
+uint64_t number_of_items;
+uint64_t size_per_item;
+uint32_t d;
+uint64_t num_requests;
+
+int getCMDArgs(int argc, char * argv[]) {
+  if(argc < no_of_expected_parameters){
+    printf("Command line parameters error, expected :\n");
+    printf("./sealpir <number_of_items> <size_per_item> <d> <num_requests>\n");
+    exit(0);
+  }
+
+  std::string str = argv[1];
+  number_of_items = std::stoi(str);
+  str = argv[2]; 
+  size_per_item = std::stoi(str); 
+  str = argv[3]; 
+  d = std::stoi(str); 
+  str = argv[4]; 
+  num_requests = std::stoi(str);    
+  return 1;
+}
+
 int main(int argc, char *argv[]) {
 
+    getCMDArgs(argc, argv);
     //uint64_t number_of_items = 1 << 11;
-    //uint64_t number_of_items = 2048;
-    uint64_t number_of_items = 1 << 12;
-
-    uint64_t size_per_item = 288; // in bytes
+    //uint64_t number_of_items = 262144;
+    //uint64_t number_of_items = 1 << 12;
+    
+    //uint64_t size_per_item = 288; // in bytes
     // uint64_t size_per_item = 1 << 10; // 1 KB.
     // uint64_t size_per_item = 10 << 10; // 10 KB.
 
     uint32_t N = 2048;
     // Recommended values: (logt, d) = (12, 2) or (8, 1). 
     uint32_t logt = 12; 
-    uint32_t d = 5;
+    //uint32_t logt = 19;
+    //uint32_t d = 2;
 
     EncryptionParameters params(scheme_type::BFV);
     PirParams pir_params;
@@ -80,68 +112,96 @@ int main(int argc, char *argv[]) {
     auto time_pre_e = high_resolution_clock::now();
     auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
 
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    //uint64_t ele_index = 35; 
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
-    // Measure query generation
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    auto time_query_s = high_resolution_clock::now();
-    PirQuery query = client.generate_query(index);
-    auto time_query_e = high_resolution_clock::now();
-    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "Main: query generated" << endl;
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    //PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
-    PirReply reply = server.generate_reply(query, 0, client);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-
-    // Measure response extraction
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    Plaintext result = client.decode_reply(reply);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-
-    // Convert to elements
-    vector<uint8_t> elems(N * logt / 8);
-    coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
-    // cout << "printing the bytes...of the supposed item: "; 
-    // for (int i = 0; i < size_per_item; i++){
-    //     cout << (int) elems[offset*size_per_item + i] << ", "; 
-    // }
-    // cout << endl; 
-
-    // // cout << "offset = " << offset << endl; 
-
-    // cout << "printing the bytes of real item: "; 
-    // for (int i = 0; i < size_per_item; i++){
-    //     cout << (int) check_db.get()[ele_index *size_per_item + i] << ", "; 
-    // }
-
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[(offset * size_per_item) + i] != check_db.get()[(ele_index * size_per_item) + i]) {
-            cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
-                 << (int) check_db.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "PIR result wrong!" << endl;
-            return -1;
-        }
+    //TODO: Loop query generation, processing and reply extraction num_requests times
+    for (uint64_t l = 0; l< num_requests; l++) {
+	
+	// Choose an index of an element in the DB
+	uint64_t ele_index = rd() % number_of_items; // element in DB at random position
+	//uint64_t ele_index = 35; 
+	cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
+	uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
+	uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
+	// Measure query generation
+	cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
+
+	auto time_query_s = high_resolution_clock::now();
+	PirQuery query = client.generate_query(index);
+	// Note: PirQuery is std::vector<seal::Ciphertext>, so size of a PirQuery is: 
+	// sum of sizes of each of these Ciphertexts
+	uint64_t query_size = 0;
+	for (vector<vector<Ciphertext>>::iterator iter = query.begin(); iter != query.end(); iter++) { 
+	  uint64_t vector_size = (*iter).size(); 
+	  //cout<<"d = "<< d_ctr <<", vector_size = " << (*iter).size() << endl;
+	  //cout<<"Ciphertext N = "<< ((*iter).front()).poly_modulus_degree();
+	  //cout<<", k = "<<  ((*iter).front()).coeff_mod_count();
+	  //cout<< "query size of each element in "<<d_ctr<<"-th level ciphertext vector="<< ((*iter).front()).uint64_count() << endl;
+	  //d_ctr++;
+	  //uint64_count() returns no of 64-bit words used for storing that ciphertext
+	  query_size+=(vector_size*(((*iter).front()).uint64_count())*8);
+	} 
+	auto time_query_e = high_resolution_clock::now();
+	auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
+	cout << "Main: query generated" << endl;
+
+	// Measure query processing (including expansion)
+	auto time_server_s = high_resolution_clock::now();
+        double time_expand_us = 0;
+	//PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
+	PirReply reply = server.generate_reply(query, 0, client, &time_expand_us);
+	uint64_t reply_size, reply_vector_size;
+	vector<Ciphertext>::iterator iter=reply.begin();
+	reply_vector_size = reply.size();
+	reply_size=reply_vector_size * ((*iter).uint64_count()*8);
+	cout<<"reply_vector_size = "<<reply_vector_size<<", reply_size = "<< reply_size<<endl;
+	auto time_server_e = high_resolution_clock::now();
+	auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+	// Measure response extraction
+	auto time_decode_s = chrono::high_resolution_clock::now();
+	Plaintext result = client.decode_reply(reply);
+	auto time_decode_e = chrono::high_resolution_clock::now();
+	auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+	// Convert to elements
+	vector<uint8_t> elems(N * logt / 8);
+	coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
+	// cout << "printing the bytes...of the supposed item: "; 
+	// for (int i = 0; i < size_per_item; i++){
+	//     cout << (int) elems[offset*size_per_item + i] << ", "; 
+	// }
+	// cout << endl; 
+
+	// // cout << "offset = " << offset << endl; 
+
+	// cout << "printing the bytes of real item: "; 
+	// for (int i = 0; i < size_per_item; i++){
+	//     cout << (int) check_db.get()[ele_index *size_per_item + i] << ", "; 
+	// }
+
+	// Check that we retrieved the correct element
+	for (uint32_t i = 0; i < size_per_item; i++) {
+	    if (elems[(offset * size_per_item) + i] != check_db.get()[(ele_index * size_per_item) + i]) {
+		cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
+		     << (int) check_db.get()[(ele_index * size_per_item) + i] << endl;
+		cout << "PIR result wrong!" << endl;
+		return -1;
+	    }
+	}
+
+	// Output results
+	cout << "PIR reseult correct!" << endl;
+	cout << "_QGT_: PIRClient query generation time: " << double(time_query_us) / 1000 << " ms" << endl;
+	cout << "_QPT_: PIRServer query total processing time: " << double(time_server_us) / 1000 << " ms"
+	     << endl;
+	cout << "_QET_: PIRServer query expansion time: " << double(time_expand_us) / 1000 << " ms"
+	     << endl;
+	cout << "_RET_: PIRClient answer decode time: " << double(time_decode_us) / 1000 << " ms" << endl;
+	
+	cout << "_TQS_: Total query size: "<< query_size <<" bytes"<< endl;
+	cout << "_TRS_: Total reply size: "<< reply_size <<" bytes"<< endl;
+	cout << "_PPT_: PIRServer pre-processing time: " << double(time_pre_us) / 1000 << " ms" << endl;
+	cout << "Reply num ciphertexts: " << reply.size() << endl;
     }
 
-    // Output results
-    cout << "PIR reseult correct!" << endl;
-    cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "PIRServer reply generation time: " << time_server_us / 1000 << " ms"
-         << endl;
-    cout << "PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
-    cout << "PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
-    cout << "Reply num ciphertexts: " << reply.size() << endl;
-
     return 0;
 }

+ 5 - 1
pir.cpp

@@ -42,6 +42,10 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
     uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
     uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
 
+    uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
+    cout << "log(plain mod) before expand = " << logt << endl;
+    cout << "log(plain mod) after expand = " << logtp << endl;
+
 #ifdef DEBUG
     cout << "log(plain mod) before expand = " << logt << endl;
     cout << "number of FV plaintexts = " << plaintext_num << endl;
@@ -85,10 +89,10 @@ void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
 
     // Determine the maximum size of each dimension
     uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
-
     uint64_t expanded_plain_mod = static_cast<uint64_t>(1) << logtp;
     uint64_t plaintext_num = plaintexts_per_db(logtp, N, ele_num, ele_size);
 
+
 #ifdef DEBUG
     cout << "log(plain mod) before expand = " << logt << endl;
     cout << "log(plain mod) after expand = " << logtp << endl;

+ 13 - 3
pir_server.cpp

@@ -2,6 +2,7 @@
 #include "pir_client.hpp"
 
 using namespace std;
+using namespace std::chrono;
 using namespace seal;
 using namespace seal::util;
 
@@ -60,7 +61,10 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 
     uint32_t logt = floor(log2(params_.plain_modulus().value()));
     uint32_t N = params_.poly_modulus_degree();
-
+    
+    cout << "In PIRServer::set_database: \n";
+    cout << "logt = "<< logt <<", N = " << N <<endl;
+    
     // number of FV plaintexts needed to represent all elements
     uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
 
@@ -148,7 +152,7 @@ void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey)
     galoisKeys_[client_id] = galkey;
 }
 
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient &client) {
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient &client, double *expand_time_us) {
 
     vector<uint64_t> nvec = pir_params_.nvec;
     uint64_t product = 1;
@@ -168,6 +172,8 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
     int N = params_.poly_modulus_degree();
 
     int logt = floor(log2(params_.plain_modulus().value()));
+    cout<<"In PIR_SERVER.CPP : logt = "<< logt << endl;
+    cout<<"In PIR_SERVER.CPP : log2(params_.plain_modulus().value())) = "<<log2(params_.plain_modulus().value())<< endl;
 
     cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
     for (uint32_t i = 0; i < nvec.size(); i++) {
@@ -179,6 +185,8 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
         uint64_t n_i = nvec[i];
         cout << "Server: n_i = " << n_i << endl; 
         cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
+
+	auto time_expand_s = high_resolution_clock::now();
         for (uint32_t j = 0; j < query[i].size(); j++){
             uint64_t total = N; 
             if (j == query[i].size() - 1){
@@ -194,7 +202,9 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
         if (expanded_query.size() != n_i) {
             cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
         }    
-
+	auto time_expand_e = high_resolution_clock::now();
+	auto time_expand_us = duration_cast<microseconds>(time_expand_e - time_expand_s).count();
+        *expand_time_us+=time_expand_us;
         /*
         cout << "Checking expanded query " << endl; 
         Plaintext tempPt; 

+ 2 - 1
pir_server.hpp

@@ -2,6 +2,7 @@
 
 #include "pir.hpp"
 #include <map>
+#include <chrono>
 #include <memory>
 #include <vector>
 #include "pir_client.hpp"
@@ -22,7 +23,7 @@ class PIRServer {
     std::vector<seal::Ciphertext> expand_query(
             const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id, PIRClient &client);
 
-    PirReply generate_reply(PirQuery query, std::uint32_t client_id, PIRClient &client);
+    PirReply generate_reply(PirQuery query, std::uint32_t client_id, PIRClient &client, double *expand_time_us);
 
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);