dpfgen.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. void compute_CW(tcp::socket& sout, tcp::socket& sin, __m128i L, __m128i R, uint8_t bit, __m128i & CW)
  2. {
  3. struct cw_construction
  4. {
  5. __m128i rand_b, gamma_b;
  6. uint8_t bit_b;
  7. };
  8. cw_construction computecw;
  9. read(sin, boost::asio::buffer(&computecw, sizeof(computecw)));
  10. __m128i rand_b = computecw.rand_b;
  11. __m128i gamma_b = computecw.gamma_b;
  12. uint8_t bit_b = computecw.bit_b;
  13. #ifdef DEBUG
  14. __m128i rand_b2, gamma_b2;
  15. uint8_t bit_b2;
  16. read(sin, boost::asio::buffer(&rand_b2, sizeof(rand_b)));
  17. read(sin, boost::asio::buffer(&gamma_b2, sizeof(gamma_b)));
  18. read(sin, boost::asio::buffer(&bit_b2, sizeof(bit_b)));
  19. assert(rand_b2[0] == rand_b[0]);
  20. assert(rand_b2[1] == rand_b[1]);
  21. assert(gamma_b2[0] == gamma_b[0]);
  22. assert(gamma_b2[1] == gamma_b[1]);
  23. assert(bit_b2 == bit_b);
  24. #endif
  25. uint8_t blinded_bit, blinded_bit_read;
  26. blinded_bit = bit ^ bit_b;
  27. __m128i blinded_L = L ^ R ^ rand_b;
  28. __m128i blinded_L_read;
  29. struct BlindsCW
  30. {
  31. __m128i blinded_message;
  32. uint8_t blinded_bit;
  33. };
  34. BlindsCW blinds_sent, blinds_recv;
  35. blinds_sent.blinded_bit = blinded_bit;
  36. blinds_sent.blinded_message = blinded_L;
  37. boost::asio::write(sout, boost::asio::buffer(&blinds_sent, sizeof(blinds_sent)));
  38. boost::asio::read(sout, boost::asio::buffer(&blinds_recv, sizeof(blinds_recv)));
  39. blinded_bit_read = blinds_recv.blinded_bit;
  40. blinded_L_read = blinds_recv.blinded_message;
  41. __m128i out_ = R ^ gamma_b;//_mm_setzero_si128;
  42. if(bit)
  43. {
  44. out_ ^= (L ^ R ^ blinded_L_read);
  45. }
  46. if(blinded_bit_read)
  47. {
  48. out_ ^= rand_b;
  49. }
  50. __m128i out_reconstruction;
  51. boost::asio::write(sout, boost::asio::buffer(&out_, sizeof(out_)));
  52. boost::asio::read(sout, boost::asio::buffer(&out_reconstruction, sizeof(out_reconstruction)));
  53. out_reconstruction = out_ ^ out_reconstruction;
  54. CW = out_reconstruction;
  55. #ifdef DEBUG
  56. uint8_t bit_reconstruction;
  57. boost::asio::write(sout, boost::asio::buffer(&bit, sizeof(bit)));
  58. boost::asio::read(sout, boost::asio::buffer(&bit_reconstruction, sizeof(bit_reconstruction)));
  59. bit_reconstruction = bit ^ bit_reconstruction;
  60. __m128i L_reconstruction;
  61. boost::asio::write(sout, boost::asio::buffer(&L, sizeof(L)));
  62. boost::asio::read(sout, boost::asio::buffer(&L_reconstruction, sizeof(L_reconstruction)));
  63. L_reconstruction = L ^ L_reconstruction;
  64. __m128i R_reconstruction;
  65. boost::asio::write(sout, boost::asio::buffer(&R, sizeof(R)));
  66. boost::asio::read(sout, boost::asio::buffer(&R_reconstruction, sizeof(R_reconstruction)));
  67. R_reconstruction = R ^ R_reconstruction;
  68. __m128i CW_debug;
  69. if(bit_reconstruction != 0)
  70. {
  71. CW_debug = L_reconstruction;
  72. }
  73. else
  74. {
  75. CW_debug = R_reconstruction;
  76. }
  77. assert(CW_debug[0] == CW[0]);
  78. assert(CW_debug[1] == CW[1]);
  79. #endif
  80. }
  81. template<typename node_t, typename prgkey_t>
  82. static inline void traverse(const prgkey_t & prgkey, const node_t & seed, node_t s[2])
  83. {
  84. dpf::PRG(prgkey, clear_lsb(seed, 0b11), s, 2);
  85. } // dpf::expand
  86. inline void evalfull_mpc(const size_t& nodes_per_leaf, const size_t& depth, const size_t& nbits, const size_t& nodes_in_interval,
  87. const AES_KEY& prgkey, uint8_t target_share[64], std::vector<socket_t>& socketsPb, std::vector<socket_t>& socketsP2,
  88. const size_t from, const size_t to, __m128i * output, int8_t * _t, __m128i& final_correction_word, bool party, size_t socket_no = 0)
  89. {
  90. __m128i root;
  91. arc4random_buf(&root, sizeof(root));
  92. root = set_lsb(root, party);
  93. const size_t from_node = std::floor(static_cast<double>(from) / nodes_per_leaf);
  94. __m128i * s[2] = {
  95. reinterpret_cast<__m128i *>(output) + nodes_in_interval * (nodes_per_leaf - 1),
  96. s[0] + nodes_in_interval / 2
  97. };
  98. int8_t * t[2] = { _t, _t + nodes_in_interval / 2};
  99. int curlayer = depth % 2;
  100. s[curlayer][0] = root;
  101. t[curlayer][0] = get_lsb(root, 0b01);
  102. __m128i * CW = (__m128i *) std::aligned_alloc(sizeof(__m256i), depth * sizeof(__m128i));
  103. for (size_t layer = 0; layer < depth; ++layer)
  104. {
  105. #ifdef VERBOSE
  106. printf("layer = %zu\n", layer);
  107. #endif
  108. curlayer = 1-curlayer;
  109. size_t i=0, j=0;
  110. auto nextbit = (from_node >> (nbits-layer-1)) & 1;
  111. size_t nodes_in_prev_layer = std::ceil(static_cast<double>(nodes_in_interval) / (1ULL << (depth-layer)));
  112. size_t nodes_in_cur_layer = std::ceil(static_cast<double>(nodes_in_interval) / (1ULL << (depth-layer-1)));
  113. __m128i L = _mm_setzero_si128();
  114. __m128i R = _mm_setzero_si128();
  115. for (i = nextbit, j = nextbit; j < nodes_in_prev_layer-1; ++j, i+=2)
  116. {
  117. traverse(prgkey, s[1-curlayer][j], &s[curlayer][i]);
  118. L ^= s[curlayer][i];
  119. R ^= s[curlayer][i+1];
  120. }
  121. if (nodes_in_prev_layer > j)
  122. {
  123. if (i < nodes_in_cur_layer - 1)
  124. {
  125. traverse(prgkey, s[1-curlayer][j], &s[curlayer][i]);
  126. L ^= s[curlayer][i];
  127. R ^= s[curlayer][i+1];
  128. }
  129. }
  130. compute_CW(socketsPb[socket_no], socketsP2[socket_no], L, R, target_share[layer], CW[layer]);
  131. uint8_t advice_L = get_lsb(L) ^ target_share[layer];
  132. uint8_t advice_R = get_lsb(R) ^ target_share[layer];
  133. uint8_t cwt_L, cwt_R;
  134. uint8_t advice[2];
  135. uint8_t cwts[2];
  136. advice[0] = advice_L;
  137. advice[1] = advice_R;
  138. boost::asio::write(socketsPb[socket_no+1], boost::asio::buffer(&advice, sizeof(advice)));
  139. boost::asio::read(socketsPb[socket_no+1], boost::asio::buffer(&cwts, sizeof(cwts)));
  140. cwt_L = cwts[0];
  141. cwt_R = cwts[1];
  142. cwt_L = cwt_L ^ advice_L ^ 1;
  143. cwt_R = cwt_R ^ advice_R;
  144. for(size_t j = 0; j < nodes_in_prev_layer; ++j)
  145. {
  146. t[curlayer][2*j] = get_lsb(s[curlayer][2*j]) ^ (cwt_L & t[1-curlayer][j]);
  147. s[curlayer][2*j] = clear_lsb(xor_if(s[curlayer][2*j], CW[layer], !t[1-curlayer][j]), 0b11);
  148. t[curlayer][(2*j)+1] = get_lsb(s[curlayer][(2*j)+1]) ^ (cwt_R & t[1-curlayer][j]);
  149. s[curlayer][(2*j)+1] = clear_lsb(xor_if(s[curlayer][(2*j)+1], CW[layer], !t[1-curlayer][j]), 0b11);
  150. }
  151. }
  152. __m128i Gamma = _mm_setzero_si128();
  153. for (size_t i = 0; i < to + 1; ++i)
  154. {
  155. Gamma[0] += output[i][0];
  156. Gamma[1] += output[i][1];
  157. }
  158. if(party)
  159. {
  160. Gamma[0] = -Gamma[0];
  161. Gamma[1] = -Gamma[1];
  162. }
  163. boost::asio::write(socketsPb[socket_no + 3], boost::asio::buffer(&Gamma, sizeof(Gamma)));
  164. boost::asio::read(socketsPb[socket_no + 3], boost::asio::buffer(&final_correction_word, sizeof(final_correction_word)));
  165. final_correction_word = final_correction_word + Gamma;
  166. } // dpf::__evalinterval