preproc.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x80: Multiplication triple
  61. // 0x81: Multiplication half-triple
  62. // 0x01 to 0x30: RAM DPF of that depth
  63. // 0x40: Comparison DPF
  64. // 0x82: Counter (for testing)
  65. // 0x83: Set number of CPU threads for this communication thread
  66. // 0x00: End of preprocessing
  67. //
  68. // Four bytes: number of objects of that type (not sent for type == 0x00)
  69. //
  70. // Then that number of objects
  71. //
  72. // Repeat the whole thing until type == 0x00 is received
  73. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  74. {
  75. int num_threads = opts.num_threads;
  76. boost::asio::thread_pool pool(num_threads);
  77. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  78. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  79. MPCTIO tio(mpcio, thread_num);
  80. Openfiles ofiles(opts.append_to_files);
  81. std::vector<coro_t> coroutines;
  82. while(1) {
  83. unsigned char type = 0;
  84. unsigned int num = 0;
  85. size_t res = tio.recv_server(&type, 1);
  86. if (res < 1 || type == 0) break;
  87. tio.recv_server(&num, 4);
  88. if (type == 0x80) {
  89. // Multiplication triples
  90. auto tripfile = ofiles.open("triples",
  91. mpcio.player, thread_num);
  92. for (unsigned int i=0; i<num; ++i) {
  93. coroutines.emplace_back(
  94. [&tio, tripfile](yield_t &yield) {
  95. yield();
  96. MultTriple T = tio.triple(yield);
  97. tripfile.os() << T;
  98. });
  99. }
  100. } else if (type == 0x81) {
  101. // Multiplication half triples
  102. auto halffile = ofiles.open("halves",
  103. mpcio.player, thread_num);
  104. for (unsigned int i=0; i<num; ++i) {
  105. coroutines.emplace_back(
  106. [&tio, halffile](yield_t &yield) {
  107. yield();
  108. HalfTriple H = tio.halftriple(yield);
  109. halffile.os() << H;
  110. });
  111. }
  112. } else if (type >= 0x01 && type <= 0x30) {
  113. // RAM DPFs
  114. auto tripfile = ofiles.open("rdpf",
  115. mpcio.player, thread_num, type);
  116. for (unsigned int i=0; i<num; ++i) {
  117. coroutines.emplace_back(
  118. [&tio, &opts, tripfile, type](yield_t &yield) {
  119. yield();
  120. RDPFTriple rdpftrip =
  121. tio.rdpftriple(yield, type, opts.expand_rdpfs);
  122. printf("dep = %d\n", type);
  123. printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
  124. printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
  125. printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
  126. printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
  127. printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
  128. printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
  129. tripfile.os() << rdpftrip;
  130. });
  131. }
  132. } else if (type == 0x40) {
  133. // Comparison DPFs
  134. auto cdpffile = ofiles.open("cdpf",
  135. mpcio.player, thread_num);
  136. for (unsigned int i=0; i<num; ++i) {
  137. coroutines.emplace_back(
  138. [&tio, cdpffile](yield_t &yield) {
  139. yield();
  140. CDPF C = tio.cdpf(yield);
  141. cdpffile.os() << C;
  142. });
  143. }
  144. } else if (type == 0x82) {
  145. coroutines.emplace_back(
  146. [&tio, num](yield_t &yield) {
  147. yield();
  148. unsigned int istart = 0x31415080;
  149. for (unsigned int i=istart; i<istart+num; ++i) {
  150. tio.queue_peer(&i, sizeof(i));
  151. tio.queue_server(&i, sizeof(i));
  152. yield();
  153. unsigned int peeri, srvi;
  154. tio.recv_peer(&peeri, sizeof(peeri));
  155. tio.recv_server(&srvi, sizeof(srvi));
  156. if (peeri != i || srvi != i) {
  157. printf("Incorrect counter received: "
  158. "peer=%08x srv=%08x\n", peeri,
  159. srvi);
  160. }
  161. }
  162. });
  163. } else if (type == 0x83) {
  164. tio.cpu_nthreads(num);
  165. }
  166. }
  167. run_coroutines(tio, coroutines);
  168. ofiles.closeall();
  169. });
  170. }
  171. pool.join();
  172. }
  173. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  174. {
  175. int num_threads = opts.num_threads;
  176. boost::asio::thread_pool pool(num_threads);
  177. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  178. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  179. char **threadargs = args;
  180. MPCTIO stio(mpcsrvio, thread_num);
  181. Openfiles ofiles(opts.append_to_files);
  182. std::vector<coro_t> coroutines;
  183. if (*threadargs && threadargs[0][0] == 'T') {
  184. // Per-thread initialization. The args look like:
  185. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  186. // Skip to the arg marking our thread
  187. char us[20];
  188. sprintf(us, "T%u", thread_num);
  189. while (*threadargs && strcmp(*threadargs, us)) {
  190. ++threadargs;
  191. }
  192. // Now skip to the next arg if there is one
  193. if (*threadargs) {
  194. ++threadargs;
  195. }
  196. }
  197. // Stop scanning for args when we get to the end or when we
  198. // get to another per-thread initialization marker
  199. while (*threadargs && threadargs[0][0] != 'T') {
  200. char *arg = strdup(*threadargs);
  201. char *colon = strchr(arg, ':');
  202. if (!colon) {
  203. std::cerr << "Args must be type:num\n";
  204. ++threadargs;
  205. free(arg);
  206. continue;
  207. }
  208. unsigned num = atoi(colon+1);
  209. *colon = '\0';
  210. char *type = arg;
  211. if (!strcmp(type, "t")) {
  212. unsigned char typetag = 0x80;
  213. stio.queue_p0(&typetag, 1);
  214. stio.queue_p0(&num, 4);
  215. stio.queue_p1(&typetag, 1);
  216. stio.queue_p1(&num, 4);
  217. for (unsigned int i=0; i<num; ++i) {
  218. coroutines.emplace_back(
  219. [&stio](yield_t &yield) {
  220. yield();
  221. stio.triple(yield);
  222. });
  223. }
  224. } else if (!strcmp(type, "h")) {
  225. unsigned char typetag = 0x81;
  226. stio.queue_p0(&typetag, 1);
  227. stio.queue_p0(&num, 4);
  228. stio.queue_p1(&typetag, 1);
  229. stio.queue_p1(&num, 4);
  230. for (unsigned int i=0; i<num; ++i) {
  231. coroutines.emplace_back(
  232. [&stio](yield_t &yield) {
  233. yield();
  234. stio.halftriple(yield);
  235. });
  236. }
  237. } else if (type[0] == 'r') {
  238. int depth = atoi(type+1);
  239. if (depth < 1 || depth > 48) {
  240. std::cerr << "Invalid DPF depth\n";
  241. } else {
  242. unsigned char typetag = depth;
  243. stio.queue_p0(&typetag, 1);
  244. stio.queue_p0(&num, 4);
  245. stio.queue_p1(&typetag, 1);
  246. stio.queue_p1(&num, 4);
  247. auto pairfile = ofiles.open("rdpf",
  248. mpcsrvio.player, thread_num, depth);
  249. for (unsigned int i=0; i<num; ++i) {
  250. coroutines.emplace_back(
  251. [&stio, &opts, pairfile, depth](yield_t &yield) {
  252. yield();
  253. RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
  254. printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
  255. printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
  256. printf("dep0 = %d\n", rdpfpair.dpf[0].depth());
  257. printf("usi1 = %016lx\n", rdpfpair.dpf[1].unit_sum_inverse);
  258. printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
  259. printf("dep1 = %d\n", rdpfpair.dpf[1].depth());
  260. if (opts.expand_rdpfs) {
  261. rdpfpair.dpf[0].expand(stio.aes_ops());
  262. rdpfpair.dpf[1].expand(stio.aes_ops());
  263. }
  264. pairfile.os() << rdpfpair;
  265. });
  266. }
  267. }
  268. } else if (!strcmp(type, "c")) {
  269. unsigned char typetag = 0x40;
  270. stio.queue_p0(&typetag, 1);
  271. stio.queue_p0(&num, 4);
  272. stio.queue_p1(&typetag, 1);
  273. stio.queue_p1(&num, 4);
  274. for (unsigned int i=0; i<num; ++i) {
  275. coroutines.emplace_back(
  276. [&stio](yield_t &yield) {
  277. yield();
  278. stio.cdpf(yield);
  279. });
  280. }
  281. } else if (!strcmp(type, "i")) {
  282. unsigned char typetag = 0x82;
  283. stio.queue_p0(&typetag, 1);
  284. stio.queue_p0(&num, 4);
  285. stio.queue_p1(&typetag, 1);
  286. stio.queue_p1(&num, 4);
  287. coroutines.emplace_back(
  288. [&stio, num] (yield_t &yield) {
  289. unsigned int istart = 0x31415080;
  290. yield();
  291. for (unsigned int i=istart; i<istart+num; ++i) {
  292. stio.queue_p0(&i, sizeof(i));
  293. stio.queue_p1(&i, sizeof(i));
  294. yield();
  295. unsigned int p0i, p1i;
  296. stio.recv_p0(&p0i, sizeof(p0i));
  297. stio.recv_p1(&p1i, sizeof(p1i));
  298. if (p0i != i || p1i != i) {
  299. printf("Incorrect counter received: "
  300. "p0=%08x p1=%08x\n", p0i,
  301. p1i);
  302. }
  303. }
  304. });
  305. } else if (!strcmp(type, "p")) {
  306. unsigned char typetag = 0x83;
  307. stio.queue_p0(&typetag, 1);
  308. stio.queue_p0(&num, 4);
  309. stio.queue_p1(&typetag, 1);
  310. stio.queue_p1(&num, 4);
  311. stio.cpu_nthreads(num);
  312. }
  313. free(arg);
  314. ++threadargs;
  315. }
  316. // That's all
  317. unsigned char typetag = 0x00;
  318. stio.queue_p0(&typetag, 1);
  319. stio.queue_p1(&typetag, 1);
  320. run_coroutines(stio, coroutines);
  321. ofiles.closeall();
  322. });
  323. }
  324. pool.join();
  325. }