decomposition_test.cpp 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <chrono>
  5. #include <cstddef>
  6. #include <cstdint>
  7. #include <memory>
  8. #include <random>
  9. #include <seal/seal.h>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. int main(int argc, char *argv[]) {
  14. uint64_t number_of_items = 2048;
  15. uint64_t size_per_item = 288; // in bytes
  16. uint32_t N = 8192;
  17. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  18. uint32_t logt = 20;
  19. EncryptionParameters enc_params(scheme_type::bfv);
  20. // Generates all parameters
  21. cout << "Main: Generating SEAL parameters" << endl;
  22. gen_encryption_params(N, logt, enc_params);
  23. cout << "Main: Verifying SEAL parameters" << endl;
  24. verify_encryption_params(enc_params);
  25. cout << "Main: SEAL parameters are good" << endl;
  26. SEALContext context(enc_params, true);
  27. KeyGenerator keygen(context);
  28. SecretKey secret_key = keygen.secret_key();
  29. Encryptor encryptor(context, secret_key);
  30. Decryptor decryptor(context, secret_key);
  31. Evaluator evaluator(context);
  32. BatchEncoder encoder(context);
  33. logt = floor(log2(enc_params.plain_modulus().value()));
  34. uint32_t plain_modulus = enc_params.plain_modulus().value();
  35. size_t slot_count = encoder.slot_count();
  36. vector<uint64_t> coefficients(slot_count, 0ULL);
  37. for (uint32_t i = 0; i < coefficients.size(); i++) {
  38. coefficients[i] = rand() % plain_modulus;
  39. }
  40. Plaintext pt;
  41. encoder.encode(coefficients, pt);
  42. Ciphertext ct;
  43. encryptor.encrypt_symmetric(pt, ct);
  44. std::cout << "Encrypting" << std::endl;
  45. auto context_data = context.last_context_data();
  46. auto parms_id = context.last_parms_id();
  47. evaluator.mod_switch_to_inplace(ct, parms_id);
  48. EncryptionParameters params = context_data->parms();
  49. std::cout << "Encoding" << std::endl;
  50. vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
  51. std::cout << "Expansion Factor: " << encoded.size() << std::endl;
  52. std::cout << "Decoding" << std::endl;
  53. Ciphertext decoded(context, parms_id);
  54. compose_to_ciphertext(params, encoded, decoded);
  55. std::cout << "Checking" << std::endl;
  56. Plaintext pt2;
  57. decryptor.decrypt(decoded, pt2);
  58. assert(pt == pt2);
  59. std::cout << "Worked" << std::endl;
  60. return 0;
  61. }