main.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #include "pir.hpp"
  2. #include <time.h>
  3. #define BILLION 1000000000L
  4. #define MILLION (1.0*1000000L)
  5. #define KILO (1.0*1024L)
  6. #include <fstream>
  7. #include <vector>
  8. #include <sstream>
  9. #include <algorithm>
  10. #include <chrono>
  11. #include <random>
  12. #define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
  13. #define PBWIDTH 60
  14. #define NUM_SLOT 64
  15. #define NUM_THREAD 2
  16. int main(int argc, char *argv[]) {
  17. uint64_t number_of_items = 1 << 22;
  18. uint64_t size_per_item = 288 << 3; // 288 B.
  19. int n = 2048;
  20. int logt = 21;
  21. uint64_t plainMod = static_cast<uint64_t> (1) << logt;
  22. double hao_const = 0.5 * log2(number_of_items *size_per_item) - 0.5 * log2(n);
  23. int logtprime = logt;
  24. while(true){
  25. if (logtprime + ceil(hao_const - 0.5*log2(logtprime)) == logt) break;
  26. logtprime--;
  27. }
  28. int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logtprime );
  29. EncryptionParameters parms;
  30. parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
  31. vector<SmallModulus> coeff_mod_array;
  32. int logq = 0;
  33. for (int i = 0; i < 1; ++i)
  34. {
  35. coeff_mod_array.emplace_back(SmallModulus());
  36. coeff_mod_array[i] = small_mods_60bit(i);
  37. logq += coeff_mod_array[i].bit_count();
  38. }
  39. parms.set_coeff_modulus(coeff_mod_array);
  40. parms.set_plain_modulus(plainMod);
  41. pirParams pirparms;
  42. uint64_t newplainMod = 1 << logtprime;
  43. int item_per_plaintext = floor((double)get_power_of_two(newplainMod) *n / size_per_item);
  44. pirparms.d = 2;
  45. pirparms.alpha = 1;
  46. pirparms.dbc = 8;
  47. pirparms.N = number_of_plaintexts;
  48. int sqrt_items = ceil(sqrt(number_of_plaintexts));
  49. int bound1 = ceil((double) number_of_plaintexts / sqrt_items);
  50. int bound2 = sqrt_items;
  51. vector<int> Nvec = { bound1, bound2 };
  52. pirparms.Nvec = Nvec;
  53. // Initialize PIR client....
  54. PIRClient client(parms, pirparms);
  55. GaloisKeys galois_keys = client.generate_galois_keys();
  56. EncryptionParameters newparms = client.get_new_parms();
  57. galois_keys.mutable_hash_block() = newparms.hash_block();
  58. PIRServer server(client.get_new_parms(), client.get_pir_parms());
  59. server.set_galois_key(0, galois_keys);
  60. int index = 3; // we want to obtain the 3rd item.
  61. random_device rd;
  62. vector<uint64_t> no_choose(n+1);
  63. vector<uint64_t> choose(n+1);
  64. for (int i = 0; i < n+1; i++) {
  65. no_choose[i] = rd() % newplainMod;
  66. choose[i] = rd() % newplainMod;
  67. if (i == n) {
  68. choose[i] = 0;
  69. no_choose[i] = 0;
  70. }
  71. }
  72. unique_ptr<uint64_t> items_anchor(new uint64_t[bound1*bound2*(n + 1)]);
  73. vector<Plaintext> items;
  74. uint64_t * items_ptr = items_anchor.get();
  75. for (int i = 0; i < bound1*bound2; i++) {
  76. items.emplace_back(n + 1, items_ptr);
  77. if (i != index) {
  78. util::set_uint_uint(no_choose.data(), n+1, items_ptr);
  79. } else {
  80. util::set_uint_uint(choose.data(), n+1, items_ptr);
  81. }
  82. items_ptr += n + 1;
  83. }
  84. server.set_database(&items);
  85. auto time_querygen_start = chrono::high_resolution_clock::now();
  86. pirQuery query = client.generate_query(index);
  87. for (int i = 0; i < query.size(); i++) {
  88. query[i].mutable_hash_block() = newparms.hash_block();
  89. }
  90. auto time_querygen_end = chrono::high_resolution_clock::now();
  91. cout << "PIRClient query generation time : " << chrono::duration_cast<chrono::microseconds>(time_querygen_end - time_querygen_start).count() / 1000
  92. << " ms" << endl;
  93. cout << "Query size = " << (double) n * 2 * logq * pirparms.d / (1024 * 8) << "KB" << endl;
  94. auto time_pre_start = chrono::high_resolution_clock::now();
  95. server.preprocess_database();
  96. auto time_pre_end = chrono::high_resolution_clock::now();
  97. cout << "pre-processing time = " << chrono::duration_cast<chrono::microseconds>(time_pre_end - time_pre_start).count() / 1000
  98. << " ms" << endl;
  99. pirQuery query_ser = deserialize_ciphertexts(2, serialize_ciphertexts(query), 32828);
  100. auto time_server_start = chrono::high_resolution_clock::now();
  101. pirReply reply = server.generate_reply(query_ser, 0);
  102. auto time_server_end = chrono::high_resolution_clock::now();
  103. cout << "Server reply generation time : " << chrono::duration_cast<chrono::microseconds>(time_server_end - time_server_start).count() / 1000
  104. << " ms" << endl;
  105. cout<<"Reply ciphertexts"<<reply.size()<<endl;
  106. cout << "Reply size = " << (double) reply.size() * n * 2 * logq / (1024 * 8) << "KB" << endl;
  107. auto time_decode_start = chrono::high_resolution_clock::now();
  108. Plaintext result = client.decode_reply(reply);
  109. auto time_decode_end = chrono::high_resolution_clock::now();
  110. cout << "PIRClient decoding time : " << chrono::duration_cast<chrono::microseconds>(time_decode_end - time_decode_start).count() / 1000
  111. << " ms" << endl;
  112. cout << "Result = ";
  113. bool pircorrect = true;
  114. for (int i = 0; i < n; i++) {
  115. if (result[i] != choose[i]) {
  116. pircorrect = false;
  117. break;
  118. }
  119. }
  120. if (pircorrect) {
  121. cout << "PIR result correct!!" << endl;
  122. }
  123. else {
  124. cout << "PIR result wrong!" << endl;
  125. }
  126. return 0;
  127. }