mpcops.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include "mpcops.hpp"
  2. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  3. // (shares are y0 and y1); compute additive shares of z = x*y =
  4. // (x0+x1)*(y0+y1). x, y, and z are each at most nbits bits long.
  5. //
  6. // Cost:
  7. // 2 words sent in 1 message
  8. // consumes 1 MultTriple
  9. void mpc_mul(MPCTIO &tio, yield_t &yield,
  10. RegAS &z, RegAS x, RegAS y,
  11. nbits_t nbits)
  12. {
  13. const value_t mask = MASKBITS(nbits);
  14. // Compute z to be an additive share of (x0*y1+y0*x1)
  15. mpc_cross(tio, yield, z, x, y, nbits);
  16. // Add x0*y0 (the peer will add x1*y1)
  17. z.ashare = (z.ashare + x.ashare * y.ashare) & mask;
  18. }
  19. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  20. // (shares are y0 and y1); compute additive shares of z = x0*y1 + y0*x1.
  21. // x, y, and z are each at most nbits bits long.
  22. //
  23. // Cost:
  24. // 2 words sent in 1 message
  25. // consumes 1 MultTriple
  26. void mpc_cross(MPCTIO &tio, yield_t &yield,
  27. RegAS &z, RegAS x, RegAS y,
  28. nbits_t nbits)
  29. {
  30. const value_t mask = MASKBITS(nbits);
  31. size_t nbytes = BITBYTES(nbits);
  32. auto [X, Y, Z] = tio.triple();
  33. // Send x+X and y+Y
  34. value_t blind_x = (x.ashare + X) & mask;
  35. value_t blind_y = (y.ashare + Y) & mask;
  36. tio.queue_peer(&blind_x, nbytes);
  37. tio.queue_peer(&blind_y, nbytes);
  38. yield();
  39. // Read the peer's x+X and y+Y
  40. value_t peer_blind_x, peer_blind_y;
  41. tio.recv_peer(&peer_blind_x, nbytes);
  42. tio.recv_peer(&peer_blind_y, nbytes);
  43. z.ashare = ((x.ashare * peer_blind_y) - (Y * peer_blind_x) + Z) & mask;
  44. }
  45. // P0 holds the (complete) value x, P1 holds the (complete) value y;
  46. // compute additive shares of z = x*y. x, y, and z are each at most
  47. // nbits bits long. The parameter is called x, but P1 will pass y
  48. // there.
  49. //
  50. // Cost:
  51. // 1 word sent in 1 message
  52. // consumes 1 HalfTriple
  53. void mpc_valuemul(MPCTIO &tio, yield_t &yield,
  54. RegAS &z, value_t x,
  55. nbits_t nbits)
  56. {
  57. const value_t mask = MASKBITS(nbits);
  58. size_t nbytes = BITBYTES(nbits);
  59. auto [X, Z] = tio.halftriple();
  60. // Send x+X
  61. value_t blind_x = (x + X) & mask;
  62. tio.queue_peer(&blind_x, nbytes);
  63. yield();
  64. // Read the peer's y+Y
  65. value_t peer_blind_y;
  66. tio.recv_peer(&peer_blind_y, nbytes);
  67. if (tio.player() == 0) {
  68. z.ashare = ((x * peer_blind_y) + Z) & mask;
  69. } else if (tio.player() == 1) {
  70. z.ashare = ((-X * peer_blind_y) + Z) & mask;
  71. }
  72. }
  73. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  74. // shares y0 and y1 of the value y; compute additive shares of
  75. // z = f * y = (f0 XOR f1) * (y0 + y1). y and z are each at most nbits
  76. // bits long.
  77. //
  78. // Cost:
  79. // 2 words sent in 1 message
  80. // consumes 1 MultTriple
  81. void mpc_flagmult(MPCTIO &tio, yield_t &yield,
  82. RegAS &z, RegBS f, RegAS y,
  83. nbits_t nbits)
  84. {
  85. const value_t mask = MASKBITS(nbits);
  86. // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
  87. value_t bs_fval = value_t(f.bshare);
  88. RegAS fval;
  89. fval.ashare = bs_fval;
  90. mpc_cross(tio, yield, z, y*(1-2*bs_fval), fval, nbits);
  91. // Add f0*y0 (and the peer will add f1*y1)
  92. z.ashare = (z.ashare + bs_fval*y.ashare) & mask;
  93. // Now the shares add up to:
  94. // [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0 + f0*y0 + f1*y1
  95. // which you can rearrange to see that it's equal to the desired
  96. // (f0 + f1 - 2*f0*f1)*(y0+y1), since f0 XOR f1 = (f0 + f1 - 2*f0*f1).
  97. }
  98. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  99. // shares of the values x and y; compute additive shares of z, where
  100. // z = x if f=0 and z = y if f=1. x, y, and z are each at most nbits
  101. // bits long.
  102. //
  103. // Cost:
  104. // 2 words sent in 1 message
  105. // consumes 1 MultTriple
  106. void mpc_select(MPCTIO &tio, yield_t &yield,
  107. RegAS &z, RegBS f, RegAS x, RegAS y,
  108. nbits_t nbits)
  109. {
  110. const value_t mask = MASKBITS(nbits);
  111. // The desired result is z = x + f * (y-x)
  112. mpc_flagmult(tio, yield, z, f, y-x, nbits);
  113. z.ashare = (z.ashare + x.ashare) & mask;
  114. }
  115. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  116. // shares of the values x and y. Obliviously swap x and y; that is,
  117. // replace x and y with new additive sharings of x and y respectively
  118. // (if f=0) or y and x respectively (if f=1). x and y are each at most
  119. // nbits bits long.
  120. //
  121. // Cost:
  122. // 2 words sent in 1 message
  123. // consumes 1 MultTriple
  124. void mpc_oswap(MPCTIO &tio, yield_t &yield,
  125. RegAS &x, RegAS &y, RegBS f,
  126. nbits_t nbits)
  127. {
  128. const value_t mask = MASKBITS(nbits);
  129. // Let s = f*(y-x). Then the desired result is
  130. // x <- x + s, y <- y - s.
  131. RegAS s;
  132. mpc_flagmult(tio, yield, s, f, y-x, nbits);
  133. x.ashare = (x.ashare + s.ashare) & mask;
  134. y.ashare = (y.ashare - s.ashare) & mask;
  135. }
  136. // P0 and P1 hold XOR shares of x. Compute additive shares of the same
  137. // x. x is at most nbits bits long.
  138. //
  139. // Cost:
  140. // nbits-1 words sent in 1 message
  141. // consumes nbits-1 HalfTriples
  142. void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
  143. RegAS &as_x, RegXS xs_x,
  144. nbits_t nbits)
  145. {
  146. const value_t mask = MASKBITS(nbits);
  147. // We use the fact that for any nbits-bit A and B,
  148. // A+B = (A XOR B) + 2*(A AND B) mod 2^nbits
  149. // so if we have additive shares C0 and C1 of 2*(A AND B)
  150. // (so C0 + C1 = 2*(A AND B)), then (A-C0) and (B-C1) are
  151. // additive shares of (A XOR B).
  152. // To get additive shares of 2*(A AND B) (mod 2^nbits), we first
  153. // note that we can ignore the top bits of A and B, since the
  154. // multiplication by 2 will shift it out of the nbits-bit range.
  155. // For the other bits, use valuemult to get the product of the
  156. // corresponding bit i of A and B (i=0..nbits-2), and compute
  157. // C = \sum_i 2^{i+1} * (A_i * B_i).
  158. // This can all be done in a single message, using the coroutine
  159. // mechanism to have all nbits-1 instances of valuemult queue their
  160. // message, then yield, so that all of their messages get sent at
  161. // once, then each will read their results.
  162. RegAS as_bitand[nbits-1];
  163. std::vector<coro_t> coroutines;
  164. for (nbits_t i=0; i<nbits-1; ++i) {
  165. coroutines.emplace_back(
  166. [&](yield_t &yield) {
  167. mpc_valuemul(tio, yield, as_bitand[i], (xs_x.xshare>>i)&1, nbits);
  168. });
  169. }
  170. run_coroutines(yield, coroutines);
  171. value_t as_C = 0;
  172. for (nbits_t i=0; i<nbits-1; ++i) {
  173. as_C += (as_bitand[i].ashare<<(i+1));
  174. }
  175. as_x.ashare = (xs_x.xshare - as_C) & mask;
  176. }