preproc.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "coroutine.hpp"
  4. #include "preproc.hpp"
  5. #include "rdpf.hpp"
  6. #include "cdpf.hpp"
  7. // Keep track of open files that coroutines might be writing into
  8. class Openfiles {
  9. std::vector<std::ofstream> files;
  10. public:
  11. class Handle {
  12. Openfiles &parent;
  13. size_t idx;
  14. public:
  15. Handle(Openfiles &parent, size_t idx) :
  16. parent(parent), idx(idx) {}
  17. // Retrieve the ofstream from this Handle
  18. std::ofstream &os() const { return parent.files[idx]; }
  19. };
  20. Handle open(const char *prefix, unsigned player,
  21. unsigned thread_num, nbits_t depth = 0);
  22. void closeall();
  23. };
  24. // Open a file for writing with name the given prefix, and ".pX.tY"
  25. // suffix, where X is the (one-digit) player number and Y is the thread
  26. // number. If depth D is given, use "D.pX.tY" as the suffix.
  27. Openfiles::Handle Openfiles::open(const char *prefix, unsigned player,
  28. unsigned thread_num, nbits_t depth)
  29. {
  30. std::string filename(prefix);
  31. char suffix[20];
  32. if (depth > 0) {
  33. sprintf(suffix, "%02d.p%d.t%u", depth, player%10, thread_num);
  34. } else {
  35. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  36. }
  37. filename.append(suffix);
  38. std::ofstream &f = files.emplace_back(filename);
  39. if (f.fail()) {
  40. std::cerr << "Failed to open " << filename << "\n";
  41. exit(1);
  42. }
  43. return Handle(*this, files.size()-1);
  44. }
  45. // Close all the open files
  46. void Openfiles::closeall()
  47. {
  48. for (auto& f: files) {
  49. f.close();
  50. }
  51. files.clear();
  52. }
  53. // The server-to-computational-peer protocol for sending precomputed
  54. // data is:
  55. //
  56. // One byte: type
  57. // 0x80: Multiplication triple
  58. // 0x81: Multiplication half-triple
  59. // 0x01 to 0x30: RAM DPF of that depth
  60. // 0x40: Comparison DPF
  61. // 0x00: End of preprocessing
  62. //
  63. // Four bytes: number of objects of that type (not sent for type == 0x00)
  64. //
  65. // Then that number of objects
  66. //
  67. // Repeat the whole thing until type == 0x00 is received
  68. void preprocessing_comp(MPCIO &mpcio, const PRACOptions &opts, char **args)
  69. {
  70. int num_threads = opts.num_threads;
  71. boost::asio::thread_pool pool(num_threads);
  72. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  73. boost::asio::post(pool, [&mpcio, &opts, thread_num] {
  74. MPCTIO tio(mpcio, thread_num);
  75. Openfiles ofiles;
  76. std::vector<coro_t> coroutines;
  77. while(1) {
  78. unsigned char type = 0;
  79. unsigned int num = 0;
  80. size_t res = tio.recv_server(&type, 1);
  81. if (res < 1 || type == 0) break;
  82. tio.recv_server(&num, 4);
  83. if (type == 0x80) {
  84. // Multiplication triples
  85. auto tripfile = ofiles.open("triples",
  86. mpcio.player, thread_num);
  87. MultTriple T;
  88. for (unsigned int i=0; i<num; ++i) {
  89. T = tio.triple();
  90. tripfile.os() << T;
  91. }
  92. } else if (type == 0x81) {
  93. // Multiplication half triples
  94. auto halffile = ofiles.open("halves",
  95. mpcio.player, thread_num);
  96. HalfTriple H;
  97. for (unsigned int i=0; i<num; ++i) {
  98. H = tio.halftriple();
  99. halffile.os() << H;
  100. }
  101. } else if (type >= 0x01 && type <= 0x30) {
  102. // RAM DPFs
  103. auto tripfile = ofiles.open("rdpf",
  104. mpcio.player, thread_num, type);
  105. for (unsigned int i=0; i<num; ++i) {
  106. coroutines.emplace_back(
  107. [&, tripfile, type](yield_t &yield) {
  108. RDPFTriple rdpftrip =
  109. tio.rdpftriple(yield, type, opts.expand_rdpfs);
  110. printf("dep = %d\n", type);
  111. printf("usi0 = %016lx\n", rdpftrip.dpf[0].unit_sum_inverse);
  112. printf("sxr0 = %016lx\n", rdpftrip.dpf[0].scaled_xor.xshare);
  113. printf("usi1 = %016lx\n", rdpftrip.dpf[1].unit_sum_inverse);
  114. printf("sxr1 = %016lx\n", rdpftrip.dpf[1].scaled_xor.xshare);
  115. printf("usi2 = %016lx\n", rdpftrip.dpf[2].unit_sum_inverse);
  116. printf("sxr2 = %016lx\n", rdpftrip.dpf[2].scaled_xor.xshare);
  117. tripfile.os() << rdpftrip;
  118. });
  119. }
  120. } else if (type == 0x40) {
  121. // Comparison DPFs
  122. auto cdpffile = ofiles.open("cdpf",
  123. mpcio.player, thread_num);
  124. CDPF C;
  125. for (unsigned int i=0; i<num; ++i) {
  126. C = tio.cdpf();
  127. cdpffile.os() << C;
  128. }
  129. }
  130. }
  131. run_coroutines(tio, coroutines);
  132. ofiles.closeall();
  133. });
  134. }
  135. pool.join();
  136. }
  137. void preprocessing_server(MPCServerIO &mpcsrvio, const PRACOptions &opts, char **args)
  138. {
  139. int num_threads = opts.num_threads;
  140. boost::asio::thread_pool pool(num_threads);
  141. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  142. boost::asio::post(pool, [&mpcsrvio, &opts, thread_num, args] {
  143. char **threadargs = args;
  144. MPCTIO stio(mpcsrvio, thread_num);
  145. Openfiles ofiles;
  146. std::vector<coro_t> coroutines;
  147. if (*threadargs && threadargs[0][0] == 'T') {
  148. // Per-thread initialization. The args look like:
  149. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  150. // Skip to the arg marking our thread
  151. char us[20];
  152. sprintf(us, "T%u", thread_num);
  153. while (*threadargs && strcmp(*threadargs, us)) {
  154. ++threadargs;
  155. }
  156. // Now skip to the next arg if there is one
  157. if (*threadargs) {
  158. ++threadargs;
  159. }
  160. }
  161. // Stop scanning for args when we get to the end or when we
  162. // get to another per-thread initialization marker
  163. while (*threadargs && threadargs[0][0] != 'T') {
  164. char *arg = strdup(*threadargs);
  165. char *colon = strchr(arg, ':');
  166. if (!colon) {
  167. std::cerr << "Args must be type:num\n";
  168. ++threadargs;
  169. free(arg);
  170. continue;
  171. }
  172. unsigned num = atoi(colon+1);
  173. *colon = '\0';
  174. char *type = arg;
  175. if (!strcmp(type, "t")) {
  176. unsigned char typetag = 0x80;
  177. stio.queue_p0(&typetag, 1);
  178. stio.queue_p0(&num, 4);
  179. stio.queue_p1(&typetag, 1);
  180. stio.queue_p1(&num, 4);
  181. for (unsigned int i=0; i<num; ++i) {
  182. stio.triple();
  183. }
  184. } else if (!strcmp(type, "h")) {
  185. unsigned char typetag = 0x81;
  186. stio.queue_p0(&typetag, 1);
  187. stio.queue_p0(&num, 4);
  188. stio.queue_p1(&typetag, 1);
  189. stio.queue_p1(&num, 4);
  190. for (unsigned int i=0; i<num; ++i) {
  191. stio.halftriple();
  192. }
  193. } else if (type[0] == 'r') {
  194. int depth = atoi(type+1);
  195. if (depth < 1 || depth > 48) {
  196. std::cerr << "Invalid DPF depth\n";
  197. } else {
  198. unsigned char typetag = depth;
  199. stio.queue_p0(&typetag, 1);
  200. stio.queue_p0(&num, 4);
  201. stio.queue_p1(&typetag, 1);
  202. stio.queue_p1(&num, 4);
  203. auto pairfile = ofiles.open("rdpf",
  204. mpcsrvio.player, thread_num, depth);
  205. for (unsigned int i=0; i<num; ++i) {
  206. coroutines.emplace_back(
  207. [&, pairfile, depth](yield_t &yield) {
  208. RDPFPair rdpfpair = stio.rdpfpair(yield, depth);
  209. printf("usi0 = %016lx\n", rdpfpair.dpf[0].unit_sum_inverse);
  210. printf("sxr0 = %016lx\n", rdpfpair.dpf[0].scaled_xor.xshare);
  211. printf("dep0 = %d\n", rdpfpair.dpf[0].depth());
  212. printf("usi1 = %016lx\n", rdpfpair.dpf[1].unit_sum_inverse);
  213. printf("sxr1 = %016lx\n", rdpfpair.dpf[1].scaled_xor.xshare);
  214. printf("dep1 = %d\n", rdpfpair.dpf[1].depth());
  215. if (opts.expand_rdpfs) {
  216. rdpfpair.dpf[0].expand(stio.aes_ops());
  217. rdpfpair.dpf[1].expand(stio.aes_ops());
  218. }
  219. pairfile.os() << rdpfpair;
  220. });
  221. }
  222. }
  223. } else if (type[0] == 'c') {
  224. unsigned char typetag = 0x40;
  225. stio.queue_p0(&typetag, 1);
  226. stio.queue_p0(&num, 4);
  227. stio.queue_p1(&typetag, 1);
  228. stio.queue_p1(&num, 4);
  229. for (unsigned int i=0; i<num; ++i) {
  230. stio.cdpf();
  231. }
  232. }
  233. free(arg);
  234. ++threadargs;
  235. }
  236. // That's all
  237. unsigned char typetag = 0x00;
  238. stio.queue_p0(&typetag, 1);
  239. stio.queue_p1(&typetag, 1);
  240. run_coroutines(stio, coroutines);
  241. ofiles.closeall();
  242. });
  243. }
  244. pool.join();
  245. }