preproc.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. std::vector<std::ofstream> files;
  10. public:
  11. class Handle {
  12. Openfiles &parent;
  13. size_t idx;
  14. public:
  15. Handle(Openfiles &parent, size_t idx) :
  16. parent(parent), idx(idx) {}
  17. // Retrieve the ofstream from this Handle
  18. std::ofstream &os() const { return parent.files[idx]; }
  19. };
  20. Handle open(const char *prefix, unsigned player,
  21. unsigned thread_num, nbits_t depth = 0);
  22. void closeall();
  23. };
  24. // Open a file for writing with name the given prefix, and ".pX.tY"
  25. // suffix, where X is the (one-digit) player number and Y is the thread
  26. // number. If depth D is given, use "D.pX.tY" as the suffix.
  27. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  28. unsigned thread_num, nbits_t depth)
  29. {
  30. std::string filename(prefix);
  31. char suffix[20];
  32. if (depth > 0) {
  33. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  34. } else {
  35. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  36. }
  37. filename.append(suffix);
  38. std::ofstream &f = files.emplace_back(filename);
  39. if (f.fail()) {
  40. std::cerr << "Failed to open " << filename << "\n";
  41. exit(1);
  42. }
  43. return Handle(*this, files.size()-1);
  44. }
  45. // Close all the open files
  46. void Openfiles::closeall()
  47. {
  48. for (auto& f: files) {
  49. f.close();
  50. }
  51. files.clear();
  52. }
  53. // The server-to-computational-peer protocol for sending precomputed
  54. // data is:
  55. //
  56. // One byte: type
  57. // 0x80: Multiplication triple
  58. // 0x81: Multiplication half-triple
  59. // 0x01 to 0x30: RAM DPF of that depth
  60. // 0x40: Comparison DPF
  61. // 0x82: Counter (for testing)
  62. // 0x00: End of preprocessing
  63. //
  64. // Four bytes: number of objects of that type (not sent for type == 0x00)
  65. //
  66. // Then that number of objects
  67. //
  68. // Repeat the whole thing until type == 0x00 is received
  69. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  70. {
  71. int num_threads = opts.num_threads;
  72. boost::asio::thread_pool pool(num_threads);
  73. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  74. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  75. MPCTIO tio(mpcio, thread_num);
  76. Openfiles ofiles;
  77. std::vector<coro_t> coroutines;
  78. while(1) {
  79. unsigned char type = 0;
  80. unsigned int num = 0;
  81. size_t res = tio.recv_server(&type, 1);
  82. if (res < 1 || type == 0) break;
  83. tio.recv_server(&num, 4);
  84. if (type == 0x80) {
  85. // Multiplication triples
  86. auto tripfile = ofiles.open("triples",
  87. mpcio.player, thread_num);
  88. for (unsigned int i=0; i<num; ++i) {
  89. coroutines.emplace_back(
  90. [&tio, tripfile](yield_t &yield) {
  91. yield();
  92. MultTriple T = tio.triple(yield);
  93. tripfile.os() << T;
  94. });
  95. }
  96. } else if (type == 0x81) {
  97. // Multiplication half triples
  98. auto halffile = ofiles.open("halves",
  99. mpcio.player, thread_num);
  100. for (unsigned int i=0; i<num; ++i) {
  101. coroutines.emplace_back(
  102. [&tio, halffile](yield_t &yield) {
  103. yield();
  104. HalfTriple H = tio.halftriple(yield);
  105. halffile.os() << H;
  106. });
  107. }
  108. } else if (type >= 0x01 && type <= 0x30) {
  109. // RAM DPFs
  110. auto tripfile = ofiles.open("rdpf",
  111. mpcio.player, thread_num, type);
  112. for (unsigned int i=0; i<num; ++i) {
  113. coroutines.emplace_back(
  114. [&tio, &opts, tripfile, type](yield_t &yield) {
  115. yield();
  116. RDPFTriple rdpftrip =
  117. tio.rdpftriple(yield, type, opts.expand_rdpfs);
  118. printf("dep = %d\n", type);
  119. printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
  120. printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
  121. printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
  122. printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
  123. printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
  124. printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
  125. tripfile.os() << rdpftrip;
  126. });
  127. }
  128. } else if (type == 0x40) {
  129. // Comparison DPFs
  130. auto cdpffile = ofiles.open("cdpf",
  131. mpcio.player, thread_num);
  132. for (unsigned int i=0; i<num; ++i) {
  133. coroutines.emplace_back(
  134. [&tio, cdpffile](yield_t &yield) {
  135. yield();
  136. CDPF C = tio.cdpf(yield);
  137. cdpffile.os() << C;
  138. });
  139. }
  140. } else if (type == 0x82) {
  141. coroutines.emplace_back(
  142. [&tio, num](yield_t &yield) {
  143. yield();
  144. unsigned int istart = 0x31415080;
  145. for (unsigned int i=istart; i<istart+num; ++i) {
  146. tio.queue_peer(&i, sizeof(i));
  147. tio.queue_server(&i, sizeof(i));
  148. yield();
  149. unsigned int peeri, srvi;
  150. tio.recv_peer(&peeri, sizeof(peeri));
  151. tio.recv_server(&srvi, sizeof(srvi));
  152. if (peeri != i || srvi != i) {
  153. printf("Incorrect counter received: "
  154. "peer=%08x srv=%08x\n", peeri,
  155. srvi);
  156. }
  157. }
  158. });
  159. }
  160. }
  161. run_coroutines(tio, coroutines);
  162. ofiles.closeall();
  163. });
  164. }
  165. pool.join();
  166. }
  167. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  168. {
  169. int num_threads = opts.num_threads;
  170. boost::asio::thread_pool pool(num_threads);
  171. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  172. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  173. char **threadargs = args;
  174. MPCTIO stio(mpcsrvio, thread_num);
  175. Openfiles ofiles;
  176. std::vector<coro_t> coroutines;
  177. if (*threadargs && threadargs[0][0] == 'T') {
  178. // Per-thread initialization. The args look like:
  179. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  180. // Skip to the arg marking our thread
  181. char us[20];
  182. sprintf(us, "T%u", thread_num);
  183. while (*threadargs && strcmp(*threadargs, us)) {
  184. ++threadargs;
  185. }
  186. // Now skip to the next arg if there is one
  187. if (*threadargs) {
  188. ++threadargs;
  189. }
  190. }
  191. // Stop scanning for args when we get to the end or when we
  192. // get to another per-thread initialization marker
  193. while (*threadargs && threadargs[0][0] != 'T') {
  194. char *arg = strdup(*threadargs);
  195. char *colon = strchr(arg, ':');
  196. if (!colon) {
  197. std::cerr << "Args must be type:num\n";
  198. ++threadargs;
  199. free(arg);
  200. continue;
  201. }
  202. unsigned num = atoi(colon+1);
  203. *colon = '\0';
  204. char *type = arg;
  205. if (!strcmp(type, "t")) {
  206. unsigned char typetag = 0x80;
  207. stio.queue_p0(&typetag, 1);
  208. stio.queue_p0(&num, 4);
  209. stio.queue_p1(&typetag, 1);
  210. stio.queue_p1(&num, 4);
  211. for (unsigned int i=0; i<num; ++i) {
  212. coroutines.emplace_back(
  213. [&stio](yield_t &yield) {
  214. yield();
  215. stio.triple(yield);
  216. });
  217. }
  218. } else if (!strcmp(type, "h")) {
  219. unsigned char typetag = 0x81;
  220. stio.queue_p0(&typetag, 1);
  221. stio.queue_p0(&num, 4);
  222. stio.queue_p1(&typetag, 1);
  223. stio.queue_p1(&num, 4);
  224. for (unsigned int i=0; i<num; ++i) {
  225. coroutines.emplace_back(
  226. [&stio](yield_t &yield) {
  227. yield();
  228. stio.halftriple(yield);
  229. });
  230. }
  231. } else if (type[0] == 'r') {
  232. int depth = atoi(type+1);
  233. if (depth < 1 || depth > 48) {
  234. std::cerr << "Invalid DPF depth\n";
  235. } else {
  236. unsigned char typetag = depth;
  237. stio.queue_p0(&typetag, 1);
  238. stio.queue_p0(&num, 4);
  239. stio.queue_p1(&typetag, 1);
  240. stio.queue_p1(&num, 4);
  241. auto pairfile = ofiles.open("rdpf",
  242. mpcsrvio.player, thread_num, depth);
  243. for (unsigned int i=0; i<num; ++i) {
  244. coroutines.emplace_back(
  245. [&stio, &opts, pairfile, depth](yield_t &yield) {
  246. yield();
  247. RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
  248. printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
  249. printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
  250. printf("dep0 = %d\n", rdpfpair.dpf[0].depth());
  251. printf("usi1 = %016lx\n", rdpfpair.dpf[1].unit_sum_inverse);
  252. printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
  253. printf("dep1 = %d\n", rdpfpair.dpf[1].depth());
  254. if (opts.expand_rdpfs) {
  255. rdpfpair.dpf[0].expand(stio.aes_ops());
  256. rdpfpair.dpf[1].expand(stio.aes_ops());
  257. }
  258. pairfile.os() << rdpfpair;
  259. });
  260. }
  261. }
  262. } else if (!strcmp(type, "c")) {
  263. unsigned char typetag = 0x40;
  264. stio.queue_p0(&typetag, 1);
  265. stio.queue_p0(&num, 4);
  266. stio.queue_p1(&typetag, 1);
  267. stio.queue_p1(&num, 4);
  268. for (unsigned int i=0; i<num; ++i) {
  269. coroutines.emplace_back(
  270. [&stio](yield_t &yield) {
  271. yield();
  272. stio.cdpf(yield);
  273. });
  274. }
  275. } else if (!strcmp(type, "i")) {
  276. unsigned char typetag = 0x82;
  277. stio.queue_p0(&typetag, 1);
  278. stio.queue_p0(&num, 4);
  279. stio.queue_p1(&typetag, 1);
  280. stio.queue_p1(&num, 4);
  281. coroutines.emplace_back(
  282. [&stio, num] (yield_t &yield) {
  283. unsigned int istart = 0x31415080;
  284. yield();
  285. for (unsigned int i=istart; i<istart+num; ++i) {
  286. stio.queue_p0(&i, sizeof(i));
  287. stio.queue_p1(&i, sizeof(i));
  288. yield();
  289. unsigned int p0i, p1i;
  290. stio.recv_p0(&p0i, sizeof(p0i));
  291. stio.recv_p1(&p1i, sizeof(p1i));
  292. if (p0i != i || p1i != i) {
  293. printf("Incorrect counter received: "
  294. "p0=%08x p1=%08x\n", p0i,
  295. p1i);
  296. }
  297. }
  298. });
  299. }
  300. free(arg);
  301. ++threadargs;
  302. }
  303. // That's all
  304. unsigned char typetag = 0x00;
  305. stio.queue_p0(&typetag, 1);
  306. stio.queue_p1(&typetag, 1);
  307. run_coroutines(stio, coroutines);
  308. ofiles.closeall();
  309. });
  310. }
  311. pool.join();
  312. }