main.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <seal/seal.h>
  5. #include <chrono>
  6. #include <memory>
  7. #include <random>
  8. #include <cstdint>
  9. #include <cstddef>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. int main(int argc, char *argv[]) {
  14. uint64_t number_of_items = 1 << 12;
  15. uint64_t size_per_item = 288; // in bytes
  16. uint32_t N = 4096;
  17. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  18. uint32_t logt = 20;
  19. uint32_t d = 2;
  20. EncryptionParameters params(scheme_type::bfv);
  21. PirParams pir_params;
  22. // Generates all parameters
  23. cout << "Main: Generating all parameters" << endl;
  24. gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
  25. cout << "Main: Initializing the database (this may take some time) ..." << endl;
  26. // Create test database
  27. auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  28. // Copy of the database. We use this at the end to make sure we retrieved
  29. // the correct element.
  30. auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
  31. random_device rd;
  32. for (uint64_t i = 0; i < number_of_items; i++) {
  33. for (uint64_t j = 0; j < size_per_item; j++) {
  34. uint8_t val = rd() % 256;
  35. db.get()[(i * size_per_item) + j] = val;
  36. db_copy.get()[(i * size_per_item) + j] = val;
  37. }
  38. }
  39. // Initialize PIR Server
  40. cout << "Main: Initializing server and client" << endl;
  41. PIRServer server(params, pir_params);
  42. // Initialize PIR client....
  43. PIRClient client(params, pir_params);
  44. GaloisKeys galois_keys = client.generate_galois_keys();
  45. // Set galois key for client with id 0
  46. cout << "Main: Setting Galois keys...";
  47. server.set_galois_key(0, galois_keys);
  48. // Measure database setup
  49. auto time_pre_s = high_resolution_clock::now();
  50. server.set_database(move(db), number_of_items, size_per_item);
  51. server.preprocess_database();
  52. cout << "Main: database pre processed " << endl;
  53. auto time_pre_e = high_resolution_clock::now();
  54. auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
  55. // Choose an index of an element in the DB
  56. uint64_t ele_index = rd() % number_of_items; // element in DB at random position
  57. uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext
  58. uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
  59. cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
  60. cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
  61. // Measure query generation
  62. auto time_query_s = high_resolution_clock::now();
  63. PirQuery query = client.generate_query(index);
  64. auto time_query_e = high_resolution_clock::now();
  65. auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
  66. cout << "Main: query generated" << endl;
  67. //To marshall query to send over the network, you can use serialize/deserialize:
  68. //std::string query_ser = serialize_query(query);
  69. //PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
  70. // Measure query processing (including expansion)
  71. auto time_server_s = high_resolution_clock::now();
  72. PirReply reply = server.generate_reply(query, 0, client);
  73. auto time_server_e = high_resolution_clock::now();
  74. auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
  75. // Measure response extraction
  76. auto time_decode_s = chrono::high_resolution_clock::now();
  77. Plaintext result = client.decode_reply(reply);
  78. auto time_decode_e = chrono::high_resolution_clock::now();
  79. auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
  80. Ciphertext one_ct = client.get_encrypted_one();
  81. Ciphertext reply2 = server.generate_public_reply(one_ct, index);
  82. Plaintext result2 = client.decrypt(reply2);
  83. logt = floor(log2(params.plain_modulus().value()));
  84. // Convert from FV plaintext (polynomial) to database element at the client
  85. vector<uint8_t> elems(N * logt / 8);
  86. coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
  87. vector<uint8_t> elems2(N * logt / 8);
  88. coeffs_to_bytes(logt, result2, elems2.data(), (N * logt) / 8);
  89. // Check that we retrieved the correct element
  90. for (uint32_t i = 0; i < size_per_item; i++) {
  91. if (elems[(offset * size_per_item) + i] != elems2[(offset * size_per_item) + i]) {
  92. cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", elems2 "
  93. << (int)elems[(offset * size_per_item) + i] << endl;
  94. cout << "Main: PIR results inconsistent at" << i << endl;
  95. return -1;
  96. }
  97. }
  98. bool failed = false;
  99. // Check that we retrieved the correct element
  100. for (uint32_t i = 0; i < size_per_item; i++) {
  101. if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
  102. cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
  103. << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
  104. cout << "Main: PIR result wrong at " << i << endl;
  105. failed = true;
  106. }
  107. }
  108. if(failed){
  109. return -1;
  110. }
  111. // Output results
  112. cout << "Main: PIR result correct!" << endl;
  113. cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
  114. cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
  115. cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms"
  116. << endl;
  117. cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
  118. cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
  119. return 0;
  120. }