encoding_test.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <seal/seal.h>
  5. #include <memory>
  6. #include <random>
  7. #include <cstdint>
  8. #include <cstddef>
  9. using namespace std;
  10. using namespace seal;
  11. int main(int argc, char *argv[]) {
  12. uint64_t number_of_items = 1 << 12;
  13. uint64_t size_per_item = 288; // in bytes
  14. uint32_t N = 4096;
  15. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  16. uint32_t logt = 20;
  17. uint32_t d = 1;
  18. EncryptionParameters params(scheme_type::bfv);
  19. PirParams pir_params;
  20. // Generates all parameters
  21. cout << "Main: Generating all parameters" << endl;
  22. gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
  23. logt = floor(log2(params.plain_modulus().value()));
  24. cout << "Main: Initializing the database (this may take some time) ..." << endl;
  25. // Create test database
  26. auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  27. // Copy of the database. We use this at the end to make sure we retrieved
  28. // the correct element.
  29. auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
  30. random_device rd;
  31. for (uint64_t i = 0; i < number_of_items; i++) {
  32. for (uint64_t j = 0; j < size_per_item; j++) {
  33. uint8_t val = rd() % 256;
  34. db.get()[(i * size_per_item) + j] = val;
  35. db_copy.get()[(i * size_per_item) + j] = val;
  36. }
  37. }
  38. shared_ptr<SEALContext> context = make_shared<SEALContext>(params, true);
  39. unique_ptr<KeyGenerator> keygen = make_unique<KeyGenerator>(*context);
  40. PublicKey public_key;
  41. keygen->create_public_key(public_key);
  42. unique_ptr<Encryptor> encryptor = make_unique<Encryptor>(*context, public_key);
  43. SecretKey secret_key = keygen->secret_key();
  44. unique_ptr<Decryptor> decryptor = make_unique<Decryptor>(*context, secret_key);
  45. unique_ptr<Evaluator> evaluator = make_unique<Evaluator>(*context);
  46. uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, size_per_item);
  47. uint64_t bytes_per_ptxt = ele_per_ptxt * size_per_item;
  48. uint64_t db_size = number_of_items * size_per_item;
  49. uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, size_per_item);
  50. assert(coeff_per_ptxt <= N);
  51. vector<uint64_t> coefficients = bytes_to_coeffs(logt, db.get(), size_per_item);
  52. uint64_t used = coefficients.size();
  53. assert(used <= coeff_per_ptxt);
  54. // Pad the rest with 1s
  55. for (uint64_t j = 0; j < (N - used); j++) {
  56. coefficients.push_back(1);
  57. }
  58. Plaintext plain;
  59. vector_to_plaintext(coefficients, plain);
  60. //cout << "Plaintext: " << plain.to_string() << endl;
  61. vector<uint8_t> elems(N * logt / 8);
  62. coeffs_to_bytes(logt, plain, elems.data(), (N * logt) / 8);
  63. bool failed = false;
  64. // Check that we retrieved the correct element
  65. for (uint32_t i = 0; i < size_per_item; i++) {
  66. if (elems[i] != db_copy.get()[i]) {
  67. cout << "Main: elems " << (int)elems[i] << ", db "
  68. << (int) db_copy.get()[i] << endl;
  69. cout << "Main: PIR result wrong at " << i << endl;
  70. failed = true;
  71. }
  72. }
  73. if(failed){
  74. return -1;
  75. }
  76. else{
  77. cout << "succeeded" << endl;
  78. }
  79. }