preproc.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x01 to 0x30: RAM DPF of that depth
  61. // 0x40: Comparison DPF
  62. // 0x80: Multiplication triple
  63. // 0x81: Multiplication half-triple
  64. // 0x82: AND triple
  65. // 0x83: Select triple
  66. // 0x8e: Counter (for testing)
  67. // 0x8f: Set number of CPU threads for this communication thread
  68. // 0x00: End of preprocessing
  69. //
  70. // One byte: subtype (not sent for type == 0x00)
  71. // For RAM DPFs, the subtype is the width (0x01 to 0x05), OR'd with
  72. // 0x80 if it is an incremental RDPF
  73. // Otherwise, it is 0x00
  74. //
  75. // Four bytes: number of objects of that type (not sent for type == 0x00)
  76. //
  77. // Then that number of objects
  78. //
  79. // Repeat the whole thing until type == 0x00 is received
  80. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  81. {
  82. int num_threads = opts.num_threads;
  83. boost::asio::thread_pool pool(num_threads);
  84. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  85. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  86. MPCTIO tio(mpcio, thread_num);
  87. Openfiles ofiles(opts.append_to_files);
  88. std::vector<coro_t> coroutines;
  89. while(1) {
  90. unsigned char type = 0;
  91. unsigned char subtype = 0;
  92. unsigned int num = 0;
  93. size_t res = tio.recv_server(&type, 1);
  94. if (res < 1 || type == 0) break;
  95. tio.recv_server(&subtype, 1);
  96. tio.recv_server(&num, 4);
  97. if (type == 0x80) {
  98. // Multiplication triples
  99. auto tripfile = ofiles.open("mults",
  100. mpcio.player, thread_num);
  101. for (unsigned int i=0; i<num; ++i) {
  102. coroutines.emplace_back(
  103. [&tio, tripfile](yield_t &yield) {
  104. yield();
  105. MultTriple T = tio.multtriple(yield);
  106. tripfile.os() << T;
  107. });
  108. }
  109. } else if (type == 0x81) {
  110. // Multiplication half triples
  111. auto halffile = ofiles.open("halves",
  112. mpcio.player, thread_num);
  113. for (unsigned int i=0; i<num; ++i) {
  114. coroutines.emplace_back(
  115. [&tio, halffile](yield_t &yield) {
  116. yield();
  117. HalfTriple H = tio.halftriple(yield);
  118. halffile.os() << H;
  119. });
  120. }
  121. } else if (type == 0x82) {
  122. // AND triples
  123. auto andfile = ofiles.open("ands",
  124. mpcio.player, thread_num);
  125. for (unsigned int i=0; i<num; ++i) {
  126. coroutines.emplace_back(
  127. [&tio, andfile](yield_t &yield) {
  128. yield();
  129. AndTriple A = tio.andtriple(yield);
  130. andfile.os() << A;
  131. });
  132. }
  133. } else if (type == 0x83) {
  134. // Select triples
  135. auto selfile = ofiles.open("selects",
  136. mpcio.player, thread_num);
  137. for (unsigned int i=0; i<num; ++i) {
  138. coroutines.emplace_back(
  139. [&tio, selfile](yield_t &yield) {
  140. yield();
  141. SelectTriple<value_t> S =
  142. tio.valselecttriple(yield);
  143. selfile.os() << S;
  144. });
  145. }
  146. } else if (type >= 0x01 && type <= 0x30) {
  147. // RAM DPFs
  148. bool incremental = false;
  149. if (subtype >= 0x80) {
  150. incremental = true;
  151. subtype -= 0x80;
  152. }
  153. assert(subtype >= 0x01 && subtype <= 0x05);
  154. char prefix[12];
  155. strcpy(prefix, incremental ? "irdpf" : "rdpf");
  156. if (subtype > 1) {
  157. sprintf(prefix+strlen(prefix), "%d_", subtype);
  158. }
  159. auto tripfile = ofiles.open(prefix,
  160. mpcio.player, thread_num, type);
  161. for (unsigned int i=0; i<num; ++i) {
  162. coroutines.emplace_back(
  163. [&tio, &opts, incremental, tripfile, type,
  164. subtype](yield_t &yield) {
  165. yield();
  166. switch(subtype) {
  167. case 1: {
  168. RDPFTriple<1> rdpftrip =
  169. tio.rdpftriple<1>(yield, type,
  170. incremental, opts.expand_rdpfs);
  171. tripfile.os() << rdpftrip;
  172. break;
  173. }
  174. case 2: {
  175. RDPFTriple<2> rdpftrip =
  176. tio.rdpftriple<2>(yield, type,
  177. incremental, opts.expand_rdpfs);
  178. tripfile.os() << rdpftrip;
  179. break;
  180. }
  181. case 3: {
  182. RDPFTriple<3> rdpftrip =
  183. tio.rdpftriple<3>(yield, type,
  184. incremental, opts.expand_rdpfs);
  185. tripfile.os() << rdpftrip;
  186. break;
  187. }
  188. case 4: {
  189. RDPFTriple<4> rdpftrip =
  190. tio.rdpftriple<4>(yield, type,
  191. incremental, opts.expand_rdpfs);
  192. tripfile.os() << rdpftrip;
  193. break;
  194. }
  195. case 5: {
  196. RDPFTriple<5> rdpftrip =
  197. tio.rdpftriple<5>(yield, type,
  198. incremental, opts.expand_rdpfs);
  199. tripfile.os() << rdpftrip;
  200. break;
  201. }
  202. }
  203. });
  204. }
  205. } else if (type == 0x40) {
  206. // Comparison DPFs
  207. auto cdpffile = ofiles.open("cdpf",
  208. mpcio.player, thread_num);
  209. for (unsigned int i=0; i<num; ++i) {
  210. coroutines.emplace_back(
  211. [&tio, cdpffile](yield_t &yield) {
  212. yield();
  213. CDPF C = tio.cdpf(yield);
  214. cdpffile.os() << C;
  215. });
  216. }
  217. } else if (type == 0x8e) {
  218. coroutines.emplace_back(
  219. [&tio, num](yield_t &yield) {
  220. yield();
  221. unsigned int istart = 0x31415080;
  222. for (unsigned int i=istart; i<istart+num; ++i) {
  223. tio.queue_peer(&i, sizeof(i));
  224. tio.queue_server(&i, sizeof(i));
  225. yield();
  226. unsigned int peeri, srvi;
  227. tio.recv_peer(&peeri, sizeof(peeri));
  228. tio.recv_server(&srvi, sizeof(srvi));
  229. if (peeri != i || srvi != i) {
  230. printf("Incorrect counter received: "
  231. "peer=%08x srv=%08x\n", peeri,
  232. srvi);
  233. }
  234. }
  235. });
  236. } else if (type == 0x8f) {
  237. tio.cpu_nthreads(num);
  238. }
  239. }
  240. run_coroutines(tio, coroutines);
  241. ofiles.closeall();
  242. });
  243. }
  244. pool.join();
  245. }
  246. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  247. {
  248. int num_threads = opts.num_threads;
  249. boost::asio::thread_pool pool(num_threads);
  250. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  251. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  252. char **threadargs = args;
  253. MPCTIO stio(mpcsrvio, thread_num);
  254. Openfiles ofiles(opts.append_to_files);
  255. std::vector<coro_t> coroutines;
  256. if (*threadargs && threadargs[0][0] == 'T') {
  257. // Per-thread initialization. The args look like:
  258. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  259. // Skip to the arg marking our thread
  260. char us[20];
  261. sprintf(us, "T%u", thread_num);
  262. while (*threadargs && strcmp(*threadargs, us)) {
  263. ++threadargs;
  264. }
  265. // Now skip to the next arg if there is one
  266. if (*threadargs) {
  267. ++threadargs;
  268. }
  269. }
  270. // Stop scanning for args when we get to the end or when we
  271. // get to another per-thread initialization marker
  272. while (*threadargs && threadargs[0][0] != 'T') {
  273. char *arg = strdup(*threadargs);
  274. char *colon = strchr(arg, ':');
  275. if (!colon) {
  276. std::cerr << "Args must be type:num\n";
  277. ++threadargs;
  278. free(arg);
  279. continue;
  280. }
  281. unsigned num = atoi(colon+1);
  282. *colon = '\0';
  283. char *type = arg;
  284. if (!strcmp(type, "m")) {
  285. unsigned char typetag = 0x80;
  286. unsigned char subtypetag = 0x00;
  287. stio.queue_p0(&typetag, 1);
  288. stio.queue_p0(&subtypetag, 1);
  289. stio.queue_p0(&num, 4);
  290. stio.queue_p1(&typetag, 1);
  291. stio.queue_p1(&subtypetag, 1);
  292. stio.queue_p1(&num, 4);
  293. for (unsigned int i=0; i<num; ++i) {
  294. coroutines.emplace_back(
  295. [&stio](yield_t &yield) {
  296. yield();
  297. stio.multtriple(yield);
  298. });
  299. }
  300. } else if (!strcmp(type, "h")) {
  301. unsigned char typetag = 0x81;
  302. unsigned char subtypetag = 0x00;
  303. stio.queue_p0(&typetag, 1);
  304. stio.queue_p0(&subtypetag, 1);
  305. stio.queue_p0(&num, 4);
  306. stio.queue_p1(&typetag, 1);
  307. stio.queue_p1(&subtypetag, 1);
  308. stio.queue_p1(&num, 4);
  309. for (unsigned int i=0; i<num; ++i) {
  310. coroutines.emplace_back(
  311. [&stio](yield_t &yield) {
  312. yield();
  313. stio.halftriple(yield);
  314. });
  315. }
  316. } else if (!strcmp(type, "a")) {
  317. unsigned char typetag = 0x82;
  318. unsigned char subtypetag = 0x00;
  319. stio.queue_p0(&typetag, 1);
  320. stio.queue_p0(&subtypetag, 1);
  321. stio.queue_p0(&num, 4);
  322. stio.queue_p1(&typetag, 1);
  323. stio.queue_p1(&subtypetag, 1);
  324. stio.queue_p1(&num, 4);
  325. for (unsigned int i=0; i<num; ++i) {
  326. coroutines.emplace_back(
  327. [&stio](yield_t &yield) {
  328. yield();
  329. stio.andtriple(yield);
  330. });
  331. }
  332. } else if (!strcmp(type, "s")) {
  333. unsigned char typetag = 0x83;
  334. unsigned char subtypetag = 0x00;
  335. stio.queue_p0(&typetag, 1);
  336. stio.queue_p0(&subtypetag, 1);
  337. stio.queue_p0(&num, 4);
  338. stio.queue_p1(&typetag, 1);
  339. stio.queue_p1(&subtypetag, 1);
  340. stio.queue_p1(&num, 4);
  341. for (unsigned int i=0; i<num; ++i) {
  342. coroutines.emplace_back(
  343. [&stio](yield_t &yield) {
  344. yield();
  345. stio.valselecttriple(yield);
  346. });
  347. }
  348. } else if (type[0] == 'r' || type[0] == 'i') {
  349. bool incremental = (type[0] == 'i');
  350. char *widthstr = strchr(type, '.');
  351. unsigned char width = 1;
  352. if (widthstr) {
  353. *widthstr = '\0';
  354. ++widthstr;
  355. width = atoi(widthstr);
  356. }
  357. int depth = atoi(type+1);
  358. if (depth < 1 || depth > 48) {
  359. std::cerr << "Invalid DPF depth\n";
  360. } else {
  361. unsigned char typetag = depth;
  362. unsigned char subtypetag = width;
  363. if (incremental) {
  364. subtypetag += 0x80;
  365. }
  366. stio.queue_p0(&typetag, 1);
  367. stio.queue_p0(&subtypetag, 1);
  368. stio.queue_p0(&num, 4);
  369. stio.queue_p1(&typetag, 1);
  370. stio.queue_p1(&subtypetag, 1);
  371. stio.queue_p1(&num, 4);
  372. char prefix[12];
  373. strcpy(prefix, incremental ? "irdpf" : "rdpf");
  374. if (width > 1) {
  375. sprintf(prefix+strlen(prefix), "%d_", width);
  376. }
  377. auto pairfile = ofiles.open(prefix,
  378. mpcsrvio.player, thread_num, depth);
  379. for (unsigned int i=0; i<num; ++i) {
  380. coroutines.emplace_back(
  381. [&stio, &opts, pairfile, depth,
  382. incremental, width](yield_t &yield) {
  383. yield();
  384. switch (width) {
  385. case 1: {
  386. RDPFPair<1> rdpfpair =
  387. stio.rdpfpair<1>(yield, depth, incremental);
  388. if (opts.expand_rdpfs) {
  389. rdpfpair.dpf[0].expand(stio.aes_ops());
  390. rdpfpair.dpf[1].expand(stio.aes_ops());
  391. }
  392. pairfile.os() << rdpfpair;
  393. break;
  394. }
  395. case 2: {
  396. RDPFPair<2> rdpfpair =
  397. stio.rdpfpair<2>(yield, depth, incremental);
  398. if (opts.expand_rdpfs) {
  399. rdpfpair.dpf[0].expand(stio.aes_ops());
  400. rdpfpair.dpf[1].expand(stio.aes_ops());
  401. }
  402. pairfile.os() << rdpfpair;
  403. break;
  404. }
  405. case 3: {
  406. RDPFPair<3> rdpfpair =
  407. stio.rdpfpair<3>(yield, depth, incremental);
  408. if (opts.expand_rdpfs) {
  409. rdpfpair.dpf[0].expand(stio.aes_ops());
  410. rdpfpair.dpf[1].expand(stio.aes_ops());
  411. }
  412. pairfile.os() << rdpfpair;
  413. break;
  414. }
  415. case 4: {
  416. RDPFPair<4> rdpfpair =
  417. stio.rdpfpair<4>(yield, depth, incremental);
  418. if (opts.expand_rdpfs) {
  419. rdpfpair.dpf[0].expand(stio.aes_ops());
  420. rdpfpair.dpf[1].expand(stio.aes_ops());
  421. }
  422. pairfile.os() << rdpfpair;
  423. break;
  424. }
  425. case 5: {
  426. RDPFPair<5> rdpfpair =
  427. stio.rdpfpair<5>(yield, depth, incremental);
  428. if (opts.expand_rdpfs) {
  429. rdpfpair.dpf[0].expand(stio.aes_ops());
  430. rdpfpair.dpf[1].expand(stio.aes_ops());
  431. }
  432. pairfile.os() << rdpfpair;
  433. break;
  434. }
  435. }
  436. });
  437. }
  438. }
  439. } else if (!strcmp(type, "c")) {
  440. unsigned char typetag = 0x40;
  441. unsigned char subtypetag = 0x00;
  442. stio.queue_p0(&typetag, 1);
  443. stio.queue_p0(&subtypetag, 1);
  444. stio.queue_p0(&num, 4);
  445. stio.queue_p1(&typetag, 1);
  446. stio.queue_p1(&subtypetag, 1);
  447. stio.queue_p1(&num, 4);
  448. for (unsigned int i=0; i<num; ++i) {
  449. coroutines.emplace_back(
  450. [&stio](yield_t &yield) {
  451. yield();
  452. stio.cdpf(yield);
  453. });
  454. }
  455. } else if (!strcmp(type, "i")) {
  456. unsigned char typetag = 0x8e;
  457. unsigned char subtypetag = 0x00;
  458. stio.queue_p0(&typetag, 1);
  459. stio.queue_p0(&subtypetag, 1);
  460. stio.queue_p0(&num, 4);
  461. stio.queue_p1(&typetag, 1);
  462. stio.queue_p1(&subtypetag, 1);
  463. stio.queue_p1(&num, 4);
  464. coroutines.emplace_back(
  465. [&stio, num] (yield_t &yield) {
  466. unsigned int istart = 0x31415080;
  467. yield();
  468. for (unsigned int i=istart; i<istart+num; ++i) {
  469. stio.queue_p0(&i, sizeof(i));
  470. stio.queue_p1(&i, sizeof(i));
  471. yield();
  472. unsigned int p0i, p1i;
  473. stio.recv_p0(&p0i, sizeof(p0i));
  474. stio.recv_p1(&p1i, sizeof(p1i));
  475. if (p0i != i || p1i != i) {
  476. printf("Incorrect counter received: "
  477. "p0=%08x p1=%08x\n", p0i,
  478. p1i);
  479. }
  480. }
  481. });
  482. } else if (!strcmp(type, "p")) {
  483. unsigned char typetag = 0x8f;
  484. unsigned char subtypetag = 0x00;
  485. stio.queue_p0(&typetag, 1);
  486. stio.queue_p0(&subtypetag, 1);
  487. stio.queue_p0(&num, 4);
  488. stio.queue_p1(&typetag, 1);
  489. stio.queue_p1(&subtypetag, 1);
  490. stio.queue_p1(&num, 4);
  491. stio.cpu_nthreads(num);
  492. }
  493. free(arg);
  494. ++threadargs;
  495. }
  496. // That's all
  497. unsigned char typetag = 0x00;
  498. stio.queue_p0(&typetag, 1);
  499. stio.queue_p1(&typetag, 1);
  500. run_coroutines(stio, coroutines);
  501. ofiles.closeall();
  502. });
  503. }
  504. pool.join();
  505. }