replace_test.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #include "pir.hpp"
  2. #include "pir_client.hpp"
  3. #include "pir_server.hpp"
  4. #include <chrono>
  5. #include <cstddef>
  6. #include <cstdint>
  7. #include <memory>
  8. #include <random>
  9. #include <seal/seal.h>
  10. using namespace std::chrono;
  11. using namespace std;
  12. using namespace seal;
  13. int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
  14. uint32_t lt, uint32_t dim);
  15. int main(int argc, char *argv[]) {
  16. // Quick check
  17. assert(replace_test(1 << 13, 1, 4096, 20, 1) == 0);
  18. // Forces ciphertext expansion to be the same as the degree
  19. assert(replace_test(1 << 20, 288, 4096, 20, 1) == 0);
  20. assert(replace_test(1 << 20, 288, 4096, 20, 2) == 0);
  21. }
  22. int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
  23. uint32_t lt, uint32_t dim) {
  24. uint64_t number_of_items = num_items;
  25. uint64_t size_per_item = item_size; // in bytes
  26. uint32_t N = degree;
  27. // Recommended values: (logt, d) = (12, 2) or (8, 1).
  28. uint32_t logt = lt;
  29. uint32_t d = dim;
  30. EncryptionParameters enc_params(scheme_type::bfv);
  31. PirParams pir_params;
  32. // Generates all parameters
  33. cout << "Main: Generating SEAL parameters" << endl;
  34. gen_encryption_params(N, logt, enc_params);
  35. cout << "Main: Verifying SEAL parameters" << endl;
  36. verify_encryption_params(enc_params);
  37. cout << "Main: SEAL parameters are good" << endl;
  38. cout << "Main: Generating PIR parameters" << endl;
  39. gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
  40. // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
  41. // pir_params);
  42. print_pir_params(pir_params);
  43. cout << "Main: Initializing the database (this may take some time) ..."
  44. << endl;
  45. // Create test database
  46. auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
  47. // Copy of the database. We use this at the end to make sure we retrieved
  48. // the correct element.
  49. auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
  50. random_device rd;
  51. for (uint64_t i = 0; i < number_of_items; i++) {
  52. for (uint64_t j = 0; j < size_per_item; j++) {
  53. uint8_t val = rd() % 256;
  54. db.get()[(i * size_per_item) + j] = val;
  55. db_copy.get()[(i * size_per_item) + j] = val;
  56. }
  57. }
  58. // Initialize PIR Server
  59. cout << "Main: Initializing server and client" << endl;
  60. PIRServer server(enc_params, pir_params);
  61. // Initialize PIR client....
  62. PIRClient client(enc_params, pir_params);
  63. Ciphertext one_ct = client.get_one();
  64. GaloisKeys galois_keys = client.generate_galois_keys();
  65. // Set galois key for client with id 0
  66. cout << "Main: Setting Galois keys...";
  67. server.set_galois_key(0, galois_keys);
  68. // Measure database setup
  69. auto time_pre_s = high_resolution_clock::now();
  70. server.set_database(move(db), number_of_items, size_per_item);
  71. server.preprocess_database();
  72. server.set_one_ct(one_ct);
  73. cout << "Main: database pre processed " << endl;
  74. auto time_pre_e = high_resolution_clock::now();
  75. auto time_pre_us =
  76. duration_cast<microseconds>(time_pre_e - time_pre_s).count();
  77. // Choose an index of an element in the DB
  78. uint64_t ele_index =
  79. rd() % number_of_items; // element in DB at random position
  80. uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext
  81. uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
  82. cout << "Main: element index = " << ele_index << " from [0, "
  83. << number_of_items - 1 << "]" << endl;
  84. cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
  85. // Generate a new element
  86. vector<uint8_t> new_element(size_per_item);
  87. vector<uint8_t> new_element_copy(size_per_item);
  88. for (uint64_t i = 0; i < size_per_item; i++) {
  89. uint8_t val = rd() % 256;
  90. new_element[i] = val;
  91. new_element_copy[i] = val;
  92. }
  93. // Get element to replace
  94. auto time_server_s = high_resolution_clock::now();
  95. Ciphertext reply = server.simple_query(index);
  96. auto time_server_e = high_resolution_clock::now();
  97. auto time_server_us =
  98. duration_cast<microseconds>(time_server_e - time_server_s).count();
  99. auto time_decode_s = chrono::high_resolution_clock::now();
  100. Plaintext old_pt = client.decrypt(reply);
  101. auto time_decode_e = chrono::high_resolution_clock::now();
  102. auto time_decode_us =
  103. duration_cast<microseconds>(time_decode_e - time_decode_s).count();
  104. // Replace element
  105. Modulus t = enc_params.plain_modulus();
  106. logt = floor(log2(t.value()));
  107. vector<uint64_t> new_coeffs =
  108. bytes_to_coeffs(logt, new_element.data(), size_per_item);
  109. Plaintext new_pt = client.replace_element(old_pt, new_coeffs, offset);
  110. server.simple_set(index, new_pt);
  111. // Get the replaced element
  112. PirQuery query = client.generate_query(index);
  113. PirReply server_reply = server.generate_reply(query, 0);
  114. vector<uint8_t> elems = client.decode_reply(server_reply, offset);
  115. // vector<uint8_t> elems =
  116. // client.extract_bytes(client.decrypt(server.simple_query(index)), offset);
  117. vector<uint8_t> old_elems = client.extract_bytes(old_pt, offset);
  118. assert(elems.size() == size_per_item);
  119. bool failed = false;
  120. // Check that we retrieved the correct element
  121. for (uint32_t i = 0; i < size_per_item; i++) {
  122. if (elems[i] != new_element_copy[i]) {
  123. cout << "Main: elems " << (int)elems[i] << ", new "
  124. << (int)new_element_copy[i] << ", old "
  125. << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
  126. cout << "Main: PIR result wrong at " << i << endl;
  127. failed = true;
  128. }
  129. }
  130. if (failed) {
  131. return -1;
  132. }
  133. // Output results
  134. cout << "Main: PIR result correct!" << endl;
  135. cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
  136. << endl;
  137. cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
  138. << " ms" << endl;
  139. cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
  140. << " ms" << endl;
  141. return 0;
  142. }