main.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <seal/seal.h>
  5. #include <chrono>
  6. #include <memory>
  7. #include <random>
  8. #include <cstdint>
  9. #include <cstddef>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. //TODO: Accept the following as CMD parameters
  14. // 1) number_of_items
  15. // 2) size_per_item
  16. // 3) d?
  17. // 4) num_requests
  18. uint8_t no_of_expected_parameters = 4;
  19. uint64_t number_of_items;
  20. uint64_t size_per_item;
  21. uint32_t d;
  22. uint64_t num_requests;
  23. int getCMDArgs(int argc, char * argv[]) {
  24. if(argc < no_of_expected_parameters){
  25. printf("Command line parameters error, expected :\n");
  26. printf("./sealpir <number_of_items> <size_per_item> <d> <num_requests>\n");
  27. exit(0);
  28. }
  29. std::string str = argv[1];
  30. number_of_items = std::stoi(str);
  31. str = argv[2];
  32. size_per_item = std::stoi(str);
  33. str = argv[3];
  34. d = std::stoi(str);
  35. str = argv[4];
  36. num_requests = std::stoi(str);
  37. return 1;
  38. }
  39. int main(int argc, char *argv[]) {
  40. getCMDArgs(argc, argv);
  41. //uint64_t number_of_items = 1 << 11;
  42. //uint64_t number_of_items = 262144;
  43. //uint64_t number_of_items = 1 << 12;
  44. //uint64_t size_per_item = 288; // in bytes
  45. // uint64_t size_per_item = 1 << 10; // 1 KB.
  46. // uint64_t size_per_item = 10 << 10; // 10 KB.
  47. uint32_t N = 2048;
  48. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  49. uint32_t logt = 12;
  50. //uint32_t logt = 19;
  51. //uint32_t d = 2;
  52. EncryptionParameters params(scheme_type::BFV);
  53. PirParams pir_params;
  54. // Generates all parameters
  55. cout << "Generating all parameters" << endl;
  56. gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
  57. cout << "Initializing the database (this may take some time) ..." << endl;
  58. // Create test database
  59. auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  60. // For testing purposes only
  61. auto check_db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  62. random_device rd;
  63. for (uint64_t i = 0; i < number_of_items; i++) {
  64. for (uint64_t j = 0; j < size_per_item; j++) {
  65. auto val = rd() % 256;
  66. db.get()[(i * size_per_item) + j] = val;
  67. check_db.get()[(i * size_per_item) + j] = val;
  68. }
  69. }
  70. // Initialize PIR Server
  71. cout << "Initializing server and client" << endl;
  72. PIRServer server(params, pir_params);
  73. // Initialize PIR client....
  74. PIRClient client(params, pir_params);
  75. GaloisKeys galois_keys = client.generate_galois_keys();
  76. // Set galois key
  77. cout << "Main: Setting Galois keys...";
  78. server.set_galois_key(0, galois_keys);
  79. // The following can be used to update parameters rather than creating new instances
  80. // (here it doesn't do anything).
  81. // cout << "Updating database size to: " << number_of_items << " elements" << endl;
  82. // update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params);
  83. cout << "done" << endl;
  84. // Measure database setup
  85. auto time_pre_s = high_resolution_clock::now();
  86. server.set_database(move(db), number_of_items, size_per_item);
  87. server.preprocess_database();
  88. cout << "database pre processed " << endl;
  89. auto time_pre_e = high_resolution_clock::now();
  90. auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
  91. //TODO: Loop query generation, processing and reply extraction num_requests times
  92. for (uint64_t l = 0; l< num_requests; l++) {
  93. // Choose an index of an element in the DB
  94. uint64_t ele_index = rd() % number_of_items; // element in DB at random position
  95. //uint64_t ele_index = 35;
  96. cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
  97. uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext
  98. uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
  99. // Measure query generation
  100. cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
  101. auto time_query_s = high_resolution_clock::now();
  102. PirQuery query = client.generate_query(index);
  103. // Note: PirQuery is std::vector<seal::Ciphertext>, so size of a PirQuery is:
  104. // sum of sizes of each of these Ciphertexts
  105. uint64_t query_size = 0;
  106. for (vector<vector<Ciphertext>>::iterator iter = query.begin(); iter != query.end(); iter++) {
  107. uint64_t vector_size = (*iter).size();
  108. //cout<<"d = "<< d_ctr <<", vector_size = " << (*iter).size() << endl;
  109. //cout<<"Ciphertext N = "<< ((*iter).front()).poly_modulus_degree();
  110. //cout<<", k = "<< ((*iter).front()).coeff_mod_count();
  111. //cout<< "query size of each element in "<<d_ctr<<"-th level ciphertext vector="<< ((*iter).front()).uint64_count() << endl;
  112. //d_ctr++;
  113. //uint64_count() returns no of 64-bit words used for storing that ciphertext
  114. query_size+=(vector_size*(((*iter).front()).uint64_count())*8);
  115. }
  116. auto time_query_e = high_resolution_clock::now();
  117. auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
  118. cout << "Main: query generated" << endl;
  119. // Measure query processing (including expansion)
  120. auto time_server_s = high_resolution_clock::now();
  121. double time_expand_us = 0;
  122. //PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
  123. PirReply reply = server.generate_reply(query, 0, client, &time_expand_us);
  124. uint64_t reply_size, reply_vector_size;
  125. vector<Ciphertext>::iterator iter=reply.begin();
  126. reply_vector_size = reply.size();
  127. reply_size=reply_vector_size * ((*iter).uint64_count()*8);
  128. cout<<"reply_vector_size = "<<reply_vector_size<<", reply_size = "<< reply_size<<endl;
  129. auto time_server_e = high_resolution_clock::now();
  130. auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
  131. // Measure response extraction
  132. auto time_decode_s = chrono::high_resolution_clock::now();
  133. Plaintext result = client.decode_reply(reply);
  134. auto time_decode_e = chrono::high_resolution_clock::now();
  135. auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
  136. // Convert to elements
  137. vector<uint8_t> elems(N * logt / 8);
  138. coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
  139. // cout << "printing the bytes...of the supposed item: ";
  140. // for (int i = 0; i < size_per_item; i++){
  141. // cout << (int) elems[offset*size_per_item + i] << ", ";
  142. // }
  143. // cout << endl;
  144. // // cout << "offset = " << offset << endl;
  145. // cout << "printing the bytes of real item: ";
  146. // for (int i = 0; i < size_per_item; i++){
  147. // cout << (int) check_db.get()[ele_index *size_per_item + i] << ", ";
  148. // }
  149. // Check that we retrieved the correct element
  150. for (uint32_t i = 0; i < size_per_item; i++) {
  151. if (elems[(offset * size_per_item) + i] != check_db.get()[(ele_index * size_per_item) + i]) {
  152. cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
  153. << (int) check_db.get()[(ele_index * size_per_item) + i] << endl;
  154. cout << "PIR result wrong!" << endl;
  155. return -1;
  156. }
  157. }
  158. // Output results
  159. cout << "PIR reseult correct!" << endl;
  160. cout << "_QGT_: PIRClient query generation time: " << double(time_query_us) / 1000 << " ms" << endl;
  161. cout << "_QPT_: PIRServer query total processing time: " << double(time_server_us) / 1000 << " ms"
  162. << endl;
  163. cout << "_QET_: PIRServer query expansion time: " << double(time_expand_us) / 1000 << " ms"
  164. << endl;
  165. cout << "_RET_: PIRClient answer decode time: " << double(time_decode_us) / 1000 << " ms" << endl;
  166. cout << "_TQS_: Total query size: "<< query_size <<" bytes"<< endl;
  167. cout << "_TRS_: Total reply size: "<< reply_size <<" bytes"<< endl;
  168. cout << "_PPT_: PIRServer pre-processing time: " << double(time_pre_us) / 1000 << " ms" << endl;
  169. cout << "Reply num ciphertexts: " << reply.size() << endl;
  170. }
  171. return 0;
  172. }