main.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <chrono>
  5. #include <random>
  6. using namespace chrono;
  7. using namespace seal;
  8. int main(int argc, char *argv[]) {
  9. // uint64_t number_of_items = 1 << 13;
  10. // uint64_t number_of_items = 4096;
  11. uint64_t number_of_items = 1 << 16;
  12. uint64_t size_per_item = 288; // in bytes
  13. // uint64_t size_per_item = 1 << 10; // 1 KB.
  14. // uint64_t size_per_item = 10 << 10; // 10 KB.
  15. uint32_t N = 2048;
  16. uint32_t logt = 20;
  17. uint32_t d = 2;
  18. EncryptionParameters params;
  19. EncryptionParameters expanded_params;
  20. PirParams pir_params;
  21. // Generates all parameters
  22. cout << "Generating all parameters" << endl;
  23. gen_params(number_of_items, size_per_item, N, logt, d, params, expanded_params, pir_params);
  24. // Create test database
  25. uint8_t *db = (uint8_t *)malloc(number_of_items * size_per_item);
  26. random_device rd;
  27. for (uint64_t i = 0; i < number_of_items; i++) {
  28. for (uint64_t j = 0; j < size_per_item; j++) {
  29. *(db + (i * size_per_item) + j) = rd() % 256;
  30. }
  31. }
  32. // Initialize PIR Server
  33. cout << "Initializing server and client" << endl;
  34. PIRServer server(expanded_params, pir_params);
  35. // Initialize PIR client....
  36. PIRClient client(params, expanded_params, pir_params);
  37. GaloisKeys galois_keys = client.generate_galois_keys();
  38. // Set galois key
  39. cout << "Setting Galois keys" << endl;
  40. server.set_galois_key(0, galois_keys);
  41. // The following can be used to update parameters rather than creating new instances
  42. // (here it doesn't do anything).
  43. cout << "Updating database size to: " << number_of_items << " elements" << endl;
  44. update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params);
  45. uint32_t logtp = ceil(log2(expanded_params.plain_modulus().value()));
  46. cout << "logtp: " << logtp << endl;
  47. client.update_parameters(expanded_params, pir_params);
  48. server.update_parameters(expanded_params, pir_params);
  49. // Measure database setup
  50. auto time_pre_s = high_resolution_clock::now();
  51. server.set_database(db, number_of_items, size_per_item);
  52. server.preprocess_database();
  53. auto time_pre_e = high_resolution_clock::now();
  54. auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
  55. // Choose an index of an element in the DB
  56. uint64_t ele_index = rd() % number_of_items; // element in DB at random position
  57. uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext
  58. uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
  59. // Measure query generation
  60. auto time_query_s = high_resolution_clock::now();
  61. PirQuery query = client.generate_query(index);
  62. auto time_query_e = high_resolution_clock::now();
  63. auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
  64. // Measure query processing (including expansion)
  65. auto time_server_s = high_resolution_clock::now();
  66. PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
  67. PirReply reply = server.generate_reply(query_ser, 0);
  68. auto time_server_e = high_resolution_clock::now();
  69. auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
  70. // Measure response extraction
  71. auto time_decode_s = chrono::high_resolution_clock::now();
  72. Plaintext result = client.decode_reply(reply);
  73. auto time_decode_e = chrono::high_resolution_clock::now();
  74. auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
  75. // Convert to elements
  76. vector<uint8_t> elems(N * logtp / 8);
  77. coeffs_to_bytes(logtp, result, elems.data(), (N * logtp) / 8);
  78. // Check that we retrieved the correct element
  79. for (uint32_t i = 0; i < size_per_item; i++) {
  80. if (elems[(offset * size_per_item) + i] != db[(ele_index * size_per_item) + i]) {
  81. cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
  82. << (int)db[(ele_index * size_per_item) + i] << endl;
  83. cout << "PIR result wrong!" << endl;
  84. return -1;
  85. }
  86. }
  87. // Output results
  88. cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
  89. cout << "PIRServer query processing generation time: " << time_server_us / 1000 << " ms"
  90. << endl;
  91. cout << "PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
  92. cout << "PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
  93. cout << "Reply num ciphertexts: " << reply.size() << endl;
  94. return 0;
  95. }