rdpf.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "rdpf.hpp"
  3. #include "bitutils.hpp"
  4. #include "mpcops.hpp"
  5. // Compute the multiplicative inverse of x mod 2^{VALUE_BITS}
  6. // This is the same as computing x to the power of
  7. // 2^{VALUE_BITS-1}-1.
  8. static value_t inverse_value_t(value_t x)
  9. {
  10. int expon = 1;
  11. value_t xe = x;
  12. // Invariant: xe = x^(2^expon - 1) mod 2^{VALUE_BITS}
  13. // Goal: compute x^(2^{VALUE_BITS-1} - 1)
  14. while (expon < VALUE_BITS-1) {
  15. xe = xe * xe * x;
  16. ++expon;
  17. }
  18. return xe;
  19. }
  20. // Construct a DPF with the given (XOR-shared) target location, and
  21. // of the given depth, to be used for random-access memory reads and
  22. // writes. The DPF is construction collaboratively by P0 and P1,
  23. // with the server P2 helping by providing various kinds of
  24. // correlated randomness, such as MultTriples and AndTriples.
  25. //
  26. // This algorithm is based on Appendix C from the Duoram paper, with a
  27. // small optimization noted below.
  28. RDPF::RDPF(MPCTIO &tio, yield_t &yield,
  29. RegXS target, nbits_t depth, bool save_expansion)
  30. {
  31. int player = tio.player();
  32. size_t &aes_ops = tio.aes_ops();
  33. // Choose a random seed
  34. arc4random_buf(&seed, sizeof(seed));
  35. // Ensure the flag bits (the lsb of each node) are different
  36. seed = set_lsb(seed, !!player);
  37. cfbits = 0;
  38. whichhalf = (player == 1);
  39. // The root level is just the seed
  40. nbits_t level = 0;
  41. DPFnode *curlevel = NULL;
  42. DPFnode *nextlevel = new DPFnode[1];
  43. nextlevel[0] = seed;
  44. // Construct each intermediate level
  45. while(level < depth) {
  46. delete[] curlevel;
  47. curlevel = nextlevel;
  48. if (save_expansion && level == depth-1) {
  49. expansion.resize(1<<depth);
  50. nextlevel = expansion.data();
  51. } else {
  52. nextlevel = new DPFnode[1<<(level+1)];
  53. }
  54. // Invariant: curlevel has 2^level elements; nextlevel has
  55. // 2^{level+1} elements
  56. // The bit-shared choice bit is bit (depth-level-1) of the
  57. // XOR-shared target index
  58. RegBS bs_choice = target.bit(depth-level-1);
  59. size_t curlevel_size = (size_t(1)<<level);
  60. DPFnode L = _mm_setzero_si128();
  61. DPFnode R = _mm_setzero_si128();
  62. // The server doesn't need to do this computation, but it does
  63. // need to execute mpc_reconstruct_choice so that it sends
  64. // the AndTriples at the appropriate time.
  65. if (player < 2) {
  66. for(size_t i=0;i<curlevel_size;++i) {
  67. DPFnode lchild, rchild;
  68. prgboth(lchild, rchild, curlevel[i], aes_ops);
  69. L = (L ^ lchild);
  70. R = (R ^ rchild);
  71. if (nextlevel) {
  72. nextlevel[2*i] = lchild;
  73. nextlevel[2*i+1] = rchild;
  74. }
  75. }
  76. }
  77. // If we're going left (bs_choice = 0), we want the correction
  78. // word to be the XOR of our right side and our peer's right
  79. // side; if bs_choice = 1, it should be the XOR or our left side
  80. // and our peer's left side.
  81. // We also have to ensure that the flag bits (the lsb) of the
  82. // side that will end up the same be of course the same, but
  83. // also that the flag bits (the lsb) of the side that will end
  84. // up different _must_ be different. That is, it's not enough
  85. // for the nodes of the child selected by choice to be different
  86. // as 128-bit values; they also have to be different in their
  87. // lsb.
  88. // This is where we make a small optimization over Appendix C of
  89. // the Duoram paper: instead of keeping separate correction flag
  90. // bits for the left and right children, we observe that the low
  91. // bit of the overall correction word effectively serves as one
  92. // of those bits, so we just need to store one extra bit per
  93. // level, not two. (We arbitrarily choose the one for the right
  94. // child.)
  95. // Note that the XOR of our left and right child before and
  96. // after applying the correction word won't change, since the
  97. // correction word is applied to either both children or
  98. // neither, depending on the value of the parent's flag. So in
  99. // particular, the XOR of the flag bits won't change, and if our
  100. // children's flag's XOR equals our peer's children's flag's
  101. // XOR, then we won't have different flag bits even for the
  102. // children that have different 128-bit values.
  103. // So we compute our_parity = lsb(L^R)^player, and we XOR that
  104. // into the R value in the correction word computation. At the
  105. // same time, we exchange these parity values to compute the
  106. // combined parity, which we store in the DPF. Then when the
  107. // DPF is evaluated, if the parent's flag is set, not only apply
  108. // the correction work to both children, but also apply the
  109. // (combined) parity bit to just the right child. Then for
  110. // unequal nodes (where the flag bit is different), exactly one
  111. // of the four children (two for P0 and two for P1) will have
  112. // the parity bit applied, which will set the XOR of the lsb of
  113. // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
  114. // = 1 because everything cancels out except player (for which
  115. // one player is 0 and the other is 1).
  116. bool our_parity_bit = get_lsb(L ^ R) ^ !!player;
  117. DPFnode our_parity = lsb128_mask[our_parity_bit];
  118. DPFnode CW;
  119. bool peer_parity_bit;
  120. // Exchange the parities and do mpc_reconstruct_choice at the
  121. // same time (bundled into the same rounds)
  122. std::vector<coro_t> coroutines;
  123. coroutines.emplace_back(
  124. [&](yield_t &yield) {
  125. tio.queue_peer(&our_parity_bit, 1);
  126. yield();
  127. uint8_t peer_parity_byte;
  128. tio.recv_peer(&peer_parity_byte, 1);
  129. peer_parity_bit = peer_parity_byte & 1;
  130. });
  131. coroutines.emplace_back(
  132. [&](yield_t &yield) {
  133. mpc_reconstruct_choice(tio, yield, CW, bs_choice,
  134. (R ^ our_parity), L);
  135. });
  136. run_coroutines(yield, coroutines);
  137. bool parity_bit = our_parity_bit ^ peer_parity_bit;
  138. cfbits |= (value_t(parity_bit)<<level);
  139. DPFnode CWR = CW ^ lsb128_mask[parity_bit];
  140. if (player < 2) {
  141. if (level < depth-1) {
  142. for(size_t i=0;i<curlevel_size;++i) {
  143. bool flag = get_lsb(curlevel[i]);
  144. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  145. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  146. }
  147. } else {
  148. // Recall there are four potentially useful vectors that
  149. // can come out of a DPF:
  150. // - (single-bit) bitwise unit vector
  151. // - additive-shared unit vector
  152. // - XOR-shared scaled unit vector
  153. // - additive-shared scaled unit vector
  154. //
  155. // (No single DPF should be used for both of the first
  156. // two or both of the last two, though, since they're
  157. // correlated; you _can_ use one of the first two and
  158. // one of the last two.)
  159. //
  160. // For each 128-bit leaf, the low bit is the flag bit,
  161. // and we're guaranteed that the flag bits (and indeed
  162. // the whole 128-bit value) for P0 and P1 are the same
  163. // for every leaf except the target, and that the flag
  164. // bits definitely differ for the target (and the other
  165. // 127 bits are independently random on each side).
  166. //
  167. // We divide the 128-bit leaf into a low 64-bit word and
  168. // a high 64-bit word. We use the low word for the unit
  169. // vector and the high word for the scaled vector; this
  170. // choice is not arbitrary: the flag bit in the low word
  171. // means that the sum of all the low words (with P1's
  172. // low words negated) across both P0 and P1 is
  173. // definitely odd, so we can compute that sum's inverse
  174. // mod 2^64, and store it now during precomputation. At
  175. // evaluation time for the additive-shared unit vector,
  176. // we will output this global inverse times the low word
  177. // of each leaf, which will make the sum of all of those
  178. // values 1. (This technique replaces the protocol in
  179. // Appendix D of the Duoram paper.)
  180. //
  181. // For the scaled vector, we just have to compute shares
  182. // of what the scaled vector is a sharing _of_, but
  183. // that's just XORing or adding all of each party's
  184. // local high words; no communication needed.
  185. value_t low_sum = 0;
  186. value_t high_sum = 0;
  187. value_t high_xor = 0;
  188. for(size_t i=0;i<curlevel_size;++i) {
  189. bool flag = get_lsb(curlevel[i]);
  190. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  191. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  192. if (save_expansion) {
  193. nextlevel[2*i] = leftchild;
  194. nextlevel[2*i+1] = rightchild;
  195. }
  196. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  197. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  198. value_t lefthigh =
  199. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  200. value_t righthigh =
  201. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  202. low_sum += (leftlow + rightlow);
  203. high_sum += (lefthigh + righthigh);
  204. high_xor ^= (lefthigh ^ righthigh);
  205. }
  206. if (player == 1) {
  207. low_sum = -low_sum;
  208. high_sum = -high_sum;
  209. }
  210. scaled_sum.ashare = high_sum;
  211. scaled_xor.xshare = high_xor;
  212. // Exchange low_sum and add them up
  213. tio.queue_peer(&low_sum, sizeof(low_sum));
  214. yield();
  215. value_t peer_low_sum;
  216. tio.recv_peer(&peer_low_sum, sizeof(peer_low_sum));
  217. low_sum += peer_low_sum;
  218. // The low_sum had better be odd
  219. assert(low_sum & 1);
  220. unit_sum_inverse = inverse_value_t(low_sum);
  221. }
  222. cw.push_back(CW);
  223. }
  224. ++level;
  225. }
  226. delete[] curlevel;
  227. if (!save_expansion) {
  228. delete[] nextlevel;
  229. }
  230. }
  231. // The number of bytes it will take to store a RDPF of the given depth
  232. size_t RDPF::size(nbits_t depth)
  233. {
  234. return sizeof(seed) + depth*sizeof(DPFnode) + BITBYTES(depth) +
  235. sizeof(unit_sum_inverse) + sizeof(scaled_sum) +
  236. sizeof(scaled_xor);
  237. }
  238. // The number of bytes it will take to store this RDPF
  239. size_t RDPF::size() const
  240. {
  241. uint8_t depth = cw.size();
  242. return size(depth);
  243. }
  244. // Get the leaf node for the given input
  245. DPFnode RDPF::leaf(address_t input, size_t &aes_ops) const
  246. {
  247. // If we have a precomputed expansion, just use it
  248. if (expansion.size()) {
  249. return expansion[input];
  250. }
  251. nbits_t totdepth = depth();
  252. DPFnode node = seed;
  253. for (nbits_t d=0;d<totdepth;++d) {
  254. bit_t dir = !!(input & (address_t(1)<<(totdepth-d-1)));
  255. node = descend(node, d, dir, aes_ops);
  256. }
  257. return node;
  258. }
  259. // Expand the DPF if it's not already expanded
  260. //
  261. // This routine is slightly more efficient than repeatedly calling
  262. // Eval::next(), but it uses a lot more memory.
  263. void RDPF::expand(size_t &aes_ops)
  264. {
  265. nbits_t depth = this->depth();
  266. size_t num_leaves = size_t(1)<<depth;
  267. if (expansion.size() == num_leaves) return;
  268. expansion.resize(num_leaves);
  269. address_t index = 0;
  270. address_t lastindex = 0;
  271. DPFnode *path = new DPFnode[depth];
  272. path[0] = seed;
  273. for (nbits_t i=1;i<depth;++i) {
  274. path[i] = descend(path[i-1], i-1, 0, aes_ops);
  275. }
  276. expansion[index++] = descend(path[depth-1], depth-1, 0, aes_ops);
  277. expansion[index++] = descend(path[depth-1], depth-1, 1, aes_ops);
  278. while(index < num_leaves) {
  279. // Invariant: lastindex and index will both be even, and
  280. // index=lastindex+2
  281. uint64_t index_xor = index ^ lastindex;
  282. nbits_t how_many_1_bits = __builtin_popcount(index_xor);
  283. // If lastindex -> index goes for example from (in binary)
  284. // 010010110 -> 010011000, then index_xor will be
  285. // 000001110 and how_many_1_bits will be 3.
  286. // That indicates that path[depth-3] was a left child, and now
  287. // we need to change it to a right child by descending right
  288. // from path[depth-4], and then filling the path after that with
  289. // left children.
  290. path[depth-how_many_1_bits] =
  291. descend(path[depth-how_many_1_bits-1],
  292. depth-how_many_1_bits-1, 1, aes_ops);
  293. for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
  294. path[i+1] = descend(path[i], i, 0, aes_ops);
  295. }
  296. lastindex = index;
  297. expansion[index++] = descend(path[depth-1], depth-1, 0, aes_ops);
  298. expansion[index++] = descend(path[depth-1], depth-1, 1, aes_ops);
  299. }
  300. delete[] path;
  301. }
  302. // Construct three RDPFs of the given depth all with the same randomly
  303. // generated target index.
  304. RDPFTriple::RDPFTriple(MPCTIO &tio, yield_t &yield,
  305. nbits_t depth, bool save_expansion)
  306. {
  307. // Pick a random XOR share of the target
  308. xs_target.randomize(depth);
  309. // Now create three RDPFs with that target, and also convert the XOR
  310. // shares of the target to additive shares
  311. std::vector<coro_t> coroutines;
  312. for (int i=0;i<3;++i) {
  313. coroutines.emplace_back(
  314. [&, i](yield_t &yield) {
  315. dpf[i] = RDPF(tio, yield, xs_target, depth,
  316. save_expansion);
  317. });
  318. }
  319. coroutines.emplace_back(
  320. [&](yield_t &yield) {
  321. mpc_xs_to_as(tio, yield, as_target, xs_target, depth);
  322. });
  323. run_coroutines(yield, coroutines);
  324. }
  325. RDPFTriple::node RDPFTriple::descend(const RDPFTriple::node &parent,
  326. nbits_t parentdepth, bit_t whichchild,
  327. size_t &aes_ops) const
  328. {
  329. auto [P0, P1, P2] = parent;
  330. DPFnode C0, C1, C2;
  331. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  332. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  333. C2 = dpf[2].descend(P2, parentdepth, whichchild, aes_ops);
  334. return std::make_tuple(C0,C1,C2);
  335. }
  336. RDPFPair::node RDPFPair::descend(const RDPFPair::node &parent,
  337. nbits_t parentdepth, bit_t whichchild,
  338. size_t &aes_ops) const
  339. {
  340. auto [P0, P1] = parent;
  341. DPFnode C0, C1;
  342. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  343. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  344. return std::make_tuple(C0,C1);
  345. }