online.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "online.hpp"
  3. #include "mpcops.hpp"
  4. static void online_test(MPCIO &mpcio, int num_threads, char **args)
  5. {
  6. nbits_t nbits = VALUE_BITS;
  7. if (*args) {
  8. nbits = atoi(*args);
  9. }
  10. size_t memsize = 9;
  11. MPCTIO tio(mpcio, 0);
  12. bool is_server = (mpcio.player == 2);
  13. RegAS *A = new RegAS[memsize];
  14. value_t V;
  15. RegBS F0, F1;
  16. RegXS X;
  17. if (!is_server) {
  18. A[0].randomize();
  19. A[1].randomize();
  20. F0.randomize();
  21. A[4].randomize();
  22. F1.randomize();
  23. A[6].randomize();
  24. A[7].randomize();
  25. X.randomize();
  26. arc4random_buf(&V, sizeof(V));
  27. printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
  28. printf("V : %016lX\n", V);
  29. printf("F0 : %01X\n", F0.bshare);
  30. printf("F1 : %01X\n", F1.bshare);
  31. printf("X : %016lX\n", X.xshare);
  32. }
  33. std::vector<coro_t> coroutines;
  34. coroutines.emplace_back(
  35. [&](yield_t &yield) {
  36. mpc_mul(tio, yield, A[2], A[0], A[1], nbits);
  37. });
  38. coroutines.emplace_back(
  39. [&](yield_t &yield) {
  40. mpc_valuemul(tio, yield, A[3], V, nbits);
  41. });
  42. coroutines.emplace_back(
  43. [&](yield_t &yield) {
  44. mpc_flagmult(tio, yield, A[5], F0, A[4], nbits);
  45. });
  46. coroutines.emplace_back(
  47. [&](yield_t &yield) {
  48. mpc_oswap(tio, yield, A[6], A[7], F1, nbits);
  49. });
  50. coroutines.emplace_back(
  51. [&](yield_t &yield) {
  52. mpc_xs_to_as(tio, yield, A[8], X, nbits);
  53. });
  54. run_coroutines(tio, coroutines);
  55. if (!is_server) {
  56. printf("\n");
  57. printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
  58. }
  59. // Check the answers
  60. if (mpcio.player == 1) {
  61. tio.queue_peer(A, memsize*sizeof(RegAS));
  62. tio.queue_peer(&V, sizeof(V));
  63. tio.queue_peer(&F0, sizeof(RegBS));
  64. tio.queue_peer(&F1, sizeof(RegBS));
  65. tio.queue_peer(&X, sizeof(RegXS));
  66. tio.send();
  67. } else if (mpcio.player == 0) {
  68. RegAS *B = new RegAS[memsize];
  69. RegBS BF0, BF1;
  70. RegXS BX;
  71. value_t BV;
  72. value_t *S = new value_t[memsize];
  73. bit_t SF0, SF1;
  74. value_t SX;
  75. tio.recv_peer(B, memsize*sizeof(RegAS));
  76. tio.recv_peer(&BV, sizeof(BV));
  77. tio.recv_peer(&BF0, sizeof(RegBS));
  78. tio.recv_peer(&BF1, sizeof(RegBS));
  79. tio.recv_peer(&BX, sizeof(RegXS));
  80. for(size_t i=0; i<memsize; ++i) S[i] = A[i].ashare+B[i].ashare;
  81. SF0 = F0.bshare ^ BF0.bshare;
  82. SF1 = F1.bshare ^ BF1.bshare;
  83. SX = X.xshare ^ BX.xshare;
  84. printf("S:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, S[i]);
  85. printf("SF0: %01X\n", SF0);
  86. printf("SF1: %01X\n", SF1);
  87. printf("SX : %016lX\n", SX);
  88. printf("\n%016lx\n", S[0]*S[1]-S[2]);
  89. printf("%016lx\n", (V*BV)-S[3]);
  90. printf("%016lx\n", (SF0*S[4])-S[5]);
  91. printf("%016lx\n", S[8]-SX);
  92. delete[] B;
  93. delete[] S;
  94. }
  95. delete[] A;
  96. }
  97. static void lamport_test(MPCIO &mpcio, int num_threads, char **args)
  98. {
  99. // Create a bunch of threads and send a bunch of data to the other
  100. // peer, and receive their data. If an arg is specified, repeat
  101. // that many times. The Lamport clock at the end should be just the
  102. // number of repetitions. Subsequent args are the chunk size and
  103. // the number of chunks per message
  104. size_t niters = 1;
  105. size_t chunksize = 1<<20;
  106. size_t numchunks = 1;
  107. if (*args) {
  108. niters = atoi(*args);
  109. ++args;
  110. }
  111. if (*args) {
  112. chunksize = atoi(*args);
  113. ++args;
  114. }
  115. if (*args) {
  116. numchunks = atoi(*args);
  117. ++args;
  118. }
  119. boost::asio::thread_pool pool(num_threads);
  120. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  121. boost::asio::post(pool, [&mpcio, thread_num, niters, chunksize, numchunks] {
  122. MPCTIO tio(mpcio, thread_num);
  123. char *sendbuf = new char[chunksize];
  124. char *recvbuf = new char[chunksize*numchunks];
  125. for (size_t i=0; i<niters; ++i) {
  126. for (size_t chunk=0; chunk<numchunks; ++chunk) {
  127. arc4random_buf(sendbuf, chunksize);
  128. tio.queue_peer(sendbuf, chunksize);
  129. }
  130. tio.send();
  131. tio.recv_peer(recvbuf, chunksize*numchunks);
  132. }
  133. delete[] recvbuf;
  134. delete[] sendbuf;
  135. });
  136. }
  137. pool.join();
  138. }
  139. void online_main(MPCIO &mpcio, int num_threads, char **args)
  140. {
  141. if (!*args) {
  142. std::cerr << "Mode is required as the first argument when not preprocessing.\n";
  143. return;
  144. } else if (!strcmp(*args, "test")) {
  145. ++args;
  146. online_test(mpcio, num_threads, args);
  147. } else if (!strcmp(*args, "lamporttest")) {
  148. ++args;
  149. lamport_test(mpcio, num_threads, args);
  150. } else {
  151. std::cerr << "Unknown mode " << *args << "\n";
  152. }
  153. }