mpcops.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include "mpcops.hpp"
  2. // as_ denotes additive shares
  3. // xs_ denotes xor shares
  4. // bs_ denotes a share of a single bit (which is effectively both an xor
  5. // share and an additive share mod 2)
  6. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  7. // (shares are y0 and y1); compute additive shares of z = x*y =
  8. // (x0+x1)*(y0+y1). x, y, and z are each at most nbits bits long.
  9. //
  10. // Cost:
  11. // 2 words sent in 1 message
  12. // consumes 1 MultTriple
  13. void mpc_mul(MPCTIO &tio, yield_t &yield,
  14. value_t &as_z, value_t as_x, value_t as_y,
  15. nbits_t nbits)
  16. {
  17. if (tio.is_server()) return;
  18. const value_t mask = MASKBITS(nbits);
  19. // Compute as_z to be an additive share of (x0*y1+y0*x1)
  20. mpc_cross(tio, yield, as_z, as_x, as_y, nbits);
  21. // Add x0*y0 (the peer will add x1*y1)
  22. as_z = (as_z + as_x * as_y) & mask;
  23. }
  24. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  25. // (shares are y0 and y1); compute additive shares of z = x0*y1 + y0*x1.
  26. // x, y, and z are each at most nbits bits long.
  27. //
  28. // Cost:
  29. // 2 words sent in 1 message
  30. // consumes 1 MultTriple
  31. void mpc_cross(MPCTIO &tio, yield_t &yield,
  32. value_t &as_z, value_t as_x, value_t as_y,
  33. nbits_t nbits)
  34. {
  35. if (tio.is_server()) return;
  36. const value_t mask = MASKBITS(nbits);
  37. size_t nbytes = BITBYTES(nbits);
  38. auto [X, Y, Z] = tio.triple();
  39. // Send x+X and y+Y
  40. value_t blind_x = (as_x + X) & mask;
  41. value_t blind_y = (as_y + Y) & mask;
  42. tio.queue_peer(&blind_x, nbytes);
  43. tio.queue_peer(&blind_y, nbytes);
  44. yield();
  45. // Read the peer's x+X and y+Y
  46. value_t peer_blind_x, peer_blind_y;
  47. tio.recv_peer(&peer_blind_x, nbytes);
  48. tio.recv_peer(&peer_blind_y, nbytes);
  49. as_z = ((as_x * peer_blind_y) - (Y * peer_blind_x) + Z) & mask;
  50. }
  51. // P0 holds the (complete) value x, P1 holds the (complete) value y;
  52. // compute additive shares of z = x*y. x, y, and z are each at most
  53. // nbits bits long. The parameter is called x, but P1 will pass y
  54. // there.
  55. //
  56. // Cost:
  57. // 1 word sent in 1 message
  58. // consumes 1 HalfTriple
  59. void mpc_valuemul(MPCTIO &tio, yield_t &yield,
  60. value_t &as_z, value_t x,
  61. nbits_t nbits)
  62. {
  63. if (tio.is_server()) return;
  64. const value_t mask = MASKBITS(nbits);
  65. size_t nbytes = BITBYTES(nbits);
  66. auto [X, Z] = tio.halftriple();
  67. // Send x+X
  68. value_t blind_x = (x + X) & mask;
  69. tio.queue_peer(&blind_x, nbytes);
  70. yield();
  71. // Read the peer's y+Y
  72. value_t peer_blind_y;
  73. tio.recv_peer(&peer_blind_y, nbytes);
  74. if (tio.player() == 0) {
  75. as_z = ((x * peer_blind_y) + Z) & mask;
  76. } else if (tio.player() == 1) {
  77. as_z = ((-X * peer_blind_y) + Z) & mask;
  78. }
  79. }
  80. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  81. // shares y0 and y1 of the value y; compute additive shares of
  82. // z = f * y = (f0 XOR f1) * (y0 + y1). y and z are each at most nbits
  83. // bits long.
  84. //
  85. // Cost:
  86. // 2 words sent in 1 message
  87. // consumes 1 MultTriple
  88. void mpc_flagmult(MPCTIO &tio, yield_t &yield,
  89. value_t &as_z, bit_t bs_f, value_t as_y,
  90. nbits_t nbits)
  91. {
  92. if (tio.is_server()) return;
  93. const value_t mask = MASKBITS(nbits);
  94. // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
  95. value_t bs_fval = value_t(bs_f);
  96. mpc_cross(tio, yield, as_z, (1-2*bs_fval)*as_y, bs_fval, nbits);
  97. // Add f0*y0 (and the peer will add f1*y1)
  98. as_z = (as_z + bs_fval*as_y) & mask;
  99. // Now the shares add up to:
  100. // [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0 + f0*y0 + f1*y1
  101. // which you can rearrange to see that it's equal to the desired
  102. // (f0 + f1 - 2*f0*f1)*(y0+y1), since f0 XOR f1 = (f0 + f1 - 2*f0*f1).
  103. }
  104. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  105. // shares of the values x and y; compute additive shares of z, where
  106. // z = x if f=0 and z = y if f=1. x, y, and z are each at most nbits
  107. // bits long.
  108. //
  109. // Cost:
  110. // 2 words sent in 1 message
  111. // consumes 1 MultTriple
  112. void mpc_select(MPCTIO &tio, yield_t &yield,
  113. value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
  114. nbits_t nbits)
  115. {
  116. if (tio.is_server()) return;
  117. const value_t mask = MASKBITS(nbits);
  118. // The desired result is z = x + f * (y-x)
  119. mpc_flagmult(tio, yield, as_z, bs_f, as_y-as_x, nbits);
  120. as_z = (as_z + as_x) & mask;
  121. }
  122. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  123. // shares of the values x and y. Obliviously swap x and y; that is,
  124. // replace x and y with new additive sharings of x and y respectively
  125. // (if f=0) or y and x respectively (if f=1). x and y are each at most
  126. // nbits bits long.
  127. //
  128. // Cost:
  129. // 2 words sent in 1 message
  130. // consumes 1 MultTriple
  131. void mpc_oswap(MPCTIO &tio, yield_t &yield,
  132. value_t &as_x, value_t &as_y, bit_t bs_f,
  133. nbits_t nbits)
  134. {
  135. if (tio.is_server()) return;
  136. const value_t mask = MASKBITS(nbits);
  137. // Let s = f*(y-x). Then the desired result is
  138. // x <- x + s, y <- y - s.
  139. value_t as_s;
  140. mpc_flagmult(tio, yield, as_s, bs_f, as_y-as_x, nbits);
  141. as_x = (as_x + as_s) & mask;
  142. as_y = (as_y - as_s) & mask;
  143. }
  144. // P0 and P1 hold XOR shares of x. Compute additive shares of the same
  145. // x. x is at most nbits bits long.
  146. //
  147. // Cost:
  148. // nbits-1 words sent in 1 message
  149. // consumes nbits-1 HalfTriples
  150. void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
  151. value_t &as_x, value_t xs_x,
  152. nbits_t nbits)
  153. {
  154. if (tio.is_server()) return;
  155. const value_t mask = MASKBITS(nbits);
  156. // We use the fact that for any nbits-bit A and B,
  157. // A+B = (A XOR B) + 2*(A AND B) mod 2^nbits
  158. // so if we have additive shares C0 and C1 of 2*(A AND B)
  159. // (so C0 + C1 = 2*(A AND B)), then (A-C0) and (B-C1) are
  160. // additive shares of (A XOR B).
  161. // To get additive shares of 2*(A AND B) (mod 2^nbits), we first
  162. // note that we can ignore the top bits of A and B, since the
  163. // multiplication by 2 will shift it out of the nbits-bit range.
  164. // For the other bits, use valuemult to get the product of the
  165. // corresponding bit i of A and B (i=0..nbits-2), and compute
  166. // C = \sum_i 2^{i+1} * (A_i * B_i).
  167. // This can all be done in a single message, using the coroutine
  168. // mechanism to have all nbits-1 instances of valuemult queue their
  169. // message, then yield, so that all of their messages get sent at
  170. // once, then each will read their results.
  171. value_t as_bitand[nbits-1];
  172. std::vector<coro_t> coroutines;
  173. for (nbits_t i=0; i<nbits-1; ++i) {
  174. coroutines.emplace_back(
  175. [&](yield_t &yield) {
  176. mpc_valuemul(tio, yield, as_bitand[i], (xs_x>>i)&1, nbits);
  177. });
  178. }
  179. run_coroutines(yield, coroutines);
  180. value_t as_C = 0;
  181. for (nbits_t i=0; i<nbits-1; ++i) {
  182. as_C += (as_bitand[i]<<(i+1));
  183. }
  184. as_x = (xs_x - as_C) & mask;
  185. }