preproc.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x01 to 0x30: RAM DPF of that depth
  61. // 0x40: Comparison DPF
  62. // 0x80: Multiplication triple
  63. // 0x81: Multiplication half-triple
  64. // 0x82: AND triple
  65. // 0x83: Select triple
  66. // 0x8e: Counter (for testing)
  67. // 0x8f: Set number of CPU threads for this communication thread
  68. // 0x00: End of preprocessing
  69. //
  70. // One byte: subtype (not sent for type == 0x00)
  71. // For RAM DPFs, the subtype is the width (0x01 to 0x05)
  72. // Otherwise, it is 0x00
  73. //
  74. // Four bytes: number of objects of that type (not sent for type == 0x00)
  75. //
  76. // Then that number of objects
  77. //
  78. // Repeat the whole thing until type == 0x00 is received
  79. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  80. {
  81. int num_threads = opts.num_threads;
  82. boost::asio::thread_pool pool(num_threads);
  83. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  84. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  85. MPCTIO tio(mpcio, thread_num);
  86. Openfiles ofiles(opts.append_to_files);
  87. std::vector<coro_t> coroutines;
  88. while(1) {
  89. unsigned char type = 0;
  90. unsigned char subtype = 0;
  91. unsigned int num = 0;
  92. size_t res = tio.recv_server(&type, 1);
  93. if (res < 1 || type == 0) break;
  94. tio.recv_server(&subtype, 1);
  95. tio.recv_server(&num, 4);
  96. if (type == 0x80) {
  97. // Multiplication triples
  98. auto tripfile = ofiles.open("mults",
  99. mpcio.player, thread_num);
  100. for (unsigned int i=0; i<num; ++i) {
  101. coroutines.emplace_back(
  102. [&tio, tripfile](yield_t &yield) {
  103. yield();
  104. MultTriple T = tio.multtriple(yield);
  105. tripfile.os() << T;
  106. });
  107. }
  108. } else if (type == 0x81) {
  109. // Multiplication half triples
  110. auto halffile = ofiles.open("halves",
  111. mpcio.player, thread_num);
  112. for (unsigned int i=0; i<num; ++i) {
  113. coroutines.emplace_back(
  114. [&tio, halffile](yield_t &yield) {
  115. yield();
  116. HalfTriple H = tio.halftriple(yield);
  117. halffile.os() << H;
  118. });
  119. }
  120. } else if (type == 0x82) {
  121. // AND triples
  122. auto andfile = ofiles.open("ands",
  123. mpcio.player, thread_num);
  124. for (unsigned int i=0; i<num; ++i) {
  125. coroutines.emplace_back(
  126. [&tio, andfile](yield_t &yield) {
  127. yield();
  128. AndTriple A = tio.andtriple(yield);
  129. andfile.os() << A;
  130. });
  131. }
  132. } else if (type == 0x83) {
  133. // Select triples
  134. auto selfile = ofiles.open("selects",
  135. mpcio.player, thread_num);
  136. for (unsigned int i=0; i<num; ++i) {
  137. coroutines.emplace_back(
  138. [&tio, selfile](yield_t &yield) {
  139. yield();
  140. SelectTriple<value_t> S =
  141. tio.valselecttriple(yield);
  142. selfile.os() << S;
  143. });
  144. }
  145. } else if (type >= 0x01 && type <= 0x30) {
  146. // RAM DPFs
  147. assert(subtype >= 0x01 && subtype <= 0x05);
  148. char prefix[11];
  149. strcpy(prefix, "rdpf");
  150. if (subtype > 1) {
  151. sprintf(prefix+strlen(prefix), "%d_", subtype);
  152. }
  153. auto tripfile = ofiles.open(prefix,
  154. mpcio.player, thread_num, type);
  155. for (unsigned int i=0; i<num; ++i) {
  156. coroutines.emplace_back(
  157. [&tio, &opts, tripfile, type, subtype](yield_t &yield) {
  158. yield();
  159. switch(subtype) {
  160. case 1: {
  161. RDPFTriple<1> rdpftrip =
  162. tio.rdpftriple<1>(yield, type, opts.expand_rdpfs);
  163. tripfile.os() << rdpftrip;
  164. break;
  165. }
  166. case 2: {
  167. RDPFTriple<2> rdpftrip =
  168. tio.rdpftriple<2>(yield, type, opts.expand_rdpfs);
  169. tripfile.os() << rdpftrip;
  170. break;
  171. }
  172. case 3: {
  173. RDPFTriple<3> rdpftrip =
  174. tio.rdpftriple<3>(yield, type, opts.expand_rdpfs);
  175. tripfile.os() << rdpftrip;
  176. break;
  177. }
  178. case 4: {
  179. RDPFTriple<4> rdpftrip =
  180. tio.rdpftriple<4>(yield, type, opts.expand_rdpfs);
  181. tripfile.os() << rdpftrip;
  182. break;
  183. }
  184. case 5: {
  185. RDPFTriple<5> rdpftrip =
  186. tio.rdpftriple<5>(yield, type, opts.expand_rdpfs);
  187. tripfile.os() << rdpftrip;
  188. break;
  189. }
  190. }
  191. });
  192. }
  193. } else if (type == 0x40) {
  194. // Comparison DPFs
  195. auto cdpffile = ofiles.open("cdpf",
  196. mpcio.player, thread_num);
  197. for (unsigned int i=0; i<num; ++i) {
  198. coroutines.emplace_back(
  199. [&tio, cdpffile](yield_t &yield) {
  200. yield();
  201. CDPF C = tio.cdpf(yield);
  202. cdpffile.os() << C;
  203. });
  204. }
  205. } else if (type == 0x8e) {
  206. coroutines.emplace_back(
  207. [&tio, num](yield_t &yield) {
  208. yield();
  209. unsigned int istart = 0x31415080;
  210. for (unsigned int i=istart; i<istart+num; ++i) {
  211. tio.queue_peer(&i, sizeof(i));
  212. tio.queue_server(&i, sizeof(i));
  213. yield();
  214. unsigned int peeri, srvi;
  215. tio.recv_peer(&peeri, sizeof(peeri));
  216. tio.recv_server(&srvi, sizeof(srvi));
  217. if (peeri != i || srvi != i) {
  218. printf("Incorrect counter received: "
  219. "peer=%08x srv=%08x\n", peeri,
  220. srvi);
  221. }
  222. }
  223. });
  224. } else if (type == 0x8f) {
  225. tio.cpu_nthreads(num);
  226. }
  227. }
  228. run_coroutines(tio, coroutines);
  229. ofiles.closeall();
  230. });
  231. }
  232. pool.join();
  233. }
  234. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  235. {
  236. int num_threads = opts.num_threads;
  237. boost::asio::thread_pool pool(num_threads);
  238. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  239. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  240. char **threadargs = args;
  241. MPCTIO stio(mpcsrvio, thread_num);
  242. Openfiles ofiles(opts.append_to_files);
  243. std::vector<coro_t> coroutines;
  244. if (*threadargs && threadargs[0][0] == 'T') {
  245. // Per-thread initialization. The args look like:
  246. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  247. // Skip to the arg marking our thread
  248. char us[20];
  249. sprintf(us, "T%u", thread_num);
  250. while (*threadargs && strcmp(*threadargs, us)) {
  251. ++threadargs;
  252. }
  253. // Now skip to the next arg if there is one
  254. if (*threadargs) {
  255. ++threadargs;
  256. }
  257. }
  258. // Stop scanning for args when we get to the end or when we
  259. // get to another per-thread initialization marker
  260. while (*threadargs && threadargs[0][0] != 'T') {
  261. char *arg = strdup(*threadargs);
  262. char *colon = strchr(arg, ':');
  263. if (!colon) {
  264. std::cerr << "Args must be type:num\n";
  265. ++threadargs;
  266. free(arg);
  267. continue;
  268. }
  269. unsigned num = atoi(colon+1);
  270. *colon = '\0';
  271. char *type = arg;
  272. if (!strcmp(type, "m")) {
  273. unsigned char typetag = 0x80;
  274. unsigned char subtypetag = 0x00;
  275. stio.queue_p0(&typetag, 1);
  276. stio.queue_p0(&subtypetag, 1);
  277. stio.queue_p0(&num, 4);
  278. stio.queue_p1(&typetag, 1);
  279. stio.queue_p1(&subtypetag, 1);
  280. stio.queue_p1(&num, 4);
  281. for (unsigned int i=0; i<num; ++i) {
  282. coroutines.emplace_back(
  283. [&stio](yield_t &yield) {
  284. yield();
  285. stio.multtriple(yield);
  286. });
  287. }
  288. } else if (!strcmp(type, "h")) {
  289. unsigned char typetag = 0x81;
  290. unsigned char subtypetag = 0x00;
  291. stio.queue_p0(&typetag, 1);
  292. stio.queue_p0(&subtypetag, 1);
  293. stio.queue_p0(&num, 4);
  294. stio.queue_p1(&typetag, 1);
  295. stio.queue_p1(&subtypetag, 1);
  296. stio.queue_p1(&num, 4);
  297. for (unsigned int i=0; i<num; ++i) {
  298. coroutines.emplace_back(
  299. [&stio](yield_t &yield) {
  300. yield();
  301. stio.halftriple(yield);
  302. });
  303. }
  304. } else if (!strcmp(type, "a")) {
  305. unsigned char typetag = 0x82;
  306. unsigned char subtypetag = 0x00;
  307. stio.queue_p0(&typetag, 1);
  308. stio.queue_p0(&subtypetag, 1);
  309. stio.queue_p0(&num, 4);
  310. stio.queue_p1(&typetag, 1);
  311. stio.queue_p1(&subtypetag, 1);
  312. stio.queue_p1(&num, 4);
  313. for (unsigned int i=0; i<num; ++i) {
  314. coroutines.emplace_back(
  315. [&stio](yield_t &yield) {
  316. yield();
  317. stio.andtriple(yield);
  318. });
  319. }
  320. } else if (!strcmp(type, "s")) {
  321. unsigned char typetag = 0x83;
  322. unsigned char subtypetag = 0x00;
  323. stio.queue_p0(&typetag, 1);
  324. stio.queue_p0(&subtypetag, 1);
  325. stio.queue_p0(&num, 4);
  326. stio.queue_p1(&typetag, 1);
  327. stio.queue_p1(&subtypetag, 1);
  328. stio.queue_p1(&num, 4);
  329. for (unsigned int i=0; i<num; ++i) {
  330. coroutines.emplace_back(
  331. [&stio](yield_t &yield) {
  332. yield();
  333. stio.valselecttriple(yield);
  334. });
  335. }
  336. } else if (type[0] == 'r') {
  337. char *widthstr = strchr(type, '.');
  338. unsigned char width = 1;
  339. if (widthstr) {
  340. *widthstr = '\0';
  341. ++widthstr;
  342. width = atoi(widthstr);
  343. }
  344. int depth = atoi(type+1);
  345. if (depth < 1 || depth > 48) {
  346. std::cerr << "Invalid DPF depth\n";
  347. } else {
  348. unsigned char typetag = depth;
  349. unsigned char subtypetag = width;
  350. stio.queue_p0(&typetag, 1);
  351. stio.queue_p0(&subtypetag, 1);
  352. stio.queue_p0(&num, 4);
  353. stio.queue_p1(&typetag, 1);
  354. stio.queue_p1(&subtypetag, 1);
  355. stio.queue_p1(&num, 4);
  356. char prefix[11];
  357. strcpy(prefix, "rdpf");
  358. if (width > 1) {
  359. sprintf(prefix+strlen(prefix), "%d_", width);
  360. }
  361. auto pairfile = ofiles.open(prefix,
  362. mpcsrvio.player, thread_num, depth);
  363. for (unsigned int i=0; i<num; ++i) {
  364. coroutines.emplace_back(
  365. [&stio, &opts, pairfile, depth, width](yield_t &yield) {
  366. yield();
  367. switch (width) {
  368. case 1: {
  369. RDPFPair<1> rdpfpair =
  370. stio.rdpfpair<1>(yield, depth);
  371. if (opts.expand_rdpfs) {
  372. rdpfpair.dpf[0].expand(stio.aes_ops());
  373. rdpfpair.dpf[1].expand(stio.aes_ops());
  374. }
  375. pairfile.os() << rdpfpair;
  376. break;
  377. }
  378. case 2: {
  379. RDPFPair<2> rdpfpair =
  380. stio.rdpfpair<2>(yield, depth);
  381. if (opts.expand_rdpfs) {
  382. rdpfpair.dpf[0].expand(stio.aes_ops());
  383. rdpfpair.dpf[1].expand(stio.aes_ops());
  384. }
  385. pairfile.os() << rdpfpair;
  386. break;
  387. }
  388. case 3: {
  389. RDPFPair<3> rdpfpair =
  390. stio.rdpfpair<3>(yield, depth);
  391. if (opts.expand_rdpfs) {
  392. rdpfpair.dpf[0].expand(stio.aes_ops());
  393. rdpfpair.dpf[1].expand(stio.aes_ops());
  394. }
  395. pairfile.os() << rdpfpair;
  396. break;
  397. }
  398. case 4: {
  399. RDPFPair<4> rdpfpair =
  400. stio.rdpfpair<4>(yield, depth);
  401. if (opts.expand_rdpfs) {
  402. rdpfpair.dpf[0].expand(stio.aes_ops());
  403. rdpfpair.dpf[1].expand(stio.aes_ops());
  404. }
  405. pairfile.os() << rdpfpair;
  406. break;
  407. }
  408. case 5: {
  409. RDPFPair<5> rdpfpair =
  410. stio.rdpfpair<5>(yield, depth);
  411. if (opts.expand_rdpfs) {
  412. rdpfpair.dpf[0].expand(stio.aes_ops());
  413. rdpfpair.dpf[1].expand(stio.aes_ops());
  414. }
  415. pairfile.os() << rdpfpair;
  416. break;
  417. }
  418. }
  419. });
  420. }
  421. }
  422. } else if (!strcmp(type, "c")) {
  423. unsigned char typetag = 0x40;
  424. unsigned char subtypetag = 0x00;
  425. stio.queue_p0(&typetag, 1);
  426. stio.queue_p0(&subtypetag, 1);
  427. stio.queue_p0(&num, 4);
  428. stio.queue_p1(&typetag, 1);
  429. stio.queue_p1(&subtypetag, 1);
  430. stio.queue_p1(&num, 4);
  431. for (unsigned int i=0; i<num; ++i) {
  432. coroutines.emplace_back(
  433. [&stio](yield_t &yield) {
  434. yield();
  435. stio.cdpf(yield);
  436. });
  437. }
  438. } else if (!strcmp(type, "i")) {
  439. unsigned char typetag = 0x8e;
  440. unsigned char subtypetag = 0x00;
  441. stio.queue_p0(&typetag, 1);
  442. stio.queue_p0(&subtypetag, 1);
  443. stio.queue_p0(&num, 4);
  444. stio.queue_p1(&typetag, 1);
  445. stio.queue_p1(&subtypetag, 1);
  446. stio.queue_p1(&num, 4);
  447. coroutines.emplace_back(
  448. [&stio, num] (yield_t &yield) {
  449. unsigned int istart = 0x31415080;
  450. yield();
  451. for (unsigned int i=istart; i<istart+num; ++i) {
  452. stio.queue_p0(&i, sizeof(i));
  453. stio.queue_p1(&i, sizeof(i));
  454. yield();
  455. unsigned int p0i, p1i;
  456. stio.recv_p0(&p0i, sizeof(p0i));
  457. stio.recv_p1(&p1i, sizeof(p1i));
  458. if (p0i != i || p1i != i) {
  459. printf("Incorrect counter received: "
  460. "p0=%08x p1=%08x\n", p0i,
  461. p1i);
  462. }
  463. }
  464. });
  465. } else if (!strcmp(type, "p")) {
  466. unsigned char typetag = 0x8f;
  467. unsigned char subtypetag = 0x00;
  468. stio.queue_p0(&typetag, 1);
  469. stio.queue_p0(&subtypetag, 1);
  470. stio.queue_p0(&num, 4);
  471. stio.queue_p1(&typetag, 1);
  472. stio.queue_p1(&subtypetag, 1);
  473. stio.queue_p1(&num, 4);
  474. stio.cpu_nthreads(num);
  475. }
  476. free(arg);
  477. ++threadargs;
  478. }
  479. // That's all
  480. unsigned char typetag = 0x00;
  481. stio.queue_p0(&typetag, 1);
  482. stio.queue_p1(&typetag, 1);
  483. run_coroutines(stio, coroutines);
  484. ofiles.closeall();
  485. });
  486. }
  487. pool.join();
  488. }