spir_test.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #include <iostream>
  2. #include <stdlib.h>
  3. #include <sys/random.h>
  4. #include <sys/time.h>
  5. #include <unistd.h>
  6. #include <bsd/stdlib.h>
  7. #include <boost/asio.hpp>
  8. #include <string>
  9. using boost::asio::ip::tcp;
  10. #include "spir.hpp"
  11. using std::cout;
  12. using std::cerr;
  13. static inline size_t elapsed_us(const struct timeval *start)
  14. {
  15. struct timeval end;
  16. gettimeofday(&end, NULL);
  17. return (end.tv_sec-start->tv_sec)*1000000 + end.tv_usec - start->tv_usec;
  18. }
  19. using socket_t = boost::asio::ip::tcp::socket;
  20. void accept_conncections_from_Pb(boost::asio::io_context&io_context, std::vector<socket_t>& sockets_, int port, size_t j)
  21. {
  22. tcp::acceptor acceptor_a(io_context, tcp::endpoint(tcp::v4(), port));
  23. tcp::socket sb_a(acceptor_a.accept());
  24. sockets_[j] = std::move(sb_a);
  25. // sockets_.emplace_back(std::move(sb_a));
  26. }
  27. void write_pub_params(tcp::socket& sout, string pub_params)
  28. {
  29. auto * bytes_to_write = pub_params.data();
  30. auto bytes_remaining = pub_params.length();
  31. while (bytes_remaining )
  32. {
  33. auto bytes_written = sout.write_some(boost::asio::buffer(bytes_to_write, bytes_remaining));
  34. bytes_to_write += bytes_written;
  35. bytes_remaining -= bytes_written;
  36. }
  37. }
  38. void read_pub_params(tcp::socket& sin, string& pub_params_recv, size_t len)
  39. {
  40. pub_params_recv.resize(len);
  41. auto bytes_remaining = len;
  42. char * bytes_to_read = (char*)pub_params_recv.data();
  43. while (bytes_remaining )
  44. {
  45. auto bytes_read = sin.read_some(boost::asio::buffer(bytes_to_read, bytes_remaining));
  46. bytes_to_read += bytes_read;
  47. bytes_remaining -= bytes_read;
  48. }
  49. }
  50. int main(int argc, char **argv)
  51. {
  52. boost::asio::io_context io_context;
  53. tcp::resolver resolver(io_context);
  54. std::string addr = "127.0.0.1";
  55. const std::string host1 = (argc <= 1) ? "127.0.0.1" : argv[1];
  56. const size_t number_of_sockets = 5;
  57. std::vector<socket_t> sockets_;
  58. for(size_t j = 0; j < number_of_sockets + 1; ++j)
  59. {
  60. tcp::socket emptysocket(io_context);
  61. sockets_.emplace_back(std::move(emptysocket));
  62. }
  63. sockets_.reserve(number_of_sockets + 1);
  64. printf("number_of_sockets = %zu\n", number_of_sockets);
  65. std::vector<socket_t> sockets_2;
  66. std::vector<int> ports;
  67. for(size_t j = 0; j < number_of_sockets; ++j)
  68. {
  69. int port = 6000;
  70. ports.push_back(port + j);
  71. }
  72. std::vector<int> ports2_0;
  73. for(size_t j = 0; j < number_of_sockets; ++j)
  74. {
  75. int port = 8000;
  76. ports2_0.push_back(port + j);
  77. }
  78. std::vector<int> ports2_1;
  79. for(size_t j = 0; j < number_of_sockets; ++j)
  80. {
  81. int port = 9000;
  82. ports2_1.push_back(port + j);
  83. }
  84. #if (PARTY == 0)
  85. for(size_t j = 0; j < number_of_sockets; ++j)
  86. {
  87. tcp::socket sb_a(io_context);
  88. boost::asio::connect(sb_a, resolver.resolve({host1, std::to_string(ports[j])}));
  89. sockets_[j] = std::move(sb_a);
  90. }
  91. #else
  92. boost::asio::thread_pool pool2(number_of_sockets);
  93. for(size_t j = 0; j < number_of_sockets; ++j)
  94. {
  95. boost::asio::post(pool2, std::bind(accept_conncections_from_Pb, std::ref(io_context), std::ref(sockets_), ports[j], j));
  96. }
  97. pool2.join();
  98. #endif
  99. #if (PARTY == 0)
  100. std::cout << "PARTY 0" << std::endl;
  101. #endif
  102. #if (PARTY == 1)
  103. std::cout << "PARTY 1" << std::endl;
  104. #endif
  105. // if (argc < 2 || argc > 5) {
  106. // cerr << "Usage: " << argv[0] << " r [num_threads [num_preproc [num_pirs]]]\n";
  107. // cerr << "r = log_2(num_records)\n";
  108. // exit(1);
  109. // }
  110. uint32_t r, num_threads = 1, num_preproc = 1, num_pirs = 1;
  111. r = strtoul(argv[2], NULL, 10);
  112. size_t num_records = ((size_t) 1)<<r;
  113. size_t num_records_mask = num_records - 1;
  114. if (argc > 3) {
  115. num_threads = strtoul(argv[3], NULL, 10);
  116. }
  117. if (argc > 4) {
  118. num_preproc = strtoul(argv[4], NULL, 10);
  119. }
  120. if (argc > 5) {
  121. num_pirs = strtoul(argv[5], NULL, 10);
  122. } else {
  123. num_pirs = num_preproc;
  124. }
  125. cout << "===== ONE-TIME SETUP =====\n\n";
  126. struct timeval otsetup_start;
  127. gettimeofday(&otsetup_start, NULL);
  128. cout << "num_threads = " << num_threads << "\n";
  129. SPIR::init(num_threads);
  130. string pub_params, pub_params_recv;
  131. SPIR_Client client(r, pub_params);
  132. std::thread writer(write_pub_params, std::ref(sockets_[0]), pub_params);
  133. std::thread reader(read_pub_params, std::ref(sockets_[0]), std::ref(pub_params_recv), pub_params.size());
  134. writer.join();
  135. reader.join();
  136. SPIR_Server server(r, pub_params_recv);
  137. size_t otsetup_us = elapsed_us(&otsetup_start);
  138. cout << "One-time setup: " << otsetup_us << " µs\n";
  139. cout << "pub_params len = " << pub_params_recv.length() << "\n";
  140. cout << "\n===== PREPROCESSING =====\n\n";
  141. cout << "num_preproc = " << num_preproc << "\n";
  142. struct timeval preproc_client_start;
  143. gettimeofday(&preproc_client_start, NULL);
  144. string preproc_msg = client.preproc(num_preproc);
  145. string preproc_msg_recv = preproc_msg;
  146. boost::asio::write(sockets_[0], boost::asio::buffer(preproc_msg));
  147. boost::asio::read(sockets_[0], boost::asio::buffer(preproc_msg_recv));
  148. size_t preproc_client_us = elapsed_us(&preproc_client_start);
  149. cout << "Preprocessing client: " << preproc_client_us << " µs\n";
  150. cout << "preproc_msg len = " << preproc_msg.length() << "\n";
  151. struct timeval preproc_server_start;
  152. gettimeofday(&preproc_server_start, NULL);
  153. string preproc_resp = server.preproc_process(preproc_msg_recv);
  154. string preproc_resp_recv = preproc_resp;
  155. boost::asio::write(sockets_[0], boost::asio::buffer(preproc_resp));
  156. boost::asio::read(sockets_[0], boost::asio::buffer(preproc_resp_recv));
  157. size_t preproc_server_us = elapsed_us(&preproc_server_start);
  158. cout << "Preprocessing server: " << preproc_server_us << " µs\n";
  159. cout << "preproc_resp len = " << preproc_resp.length() << "\n";
  160. struct timeval preproc_finish_start;
  161. gettimeofday(&preproc_finish_start, NULL);
  162. client.preproc_finish(preproc_resp_recv);
  163. size_t preproc_finish_us = elapsed_us(&preproc_finish_start);
  164. cout << "Preprocessing client finish: " << preproc_finish_us << " µs\n";
  165. size_t preproc_total_us = elapsed_us(&preproc_client_start);
  166. cout << "\n\nTotal preprocessing time: " << preproc_total_us << " µs\n";
  167. cout << "Total preprocessing bytes: " << (preproc_msg.length() + preproc_resp.length()) << "\n";
  168. // Create the database
  169. SPIR::DBEntry *db = new SPIR::DBEntry[num_records];
  170. for (size_t i=0; i<num_records; ++i) {
  171. db[i] = i;// * 10000001;
  172. #if(PARTY == 0)
  173. db[i] = 0;
  174. #endif
  175. }
  176. SPIR::DBEntry rand_blind = 1221030;
  177. struct timeval all_queries_start;
  178. gettimeofday(&all_queries_start, NULL);
  179. size_t tot_query_bytes = 0;
  180. for (size_t i=0; i<num_pirs; ++i) {
  181. if (i < 2 || i == num_pirs-1) {
  182. cout << "\n===== SPIR QUERY " << i+1 << " =====\n\n";
  183. } else if (i == 2) {
  184. cout << "\n...\n\n";
  185. }
  186. size_t idx;
  187. if (getrandom(&idx, sizeof(idx), 0) != sizeof(idx)) {
  188. cerr << "Failure in getrandom\n";
  189. exit(1);
  190. }
  191. idx &= num_records_mask;
  192. #ifdef CHECK_ANSWERS
  193. boost::asio::write(sockets_[0], boost::asio::buffer(&idx, sizeof(idx)));
  194. size_t idx_recv;
  195. boost::asio::read(sockets_[0], boost::asio::buffer(&idx_recv, sizeof(idx_recv)));
  196. idx_recv += idx;
  197. idx_recv = idx_recv % num_records;
  198. cout << "idx = " << idx << std::endl;
  199. cout << "idx_reconstructed = " << idx_recv << std::endl;
  200. // idx = 100;
  201. // #if(PARTY == 1)
  202. // idx = 40;
  203. // #endif
  204. #endif
  205. struct timeval query_client_start;
  206. gettimeofday(&query_client_start, NULL);
  207. string query_msg = client.query(idx);
  208. boost::asio::write(sockets_[0], boost::asio::buffer(query_msg));
  209. tot_query_bytes += query_msg.length();
  210. string query_msg_recv(query_msg.length(), '\0');
  211. boost::asio::read(sockets_[0], boost::asio::buffer(query_msg_recv));
  212. size_t query_client_us = elapsed_us(&query_client_start);
  213. if (i < 2 || i == num_pirs-1) {
  214. cout << "Query client: " << query_client_us << " µs\n";
  215. cout << "query_msg len = " << query_msg.length() << "\n";
  216. }
  217. struct timeval query_server_start;
  218. gettimeofday(&query_server_start, NULL);
  219. //string query_resp = server.query_process(query_msg_recv, db, 0, 0);
  220. string query_resp = server.query_process(query_msg_recv, db, idx, rand_blind);
  221. boost::asio::write(sockets_[0], boost::asio::buffer(query_resp));
  222. tot_query_bytes += query_resp.length();
  223. string query_resp_recv = query_resp;
  224. boost::asio::read(sockets_[0], boost::asio::buffer(query_resp_recv));
  225. size_t query_server_us = elapsed_us(&query_server_start);
  226. if (i < 2 || i == num_pirs-1) {
  227. cout << "Query server: " << query_server_us << " µs\n";
  228. cout << "query_resp len = " << query_resp.length() << "\n";
  229. }
  230. struct timeval query_finish_start;
  231. gettimeofday(&query_finish_start, NULL);
  232. SPIR::DBEntry entry = client.query_finish(query_resp_recv);
  233. #ifdef CHECK_ANSWERS
  234. boost::asio::write(sockets_[0], boost::asio::buffer(&entry, sizeof(entry)));
  235. SPIR::DBEntry entry_recv;
  236. boost::asio::read(sockets_[0], boost::asio::buffer(&entry_recv, sizeof(entry)));
  237. SPIR::DBEntry read_output = entry_recv - rand_blind;
  238. boost::asio::write(sockets_[0], boost::asio::buffer(&read_output, sizeof(entry)));
  239. SPIR::DBEntry read_output_recv;
  240. boost::asio::read(sockets_[0], boost::asio::buffer(&read_output_recv, sizeof(entry)));
  241. read_output_recv += read_output;
  242. cout << "read_output_recv = " << read_output_recv << std::endl;
  243. #endif
  244. size_t query_finish_us = elapsed_us(&query_finish_start);
  245. if (i < 2 || i == num_pirs-1) {
  246. cout << "Query client finish: " << query_finish_us << " µs\n";
  247. cout << "idx = " << idx << "; entry = " << entry << "\n";
  248. }
  249. }
  250. size_t all_queries_us = elapsed_us(&all_queries_start);
  251. cout << "\n\nTotal query time: " << all_queries_us << " µs\n";
  252. cout << "Total query bytes: " << tot_query_bytes << "\n";
  253. delete[] db;
  254. return 0;
  255. }