preproc.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. bool append_mode;
  10. std::vector<std::ofstream> files;
  11. public:
  12. Openfiles(bool append_mode = false) : append_mode(append_mode) {}
  13. class Handle {
  14. Openfiles &parent;
  15. size_t idx;
  16. public:
  17. Handle(Openfiles &parent, size_t idx) :
  18. parent(parent), idx(idx) {}
  19. // Retrieve the ofstream from this Handle
  20. std::ofstream &os() const { return parent.files[idx]; }
  21. };
  22. Handle open(const char *prefix, unsigned player,
  23. unsigned thread_num, nbits_t depth = 0);
  24. void closeall();
  25. };
  26. // Open a file for writing with name the given prefix, and ".pX.tY"
  27. // suffix, where X is the (one-digit) player number and Y is the thread
  28. // number. If depth D is given, use "D.pX.tY" as the suffix.
  29. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  30. unsigned thread_num, nbits_t depth)
  31. {
  32. std::string filename(prefix);
  33. char suffix[20];
  34. if (depth > 0) {
  35. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  36. } else {
  37. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  38. }
  39. filename.append(suffix);
  40. std::ofstream &f = files.emplace_back(filename,
  41. append_mode ? std::ios_base::app : std::ios_base::out);
  42. if (f.fail()) {
  43. std::cerr << "Failed to open " << filename << "\n";
  44. exit(1);
  45. }
  46. return Handle(*this, files.size()-1);
  47. }
  48. // Close all the open files
  49. void Openfiles::closeall()
  50. {
  51. for (auto& f: files) {
  52. f.close();
  53. }
  54. files.clear();
  55. }
  56. // The server-to-computational-peer protocol for sending precomputed
  57. // data is:
  58. //
  59. // One byte: type
  60. // 0x80: Multiplication triple
  61. // 0x81: Multiplication half-triple
  62. // 0x01 to 0x30: RAM DPF of that depth
  63. // 0x40: Comparison DPF
  64. // 0x82: Counter (for testing)
  65. // 0x00: End of preprocessing
  66. //
  67. // Four bytes: number of objects of that type (not sent for type == 0x00)
  68. //
  69. // Then that number of objects
  70. //
  71. // Repeat the whole thing until type == 0x00 is received
  72. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  73. {
  74. int num_threads = opts.num_threads;
  75. boost::asio::thread_pool pool(num_threads);
  76. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  77. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  78. MPCTIO tio(mpcio, thread_num);
  79. Openfiles ofiles(opts.append_to_files);
  80. std::vector<coro_t> coroutines;
  81. while(1) {
  82. unsigned char type = 0;
  83. unsigned int num = 0;
  84. size_t res = tio.recv_server(&type, 1);
  85. if (res < 1 || type == 0) break;
  86. tio.recv_server(&num, 4);
  87. if (type == 0x80) {
  88. // Multiplication triples
  89. auto tripfile = ofiles.open("triples",
  90. mpcio.player, thread_num);
  91. for (unsigned int i=0; i<num; ++i) {
  92. coroutines.emplace_back(
  93. [&tio, tripfile](yield_t &yield) {
  94. yield();
  95. MultTriple T = tio.triple(yield);
  96. tripfile.os() << T;
  97. });
  98. }
  99. } else if (type == 0x81) {
  100. // Multiplication half triples
  101. auto halffile = ofiles.open("halves",
  102. mpcio.player, thread_num);
  103. for (unsigned int i=0; i<num; ++i) {
  104. coroutines.emplace_back(
  105. [&tio, halffile](yield_t &yield) {
  106. yield();
  107. HalfTriple H = tio.halftriple(yield);
  108. halffile.os() << H;
  109. });
  110. }
  111. } else if (type >= 0x01 && type <= 0x30) {
  112. // RAM DPFs
  113. auto tripfile = ofiles.open("rdpf",
  114. mpcio.player, thread_num, type);
  115. for (unsigned int i=0; i<num; ++i) {
  116. coroutines.emplace_back(
  117. [&tio, &opts, tripfile, type](yield_t &yield) {
  118. yield();
  119. RDPFTriple rdpftrip =
  120. tio.rdpftriple(yield, type, opts.expand_rdpfs);
  121. printf("dep = %d\n", type);
  122. printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
  123. printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
  124. printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
  125. printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
  126. printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
  127. printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
  128. tripfile.os() << rdpftrip;
  129. });
  130. }
  131. } else if (type == 0x40) {
  132. // Comparison DPFs
  133. auto cdpffile = ofiles.open("cdpf",
  134. mpcio.player, thread_num);
  135. for (unsigned int i=0; i<num; ++i) {
  136. coroutines.emplace_back(
  137. [&tio, cdpffile](yield_t &yield) {
  138. yield();
  139. CDPF C = tio.cdpf(yield);
  140. cdpffile.os() << C;
  141. });
  142. }
  143. } else if (type == 0x82) {
  144. coroutines.emplace_back(
  145. [&tio, num](yield_t &yield) {
  146. yield();
  147. unsigned int istart = 0x31415080;
  148. for (unsigned int i=istart; i<istart+num; ++i) {
  149. tio.queue_peer(&i, sizeof(i));
  150. tio.queue_server(&i, sizeof(i));
  151. yield();
  152. unsigned int peeri, srvi;
  153. tio.recv_peer(&peeri, sizeof(peeri));
  154. tio.recv_server(&srvi, sizeof(srvi));
  155. if (peeri != i || srvi != i) {
  156. printf("Incorrect counter received: "
  157. "peer=%08x srv=%08x\n", peeri,
  158. srvi);
  159. }
  160. }
  161. });
  162. }
  163. }
  164. run_coroutines(tio, coroutines);
  165. ofiles.closeall();
  166. });
  167. }
  168. pool.join();
  169. }
  170. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  171. {
  172. int num_threads = opts.num_threads;
  173. boost::asio::thread_pool pool(num_threads);
  174. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  175. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  176. char **threadargs = args;
  177. MPCTIO stio(mpcsrvio, thread_num);
  178. Openfiles ofiles;
  179. std::vector<coro_t> coroutines;
  180. if (*threadargs && threadargs[0][0] == 'T') {
  181. // Per-thread initialization. The args look like:
  182. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  183. // Skip to the arg marking our thread
  184. char us[20];
  185. sprintf(us, "T%u", thread_num);
  186. while (*threadargs && strcmp(*threadargs, us)) {
  187. ++threadargs;
  188. }
  189. // Now skip to the next arg if there is one
  190. if (*threadargs) {
  191. ++threadargs;
  192. }
  193. }
  194. // Stop scanning for args when we get to the end or when we
  195. // get to another per-thread initialization marker
  196. while (*threadargs && threadargs[0][0] != 'T') {
  197. char *arg = strdup(*threadargs);
  198. char *colon = strchr(arg, ':');
  199. if (!colon) {
  200. std::cerr << "Args must be type:num\n";
  201. ++threadargs;
  202. free(arg);
  203. continue;
  204. }
  205. unsigned num = atoi(colon+1);
  206. *colon = '\0';
  207. char *type = arg;
  208. if (!strcmp(type, "t")) {
  209. unsigned char typetag = 0x80;
  210. stio.queue_p0(&typetag, 1);
  211. stio.queue_p0(&num, 4);
  212. stio.queue_p1(&typetag, 1);
  213. stio.queue_p1(&num, 4);
  214. for (unsigned int i=0; i<num; ++i) {
  215. coroutines.emplace_back(
  216. [&stio](yield_t &yield) {
  217. yield();
  218. stio.triple(yield);
  219. });
  220. }
  221. } else if (!strcmp(type, "h")) {
  222. unsigned char typetag = 0x81;
  223. stio.queue_p0(&typetag, 1);
  224. stio.queue_p0(&num, 4);
  225. stio.queue_p1(&typetag, 1);
  226. stio.queue_p1(&num, 4);
  227. for (unsigned int i=0; i<num; ++i) {
  228. coroutines.emplace_back(
  229. [&stio](yield_t &yield) {
  230. yield();
  231. stio.halftriple(yield);
  232. });
  233. }
  234. } else if (type[0] == 'r') {
  235. int depth = atoi(type+1);
  236. if (depth < 1 || depth > 48) {
  237. std::cerr << "Invalid DPF depth\n";
  238. } else {
  239. unsigned char typetag = depth;
  240. stio.queue_p0(&typetag, 1);
  241. stio.queue_p0(&num, 4);
  242. stio.queue_p1(&typetag, 1);
  243. stio.queue_p1(&num, 4);
  244. auto pairfile = ofiles.open("rdpf",
  245. mpcsrvio.player, thread_num, depth);
  246. for (unsigned int i=0; i<num; ++i) {
  247. coroutines.emplace_back(
  248. [&stio, &opts, pairfile, depth](yield_t &yield) {
  249. yield();
  250. RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
  251. printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
  252. printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
  253. printf("dep0 = %d\n", rdpfpair.dpf[0].depth());
  254. printf("usi1 = %016lx\n", rdpfpair.dpf[1].unit_sum_inverse);
  255. printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
  256. printf("dep1 = %d\n", rdpfpair.dpf[1].depth());
  257. if (opts.expand_rdpfs) {
  258. rdpfpair.dpf[0].expand(stio.aes_ops());
  259. rdpfpair.dpf[1].expand(stio.aes_ops());
  260. }
  261. pairfile.os() << rdpfpair;
  262. });
  263. }
  264. }
  265. } else if (!strcmp(type, "c")) {
  266. unsigned char typetag = 0x40;
  267. stio.queue_p0(&typetag, 1);
  268. stio.queue_p0(&num, 4);
  269. stio.queue_p1(&typetag, 1);
  270. stio.queue_p1(&num, 4);
  271. for (unsigned int i=0; i<num; ++i) {
  272. coroutines.emplace_back(
  273. [&stio](yield_t &yield) {
  274. yield();
  275. stio.cdpf(yield);
  276. });
  277. }
  278. } else if (!strcmp(type, "i")) {
  279. unsigned char typetag = 0x82;
  280. stio.queue_p0(&typetag, 1);
  281. stio.queue_p0(&num, 4);
  282. stio.queue_p1(&typetag, 1);
  283. stio.queue_p1(&num, 4);
  284. coroutines.emplace_back(
  285. [&stio, num] (yield_t &yield) {
  286. unsigned int istart = 0x31415080;
  287. yield();
  288. for (unsigned int i=istart; i<istart+num; ++i) {
  289. stio.queue_p0(&i, sizeof(i));
  290. stio.queue_p1(&i, sizeof(i));
  291. yield();
  292. unsigned int p0i, p1i;
  293. stio.recv_p0(&p0i, sizeof(p0i));
  294. stio.recv_p1(&p1i, sizeof(p1i));
  295. if (p0i != i || p1i != i) {
  296. printf("Incorrect counter received: "
  297. "p0=%08x p1=%08x\n", p0i,
  298. p1i);
  299. }
  300. }
  301. });
  302. }
  303. free(arg);
  304. ++threadargs;
  305. }
  306. // That's all
  307. unsigned char typetag = 0x00;
  308. stio.queue_p0(&typetag, 1);
  309. stio.queue_p1(&typetag, 1);
  310. run_coroutines(stio, coroutines);
  311. ofiles.closeall();
  312. });
  313. }
  314. pool.join();
  315. }