preproc.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. // Open a file for writing with name the given prefix, and ".pX.tY"
  7. // suffix, where X is the (one-digit) player number and Y is the thread
  8. // number
  9. static std::ofstream openfile(const char *prefix, unsigned player,
  10. unsigned thread_num)
  11. {
  12. std::string filename(prefix);
  13. char suffix[20];
  14. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  15. filename.append(suffix);
  16. std::ofstream f;
  17. f.open(filename);
  18. if (f.fail()) {
  19. std::cerr << "Failed to open " << filename << "\n";
  20. exit(1);
  21. }
  22. return f;
  23. }
  24. // The server-to-computational-peer protocol for sending precomputed
  25. // data is:
  26. //
  27. // One byte: type
  28. // 0x80: Multiplication triple
  29. // 0x81: Multiplication half-triple
  30. // 0x01 to 0x30: RAM DPF of that depth
  31. // 0x40: Comparison DPF
  32. // 0x00: End of preprocessing
  33. //
  34. // Four bytes: number of objects of that type (not sent for type == 0x00)
  35. //
  36. // Then that number of objects
  37. //
  38. // Repeat the whole thing until type == 0x00 is received
  39. void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
  40. {
  41. boost::asio::thread_pool pool(num_threads);
  42. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  43. boost::asio::post(pool, [&mpcio, thread_num] {
  44. MPCTIO tio(mpcio, thread_num);
  45. std::vector<coro_t> coroutines;
  46. while(1) {
  47. unsigned char type = 0;
  48. unsigned int num = 0;
  49. size_t res = tio.recv_server(&type, 1);
  50. if (res < 1 || type == 0) break;
  51. tio.recv_server(&num, 4);
  52. if (type == 0x80) {
  53. // Multiplication triples
  54. std::ofstream tripfile = openfile("triples",
  55. mpcio.player, thread_num);
  56. MultTriple T;
  57. for (unsigned int i=0; i<num; ++i) {
  58. T = tio.triple();
  59. tripfile.write((const char *)&T, sizeof(T));
  60. }
  61. tripfile.close();
  62. } else if (type == 0x81) {
  63. // Multiplication half triples
  64. std::ofstream halffile = openfile("halves",
  65. mpcio.player, thread_num);
  66. HalfTriple H;
  67. for (unsigned int i=0; i<num; ++i) {
  68. H = tio.halftriple();
  69. halffile.write((const char *)&H, sizeof(H));
  70. }
  71. halffile.close();
  72. } else if (type >= 0x01 && type <= 0x30) {
  73. // RAM DPFs
  74. for (unsigned int i=0; i<num; ++i) {
  75. coroutines.emplace_back(
  76. [&](yield_t &yield) {
  77. RDPFTriple rdpftrip(tio, yield, type);
  78. printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
  79. printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
  80. printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
  81. printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
  82. printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
  83. printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
  84. tio.iostream_server() <<
  85. rdpftrip.dpf[(mpcio.player == 0) ? 1 : 2];
  86. });
  87. }
  88. }
  89. }
  90. run_coroutines(tio, coroutines);
  91. });
  92. }
  93. pool.join();
  94. }
  95. void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
  96. {
  97. boost::asio::thread_pool pool(num_threads);
  98. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  99. boost::asio::post(pool, [&mpcsrvio, thread_num, args] {
  100. char **threadargs = args;
  101. MPCTIO stio(mpcsrvio, thread_num);
  102. std::vector<coro_t> coroutines;
  103. if (*threadargs && threadargs[0][0] == 'T') {
  104. // Per-thread initialization. The args look like:
  105. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  106. // Skip to the arg marking our thread
  107. char us[20];
  108. sprintf(us, "T%u", thread_num);
  109. while (*threadargs && strcmp(*threadargs, us)) {
  110. ++threadargs;
  111. }
  112. // Now skip to the next arg if there is one
  113. if (*threadargs) {
  114. ++threadargs;
  115. }
  116. }
  117. // Stop scanning for args when we get to the end or when we
  118. // get to another per-thread initialization marker
  119. while (*threadargs && threadargs[0][0] != 'T') {
  120. char *arg = strdup(*threadargs);
  121. char *colon = strchr(arg, ':');
  122. if (!colon) {
  123. std::cerr << "Args must be type:num\n";
  124. ++threadargs;
  125. free(arg);
  126. continue;
  127. }
  128. unsigned num = atoi(colon+1);
  129. *colon = '\0';
  130. char *type = arg;
  131. if (!strcmp(type, "t")) {
  132. unsigned char typetag = 0x80;
  133. stio.queue_p0(&typetag, 1);
  134. stio.queue_p0(&num, 4);
  135. stio.queue_p1(&typetag, 1);
  136. stio.queue_p1(&num, 4);
  137. for (unsigned int i=0; i<num; ++i) {
  138. stio.triple();
  139. }
  140. } else if (!strcmp(type, "h")) {
  141. unsigned char typetag = 0x81;
  142. stio.queue_p0(&typetag, 1);
  143. stio.queue_p0(&num, 4);
  144. stio.queue_p1(&typetag, 1);
  145. stio.queue_p1(&num, 4);
  146. for (unsigned int i=0; i<num; ++i) {
  147. stio.halftriple();
  148. }
  149. } else if (type[0] == 'r') {
  150. int depth = atoi(type+1);
  151. if (depth < 1 || depth > 48) {
  152. std::cerr << "Invalid DPF depth\n";
  153. } else {
  154. unsigned char typetag = depth;
  155. stio.queue_p0(&typetag, 1);
  156. stio.queue_p0(&num, 4);
  157. stio.queue_p1(&typetag, 1);
  158. stio.queue_p1(&num, 4);
  159. for (unsigned int i=0; i<num; ++i) {
  160. coroutines.emplace_back(
  161. [&](yield_t &yield) {
  162. RDPFTriple rdpftrip(stio, yield, depth);
  163. RDPFPair rdpfpair;
  164. stio.iostream_p0() >> rdpfpair.dpf[0];
  165. stio.iostream_p1() >> rdpfpair.dpf[1];
  166. printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
  167. printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
  168. printf("usi1 = %016lx\n", rdpfpair.dpf[1].unit_sum_inverse);
  169. printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
  170. });
  171. }
  172. }
  173. }
  174. free(arg);
  175. ++threadargs;
  176. }
  177. // That's all
  178. unsigned char typetag = 0x00;
  179. stio.queue_p0(&typetag, 1);
  180. stio.queue_p1(&typetag, 1);
  181. run_coroutines(stio, coroutines);
  182. });
  183. }
  184. pool.join();
  185. }