expand_test.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <chrono>
  5. #include <cstddef>
  6. #include <cstdint>
  7. #include <memory>
  8. #include <random>
  9. #include <seal/seal.h>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. // For this test, we need the parameters to be such that the number of
  14. // compressed ciphertexts needed is 1.
  15. int main(int argc, char *argv[]) {
  16. uint64_t number_of_items = 2048;
  17. uint64_t size_per_item = 288; // in bytes
  18. uint32_t N = 4096;
  19. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  20. uint32_t logt = 20;
  21. uint32_t d = 1;
  22. EncryptionParameters enc_params(scheme_type::bfv);
  23. PirParams pir_params;
  24. // Generates all parameters
  25. cout << "Main: Generating SEAL parameters" << endl;
  26. gen_encryption_params(N, logt, enc_params);
  27. cout << "Main: Verifying SEAL parameters" << endl;
  28. verify_encryption_params(enc_params);
  29. cout << "Main: SEAL parameters are good" << endl;
  30. cout << "Main: Generating PIR parameters" << endl;
  31. gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
  32. // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
  33. // pir_params);
  34. print_pir_params(pir_params);
  35. // Initialize PIR Server
  36. cout << "Main: Initializing server and client" << endl;
  37. PIRServer server(enc_params, pir_params);
  38. // Initialize PIR client....
  39. PIRClient client(enc_params, pir_params);
  40. GaloisKeys galois_keys = client.generate_galois_keys();
  41. // Set galois key for client with id 0
  42. cout << "Main: Setting Galois keys...";
  43. server.set_galois_key(0, galois_keys);
  44. random_device rd;
  45. // Choose an index of an element in the DB
  46. uint64_t ele_index =
  47. rd() % number_of_items; // element in DB at random position
  48. uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext
  49. uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
  50. cout << "Main: element index = " << ele_index << " from [0, "
  51. << number_of_items - 1 << "]" << endl;
  52. cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
  53. // Measure query generation
  54. auto time_query_s = high_resolution_clock::now();
  55. PirQuery query = client.generate_query(index);
  56. auto time_query_e = high_resolution_clock::now();
  57. auto time_query_us =
  58. duration_cast<microseconds>(time_query_e - time_query_s).count();
  59. cout << "Main: query generated" << endl;
  60. // Measure query processing (including expansion)
  61. auto time_server_s = high_resolution_clock::now();
  62. uint64_t n_i = pir_params.nvec[0];
  63. vector<Ciphertext> expanded_query = server.expand_query(query[0][0], n_i, 0);
  64. auto time_server_e = high_resolution_clock::now();
  65. auto time_server_us =
  66. duration_cast<microseconds>(time_server_e - time_server_s).count();
  67. cout << "Main: query expanded" << endl;
  68. assert(expanded_query.size() == n_i);
  69. cout << "Main: checking expansion" << endl;
  70. for (size_t i = 0; i < expanded_query.size(); i++) {
  71. Plaintext decryption = client.decrypt(expanded_query.at(i));
  72. if (decryption.is_zero() && index != i) {
  73. continue;
  74. } else if (decryption.is_zero()) {
  75. cout << "Found zero where index should be" << endl;
  76. return -1;
  77. } else if (std::stoi(decryption.to_string()) != 1) {
  78. cout << "Query vector at index " << index
  79. << " should be 1 but is instead " << decryption.to_string() << endl;
  80. return -1;
  81. } else {
  82. cout << "Query vector at index " << index << " is "
  83. << decryption.to_string() << endl;
  84. }
  85. }
  86. return 0;
  87. }