online.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "online.hpp"
  3. #include "mpcops.hpp"
  4. static void online_test(MPCIO &mpcio, int num_threads, char **args)
  5. {
  6. nbits_t nbits = VALUE_BITS;
  7. if (*args) {
  8. nbits = atoi(*args);
  9. }
  10. size_t memsize = 13;
  11. MPCTIO tio(mpcio, 0);
  12. bool is_server = (mpcio.player == 2);
  13. value_t *A = new value_t[memsize];
  14. if (!is_server) {
  15. arc4random_buf(A, memsize*sizeof(value_t));
  16. A[5] &= 1;
  17. A[8] &= 1;
  18. printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
  19. }
  20. std::vector<coro_t> coroutines;
  21. coroutines.emplace_back(
  22. [&](yield_t &yield) {
  23. mpc_mul(tio, yield, A[2], A[0], A[1], nbits);
  24. });
  25. coroutines.emplace_back(
  26. [&](yield_t &yield) {
  27. mpc_valuemul(tio, yield, A[4], A[3], nbits);
  28. });
  29. coroutines.emplace_back(
  30. [&](yield_t &yield) {
  31. mpc_flagmult(tio, yield, A[7], A[5], A[6], nbits);
  32. });
  33. coroutines.emplace_back(
  34. [&](yield_t &yield) {
  35. mpc_oswap(tio, yield, A[9], A[10], A[8], nbits);
  36. });
  37. coroutines.emplace_back(
  38. [&](yield_t &yield) {
  39. mpc_xs_to_as(tio, yield, A[12], A[11], nbits);
  40. });
  41. run_coroutines(tio, coroutines);
  42. if (!is_server) {
  43. printf("\n");
  44. printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
  45. }
  46. // Check the answers
  47. if (mpcio.player == 1) {
  48. tio.queue_peer(A, memsize*sizeof(value_t));
  49. tio.send();
  50. } else if (mpcio.player == 0) {
  51. value_t *B = new value_t[memsize];
  52. value_t *S = new value_t[memsize];
  53. tio.recv_peer(B, memsize*sizeof(value_t));
  54. for(size_t i=0; i<memsize; ++i) S[i] = A[i]+B[i];
  55. printf("S:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, S[i]);
  56. printf("\n%016lx\n", S[0]*S[1]-S[2]);
  57. printf("%016lx\n", (A[3]*B[3])-S[4]);
  58. delete[] B;
  59. delete[] S;
  60. }
  61. if (!is_server) {
  62. MPCPeerIO &mpcpio = static_cast<MPCPeerIO &>(mpcio);
  63. mpcpio.dump_precomp_stats(std::cout);
  64. }
  65. std::cout << "Lamport clock = " << mpcio.lamport << "\n";
  66. delete[] A;
  67. }
  68. static void lamport_test(MPCIO &mpcio, int num_threads, char **args)
  69. {
  70. // Create a bunch of threads and send a bunch of data to the other
  71. // peer, and receive their data. If an arg is specified, repeat
  72. // that many times. The Lamport clock at the end should be just the
  73. // number of repetitions. Subsequent args are the chunk size and
  74. // the number of chunks per message
  75. size_t niters = 1;
  76. size_t chunksize = 1<<20;
  77. size_t numchunks = 1;
  78. if (*args) {
  79. niters = atoi(*args);
  80. ++args;
  81. }
  82. if (*args) {
  83. chunksize = atoi(*args);
  84. ++args;
  85. }
  86. if (*args) {
  87. numchunks = atoi(*args);
  88. ++args;
  89. }
  90. boost::asio::thread_pool pool(num_threads);
  91. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  92. boost::asio::post(pool, [&mpcio, thread_num, niters, chunksize, numchunks] {
  93. MPCTIO tio(mpcio, thread_num);
  94. char *sendbuf = new char[chunksize];
  95. char *recvbuf = new char[chunksize*numchunks];
  96. for (size_t i=0; i<niters; ++i) {
  97. for (size_t chunk=0; chunk<numchunks; ++chunk) {
  98. arc4random_buf(sendbuf, chunksize);
  99. tio.queue_peer(sendbuf, chunksize);
  100. }
  101. tio.send();
  102. tio.recv_peer(recvbuf, chunksize*numchunks);
  103. }
  104. delete[] recvbuf;
  105. delete[] sendbuf;
  106. });
  107. }
  108. pool.join();
  109. std::cout << "Lamport clock = " << mpcio.lamport << "\n";
  110. }
  111. void online_main(MPCIO &mpcio, int num_threads, char **args)
  112. {
  113. if (!*args) {
  114. std::cerr << "Mode is required as the first argument when not preprocessing.\n";
  115. return;
  116. } else if (!strcmp(*args, "test")) {
  117. ++args;
  118. online_test(mpcio, num_threads, args);
  119. } else if (!strcmp(*args, "lamporttest")) {
  120. ++args;
  121. lamport_test(mpcio, num_threads, args);
  122. } else {
  123. std::cerr << "Unknown mode " << *args << "\n";
  124. }
  125. }