main.cpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <chrono>
  5. #include <cstddef>
  6. #include <cstdint>
  7. #include <memory>
  8. #include <random>
  9. #include <seal/seal.h>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. int main(int argc, char *argv[]) {
  14. uint64_t number_of_items = 1 << 16;
  15. uint64_t size_per_item = 1024; // in bytes
  16. uint32_t N = 4096;
  17. // Recommended values: (logt, d) = (20, 2).
  18. uint32_t logt = 20;
  19. uint32_t d = 2;
  20. bool use_symmetric = true; // use symmetric encryption instead of public key
  21. // (recommended for smaller query)
  22. bool use_batching = true; // pack as many elements as possible into a BFV
  23. // plaintext (recommended)
  24. bool use_recursive_mod_switching = true;
  25. EncryptionParameters enc_params(scheme_type::bfv);
  26. PirParams pir_params;
  27. // Generates all parameters
  28. cout << "Main: Generating SEAL parameters" << endl;
  29. gen_encryption_params(N, logt, enc_params);
  30. cout << "Main: Verifying SEAL parameters" << endl;
  31. verify_encryption_params(enc_params);
  32. cout << "Main: SEAL parameters are good" << endl;
  33. cout << "Main: Generating PIR parameters" << endl;
  34. gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params,
  35. use_symmetric, use_batching, use_recursive_mod_switching);
  36. print_seal_params(enc_params);
  37. print_pir_params(pir_params);
  38. // Initialize PIR client....
  39. PIRClient client(enc_params, pir_params);
  40. cout << "Main: Generating galois keys for client" << endl;
  41. GaloisKeys galois_keys = client.generate_galois_keys();
  42. // Initialize PIR Server
  43. cout << "Main: Initializing server" << endl;
  44. PIRServer server(enc_params, pir_params);
  45. // Server maps the galois key to client 0. We only have 1 client,
  46. // which is why we associate it with 0. If there are multiple PIR
  47. // clients, you should have each client generate a galois key,
  48. // and assign each client an index or id, then call the procedure below.
  49. server.set_galois_key(0, galois_keys);
  50. cout << "Main: Creating the database with random data (this may take some "
  51. "time) ..."
  52. << endl;
  53. // Create test database
  54. auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  55. // Copy of the database. We use this at the end to make sure we retrieved
  56. // the correct element.
  57. auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
  58. random_device rd;
  59. for (uint64_t i = 0; i < number_of_items; i++) {
  60. for (uint64_t j = 0; j < size_per_item; j++) {
  61. uint8_t val = rd() % 256;
  62. db.get()[(i * size_per_item) + j] = val;
  63. db_copy.get()[(i * size_per_item) + j] = val;
  64. }
  65. }
  66. // Measure database setup
  67. auto time_pre_s = high_resolution_clock::now();
  68. server.set_database(move(db), number_of_items, size_per_item);
  69. server.preprocess_database();
  70. auto time_pre_e = high_resolution_clock::now();
  71. auto time_pre_us =
  72. duration_cast<microseconds>(time_pre_e - time_pre_s).count();
  73. cout << "Main: database pre processed " << endl;
  74. // Choose an index of an element in the DB
  75. uint64_t ele_index =
  76. rd() % number_of_items; // element in DB at random position
  77. uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext
  78. uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
  79. cout << "Main: element index = " << ele_index << " from [0, "
  80. << number_of_items - 1 << "]" << endl;
  81. cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
  82. // Measure query generation
  83. auto time_query_s = high_resolution_clock::now();
  84. PirQuery query = client.generate_query(index);
  85. auto time_query_e = high_resolution_clock::now();
  86. auto time_query_us =
  87. duration_cast<microseconds>(time_query_e - time_query_s).count();
  88. cout << "Main: query generated" << endl;
  89. // Measure serialized query generation (useful for sending over the network)
  90. stringstream client_stream;
  91. stringstream server_stream;
  92. auto time_s_query_s = high_resolution_clock::now();
  93. int query_size = client.generate_serialized_query(index, client_stream);
  94. auto time_s_query_e = high_resolution_clock::now();
  95. auto time_s_query_us =
  96. duration_cast<microseconds>(time_s_query_e - time_s_query_s).count();
  97. cout << "Main: serialized query generated" << endl;
  98. // Measure query deserialization (useful for receiving over the network)
  99. auto time_deserial_s = high_resolution_clock::now();
  100. PirQuery query2 = server.deserialize_query(client_stream);
  101. auto time_deserial_e = high_resolution_clock::now();
  102. auto time_deserial_us =
  103. duration_cast<microseconds>(time_deserial_e - time_deserial_s).count();
  104. cout << "Main: query deserialized" << endl;
  105. // Measure query processing (including expansion)
  106. auto time_server_s = high_resolution_clock::now();
  107. // Answer PIR query from client 0. If there are multiple clients,
  108. // enter the id of the client (to use the associated galois key).
  109. PirReply reply = server.generate_reply(query2, 0);
  110. auto time_server_e = high_resolution_clock::now();
  111. auto time_server_us =
  112. duration_cast<microseconds>(time_server_e - time_server_s).count();
  113. cout << "Main: reply generated" << endl;
  114. // Serialize reply (useful for sending over the network)
  115. int reply_size = server.serialize_reply(reply, server_stream);
  116. // Measure response extraction
  117. auto time_decode_s = chrono::high_resolution_clock::now();
  118. vector<uint8_t> elems = client.decode_reply(reply, offset);
  119. auto time_decode_e = chrono::high_resolution_clock::now();
  120. auto time_decode_us =
  121. duration_cast<microseconds>(time_decode_e - time_decode_s).count();
  122. cout << "Main: reply decoded" << endl;
  123. assert(elems.size() == size_per_item);
  124. bool failed = false;
  125. // Check that we retrieved the correct element
  126. for (uint32_t i = 0; i < size_per_item; i++) {
  127. if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
  128. cout << "Main: elems " << (int)elems[i] << ", db "
  129. << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
  130. cout << "Main: PIR result wrong at " << i << endl;
  131. failed = true;
  132. }
  133. }
  134. if (failed) {
  135. return -1;
  136. }
  137. // Output results
  138. cout << "Main: PIR result correct!" << endl;
  139. cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
  140. << endl;
  141. cout << "Main: PIRClient query generation time: " << time_query_us / 1000
  142. << " ms" << endl;
  143. cout << "Main: PIRClient serialized query generation time: "
  144. << time_s_query_us / 1000 << " ms" << endl;
  145. cout << "Main: PIRServer query deserialization time: " << time_deserial_us
  146. << " us" << endl;
  147. cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
  148. << " ms" << endl;
  149. cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
  150. << " ms" << endl;
  151. cout << "Main: Query size: " << query_size << " bytes" << endl;
  152. cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
  153. cout << "Main: Reply size: " << reply_size << " bytes" << endl;
  154. return 0;
  155. }