mpcio.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961
  1. #include <sys/time.h> // getrusage
  2. #include <sys/resource.h> // getrusage
  3. #include "mpcio.hpp"
  4. #include "rdpf.hpp"
  5. #include "cdpf.hpp"
  6. #include "bitutils.hpp"
  7. #include "coroutine.hpp"
  8. void MPCSingleIO::async_send_from_msgqueue()
  9. {
  10. #ifdef SEND_LAMPORT_CLOCKS
  11. std::vector<boost::asio::const_buffer> tosend;
  12. tosend.push_back(boost::asio::buffer(messagequeue.front().header));
  13. tosend.push_back(boost::asio::buffer(messagequeue.front().message));
  14. #endif
  15. boost::asio::async_write(sock,
  16. #ifdef SEND_LAMPORT_CLOCKS
  17. tosend,
  18. #else
  19. boost::asio::buffer(messagequeue.front()),
  20. #endif
  21. [&](boost::system::error_code ec, std::size_t amt){
  22. messagequeuelock.lock();
  23. messagequeue.pop();
  24. if (messagequeue.size() > 0) {
  25. async_send_from_msgqueue();
  26. }
  27. messagequeuelock.unlock();
  28. });
  29. }
  30. size_t MPCSingleIO::queue(const void *data, size_t len, lamport_t lamport)
  31. {
  32. // Is this a new message?
  33. size_t newmsg = 0;
  34. dataqueue.append((const char *)data, len);
  35. // If this is the first queue() since the last explicit send(),
  36. // which we'll know because message_lamport will be nullopt, set
  37. // message_lamport to the current Lamport clock. Note that the
  38. // boolean test tests whether message_lamport is nullopt, not
  39. // whether its value is zero.
  40. if (!message_lamport) {
  41. message_lamport = lamport;
  42. newmsg = 1;
  43. }
  44. #ifdef VERBOSE_COMMS
  45. printf("Queue %s.%d len=%lu lamp=%u: ", dest.c_str(), thread_num,
  46. len, message_lamport.value());
  47. for (size_t i=0;i<len;++i) {
  48. printf("%02x", ((const unsigned char*)data)[i]);
  49. }
  50. printf("\n");
  51. #endif
  52. // If we already have some full packets worth of data, may as
  53. // well send it.
  54. if (dataqueue.size() > 28800) {
  55. send(true);
  56. }
  57. return newmsg;
  58. }
  59. void MPCSingleIO::send(bool implicit_send)
  60. {
  61. size_t thissize = dataqueue.size();
  62. // Ignore spurious calls to send(), except for resetting
  63. // message_lamport if this was an explicit send().
  64. if (thissize == 0) {
  65. #ifdef SEND_LAMPORT_CLOCKS
  66. // If this was an explicit send(), reset the message_lamport so
  67. // that it gets updated at the next queue().
  68. if (!implicit_send) {
  69. message_lamport.reset();
  70. }
  71. #endif
  72. return;
  73. }
  74. #ifdef RECORD_IOTRACE
  75. iotrace.push_back(thissize);
  76. #endif
  77. messagequeuelock.lock();
  78. // Move the current message to send into the message queue (this
  79. // moves a pointer to the data, not copying the data itself)
  80. #ifdef SEND_LAMPORT_CLOCKS
  81. messagequeue.emplace(std::move(dataqueue),
  82. message_lamport.value());
  83. // If this was an explicit send(), reset the message_lamport so
  84. // that it gets updated at the next queue().
  85. if (!implicit_send) {
  86. message_lamport.reset();
  87. }
  88. #else
  89. messagequeue.emplace(std::move(dataqueue));
  90. #endif
  91. // If this is now the first thing in the message queue, launch
  92. // an async_write to write it
  93. if (messagequeue.size() == 1) {
  94. async_send_from_msgqueue();
  95. }
  96. messagequeuelock.unlock();
  97. }
  98. size_t MPCSingleIO::recv(void *data, size_t len, lamport_t &lamport)
  99. {
  100. #ifdef VERBOSE_COMMS
  101. size_t orig_len = len;
  102. printf("Recv %s.%d len=%lu lamp=%u ", dest.c_str(), thread_num,
  103. len, lamport);
  104. #endif
  105. #ifdef SEND_LAMPORT_CLOCKS
  106. char *cdata = (char *)data;
  107. size_t res = 0;
  108. while (len > 0) {
  109. while (recvdataremain == 0) {
  110. // Read a new header
  111. char hdr[sizeof(uint32_t) + sizeof(lamport_t)];
  112. uint32_t datalen;
  113. lamport_t recv_lamport;
  114. boost::asio::read(sock, boost::asio::buffer(hdr, sizeof(hdr)));
  115. memmove(&datalen, hdr, sizeof(datalen));
  116. memmove(&recv_lamport, hdr+sizeof(datalen), sizeof(lamport_t));
  117. lamport_t new_lamport = recv_lamport + 1;
  118. if (lamport < new_lamport) {
  119. lamport = new_lamport;
  120. }
  121. if (datalen > 0) {
  122. recvdata.resize(datalen, '\0');
  123. boost::asio::read(sock, boost::asio::buffer(recvdata));
  124. recvdataremain = datalen;
  125. }
  126. }
  127. size_t amttoread = len;
  128. if (amttoread > recvdataremain) {
  129. amttoread = recvdataremain;
  130. }
  131. memmove(cdata, recvdata.data()+recvdata.size()-recvdataremain,
  132. amttoread);
  133. cdata += amttoread;
  134. len -= amttoread;
  135. recvdataremain -= amttoread;
  136. res += amttoread;
  137. }
  138. #else
  139. size_t res = boost::asio::read(sock, boost::asio::buffer(data, len));
  140. #endif
  141. #ifdef VERBOSE_COMMS
  142. printf("nlamp=%u: ", lamport);
  143. for (size_t i=0;i<orig_len;++i) {
  144. printf("%02x", ((const unsigned char*)data)[i]);
  145. }
  146. printf("\n");
  147. #endif
  148. #ifdef RECORD_IOTRACE
  149. iotrace.push_back(-(ssize_t(res)));
  150. #endif
  151. return res;
  152. }
  153. #ifdef RECORD_IOTRACE
  154. void MPCSingleIO::dumptrace(std::ostream &os, const char *label)
  155. {
  156. if (label) {
  157. os << label << " ";
  158. }
  159. os << "IO trace:";
  160. for (auto& s: iotrace) {
  161. os << " " << s;
  162. }
  163. os << "\n";
  164. }
  165. #endif
  166. void MPCIO::reset_stats()
  167. {
  168. msgs_sent.clear();
  169. msg_bytes_sent.clear();
  170. aes_ops.clear();
  171. for (size_t i=0; i<num_threads; ++i) {
  172. msgs_sent.push_back(0);
  173. msg_bytes_sent.push_back(0);
  174. aes_ops.push_back(0);
  175. }
  176. steady_start = boost::chrono::steady_clock::now();
  177. cpu_start = boost::chrono::process_cpu_clock::now();
  178. }
  179. // Report the memory usage
  180. void MPCIO::dump_memusage(std::ostream &os)
  181. {
  182. struct rusage ru;
  183. getrusage(RUSAGE_SELF, &ru);
  184. os << "Mem: " << ru.ru_maxrss << " KiB\n";
  185. }
  186. void MPCIO::dump_stats(std::ostream &os)
  187. {
  188. size_t tot_msgs_sent = 0;
  189. size_t tot_msg_bytes_sent = 0;
  190. size_t tot_aes_ops = 0;
  191. for (auto& n : msgs_sent) {
  192. tot_msgs_sent += n;
  193. }
  194. for (auto& n : msg_bytes_sent) {
  195. tot_msg_bytes_sent += n;
  196. }
  197. for (auto& n : aes_ops) {
  198. tot_aes_ops += n;
  199. }
  200. auto steady_elapsed =
  201. boost::chrono::steady_clock::now() - steady_start;
  202. auto cpu_elapsed =
  203. boost::chrono::process_cpu_clock::now() - cpu_start;
  204. os << tot_msgs_sent << " messages sent\n";
  205. os << tot_msg_bytes_sent << " message bytes sent\n";
  206. os << lamport << " Lamport clock (latencies)\n";
  207. os << tot_aes_ops << " local AES operations\n";
  208. os << boost::chrono::duration_cast
  209. <boost::chrono::milliseconds>(steady_elapsed) <<
  210. " wall clock time\n";
  211. os << cpu_elapsed << " {real;user;system}\n";
  212. dump_memusage(os);
  213. }
  214. // TVA is a tuple of vectors of arrays of PreCompStorage
  215. template <nbits_t WIDTH, typename TVA>
  216. static void rdpfstorage_init(TVA &storage, unsigned player,
  217. ProcessingMode mode, unsigned num_threads)
  218. {
  219. auto &VA = std::get<WIDTH-1>(storage);
  220. VA.resize(num_threads);
  221. char prefix[11];
  222. strcpy(prefix, "rdpf");
  223. if (WIDTH > 1) {
  224. sprintf(prefix+strlen(prefix), "_w%d", WIDTH);
  225. }
  226. for (unsigned i=0; i<num_threads; ++i) {
  227. for (unsigned depth=1; depth<=ADDRESS_MAX_BITS; ++depth) {
  228. VA[i][depth-1].init(player, mode, prefix, i, depth, WIDTH);
  229. }
  230. }
  231. }
  232. // TVA is a tuple of vectors of arrays of PreCompStorage
  233. template <nbits_t WIDTH, typename TVA>
  234. static void rdpfstorage_dumpstats(std::ostream &os, TVA &storage,
  235. size_t thread_num)
  236. {
  237. auto &VA = std::get<WIDTH-1>(storage);
  238. for (nbits_t depth=1; depth<=ADDRESS_MAX_BITS; ++depth) {
  239. size_t cnt = VA[thread_num][depth-1].get_stats();
  240. if (cnt > 0) {
  241. os << " r" << int(depth);
  242. if (WIDTH > 1) {
  243. os << "." << int(WIDTH);
  244. }
  245. os << ":" << cnt;
  246. }
  247. }
  248. }
  249. // TVA is a tuple of vectors of arrays of PreCompStorage
  250. template <nbits_t WIDTH, typename TVA>
  251. static void rdpfstorage_resetstats(TVA &storage, size_t thread_num)
  252. {
  253. auto &VA = std::get<WIDTH-1>(storage);
  254. for (nbits_t depth=1; depth<=ADDRESS_MAX_BITS; ++depth) {
  255. VA[thread_num][depth-1].reset_stats();
  256. }
  257. }
  258. MPCPeerIO::MPCPeerIO(unsigned player, ProcessingMode mode,
  259. std::deque<tcp::socket> &peersocks,
  260. std::deque<tcp::socket> &serversocks) :
  261. MPCIO(player, mode, peersocks.size())
  262. {
  263. unsigned num_threads = unsigned(peersocks.size());
  264. for (unsigned i=0; i<num_threads; ++i) {
  265. multtriples.emplace_back(player, mode, "mults", i);
  266. }
  267. for (unsigned i=0; i<num_threads; ++i) {
  268. halftriples.emplace_back(player, mode, "halves", i);
  269. }
  270. for (unsigned i=0; i<num_threads; ++i) {
  271. andtriples.emplace_back(player, mode, "ands", i);
  272. }
  273. for (unsigned i=0; i<num_threads; ++i) {
  274. valselecttriples.emplace_back(player, mode, "selects", i);
  275. }
  276. rdpfstorage_init<1>(rdpftriples, player, mode, num_threads);
  277. rdpfstorage_init<2>(rdpftriples, player, mode, num_threads);
  278. rdpfstorage_init<3>(rdpftriples, player, mode, num_threads);
  279. rdpfstorage_init<4>(rdpftriples, player, mode, num_threads);
  280. rdpfstorage_init<5>(rdpftriples, player, mode, num_threads);
  281. for (unsigned i=0; i<num_threads; ++i) {
  282. cdpfs.emplace_back(player, mode, "cdpf", i);
  283. }
  284. for (unsigned i=0; i<num_threads; ++i) {
  285. peerios.emplace_back(std::move(peersocks[i]), "peer", i);
  286. }
  287. for (unsigned i=0; i<num_threads; ++i) {
  288. serverios.emplace_back(std::move(serversocks[i]), "srv", i);
  289. }
  290. }
  291. void MPCPeerIO::dump_precomp_stats(std::ostream &os)
  292. {
  293. for (size_t i=0; i<multtriples.size(); ++i) {
  294. size_t cnt;
  295. if (i > 0) {
  296. os << " ";
  297. }
  298. os << "T" << i;
  299. cnt = multtriples[i].get_stats();
  300. if (cnt > 0) {
  301. os << " m:" << cnt;
  302. }
  303. cnt = halftriples[i].get_stats();
  304. if (cnt > 0) {
  305. os << " h:" << cnt;
  306. }
  307. cnt = andtriples[i].get_stats();
  308. if (cnt > 0) {
  309. os << " a:" << cnt;
  310. }
  311. cnt = valselecttriples[i].get_stats();
  312. if (cnt > 0) {
  313. os << " s:" << cnt;
  314. }
  315. rdpfstorage_dumpstats<1>(os, rdpftriples, i);
  316. rdpfstorage_dumpstats<2>(os, rdpftriples, i);
  317. rdpfstorage_dumpstats<3>(os, rdpftriples, i);
  318. rdpfstorage_dumpstats<4>(os, rdpftriples, i);
  319. rdpfstorage_dumpstats<5>(os, rdpftriples, i);
  320. cnt = cdpfs[i].get_stats();
  321. if (cnt > 0) {
  322. os << " c:" << cnt;
  323. }
  324. }
  325. os << "\n";
  326. }
  327. void MPCPeerIO::reset_precomp_stats()
  328. {
  329. for (size_t i=0; i<multtriples.size(); ++i) {
  330. multtriples[i].reset_stats();
  331. halftriples[i].reset_stats();
  332. andtriples[i].reset_stats();
  333. valselecttriples[i].reset_stats();
  334. rdpfstorage_resetstats<1>(rdpftriples, i);
  335. rdpfstorage_resetstats<2>(rdpftriples, i);
  336. rdpfstorage_resetstats<3>(rdpftriples, i);
  337. rdpfstorage_resetstats<4>(rdpftriples, i);
  338. rdpfstorage_resetstats<5>(rdpftriples, i);
  339. }
  340. }
  341. void MPCPeerIO::dump_stats(std::ostream &os)
  342. {
  343. MPCIO::dump_stats(os);
  344. os << "Precomputed values used: ";
  345. dump_precomp_stats(os);
  346. }
  347. MPCServerIO::MPCServerIO(ProcessingMode mode,
  348. std::deque<tcp::socket> &p0socks,
  349. std::deque<tcp::socket> &p1socks) :
  350. MPCIO(2, mode, p0socks.size())
  351. {
  352. rdpfstorage_init<1>(rdpfpairs, player, mode, num_threads);
  353. rdpfstorage_init<2>(rdpfpairs, player, mode, num_threads);
  354. rdpfstorage_init<3>(rdpfpairs, player, mode, num_threads);
  355. rdpfstorage_init<4>(rdpfpairs, player, mode, num_threads);
  356. rdpfstorage_init<5>(rdpfpairs, player, mode, num_threads);
  357. for (unsigned i=0; i<num_threads; ++i) {
  358. p0ios.emplace_back(std::move(p0socks[i]), "p0", i);
  359. }
  360. for (unsigned i=0; i<num_threads; ++i) {
  361. p1ios.emplace_back(std::move(p1socks[i]), "p1", i);
  362. }
  363. }
  364. void MPCServerIO::dump_precomp_stats(std::ostream &os)
  365. {
  366. for (size_t i=0; i<std::get<0>(rdpfpairs).size(); ++i) {
  367. if (i > 0) {
  368. os << " ";
  369. }
  370. os << "T" << i;
  371. rdpfstorage_dumpstats<1>(os, rdpfpairs, i);
  372. rdpfstorage_dumpstats<2>(os, rdpfpairs, i);
  373. rdpfstorage_dumpstats<3>(os, rdpfpairs, i);
  374. rdpfstorage_dumpstats<4>(os, rdpfpairs, i);
  375. rdpfstorage_dumpstats<5>(os, rdpfpairs, i);
  376. }
  377. os << "\n";
  378. }
  379. void MPCServerIO::reset_precomp_stats()
  380. {
  381. for (size_t i=0; i<std::get<0>(rdpfpairs).size(); ++i) {
  382. rdpfstorage_resetstats<1>(rdpfpairs, i);
  383. rdpfstorage_resetstats<2>(rdpfpairs, i);
  384. rdpfstorage_resetstats<3>(rdpfpairs, i);
  385. rdpfstorage_resetstats<4>(rdpfpairs, i);
  386. rdpfstorage_resetstats<5>(rdpfpairs, i);
  387. }
  388. }
  389. void MPCServerIO::dump_stats(std::ostream &os)
  390. {
  391. MPCIO::dump_stats(os);
  392. os << "Precomputed values used: ";
  393. dump_precomp_stats(os);
  394. }
  395. MPCTIO::MPCTIO(MPCIO &mpcio, int thread_num, int num_threads) :
  396. thread_num(thread_num), local_cpu_nthreads(num_threads),
  397. communication_nthreads(num_threads),
  398. thread_lamport(mpcio.lamport), mpcio(mpcio),
  399. #ifdef VERBOSE_COMMS
  400. round_num(0),
  401. #endif
  402. last_andtriple_bits_remaining(0)
  403. {
  404. if (mpcio.player < 2) {
  405. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  406. peer_iostream.emplace(mpcpio.peerios[thread_num],
  407. thread_lamport, mpcpio.msgs_sent[thread_num],
  408. mpcpio.msg_bytes_sent[thread_num]);
  409. server_iostream.emplace(mpcpio.serverios[thread_num],
  410. thread_lamport, mpcpio.msgs_sent[thread_num],
  411. mpcpio.msg_bytes_sent[thread_num]);
  412. } else {
  413. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  414. p0_iostream.emplace(mpcsrvio.p0ios[thread_num],
  415. thread_lamport, mpcsrvio.msgs_sent[thread_num],
  416. mpcsrvio.msg_bytes_sent[thread_num]);
  417. p1_iostream.emplace(mpcsrvio.p1ios[thread_num],
  418. thread_lamport, mpcsrvio.msgs_sent[thread_num],
  419. mpcsrvio.msg_bytes_sent[thread_num]);
  420. }
  421. }
  422. // Sync our per-thread lamport clock with the master one in the
  423. // mpcio. You only need to call this explicitly if your MPCTIO
  424. // outlives your thread (in which case call it after the join), or
  425. // if your threads do interthread communication amongst themselves
  426. // (in which case call it in the sending thread before the send, and
  427. // call it in the receiving thread after the receive).
  428. void MPCTIO::sync_lamport()
  429. {
  430. // Update the mpcio Lamport time to be max of the thread Lamport
  431. // time and what we thought it was before. We use this
  432. // compare_exchange construction in order to atomically
  433. // do the comparison, computation, and replacement
  434. lamport_t old_lamport = mpcio.lamport;
  435. lamport_t new_lamport = thread_lamport;
  436. do {
  437. if (new_lamport < old_lamport) {
  438. new_lamport = old_lamport;
  439. }
  440. // The next line atomically checks if lamport still has
  441. // the value old_lamport; if so, it changes its value to
  442. // new_lamport and returns true (ending the loop). If
  443. // not, it sets old_lamport to the current value of
  444. // lamport, and returns false (continuing the loop so
  445. // that new_lamport can be recomputed based on this new
  446. // value).
  447. } while (!mpcio.lamport.compare_exchange_weak(
  448. old_lamport, new_lamport));
  449. thread_lamport = new_lamport;
  450. }
  451. // Only call this if you can be sure that there are no outstanding
  452. // messages in flight, you can call it on all existing MPCTIOs, and
  453. // you really want to reset the Lamport clock in the midding of a
  454. // run.
  455. void MPCTIO::reset_lamport()
  456. {
  457. // Reset both our own Lamport clock and the parent MPCIO's
  458. thread_lamport = 0;
  459. mpcio.lamport = 0;
  460. }
  461. // Queue up data to the peer or to the server
  462. void MPCTIO::queue_peer(const void *data, size_t len)
  463. {
  464. if (mpcio.player < 2) {
  465. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  466. size_t newmsg = mpcpio.peerios[thread_num].queue(data, len, thread_lamport);
  467. mpcpio.msgs_sent[thread_num] += newmsg;
  468. mpcpio.msg_bytes_sent[thread_num] += len;
  469. }
  470. }
  471. void MPCTIO::queue_server(const void *data, size_t len)
  472. {
  473. if (mpcio.player < 2) {
  474. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  475. size_t newmsg = mpcpio.serverios[thread_num].queue(data, len, thread_lamport);
  476. mpcpio.msgs_sent[thread_num] += newmsg;
  477. mpcpio.msg_bytes_sent[thread_num] += len;
  478. }
  479. }
  480. // Receive data from the peer or to the server
  481. size_t MPCTIO::recv_peer(void *data, size_t len)
  482. {
  483. if (mpcio.player < 2) {
  484. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  485. return mpcpio.peerios[thread_num].recv(data, len, thread_lamport);
  486. }
  487. return 0;
  488. }
  489. size_t MPCTIO::recv_server(void *data, size_t len)
  490. {
  491. if (mpcio.player < 2) {
  492. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  493. return mpcpio.serverios[thread_num].recv(data, len, thread_lamport);
  494. }
  495. return 0;
  496. }
  497. // Queue up data to p0 or p1
  498. void MPCTIO::queue_p0(const void *data, size_t len)
  499. {
  500. if (mpcio.player == 2) {
  501. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  502. size_t newmsg = mpcsrvio.p0ios[thread_num].queue(data, len, thread_lamport);
  503. mpcsrvio.msgs_sent[thread_num] += newmsg;
  504. mpcsrvio.msg_bytes_sent[thread_num] += len;
  505. }
  506. }
  507. void MPCTIO::queue_p1(const void *data, size_t len)
  508. {
  509. if (mpcio.player == 2) {
  510. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  511. size_t newmsg = mpcsrvio.p1ios[thread_num].queue(data, len, thread_lamport);
  512. mpcsrvio.msgs_sent[thread_num] += newmsg;
  513. mpcsrvio.msg_bytes_sent[thread_num] += len;
  514. }
  515. }
  516. // Receive data from p0 or p1
  517. size_t MPCTIO::recv_p0(void *data, size_t len)
  518. {
  519. if (mpcio.player == 2) {
  520. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  521. return mpcsrvio.p0ios[thread_num].recv(data, len, thread_lamport);
  522. }
  523. return 0;
  524. }
  525. size_t MPCTIO::recv_p1(void *data, size_t len)
  526. {
  527. if (mpcio.player == 2) {
  528. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  529. return mpcsrvio.p1ios[thread_num].recv(data, len, thread_lamport);
  530. }
  531. return 0;
  532. }
  533. // Send all queued data for this thread
  534. void MPCTIO::send()
  535. {
  536. #ifdef VERBOSE_COMMS
  537. printf("Thread %u sending round %lu\n", thread_num, ++round_num);
  538. #endif
  539. if (mpcio.player < 2) {
  540. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  541. mpcpio.peerios[thread_num].send();
  542. mpcpio.serverios[thread_num].send();
  543. } else {
  544. MPCServerIO &mpcsrvio = static_cast<MPCServerIO&>(mpcio);
  545. mpcsrvio.p0ios[thread_num].send();
  546. mpcsrvio.p1ios[thread_num].send();
  547. }
  548. }
  549. // Functions to get precomputed values. If we're in the online
  550. // phase, get them from PreCompStorage. If we're in the
  551. // preprocessing or online-only phase, read them from the server.
  552. MultTriple MPCTIO::multtriple(yield_t &yield)
  553. {
  554. MultTriple val;
  555. if (mpcio.player < 2) {
  556. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  557. if (mpcpio.mode != MODE_ONLINE) {
  558. yield();
  559. recv_server(&val, sizeof(val));
  560. mpcpio.multtriples[thread_num].inc();
  561. } else {
  562. mpcpio.multtriples[thread_num].get(val);
  563. }
  564. } else if (mpcio.mode != MODE_ONLINE) {
  565. // Create multiplication triples (X0,Y0,Z0),(X1,Y1,Z1) such that
  566. // (X0*Y1 + Y0*X1) = (Z0+Z1)
  567. value_t X0, Y0, Z0, X1, Y1, Z1;
  568. arc4random_buf(&X0, sizeof(X0));
  569. arc4random_buf(&Y0, sizeof(Y0));
  570. arc4random_buf(&Z0, sizeof(Z0));
  571. arc4random_buf(&X1, sizeof(X1));
  572. arc4random_buf(&Y1, sizeof(Y1));
  573. Z1 = X0 * Y1 + X1 * Y0 - Z0;
  574. MultTriple T0, T1;
  575. T0 = std::make_tuple(X0, Y0, Z0);
  576. T1 = std::make_tuple(X1, Y1, Z1);
  577. queue_p0(&T0, sizeof(T0));
  578. queue_p1(&T1, sizeof(T1));
  579. yield();
  580. }
  581. return val;
  582. }
  583. // When halftriple() is used internally to another preprocessing
  584. // operation, don't tally it, so that it doesn't appear sepearately in
  585. // the stats from the preprocessing operation that invoked it
  586. HalfTriple MPCTIO::halftriple(yield_t &yield, bool tally)
  587. {
  588. HalfTriple val;
  589. if (mpcio.player < 2) {
  590. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  591. if (mpcpio.mode != MODE_ONLINE) {
  592. yield();
  593. recv_server(&val, sizeof(val));
  594. if (tally) {
  595. mpcpio.halftriples[thread_num].inc();
  596. }
  597. } else {
  598. mpcpio.halftriples[thread_num].get(val);
  599. }
  600. } else if (mpcio.mode != MODE_ONLINE) {
  601. // Create half-triples (X0,Z0),(Y1,Z1) such that
  602. // X0*Y1 = Z0 + Z1
  603. value_t X0, Z0, Y1, Z1;
  604. arc4random_buf(&X0, sizeof(X0));
  605. arc4random_buf(&Z0, sizeof(Z0));
  606. arc4random_buf(&Y1, sizeof(Y1));
  607. Z1 = X0 * Y1 - Z0;
  608. HalfTriple H0, H1;
  609. H0 = std::make_tuple(X0, Z0);
  610. H1 = std::make_tuple(Y1, Z1);
  611. queue_p0(&H0, sizeof(H0));
  612. queue_p1(&H1, sizeof(H1));
  613. yield();
  614. }
  615. return val;
  616. }
  617. MultTriple MPCTIO::andtriple(yield_t &yield)
  618. {
  619. AndTriple val;
  620. if (mpcio.player < 2) {
  621. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  622. if (mpcpio.mode != MODE_ONLINE) {
  623. yield();
  624. recv_server(&val, sizeof(val));
  625. mpcpio.andtriples[thread_num].inc();
  626. } else {
  627. mpcpio.andtriples[thread_num].get(val);
  628. }
  629. } else if (mpcio.mode != MODE_ONLINE) {
  630. // Create AND triples (X0,Y0,Z0),(X1,Y1,Z1) such that
  631. // (X0&Y1 ^ Y0&X1) = (Z0^Z1)
  632. value_t X0, Y0, Z0, X1, Y1, Z1;
  633. arc4random_buf(&X0, sizeof(X0));
  634. arc4random_buf(&Y0, sizeof(Y0));
  635. arc4random_buf(&Z0, sizeof(Z0));
  636. arc4random_buf(&X1, sizeof(X1));
  637. arc4random_buf(&Y1, sizeof(Y1));
  638. Z1 = (X0 & Y1) ^ (X1 & Y0) ^ Z0;
  639. AndTriple T0, T1;
  640. T0 = std::make_tuple(X0, Y0, Z0);
  641. T1 = std::make_tuple(X1, Y1, Z1);
  642. queue_p0(&T0, sizeof(T0));
  643. queue_p1(&T1, sizeof(T1));
  644. yield();
  645. }
  646. return val;
  647. }
  648. SelectTriple<DPFnode> MPCTIO::nodeselecttriple(yield_t &yield)
  649. {
  650. SelectTriple<DPFnode> val;
  651. if (mpcio.player < 2) {
  652. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  653. if (mpcpio.mode != MODE_ONLINE) {
  654. uint8_t Xbyte;
  655. yield();
  656. recv_server(&Xbyte, sizeof(Xbyte));
  657. val.X = Xbyte & 1;
  658. recv_server(&val.Y, sizeof(val.Y));
  659. recv_server(&val.Z, sizeof(val.Z));
  660. } else {
  661. std::cerr << "Attempted to read SelectTriple<DPFnode> in online phase\n";
  662. }
  663. } else if (mpcio.mode != MODE_ONLINE) {
  664. // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
  665. // (X0*Y1 ^ Y0*X1) = (Z0^Z1)
  666. bit_t X0, X1;
  667. DPFnode Y0, Z0, Y1, Z1;
  668. X0 = arc4random() & 1;
  669. arc4random_buf(&Y0, sizeof(Y0));
  670. arc4random_buf(&Z0, sizeof(Z0));
  671. X1 = arc4random() & 1;
  672. arc4random_buf(&Y1, sizeof(Y1));
  673. DPFnode X0ext, X1ext;
  674. // Sign-extend X0 and X1 (so that 0 -> 0000...0 and
  675. // 1 -> 1111...1)
  676. X0ext = if128_mask[X0];
  677. X1ext = if128_mask[X1];
  678. Z1 = ((X0ext & Y1) ^ (X1ext & Y0)) ^ Z0;
  679. queue_p0(&X0, sizeof(X0));
  680. queue_p0(&Y0, sizeof(Y0));
  681. queue_p0(&Z0, sizeof(Z0));
  682. queue_p1(&X1, sizeof(X1));
  683. queue_p1(&Y1, sizeof(Y1));
  684. queue_p1(&Z1, sizeof(Z1));
  685. yield();
  686. }
  687. return val;
  688. }
  689. SelectTriple<value_t> MPCTIO::valselecttriple(yield_t &yield)
  690. {
  691. SelectTriple<value_t> val;
  692. if (mpcio.player < 2) {
  693. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  694. if (mpcpio.mode != MODE_ONLINE) {
  695. uint8_t Xbyte;
  696. yield();
  697. recv_server(&Xbyte, sizeof(Xbyte));
  698. val.X = Xbyte & 1;
  699. recv_server(&val.Y, sizeof(val.Y));
  700. recv_server(&val.Z, sizeof(val.Z));
  701. mpcpio.valselecttriples[thread_num].inc();
  702. } else {
  703. mpcpio.valselecttriples[thread_num].get(val);
  704. }
  705. } else if (mpcio.mode != MODE_ONLINE) {
  706. // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
  707. // (X0*Y1 ^ Y0*X1) = (Z0^Z1)
  708. bit_t X0, X1;
  709. value_t Y0, Z0, Y1, Z1;
  710. X0 = arc4random() & 1;
  711. arc4random_buf(&Y0, sizeof(Y0));
  712. arc4random_buf(&Z0, sizeof(Z0));
  713. X1 = arc4random() & 1;
  714. arc4random_buf(&Y1, sizeof(Y1));
  715. value_t X0ext, X1ext;
  716. // Sign-extend X0 and X1 (so that 0 -> 0000...0 and
  717. // 1 -> 1111...1)
  718. X0ext = -value_t(X0);
  719. X1ext = -value_t(X1);
  720. Z1 = ((X0ext & Y1) ^ (X1ext & Y0)) ^ Z0;
  721. queue_p0(&X0, sizeof(X0));
  722. queue_p0(&Y0, sizeof(Y0));
  723. queue_p0(&Z0, sizeof(Z0));
  724. queue_p1(&X1, sizeof(X1));
  725. queue_p1(&Y1, sizeof(Y1));
  726. queue_p1(&Z1, sizeof(Z1));
  727. yield();
  728. }
  729. return val;
  730. }
  731. SelectTriple<bit_t> MPCTIO::bitselecttriple(yield_t &yield)
  732. {
  733. // Do we need to fetch a new AND triple?
  734. if (last_andtriple_bits_remaining == 0) {
  735. last_andtriple = andtriple(yield);
  736. last_andtriple_bits_remaining = 8*sizeof(value_t);
  737. }
  738. --last_andtriple_bits_remaining;
  739. value_t mask = value_t(1) << last_andtriple_bits_remaining;
  740. SelectTriple<bit_t> val;
  741. val.X = !!(std::get<0>(last_andtriple) & mask);
  742. val.Y = !!(std::get<1>(last_andtriple) & mask);
  743. val.Z = !!(std::get<2>(last_andtriple) & mask);
  744. return val;
  745. }
  746. CDPF MPCTIO::cdpf(yield_t &yield)
  747. {
  748. CDPF val;
  749. if (mpcio.player < 2) {
  750. MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
  751. if (mpcpio.mode != MODE_ONLINE) {
  752. yield();
  753. iostream_server() >> val;
  754. mpcpio.cdpfs[thread_num].inc();
  755. } else {
  756. mpcpio.cdpfs[thread_num].get(val);
  757. }
  758. } else if (mpcio.mode != MODE_ONLINE) {
  759. auto [ cdpf0, cdpf1 ] = CDPF::generate(aes_ops());
  760. iostream_p0() << cdpf0;
  761. iostream_p1() << cdpf1;
  762. yield();
  763. }
  764. return val;
  765. }
  766. // The port number for the P1 -> P0 connection
  767. static const unsigned short port_p1_p0 = 2115;
  768. // The port number for the P2 -> P0 connection
  769. static const unsigned short port_p2_p0 = 2116;
  770. // The port number for the P2 -> P1 connection
  771. static const unsigned short port_p2_p1 = 2117;
  772. void mpcio_setup_computational(unsigned player,
  773. boost::asio::io_context &io_context,
  774. const char *p0addr, // can be NULL when player=0
  775. int num_threads,
  776. std::deque<tcp::socket> &peersocks,
  777. std::deque<tcp::socket> &serversocks)
  778. {
  779. if (player == 0) {
  780. // Listen for connections from P1 and from P2
  781. tcp::acceptor acceptor_p1(io_context,
  782. tcp::endpoint(tcp::v4(), port_p1_p0));
  783. tcp::acceptor acceptor_p2(io_context,
  784. tcp::endpoint(tcp::v4(), port_p2_p0));
  785. peersocks.clear();
  786. serversocks.clear();
  787. for (int i=0;i<num_threads;++i) {
  788. peersocks.emplace_back(io_context);
  789. serversocks.emplace_back(io_context);
  790. }
  791. for (int i=0;i<num_threads;++i) {
  792. tcp::socket peersock = acceptor_p1.accept();
  793. // Read 2 bytes from the socket, which will be the thread
  794. // number
  795. unsigned short thread_num;
  796. boost::asio::read(peersock,
  797. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  798. if (thread_num >= num_threads) {
  799. std::cerr << "Received bad thread number from peer\n";
  800. } else {
  801. peersocks[thread_num] = std::move(peersock);
  802. }
  803. }
  804. for (int i=0;i<num_threads;++i) {
  805. tcp::socket serversock = acceptor_p2.accept();
  806. // Read 2 bytes from the socket, which will be the thread
  807. // number
  808. unsigned short thread_num;
  809. boost::asio::read(serversock,
  810. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  811. if (thread_num >= num_threads) {
  812. std::cerr << "Received bad thread number from server\n";
  813. } else {
  814. serversocks[thread_num] = std::move(serversock);
  815. }
  816. }
  817. } else if (player == 1) {
  818. // Listen for connections from P2, make num_threads connections to P0
  819. tcp::acceptor acceptor_p2(io_context,
  820. tcp::endpoint(tcp::v4(), port_p2_p1));
  821. tcp::resolver resolver(io_context);
  822. boost::system::error_code err;
  823. peersocks.clear();
  824. serversocks.clear();
  825. for (int i=0;i<num_threads;++i) {
  826. serversocks.emplace_back(io_context);
  827. }
  828. for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
  829. tcp::socket peersock(io_context);
  830. while(1) {
  831. boost::asio::connect(peersock,
  832. resolver.resolve(p0addr, std::to_string(port_p1_p0)), err);
  833. if (!err) break;
  834. std::cerr << "Connection to p0 refused, will retry.\n";
  835. sleep(1);
  836. }
  837. // Write 2 bytes to the socket indicating which thread
  838. // number this socket is for
  839. boost::asio::write(peersock,
  840. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  841. peersocks.push_back(std::move(peersock));
  842. }
  843. for (int i=0;i<num_threads;++i) {
  844. tcp::socket serversock = acceptor_p2.accept();
  845. // Read 2 bytes from the socket, which will be the thread
  846. // number
  847. unsigned short thread_num;
  848. boost::asio::read(serversock,
  849. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  850. if (thread_num >= num_threads) {
  851. std::cerr << "Received bad thread number from server\n";
  852. } else {
  853. serversocks[thread_num] = std::move(serversock);
  854. }
  855. }
  856. } else {
  857. std::cerr << "Invalid player number passed to mpcio_setup_computational\n";
  858. }
  859. }
  860. void mpcio_setup_server(boost::asio::io_context &io_context,
  861. const char *p0addr, const char *p1addr, int num_threads,
  862. std::deque<tcp::socket> &p0socks,
  863. std::deque<tcp::socket> &p1socks)
  864. {
  865. // Make connections to P0 and P1
  866. tcp::resolver resolver(io_context);
  867. boost::system::error_code err;
  868. p0socks.clear();
  869. p1socks.clear();
  870. for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
  871. tcp::socket p0sock(io_context);
  872. while(1) {
  873. boost::asio::connect(p0sock,
  874. resolver.resolve(p0addr, std::to_string(port_p2_p0)), err);
  875. if (!err) break;
  876. std::cerr << "Connection to p0 refused, will retry.\n";
  877. sleep(1);
  878. }
  879. // Write 2 bytes to the socket indicating which thread
  880. // number this socket is for
  881. boost::asio::write(p0sock,
  882. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  883. p0socks.push_back(std::move(p0sock));
  884. }
  885. for (unsigned short thread_num = 0; thread_num < num_threads; ++thread_num) {
  886. tcp::socket p1sock(io_context);
  887. while(1) {
  888. boost::asio::connect(p1sock,
  889. resolver.resolve(p1addr, std::to_string(port_p2_p1)), err);
  890. if (!err) break;
  891. std::cerr << "Connection to p1 refused, will retry.\n";
  892. sleep(1);
  893. }
  894. // Write 2 bytes to the socket indicating which thread
  895. // number this socket is for
  896. boost::asio::write(p1sock,
  897. boost::asio::buffer(&thread_num, sizeof(thread_num)));
  898. p1socks.push_back(std::move(p1sock));
  899. }
  900. }