mpcops.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. #include "mpcops.hpp"
  2. #include "bitutils.hpp"
  3. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  4. // (shares are y0 and y1); compute additive shares of z = x*y =
  5. // (x0+x1)*(y0+y1). x, y, and z are each at most nbits bits long.
  6. //
  7. // Cost:
  8. // 2 words sent in 1 message
  9. // consumes 1 MultTriple
  10. void mpc_mul(MPCTIO &tio, yield_t &yield,
  11. RegAS &z, RegAS x, RegAS y,
  12. nbits_t nbits)
  13. {
  14. const value_t mask = MASKBITS(nbits);
  15. // Compute z to be an additive share of (x0*y1+y0*x1)
  16. mpc_cross(tio, yield, z, x, y, nbits);
  17. // Add x0*y0 (the peer will add x1*y1)
  18. z.ashare = (z.ashare + x.ashare * y.ashare) & mask;
  19. }
  20. // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
  21. // (shares are y0 and y1); compute additive shares of z = x0*y1 + y0*x1.
  22. // x, y, and z are each at most nbits bits long.
  23. //
  24. // Cost:
  25. // 2 words sent in 1 message
  26. // consumes 1 MultTriple
  27. void mpc_cross(MPCTIO &tio, yield_t &yield,
  28. RegAS &z, RegAS x, RegAS y,
  29. nbits_t nbits)
  30. {
  31. const value_t mask = MASKBITS(nbits);
  32. size_t nbytes = BITBYTES(nbits);
  33. auto [X, Y, Z] = tio.multtriple(yield);
  34. // Send x+X and y+Y
  35. value_t blind_x = (x.ashare + X) & mask;
  36. value_t blind_y = (y.ashare + Y) & mask;
  37. tio.queue_peer(&blind_x, nbytes);
  38. tio.queue_peer(&blind_y, nbytes);
  39. yield();
  40. // Read the peer's x+X and y+Y
  41. value_t peer_blind_x=0, peer_blind_y=0;
  42. tio.recv_peer(&peer_blind_x, nbytes);
  43. tio.recv_peer(&peer_blind_y, nbytes);
  44. z.ashare = ((x.ashare * peer_blind_y) - (Y * peer_blind_x) + Z) & mask;
  45. }
  46. // P0 holds the (complete) value x, P1 holds the (complete) value y;
  47. // compute additive shares of z = x*y. x, y, and z are each at most
  48. // nbits bits long. The parameter is called x, but P1 will pass y
  49. // there. When called by another task during preprocessing, set tally
  50. // to false so that the required halftriples aren't accounted for
  51. // separately from the main preprocessing task.
  52. //
  53. // Cost:
  54. // 1 word sent in 1 message
  55. // consumes 1 HalfTriple
  56. void mpc_valuemul(MPCTIO &tio, yield_t &yield,
  57. RegAS &z, value_t x,
  58. nbits_t nbits, bool tally)
  59. {
  60. const value_t mask = MASKBITS(nbits);
  61. size_t nbytes = BITBYTES(nbits);
  62. auto [X, Z] = tio.halftriple(yield, tally);
  63. // Send x+X
  64. value_t blind_x = (x + X) & mask;
  65. tio.queue_peer(&blind_x, nbytes);
  66. yield();
  67. // Read the peer's y+Y
  68. value_t peer_blind_y=0;
  69. tio.recv_peer(&peer_blind_y, nbytes);
  70. if (tio.player() == 0) {
  71. z.ashare = ((x * peer_blind_y) + Z) & mask;
  72. } else if (tio.player() == 1) {
  73. z.ashare = ((-X * peer_blind_y) + Z) & mask;
  74. }
  75. }
  76. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  77. // shares y0 and y1 of the value y; compute additive shares of
  78. // z = f * y = (f0 XOR f1) * (y0 + y1). y and z are each at most nbits
  79. // bits long.
  80. //
  81. // Cost:
  82. // 2 words sent in 1 message
  83. // consumes 1 MultTriple
  84. void mpc_flagmult(MPCTIO &tio, yield_t &yield,
  85. RegAS &z, RegBS f, RegAS y,
  86. nbits_t nbits)
  87. {
  88. const value_t mask = MASKBITS(nbits);
  89. // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
  90. value_t bs_fval = value_t(f.bshare);
  91. RegAS fval;
  92. fval.ashare = bs_fval;
  93. mpc_cross(tio, yield, z, y*(1-2*bs_fval), fval, nbits);
  94. // Add f0*y0 (and the peer will add f1*y1)
  95. z.ashare = (z.ashare + bs_fval*y.ashare) & mask;
  96. // Now the shares add up to:
  97. // [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0 + f0*y0 + f1*y1
  98. // which you can rearrange to see that it's equal to the desired
  99. // (f0 + f1 - 2*f0*f1)*(y0+y1), since f0 XOR f1 = (f0 + f1 - 2*f0*f1).
  100. }
  101. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  102. // shares of the values x and y; compute additive shares of z, where
  103. // z = x if f=0 and z = y if f=1. x, y, and z are each at most nbits
  104. // bits long.
  105. //
  106. // Cost:
  107. // 2 words sent in 1 message
  108. // consumes 1 MultTriple
  109. void mpc_select(MPCTIO &tio, yield_t &yield,
  110. RegAS &z, RegBS f, RegAS x, RegAS y,
  111. nbits_t nbits)
  112. {
  113. const value_t mask = MASKBITS(nbits);
  114. // The desired result is z = x + f * (y-x)
  115. mpc_flagmult(tio, yield, z, f, y-x, nbits);
  116. z.ashare = (z.ashare + x.ashare) & mask;
  117. }
  118. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and XOR
  119. // shares of the values x and y; compute XOR shares of z, where z = x if
  120. // f=0 and z = y if f=1. x, y, and z are each at most nbits bits long.
  121. //
  122. // Cost:
  123. // 2 words sent in 1 message
  124. // consumes 1 SelectTriple
  125. void mpc_select(MPCTIO &tio, yield_t &yield,
  126. RegXS &z, RegBS f, RegXS x, RegXS y,
  127. nbits_t nbits)
  128. {
  129. const value_t mask = MASKBITS(nbits);
  130. size_t nbytes = BITBYTES(nbits);
  131. // Sign-extend f (so 0 -> 0000...0; 1 -> 1111...1)
  132. value_t fext = (-value_t(f.bshare)) & mask;
  133. // Compute XOR shares of f & (x ^ y)
  134. auto [X, Y, Z] = tio.valselecttriple(yield);
  135. bit_t blind_f = f.bshare ^ X;
  136. value_t d = (x.xshare ^ y.xshare) & mask;
  137. value_t blind_d = (d ^ Y) & mask;
  138. // Send the blinded values
  139. tio.queue_peer(&blind_f, sizeof(blind_f));
  140. tio.queue_peer(&blind_d, nbytes);
  141. yield();
  142. // Read the peer's values
  143. bit_t peer_blind_f = 0;
  144. value_t peer_blind_d;
  145. tio.recv_peer(&peer_blind_f, sizeof(peer_blind_f));
  146. peer_blind_f &= 1;
  147. tio.recv_peer(&peer_blind_d, nbytes);
  148. peer_blind_d &= mask;
  149. // Compute our share of f ? x : y = (f * (x ^ y))^x
  150. value_t peer_blind_fext = -value_t(peer_blind_f);
  151. z.xshare = ((fext & peer_blind_d) ^ (Y & peer_blind_fext) ^
  152. (fext & d) ^ (Z ^ x.xshare)) & mask;
  153. }
  154. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and bit
  155. // shares of the values x and y; compute bit shares of z, where
  156. // z = x if f=0 and z = y if f=1.
  157. //
  158. // Cost:
  159. // 1 byte sent in 1 message
  160. // consumes 1/64 AndTriple
  161. void mpc_select(MPCTIO &tio, yield_t &yield,
  162. RegBS &z, RegBS f, RegBS x, RegBS y)
  163. {
  164. // The desired result is z = x ^ (f & (y^x))
  165. mpc_and(tio, yield, z, f, y^x);
  166. z ^= x;
  167. }
  168. // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
  169. // shares of the values x and y. Obliviously swap x and y; that is,
  170. // replace x and y with new additive sharings of x and y respectively
  171. // (if f=0) or y and x respectively (if f=1). x and y are each at most
  172. // nbits bits long.
  173. //
  174. // Cost:
  175. // 2 words sent in 1 message
  176. // consumes 1 MultTriple
  177. void mpc_oswap(MPCTIO &tio, yield_t &yield,
  178. RegAS &x, RegAS &y, RegBS f,
  179. nbits_t nbits)
  180. {
  181. const value_t mask = MASKBITS(nbits);
  182. // Let s = f*(y-x). Then the desired result is
  183. // x <- x + s, y <- y - s.
  184. RegAS s;
  185. mpc_flagmult(tio, yield, s, f, y-x, nbits);
  186. x.ashare = (x.ashare + s.ashare) & mask;
  187. y.ashare = (y.ashare - s.ashare) & mask;
  188. }
  189. // P0 and P1 hold XOR shares of x. Compute additive shares of the same
  190. // x. x is at most nbits bits long. When called by another task during
  191. // preprocessing, set tally to false so that the required halftriples
  192. // aren't accounted for separately from the main preprocessing task.
  193. //
  194. // Cost:
  195. // nbits-1 words sent in 1 message
  196. // consumes nbits-1 HalfTriples
  197. void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
  198. RegAS &as_x, RegXS xs_x,
  199. nbits_t nbits, bool tally)
  200. {
  201. const value_t mask = MASKBITS(nbits);
  202. // We use the fact that for any nbits-bit A and B,
  203. // A+B = (A XOR B) + 2*(A AND B) mod 2^nbits
  204. // so if we have additive shares C0 and C1 of 2*(A AND B)
  205. // (so C0 + C1 = 2*(A AND B)), then (A-C0) and (B-C1) are
  206. // additive shares of (A XOR B).
  207. // To get additive shares of 2*(A AND B) (mod 2^nbits), we first
  208. // note that we can ignore the top bits of A and B, since the
  209. // multiplication by 2 will shift it out of the nbits-bit range.
  210. // For the other bits, use valuemult to get the product of the
  211. // corresponding bit i of A and B (i=0..nbits-2), and compute
  212. // C = \sum_i 2^{i+1} * (A_i * B_i).
  213. // This can all be done in a single message, using the coroutine
  214. // mechanism to have all nbits-1 instances of valuemult queue their
  215. // message, then yield, so that all of their messages get sent at
  216. // once, then each will read their results.
  217. RegAS as_bitand[nbits-1];
  218. std::vector<coro_t> coroutines;
  219. for (nbits_t i=0; i<nbits-1; ++i) {
  220. coroutines.emplace_back(
  221. [&tio, &as_bitand, &xs_x, i, nbits, tally](yield_t &yield) {
  222. mpc_valuemul(tio, yield, as_bitand[i],
  223. (xs_x.xshare>>i)&1, nbits, tally);
  224. });
  225. }
  226. run_coroutines(yield, coroutines);
  227. value_t as_C = 0;
  228. for (nbits_t i=0; i<nbits-1; ++i) {
  229. as_C += (as_bitand[i].ashare<<(i+1));
  230. }
  231. as_x.ashare = (xs_x.xshare - as_C) & mask;
  232. }
  233. // P0 and P1 hold XOR shares x0 and x1 of x. x is at most nbits bits
  234. // long. Return x to P0 and P1 (and 0 to P2).
  235. //
  236. // Cost: 1 word sent in 1 message
  237. value_t mpc_reconstruct(MPCTIO &tio, yield_t &yield,
  238. RegXS x, nbits_t nbits)
  239. {
  240. RegXS res;
  241. size_t nbytes = BITBYTES(nbits);
  242. if (tio.player() < 2) {
  243. tio.queue_peer(&x, nbytes);
  244. yield();
  245. tio.recv_peer(&res, nbytes);
  246. res ^= x;
  247. } else {
  248. yield();
  249. }
  250. return res.xshare;
  251. }
  252. // P0 and P1 hold additive shares x0 and x1 of x. x is at most nbits
  253. // bits long. Return x to P0 and P1 (and 0 to P2).
  254. //
  255. // Cost: 1 word sent in 1 message
  256. value_t mpc_reconstruct(MPCTIO &tio, yield_t &yield,
  257. RegAS x, nbits_t nbits)
  258. {
  259. RegAS res;
  260. size_t nbytes = BITBYTES(nbits);
  261. if (tio.player() < 2) {
  262. tio.queue_peer(&x, nbytes);
  263. yield();
  264. tio.recv_peer(&res, nbytes);
  265. res += x;
  266. } else {
  267. yield();
  268. }
  269. return res.ashare;
  270. }
  271. // P0 and P1 hold bit shares f0 and f1 of f. Return f to P0 and P1 (and
  272. // 0 to P2).
  273. //
  274. // Cost: 1 word sent in 1 message
  275. bool mpc_reconstruct(MPCTIO &tio, yield_t &yield, RegBS f)
  276. {
  277. RegBS res;
  278. if (tio.player() < 2) {
  279. tio.queue_peer(&f, 1);
  280. yield();
  281. tio.recv_peer(&res, 1);
  282. res ^= f;
  283. } else {
  284. yield();
  285. }
  286. return res.bshare;
  287. }
  288. // P0 and P1 hold bit shares of f, and DPFnode XOR shares x0,y0 and
  289. // x1,y1 of x and y. Set z to x=x0^x1 if f=0 and to y=y0^y1 if f=1.
  290. //
  291. // Cost:
  292. // 6 64-bit words sent in 2 messages
  293. // consumes one AndTriple
  294. void mpc_reconstruct_choice(MPCTIO &tio, yield_t &yield,
  295. DPFnode &z, RegBS f, DPFnode x, DPFnode y)
  296. {
  297. // Sign-extend f (so 0 -> 0000...0; 1 -> 1111...1)
  298. DPFnode fext = if128_mask[f.bshare];
  299. // Compute XOR shares of f & (x ^ y)
  300. auto [X, Y, Z] = tio.nodeselecttriple(yield);
  301. bit_t blind_f = f.bshare ^ X;
  302. DPFnode d = x ^ y;
  303. DPFnode blind_d = d ^ Y;
  304. // Send the blinded values
  305. tio.queue_peer(&blind_f, sizeof(blind_f));
  306. tio.queue_peer(&blind_d, sizeof(blind_d));
  307. yield();
  308. // Read the peer's values
  309. bit_t peer_blind_f = 0;
  310. DPFnode peer_blind_d;
  311. tio.recv_peer(&peer_blind_f, sizeof(peer_blind_f));
  312. tio.recv_peer(&peer_blind_d, sizeof(peer_blind_d));
  313. // Compute _our share_ of f ? x : y = (f * (x ^ y))^x
  314. DPFnode peer_blind_fext = if128_mask[peer_blind_f];
  315. DPFnode zshare =
  316. (fext & peer_blind_d) ^ (Y & peer_blind_fext) ^
  317. (fext & d) ^ (Z ^ x);
  318. // Now exchange shares
  319. tio.queue_peer(&zshare, sizeof(zshare));
  320. yield();
  321. DPFnode peer_zshare;
  322. tio.recv_peer(&peer_zshare, sizeof(peer_zshare));
  323. z = zshare ^ peer_zshare;
  324. }
  325. // P0 and P1 hold bit shares of x and y. Set z to bit shares of x & y.
  326. //
  327. // Cost:
  328. // 1 byte sent in 1 message
  329. // consumes 1/64 AndTriple
  330. void mpc_and(MPCTIO &tio, yield_t &yield,
  331. RegBS &z, RegBS x, RegBS y)
  332. {
  333. // Compute XOR shares of x & y
  334. auto T = tio.bitselecttriple(yield);
  335. bit_t blind_x = x.bshare ^ T.X;
  336. bit_t blind_y = y.bshare ^ T.Y;
  337. // Send the blinded values
  338. uint8_t v = (blind_x << 1) | blind_y;
  339. tio.queue_peer(&v, sizeof(v));
  340. yield();
  341. // Read the peer's values
  342. bit_t peer_blind_x = 0;
  343. bit_t peer_blind_y = 0;
  344. uint8_t peer_v = 0;
  345. tio.recv_peer(&peer_v, sizeof(peer_v));
  346. peer_blind_x = (peer_v >> 1) & 1;
  347. peer_blind_y = peer_v & 1;
  348. // Compute our share of x & y
  349. z.bshare = (x.bshare & peer_blind_y) ^ (T.Y & peer_blind_x) ^
  350. (x.bshare & y.bshare) ^ T.Z;
  351. }
  352. // P0 and P1 hold bit shares of x and y. Set z to bit shares of x | y.
  353. //
  354. // Cost:
  355. // 1 byte sent in 1 message
  356. // consumes 1/64 AndTriple
  357. void mpc_or(MPCTIO &tio, yield_t &yield,
  358. RegBS &z, RegBS x, RegBS y)
  359. {
  360. if (tio.player() == 0) {
  361. x.bshare = !x.bshare;
  362. y.bshare = !y.bshare;
  363. }
  364. mpc_and(tio, yield, z, x, y);
  365. if (tio.player() == 0) {
  366. z.bshare = !z.bshare;
  367. }
  368. }