preproc.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x01 to 0x30: RAM DPF of that depth
  61. // 0x40: Comparison DPF
  62. // 0x80: Multiplication triple
  63. // 0x81: Multiplication half-triple
  64. // 0x82: AND triple
  65. // 0x83: Select triple
  66. // 0x8e: Counter (for testing)
  67. // 0x8f: Set number of CPU threads for this communication thread
  68. // 0x00: End of preprocessing
  69. //
  70. // One byte: subtype (not sent for type == 0x00)
  71. // For RAM DPFs, the subtype is the width (0x01 to 0x05)
  72. // Otherwise, it is 0x00
  73. //
  74. // Four bytes: number of objects of that type (not sent for type == 0x00)
  75. //
  76. // Then that number of objects
  77. //
  78. // Repeat the whole thing until type == 0x00 is received
  79. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  80. {
  81. int num_threads = opts.num_threads;
  82. boost::asio::thread_pool pool(num_threads);
  83. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  84. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  85. MPCTIO tio(mpcio, thread_num);
  86. Openfiles ofiles(opts.append_to_files);
  87. std::vector<coro_t> coroutines;
  88. while(1) {
  89. unsigned char type = 0;
  90. unsigned char subtype = 0;
  91. unsigned int num = 0;
  92. size_t res = tio.recv_server(&type, 1);
  93. if (res < 1 || type == 0) break;
  94. tio.recv_server(&subtype, 1);
  95. tio.recv_server(&num, 4);
  96. if (type == 0x80) {
  97. // Multiplication triples
  98. auto tripfile = ofiles.open("mults",
  99. mpcio.player, thread_num);
  100. for (unsigned int i=0; i<num; ++i) {
  101. coroutines.emplace_back(
  102. [&tio, tripfile](yield_t &yield) {
  103. yield();
  104. MultTriple T = tio.multtriple(yield);
  105. tripfile.os() << T;
  106. });
  107. }
  108. } else if (type == 0x81) {
  109. // Multiplication half triples
  110. auto halffile = ofiles.open("halves",
  111. mpcio.player, thread_num);
  112. for (unsigned int i=0; i<num; ++i) {
  113. coroutines.emplace_back(
  114. [&tio, halffile](yield_t &yield) {
  115. yield();
  116. HalfTriple H = tio.halftriple(yield);
  117. halffile.os() << H;
  118. });
  119. }
  120. } else if (type == 0x82) {
  121. // AND triples
  122. auto andfile = ofiles.open("ands",
  123. mpcio.player, thread_num);
  124. for (unsigned int i=0; i<num; ++i) {
  125. coroutines.emplace_back(
  126. [&tio, andfile](yield_t &yield) {
  127. yield();
  128. AndTriple A = tio.andtriple(yield);
  129. andfile.os() << A;
  130. });
  131. }
  132. } else if (type == 0x83) {
  133. // Select triples
  134. auto selfile = ofiles.open("selects",
  135. mpcio.player, thread_num);
  136. for (unsigned int i=0; i<num; ++i) {
  137. coroutines.emplace_back(
  138. [&tio, selfile](yield_t &yield) {
  139. yield();
  140. SelectTriple<value_t> S =
  141. tio.valselecttriple(yield);
  142. selfile.os() << S;
  143. });
  144. }
  145. } else if (type >= 0x01 && type <= 0x30) {
  146. // RAM DPFs
  147. assert(subtype >= 0x01 && subtype <= 0x05);
  148. char prefix[11];
  149. strcpy(prefix, "rdpf");
  150. if (subtype > 1) {
  151. sprintf(prefix+strlen(prefix), "%d_", subtype);
  152. }
  153. bool incremental = false;
  154. auto tripfile = ofiles.open(prefix,
  155. mpcio.player, thread_num, type);
  156. for (unsigned int i=0; i<num; ++i) {
  157. coroutines.emplace_back(
  158. [&tio, &opts, incremental, tripfile, type,
  159. subtype](yield_t &yield) {
  160. yield();
  161. switch(subtype) {
  162. case 1: {
  163. RDPFTriple<1> rdpftrip =
  164. tio.rdpftriple<1>(yield, type,
  165. incremental, opts.expand_rdpfs);
  166. tripfile.os() << rdpftrip;
  167. break;
  168. }
  169. case 2: {
  170. RDPFTriple<2> rdpftrip =
  171. tio.rdpftriple<2>(yield, type,
  172. incremental, opts.expand_rdpfs);
  173. tripfile.os() << rdpftrip;
  174. break;
  175. }
  176. case 3: {
  177. RDPFTriple<3> rdpftrip =
  178. tio.rdpftriple<3>(yield, type,
  179. incremental, opts.expand_rdpfs);
  180. tripfile.os() << rdpftrip;
  181. break;
  182. }
  183. case 4: {
  184. RDPFTriple<4> rdpftrip =
  185. tio.rdpftriple<4>(yield, type,
  186. incremental, opts.expand_rdpfs);
  187. tripfile.os() << rdpftrip;
  188. break;
  189. }
  190. case 5: {
  191. RDPFTriple<5> rdpftrip =
  192. tio.rdpftriple<5>(yield, type,
  193. incremental, opts.expand_rdpfs);
  194. tripfile.os() << rdpftrip;
  195. break;
  196. }
  197. }
  198. });
  199. }
  200. } else if (type == 0x40) {
  201. // Comparison DPFs
  202. auto cdpffile = ofiles.open("cdpf",
  203. mpcio.player, thread_num);
  204. for (unsigned int i=0; i<num; ++i) {
  205. coroutines.emplace_back(
  206. [&tio, cdpffile](yield_t &yield) {
  207. yield();
  208. CDPF C = tio.cdpf(yield);
  209. cdpffile.os() << C;
  210. });
  211. }
  212. } else if (type == 0x8e) {
  213. coroutines.emplace_back(
  214. [&tio, num](yield_t &yield) {
  215. yield();
  216. unsigned int istart = 0x31415080;
  217. for (unsigned int i=istart; i<istart+num; ++i) {
  218. tio.queue_peer(&i, sizeof(i));
  219. tio.queue_server(&i, sizeof(i));
  220. yield();
  221. unsigned int peeri, srvi;
  222. tio.recv_peer(&peeri, sizeof(peeri));
  223. tio.recv_server(&srvi, sizeof(srvi));
  224. if (peeri != i || srvi != i) {
  225. printf("Incorrect counter received: "
  226. "peer=%08x srv=%08x\n", peeri,
  227. srvi);
  228. }
  229. }
  230. });
  231. } else if (type == 0x8f) {
  232. tio.cpu_nthreads(num);
  233. }
  234. }
  235. run_coroutines(tio, coroutines);
  236. ofiles.closeall();
  237. });
  238. }
  239. pool.join();
  240. }
  241. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  242. {
  243. int num_threads = opts.num_threads;
  244. boost::asio::thread_pool pool(num_threads);
  245. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  246. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  247. char **threadargs = args;
  248. MPCTIO stio(mpcsrvio, thread_num);
  249. Openfiles ofiles(opts.append_to_files);
  250. std::vector<coro_t> coroutines;
  251. if (*threadargs && threadargs[0][0] == 'T') {
  252. // Per-thread initialization. The args look like:
  253. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  254. // Skip to the arg marking our thread
  255. char us[20];
  256. sprintf(us, "T%u", thread_num);
  257. while (*threadargs && strcmp(*threadargs, us)) {
  258. ++threadargs;
  259. }
  260. // Now skip to the next arg if there is one
  261. if (*threadargs) {
  262. ++threadargs;
  263. }
  264. }
  265. // Stop scanning for args when we get to the end or when we
  266. // get to another per-thread initialization marker
  267. while (*threadargs && threadargs[0][0] != 'T') {
  268. char *arg = strdup(*threadargs);
  269. char *colon = strchr(arg, ':');
  270. if (!colon) {
  271. std::cerr << "Args must be type:num\n";
  272. ++threadargs;
  273. free(arg);
  274. continue;
  275. }
  276. unsigned num = atoi(colon+1);
  277. *colon = '\0';
  278. char *type = arg;
  279. if (!strcmp(type, "m")) {
  280. unsigned char typetag = 0x80;
  281. unsigned char subtypetag = 0x00;
  282. stio.queue_p0(&typetag, 1);
  283. stio.queue_p0(&subtypetag, 1);
  284. stio.queue_p0(&num, 4);
  285. stio.queue_p1(&typetag, 1);
  286. stio.queue_p1(&subtypetag, 1);
  287. stio.queue_p1(&num, 4);
  288. for (unsigned int i=0; i<num; ++i) {
  289. coroutines.emplace_back(
  290. [&stio](yield_t &yield) {
  291. yield();
  292. stio.multtriple(yield);
  293. });
  294. }
  295. } else if (!strcmp(type, "h")) {
  296. unsigned char typetag = 0x81;
  297. unsigned char subtypetag = 0x00;
  298. stio.queue_p0(&typetag, 1);
  299. stio.queue_p0(&subtypetag, 1);
  300. stio.queue_p0(&num, 4);
  301. stio.queue_p1(&typetag, 1);
  302. stio.queue_p1(&subtypetag, 1);
  303. stio.queue_p1(&num, 4);
  304. for (unsigned int i=0; i<num; ++i) {
  305. coroutines.emplace_back(
  306. [&stio](yield_t &yield) {
  307. yield();
  308. stio.halftriple(yield);
  309. });
  310. }
  311. } else if (!strcmp(type, "a")) {
  312. unsigned char typetag = 0x82;
  313. unsigned char subtypetag = 0x00;
  314. stio.queue_p0(&typetag, 1);
  315. stio.queue_p0(&subtypetag, 1);
  316. stio.queue_p0(&num, 4);
  317. stio.queue_p1(&typetag, 1);
  318. stio.queue_p1(&subtypetag, 1);
  319. stio.queue_p1(&num, 4);
  320. for (unsigned int i=0; i<num; ++i) {
  321. coroutines.emplace_back(
  322. [&stio](yield_t &yield) {
  323. yield();
  324. stio.andtriple(yield);
  325. });
  326. }
  327. } else if (!strcmp(type, "s")) {
  328. unsigned char typetag = 0x83;
  329. unsigned char subtypetag = 0x00;
  330. stio.queue_p0(&typetag, 1);
  331. stio.queue_p0(&subtypetag, 1);
  332. stio.queue_p0(&num, 4);
  333. stio.queue_p1(&typetag, 1);
  334. stio.queue_p1(&subtypetag, 1);
  335. stio.queue_p1(&num, 4);
  336. for (unsigned int i=0; i<num; ++i) {
  337. coroutines.emplace_back(
  338. [&stio](yield_t &yield) {
  339. yield();
  340. stio.valselecttriple(yield);
  341. });
  342. }
  343. } else if (type[0] == 'r') {
  344. char *widthstr = strchr(type, '.');
  345. unsigned char width = 1;
  346. if (widthstr) {
  347. *widthstr = '\0';
  348. ++widthstr;
  349. width = atoi(widthstr);
  350. }
  351. int depth = atoi(type+1);
  352. if (depth < 1 || depth > 48) {
  353. std::cerr << "Invalid DPF depth\n";
  354. } else {
  355. unsigned char typetag = depth;
  356. unsigned char subtypetag = width;
  357. stio.queue_p0(&typetag, 1);
  358. stio.queue_p0(&subtypetag, 1);
  359. stio.queue_p0(&num, 4);
  360. stio.queue_p1(&typetag, 1);
  361. stio.queue_p1(&subtypetag, 1);
  362. stio.queue_p1(&num, 4);
  363. char prefix[11];
  364. strcpy(prefix, "rdpf");
  365. if (width > 1) {
  366. sprintf(prefix+strlen(prefix), "%d_", width);
  367. }
  368. auto pairfile = ofiles.open(prefix,
  369. mpcsrvio.player, thread_num, depth);
  370. for (unsigned int i=0; i<num; ++i) {
  371. coroutines.emplace_back(
  372. [&stio, &opts, pairfile, depth, width](yield_t &yield) {
  373. yield();
  374. switch (width) {
  375. case 1: {
  376. RDPFPair<1> rdpfpair =
  377. stio.rdpfpair<1>(yield, depth);
  378. if (opts.expand_rdpfs) {
  379. rdpfpair.dpf[0].expand(stio.aes_ops());
  380. rdpfpair.dpf[1].expand(stio.aes_ops());
  381. }
  382. pairfile.os() << rdpfpair;
  383. break;
  384. }
  385. case 2: {
  386. RDPFPair<2> rdpfpair =
  387. stio.rdpfpair<2>(yield, depth);
  388. if (opts.expand_rdpfs) {
  389. rdpfpair.dpf[0].expand(stio.aes_ops());
  390. rdpfpair.dpf[1].expand(stio.aes_ops());
  391. }
  392. pairfile.os() << rdpfpair;
  393. break;
  394. }
  395. case 3: {
  396. RDPFPair<3> rdpfpair =
  397. stio.rdpfpair<3>(yield, depth);
  398. if (opts.expand_rdpfs) {
  399. rdpfpair.dpf[0].expand(stio.aes_ops());
  400. rdpfpair.dpf[1].expand(stio.aes_ops());
  401. }
  402. pairfile.os() << rdpfpair;
  403. break;
  404. }
  405. case 4: {
  406. RDPFPair<4> rdpfpair =
  407. stio.rdpfpair<4>(yield, depth);
  408. if (opts.expand_rdpfs) {
  409. rdpfpair.dpf[0].expand(stio.aes_ops());
  410. rdpfpair.dpf[1].expand(stio.aes_ops());
  411. }
  412. pairfile.os() << rdpfpair;
  413. break;
  414. }
  415. case 5: {
  416. RDPFPair<5> rdpfpair =
  417. stio.rdpfpair<5>(yield, depth);
  418. if (opts.expand_rdpfs) {
  419. rdpfpair.dpf[0].expand(stio.aes_ops());
  420. rdpfpair.dpf[1].expand(stio.aes_ops());
  421. }
  422. pairfile.os() << rdpfpair;
  423. break;
  424. }
  425. }
  426. });
  427. }
  428. }
  429. } else if (!strcmp(type, "c")) {
  430. unsigned char typetag = 0x40;
  431. unsigned char subtypetag = 0x00;
  432. stio.queue_p0(&typetag, 1);
  433. stio.queue_p0(&subtypetag, 1);
  434. stio.queue_p0(&num, 4);
  435. stio.queue_p1(&typetag, 1);
  436. stio.queue_p1(&subtypetag, 1);
  437. stio.queue_p1(&num, 4);
  438. for (unsigned int i=0; i<num; ++i) {
  439. coroutines.emplace_back(
  440. [&stio](yield_t &yield) {
  441. yield();
  442. stio.cdpf(yield);
  443. });
  444. }
  445. } else if (!strcmp(type, "i")) {
  446. unsigned char typetag = 0x8e;
  447. unsigned char subtypetag = 0x00;
  448. stio.queue_p0(&typetag, 1);
  449. stio.queue_p0(&subtypetag, 1);
  450. stio.queue_p0(&num, 4);
  451. stio.queue_p1(&typetag, 1);
  452. stio.queue_p1(&subtypetag, 1);
  453. stio.queue_p1(&num, 4);
  454. coroutines.emplace_back(
  455. [&stio, num] (yield_t &yield) {
  456. unsigned int istart = 0x31415080;
  457. yield();
  458. for (unsigned int i=istart; i<istart+num; ++i) {
  459. stio.queue_p0(&i, sizeof(i));
  460. stio.queue_p1(&i, sizeof(i));
  461. yield();
  462. unsigned int p0i, p1i;
  463. stio.recv_p0(&p0i, sizeof(p0i));
  464. stio.recv_p1(&p1i, sizeof(p1i));
  465. if (p0i != i || p1i != i) {
  466. printf("Incorrect counter received: "
  467. "p0=%08x p1=%08x\n", p0i,
  468. p1i);
  469. }
  470. }
  471. });
  472. } else if (!strcmp(type, "p")) {
  473. unsigned char typetag = 0x8f;
  474. unsigned char subtypetag = 0x00;
  475. stio.queue_p0(&typetag, 1);
  476. stio.queue_p0(&subtypetag, 1);
  477. stio.queue_p0(&num, 4);
  478. stio.queue_p1(&typetag, 1);
  479. stio.queue_p1(&subtypetag, 1);
  480. stio.queue_p1(&num, 4);
  481. stio.cpu_nthreads(num);
  482. }
  483. free(arg);
  484. ++threadargs;
  485. }
  486. // That's all
  487. unsigned char typetag = 0x00;
  488. stio.queue_p0(&typetag, 1);
  489. stio.queue_p1(&typetag, 1);
  490. run_coroutines(stio, coroutines);
  491. ofiles.closeall();
  492. });
  493. }
  494. pool.join();
  495. }