preproc.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #include <vector>
  2. #include "types.hpp"
  3. #include "preproc.hpp"
  4. // Open a file for writing with name the given prefix, and ".pX.tY"
  5. // suffix, where X is the (one-digit) player number and Y is the thread
  6. // number
  7. static std::ofstream openfile(const char *prefix, unsigned player,
  8. unsigned thread_num)
  9. {
  10. std::string filename(prefix);
  11. char suffix[20];
  12. sprintf(suffix, ".p%d.t%u", player%10, thread_num);
  13. filename.append(suffix);
  14. std::ofstream f;
  15. f.open(filename);
  16. if (f.fail()) {
  17. std::cerr << "Failed to open " << filename << "\n";
  18. exit(1);
  19. }
  20. return f;
  21. }
  22. // The server-to-computational-peer protocol for sending precomputed
  23. // data is:
  24. //
  25. // One byte: type
  26. // 0x80: Multiplication triple
  27. // 0x81: Multiplication half-triple
  28. // 0x01 to 0x40: DPF of that depth
  29. // 0x00: End of preprocessing
  30. //
  31. // Four bytes: number of objects of that type (not sent for type == 0x00)
  32. //
  33. // Then that number of objects
  34. //
  35. // Repeat the whole thing until type == 0x00 is received
  36. void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
  37. {
  38. boost::asio::thread_pool pool(num_threads);
  39. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  40. boost::asio::post(pool, [&mpcio, thread_num] {
  41. MPCTIO tio(mpcio, thread_num);
  42. while(1) {
  43. unsigned char type = 0;
  44. unsigned int num = 0;
  45. size_t res = tio.recv_server(&type, 1);
  46. if (res < 1 || type == 0) break;
  47. tio.recv_server(&num, 4);
  48. if (type == 0x80) {
  49. // Multiplication triples
  50. std::ofstream tripfile = openfile("triples",
  51. mpcio.player, thread_num);
  52. MultTriple T;
  53. for (unsigned int i=0; i<num; ++i) {
  54. T = tio.triple();
  55. tripfile.write((const char *)&T, sizeof(T));
  56. }
  57. tripfile.close();
  58. } else if (type == 0x81) {
  59. // Multiplication half triples
  60. std::ofstream halffile = openfile("halves",
  61. mpcio.player, thread_num);
  62. HalfTriple H;
  63. for (unsigned int i=0; i<num; ++i) {
  64. res = tio.recv_server(&H, sizeof(H));
  65. if (res < sizeof(H)) break;
  66. halffile.write((const char *)&H, sizeof(H));
  67. }
  68. halffile.close();
  69. }
  70. }
  71. });
  72. }
  73. pool.join();
  74. std::cout << "Lamport clock = " << mpcio.lamport << "\n";
  75. }
  76. void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
  77. {
  78. boost::asio::thread_pool pool(num_threads);
  79. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  80. boost::asio::post(pool, [&mpcsrvio, thread_num, args] {
  81. char **threadargs = args;
  82. MPCTIO stio(mpcsrvio, thread_num);
  83. if (*threadargs && threadargs[0][0] == 'T') {
  84. // Per-thread initialization. The args look like:
  85. // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
  86. // Skip to the arg marking our thread
  87. char us[20];
  88. sprintf(us, "T%u", thread_num);
  89. while (*threadargs && strcmp(*threadargs, us)) {
  90. ++threadargs;
  91. }
  92. // Now skip to the next arg if there is one
  93. if (*threadargs) {
  94. ++threadargs;
  95. }
  96. }
  97. // Stop scanning for args when we get to the end or when we
  98. // get to another per-thread initialization marker
  99. while (*threadargs && threadargs[0][0] != 'T') {
  100. char *arg = strdup(*threadargs);
  101. char *colon = strchr(arg, ':');
  102. if (!colon) {
  103. std::cerr << "Args must be type:num\n";
  104. ++threadargs;
  105. free(arg);
  106. continue;
  107. }
  108. unsigned num = atoi(colon+1);
  109. *colon = '\0';
  110. char *type = arg;
  111. if (!strcmp(type, "t")) {
  112. unsigned char typetag = 0x80;
  113. stio.queue_p0(&typetag, 1);
  114. stio.queue_p0(&num, 4);
  115. stio.queue_p1(&typetag, 1);
  116. stio.queue_p1(&num, 4);
  117. for (unsigned int i=0; i<num; ++i) {
  118. stio.triple();
  119. }
  120. } else if (!strcmp(type, "h")) {
  121. unsigned char typetag = 0x81;
  122. stio.queue_p0(&typetag, 1);
  123. stio.queue_p0(&num, 4);
  124. stio.queue_p1(&typetag, 1);
  125. stio.queue_p1(&num, 4);
  126. for (unsigned int i=0; i<num; ++i) {
  127. stio.halftriple();
  128. }
  129. }
  130. free(arg);
  131. ++threadargs;
  132. }
  133. // That's all
  134. unsigned char typetag = 0x00;
  135. stio.queue_p0(&typetag, 1);
  136. stio.queue_p1(&typetag, 1);
  137. stio.send();
  138. });
  139. }
  140. pool.join();
  141. std::cout << "Lamport clock = " << mpcsrvio.lamport << "\n";
  142. }