preproc.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x01 to 0x30: RAM DPF of that depth
  61. // 0x40: Comparison DPF
  62. // 0x80: Multiplication triple
  63. // 0x81: Multiplication half-triple
  64. // 0x82: AND triple
  65. // 0x83: Select triple
  66. // 0x8e: Counter (for testing)
  67. // 0x00: End of preprocessing
  68. //
  69. // One byte: subtype (not sent for type == 0x00)
  70. // For RAM DPFs, the subtype is the width (0x01 to 0x05), OR'd with
  71. // 0x80 if it is an incremental RDPF
  72. // Otherwise, it is 0x00
  73. //
  74. // Four bytes: number of objects of that type (not sent for type == 0x00)
  75. //
  76. // Then that number of objects
  77. //
  78. // Repeat the whole thing until type == 0x00 is received
  79. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  80. {
  81. int num_threads = opts.num_comm_threads;
  82. boost::asio::thread_pool pool(num_threads);
  83. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  84. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  85. MPCTIO tio(mpcio, thread_num, opts.num_cpu_threads);
  86. Openfiles ofiles(opts.append_to_files);
  87. std::vector<coro_t> coroutines;
  88. while(1) {
  89. unsigned char type = 0;
  90. unsigned char subtype = 0;
  91. unsigned int num = 0;
  92. size_t res = tio.recv_server(&type, 1);
  93. if (res < 1 || type == 0) break;
  94. tio.recv_server(&subtype, 1);
  95. tio.recv_server(&num, 4);
  96. if (type == 0x80) {
  97. // Multiplication triples
  98. auto tripfile = ofiles.open("mults",
  99. mpcio.player, thread_num);
  100. for (unsigned int i=0; i<num; ++i) {
  101. coroutines.emplace_back(
  102. [&tio, tripfile](yield_t &yield) {
  103. yield();
  104. MultTriple T = tio.multtriple(yield);
  105. tripfile.os() << T;
  106. });
  107. }
  108. } else if (type == 0x81) {
  109. // Multiplication half triples
  110. auto halffile = ofiles.open("halves",
  111. mpcio.player, thread_num);
  112. for (unsigned int i=0; i<num; ++i) {
  113. coroutines.emplace_back(
  114. [&tio, halffile](yield_t &yield) {
  115. yield();
  116. HalfTriple H = tio.halftriple(yield);
  117. halffile.os() << H;
  118. });
  119. }
  120. } else if (type == 0x82) {
  121. // AND triples
  122. auto andfile = ofiles.open("ands",
  123. mpcio.player, thread_num);
  124. for (unsigned int i=0; i<num; ++i) {
  125. coroutines.emplace_back(
  126. [&tio, andfile](yield_t &yield) {
  127. yield();
  128. AndTriple A = tio.andtriple(yield);
  129. andfile.os() << A;
  130. });
  131. }
  132. } else if (type == 0x83) {
  133. // Select triples
  134. auto selfile = ofiles.open("selects",
  135. mpcio.player, thread_num);
  136. for (unsigned int i=0; i<num; ++i) {
  137. coroutines.emplace_back(
  138. [&tio, selfile](yield_t &yield) {
  139. yield();
  140. SelectTriple<value_t> S =
  141. tio.valselecttriple(yield);
  142. selfile.os() << S;
  143. });
  144. }
  145. } else if (type >= 0x01 && type <= 0x30) {
  146. // RAM DPFs
  147. bool incremental = false;
  148. if (subtype >= 0x80) {
  149. incremental = true;
  150. subtype -= 0x80;
  151. }
  152. assert(subtype >= 0x01 && subtype <= 0x05);
  153. char prefix[12];
  154. strcpy(prefix, incremental ? "irdpf" : "rdpf");
  155. if (subtype > 1) {
  156. sprintf(prefix+strlen(prefix), "%d_", subtype);
  157. }
  158. auto tripfile = ofiles.open(prefix,
  159. mpcio.player, thread_num, type);
  160. for (unsigned int i=0; i<num; ++i) {
  161. coroutines.emplace_back(
  162. [&tio, &opts, incremental, tripfile, type,
  163. subtype](yield_t &yield) {
  164. yield();
  165. switch(subtype) {
  166. case 1: {
  167. RDPFTriple<1> rdpftrip =
  168. tio.rdpftriple<1>(yield, type,
  169. incremental, opts.expand_rdpfs);
  170. tripfile.os() << rdpftrip;
  171. break;
  172. }
  173. case 2: {
  174. RDPFTriple<2> rdpftrip =
  175. tio.rdpftriple<2>(yield, type,
  176. incremental, opts.expand_rdpfs);
  177. tripfile.os() << rdpftrip;
  178. break;
  179. }
  180. case 3: {
  181. RDPFTriple<3> rdpftrip =
  182. tio.rdpftriple<3>(yield, type,
  183. incremental, opts.expand_rdpfs);
  184. tripfile.os() << rdpftrip;
  185. break;
  186. }
  187. case 4: {
  188. RDPFTriple<4> rdpftrip =
  189. tio.rdpftriple<4>(yield, type,
  190. incremental, opts.expand_rdpfs);
  191. tripfile.os() << rdpftrip;
  192. break;
  193. }
  194. case 5: {
  195. RDPFTriple<5> rdpftrip =
  196. tio.rdpftriple<5>(yield, type,
  197. incremental, opts.expand_rdpfs);
  198. tripfile.os() << rdpftrip;
  199. break;
  200. }
  201. }
  202. });
  203. }
  204. } else if (type == 0x40) {
  205. // Comparison DPFs
  206. auto cdpffile = ofiles.open("cdpf",
  207. mpcio.player, thread_num);
  208. for (unsigned int i=0; i<num; ++i) {
  209. coroutines.emplace_back(
  210. [&tio, cdpffile](yield_t &yield) {
  211. yield();
  212. CDPF C = tio.cdpf(yield);
  213. cdpffile.os() << C;
  214. });
  215. }
  216. } else if (type == 0x8e) {
  217. coroutines.emplace_back(
  218. [&tio, num](yield_t &yield) {
  219. yield();
  220. unsigned int istart = 0x31415080;
  221. for (unsigned int i=istart; i<istart+num; ++i) {
  222. tio.queue_peer(&i, sizeof(i));
  223. tio.queue_server(&i, sizeof(i));
  224. yield();
  225. unsigned int peeri, srvi;
  226. tio.recv_peer(&peeri, sizeof(peeri));
  227. tio.recv_server(&srvi, sizeof(srvi));
  228. if (peeri != i || srvi != i) {
  229. printf("Incorrect counter received: "
  230. "peer=%08x srv=%08x\n", peeri,
  231. srvi);
  232. }
  233. }
  234. });
  235. }
  236. }
  237. run_coroutines(tio, coroutines);
  238. ofiles.closeall();
  239. });
  240. }
  241. pool.join();
  242. }
  243. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  244. {
  245. int num_threads = opts.num_comm_threads;
  246. boost::asio::thread_pool pool(num_threads);
  247. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  248. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  249. char **threadargs = args;
  250. MPCTIO stio(mpcsrvio, thread_num, opts.num_cpu_threads);
  251. Openfiles ofiles(opts.append_to_files);
  252. std::vector<coro_t> coroutines;
  253. if (*threadargs && threadargs[0][0] == 'T') {
  254. // Per-thread initialization. The args look like:
  255. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  256. // Skip to the arg marking our thread
  257. char us[20];
  258. sprintf(us, "T%u", thread_num);
  259. while (*threadargs && strcmp(*threadargs, us)) {
  260. ++threadargs;
  261. }
  262. // Now skip to the next arg if there is one
  263. if (*threadargs) {
  264. ++threadargs;
  265. }
  266. }
  267. // Stop scanning for args when we get to the end or when we
  268. // get to another per-thread initialization marker
  269. while (*threadargs && threadargs[0][0] != 'T') {
  270. char *arg = strdup(*threadargs);
  271. char *colon = strchr(arg, ':');
  272. if (!colon) {
  273. std::cerr << "Args must be type:num\n";
  274. ++threadargs;
  275. free(arg);
  276. continue;
  277. }
  278. unsigned num = atoi(colon+1);
  279. *colon = '\0';
  280. char *type = arg;
  281. if (!strcmp(type, "m")) {
  282. unsigned char typetag = 0x80;
  283. unsigned char subtypetag = 0x00;
  284. stio.queue_p0(&typetag, 1);
  285. stio.queue_p0(&subtypetag, 1);
  286. stio.queue_p0(&num, 4);
  287. stio.queue_p1(&typetag, 1);
  288. stio.queue_p1(&subtypetag, 1);
  289. stio.queue_p1(&num, 4);
  290. for (unsigned int i=0; i<num; ++i) {
  291. coroutines.emplace_back(
  292. [&stio](yield_t &yield) {
  293. yield();
  294. stio.multtriple(yield);
  295. });
  296. }
  297. } else if (!strcmp(type, "h")) {
  298. unsigned char typetag = 0x81;
  299. unsigned char subtypetag = 0x00;
  300. stio.queue_p0(&typetag, 1);
  301. stio.queue_p0(&subtypetag, 1);
  302. stio.queue_p0(&num, 4);
  303. stio.queue_p1(&typetag, 1);
  304. stio.queue_p1(&subtypetag, 1);
  305. stio.queue_p1(&num, 4);
  306. for (unsigned int i=0; i<num; ++i) {
  307. coroutines.emplace_back(
  308. [&stio](yield_t &yield) {
  309. yield();
  310. stio.halftriple(yield);
  311. });
  312. }
  313. } else if (!strcmp(type, "a")) {
  314. unsigned char typetag = 0x82;
  315. unsigned char subtypetag = 0x00;
  316. stio.queue_p0(&typetag, 1);
  317. stio.queue_p0(&subtypetag, 1);
  318. stio.queue_p0(&num, 4);
  319. stio.queue_p1(&typetag, 1);
  320. stio.queue_p1(&subtypetag, 1);
  321. stio.queue_p1(&num, 4);
  322. for (unsigned int i=0; i<num; ++i) {
  323. coroutines.emplace_back(
  324. [&stio](yield_t &yield) {
  325. yield();
  326. stio.andtriple(yield);
  327. });
  328. }
  329. } else if (!strcmp(type, "s")) {
  330. unsigned char typetag = 0x83;
  331. unsigned char subtypetag = 0x00;
  332. stio.queue_p0(&typetag, 1);
  333. stio.queue_p0(&subtypetag, 1);
  334. stio.queue_p0(&num, 4);
  335. stio.queue_p1(&typetag, 1);
  336. stio.queue_p1(&subtypetag, 1);
  337. stio.queue_p1(&num, 4);
  338. for (unsigned int i=0; i<num; ++i) {
  339. coroutines.emplace_back(
  340. [&stio](yield_t &yield) {
  341. yield();
  342. stio.valselecttriple(yield);
  343. });
  344. }
  345. } else if (type[0] == 'r' || type[0] == 'i') {
  346. bool incremental = (type[0] == 'i');
  347. char *widthstr = strchr(type, '.');
  348. unsigned char width = 1;
  349. if (widthstr) {
  350. *widthstr = '\0';
  351. ++widthstr;
  352. width = atoi(widthstr);
  353. }
  354. int depth = atoi(type+1);
  355. if (depth < 1 || depth > 48) {
  356. std::cerr << "Invalid DPF depth\n";
  357. } else {
  358. unsigned char typetag = depth;
  359. unsigned char subtypetag = width;
  360. if (incremental) {
  361. subtypetag += 0x80;
  362. }
  363. stio.queue_p0(&typetag, 1);
  364. stio.queue_p0(&subtypetag, 1);
  365. stio.queue_p0(&num, 4);
  366. stio.queue_p1(&typetag, 1);
  367. stio.queue_p1(&subtypetag, 1);
  368. stio.queue_p1(&num, 4);
  369. char prefix[12];
  370. strcpy(prefix, incremental ? "irdpf" : "rdpf");
  371. if (width > 1) {
  372. sprintf(prefix+strlen(prefix), "%d_", width);
  373. }
  374. auto pairfile = ofiles.open(prefix,
  375. mpcsrvio.player, thread_num, depth);
  376. for (unsigned int i=0; i<num; ++i) {
  377. coroutines.emplace_back(
  378. [&stio, &opts, pairfile, depth,
  379. incremental, width](yield_t &yield) {
  380. yield();
  381. switch (width) {
  382. case 1: {
  383. RDPFPair<1> rdpfpair =
  384. stio.rdpfpair<1>(yield, depth, incremental);
  385. if (opts.expand_rdpfs) {
  386. rdpfpair.dpf[0].expand(stio.aes_ops());
  387. rdpfpair.dpf[1].expand(stio.aes_ops());
  388. }
  389. pairfile.os() << rdpfpair;
  390. break;
  391. }
  392. case 2: {
  393. RDPFPair<2> rdpfpair =
  394. stio.rdpfpair<2>(yield, depth, incremental);
  395. if (opts.expand_rdpfs) {
  396. rdpfpair.dpf[0].expand(stio.aes_ops());
  397. rdpfpair.dpf[1].expand(stio.aes_ops());
  398. }
  399. pairfile.os() << rdpfpair;
  400. break;
  401. }
  402. case 3: {
  403. RDPFPair<3> rdpfpair =
  404. stio.rdpfpair<3>(yield, depth, incremental);
  405. if (opts.expand_rdpfs) {
  406. rdpfpair.dpf[0].expand(stio.aes_ops());
  407. rdpfpair.dpf[1].expand(stio.aes_ops());
  408. }
  409. pairfile.os() << rdpfpair;
  410. break;
  411. }
  412. case 4: {
  413. RDPFPair<4> rdpfpair =
  414. stio.rdpfpair<4>(yield, depth, incremental);
  415. if (opts.expand_rdpfs) {
  416. rdpfpair.dpf[0].expand(stio.aes_ops());
  417. rdpfpair.dpf[1].expand(stio.aes_ops());
  418. }
  419. pairfile.os() << rdpfpair;
  420. break;
  421. }
  422. case 5: {
  423. RDPFPair<5> rdpfpair =
  424. stio.rdpfpair<5>(yield, depth, incremental);
  425. if (opts.expand_rdpfs) {
  426. rdpfpair.dpf[0].expand(stio.aes_ops());
  427. rdpfpair.dpf[1].expand(stio.aes_ops());
  428. }
  429. pairfile.os() << rdpfpair;
  430. break;
  431. }
  432. }
  433. });
  434. }
  435. }
  436. } else if (!strcmp(type, "c")) {
  437. unsigned char typetag = 0x40;
  438. unsigned char subtypetag = 0x00;
  439. stio.queue_p0(&typetag, 1);
  440. stio.queue_p0(&subtypetag, 1);
  441. stio.queue_p0(&num, 4);
  442. stio.queue_p1(&typetag, 1);
  443. stio.queue_p1(&subtypetag, 1);
  444. stio.queue_p1(&num, 4);
  445. for (unsigned int i=0; i<num; ++i) {
  446. coroutines.emplace_back(
  447. [&stio](yield_t &yield) {
  448. yield();
  449. stio.cdpf(yield);
  450. });
  451. }
  452. } else if (!strcmp(type, "k")) {
  453. unsigned char typetag = 0x8e;
  454. unsigned char subtypetag = 0x00;
  455. stio.queue_p0(&typetag, 1);
  456. stio.queue_p0(&subtypetag, 1);
  457. stio.queue_p0(&num, 4);
  458. stio.queue_p1(&typetag, 1);
  459. stio.queue_p1(&subtypetag, 1);
  460. stio.queue_p1(&num, 4);
  461. coroutines.emplace_back(
  462. [&stio, num] (yield_t &yield) {
  463. unsigned int istart = 0x31415080;
  464. yield();
  465. for (unsigned int i=istart; i<istart+num; ++i) {
  466. stio.queue_p0(&i, sizeof(i));
  467. stio.queue_p1(&i, sizeof(i));
  468. yield();
  469. unsigned int p0i, p1i;
  470. stio.recv_p0(&p0i, sizeof(p0i));
  471. stio.recv_p1(&p1i, sizeof(p1i));
  472. if (p0i != i || p1i != i) {
  473. printf("Incorrect counter received: "
  474. "p0=%08x p1=%08x\n", p0i,
  475. p1i);
  476. }
  477. }
  478. });
  479. }
  480. free(arg);
  481. ++threadargs;
  482. }
  483. // That's all
  484. unsigned char typetag = 0x00;
  485. stio.queue_p0(&typetag, 1);
  486. stio.queue_p1(&typetag, 1);
  487. run_coroutines(stio, coroutines);
  488. ofiles.closeall();
  489. });
  490. }
  491. pool.join();
  492. }