rdpf.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "rdpf.hpp"
  3. #include "bitutils.hpp"
  4. #include "mpcops.hpp"
  5. #include "aes.hpp"
  6. #include "prg.hpp"
  7. // Don't warn if we never actually use these functions
  8. static void dump_node(DPFnode node, const char *label = NULL)
  9. __attribute__ ((unused));
  10. static void dump_level(DPFnode *nodes, size_t num, const char *label = NULL)
  11. __attribute__ ((unused));
  12. static void dump_node(DPFnode node, const char *label)
  13. {
  14. if (label) printf("%s: ", label);
  15. for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&node)[15-i]); } printf("\n");
  16. }
  17. static void dump_level(DPFnode *nodes, size_t num, const char *label)
  18. {
  19. if (label) printf("%s:\n", label);
  20. for (size_t i=0;i<num;++i) {
  21. dump_node(nodes[i]);
  22. }
  23. printf("\n");
  24. }
  25. // Compute the multiplicative inverse of x mod 2^{VALUE_BITS}
  26. // This is the same as computing x to the power of
  27. // 2^{VALUE_BITS-1}-1.
  28. static value_t inverse_value_t(value_t x)
  29. {
  30. int expon = 1;
  31. value_t xe = x;
  32. // Invariant: xe = x^(2^expon - 1) mod 2^{VALUE_BITS}
  33. // Goal: compute x^(2^{VALUE_BITS-1} - 1)
  34. while (expon < VALUE_BITS-1) {
  35. xe = xe * xe * x;
  36. ++expon;
  37. }
  38. return xe;
  39. }
  40. // Construct a DPF with the given (XOR-shared) target location, and
  41. // of the given depth, to be used for random-access memory reads and
  42. // writes. The DPF is construction collaboratively by P0 and P1,
  43. // with the server P2 helping by providing various kinds of
  44. // correlated randomness, such as MultTriples and AndTriples.
  45. //
  46. // This algorithm is based on Appendix C from the Duoram paper, with a
  47. // small optimization noted below.
  48. RDPF::RDPF(MPCTIO &tio, yield_t &yield,
  49. RegXS target, nbits_t depth)
  50. {
  51. int player = tio.player();
  52. size_t &aesops = tio.aes_ops();
  53. // Choose a random seed
  54. arc4random_buf(&seed, sizeof(seed));
  55. // Ensure the flag bits (the lsb of each node) are different
  56. seed = set_lsb(seed, !!player);
  57. cfbits = 0;
  58. whichhalf = (player == 1);
  59. // The root level is just the seed
  60. nbits_t level = 0;
  61. DPFnode *curlevel = NULL;
  62. DPFnode *nextlevel = new DPFnode[1];
  63. nextlevel[0] = seed;
  64. // Construct each intermediate level
  65. while(level < depth) {
  66. delete[] curlevel;
  67. curlevel = nextlevel;
  68. nextlevel = new DPFnode[1<<(level+1)];
  69. // Invariant: curlevel has 2^level elements; nextlevel has
  70. // 2^{level+1} elements
  71. // The bit-shared choice bit is bit (depth-level-1) of the
  72. // XOR-shared target index
  73. RegBS bs_choice = target.bit(depth-level-1);
  74. size_t curlevel_size = (size_t(1)<<level);
  75. DPFnode L = _mm_setzero_si128();
  76. DPFnode R = _mm_setzero_si128();
  77. // The server doesn't need to do this computation, but it does
  78. // need to execute mpc_reconstruct_choice so that it sends
  79. // the AndTriples at the appropriate time.
  80. if (player < 2) {
  81. for(size_t i=0;i<curlevel_size;++i) {
  82. DPFnode lchild, rchild;
  83. prgboth(lchild, rchild, curlevel[i], aesops);
  84. L = (L ^ lchild);
  85. R = (R ^ rchild);
  86. if (nextlevel) {
  87. nextlevel[2*i] = lchild;
  88. nextlevel[2*i+1] = rchild;
  89. }
  90. }
  91. }
  92. // If we're going left (bs_choice = 0), we want the correction
  93. // word to be the XOR of our right side and our peer's right
  94. // side; if bs_choice = 1, it should be the XOR or our left side
  95. // and our peer's left side.
  96. // We also have to ensure that the flag bits (the lsb) of the
  97. // side that will end up the same be of course the same, but
  98. // also that the flag bits (the lsb) of the side that will end
  99. // up different _must_ be different. That is, it's not enough
  100. // for the nodes of the child selected by choice to be different
  101. // as 128-bit values; they also have to be different in their
  102. // lsb.
  103. // This is where we make a small optimization over Appendix C of
  104. // the Duoram paper: instead of keeping separate correction flag
  105. // bits for the left and right children, we observe that the low
  106. // bit of the overall correction word effectively serves as one
  107. // of those bits, so we just need to store one extra bit per
  108. // level, not two. (We arbitrarily choose the one for the right
  109. // child.)
  110. // Note that the XOR of our left and right child before and
  111. // after applying the correction word won't change, since the
  112. // correction word is applied to either both children or
  113. // neither, depending on the value of the parent's flag. So in
  114. // particular, the XOR of the flag bits won't change, and if our
  115. // children's flag's XOR equals our peer's children's flag's
  116. // XOR, then we won't have different flag bits even for the
  117. // children that have different 128-bit values.
  118. // So we compute our_parity = lsb(L^R)^player, and we XOR that
  119. // into the R value in the correction word computation. At the
  120. // same time, we exchange these parity values to compute the
  121. // combined parity, which we store in the DPF. Then when the
  122. // DPF is evaluated, if the parent's flag is set, not only apply
  123. // the correction work to both children, but also apply the
  124. // (combined) parity bit to just the right child. Then for
  125. // unequal nodes (where the flag bit is different), exactly one
  126. // of the four children (two for P0 and two for P1) will have
  127. // the parity bit applied, which will set the XOR of the lsb of
  128. // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
  129. // = 1 because everything cancels out except player (for which
  130. // one player is 0 and the other is 1).
  131. bool our_parity_bit = get_lsb(L ^ R) ^ !!player;
  132. DPFnode our_parity = lsb128_mask[our_parity_bit];
  133. DPFnode CW;
  134. bool peer_parity_bit;
  135. // Exchange the parities and do mpc_reconstruct_choice at the
  136. // same time (bundled into the same rounds)
  137. std::vector<coro_t> coroutines;
  138. coroutines.emplace_back(
  139. [&](yield_t &yield) {
  140. tio.queue_peer(&our_parity_bit, 1);
  141. yield();
  142. uint8_t peer_parity_byte;
  143. tio.recv_peer(&peer_parity_byte, 1);
  144. peer_parity_bit = peer_parity_byte & 1;
  145. });
  146. coroutines.emplace_back(
  147. [&](yield_t &yield) {
  148. mpc_reconstruct_choice(tio, yield, CW, bs_choice,
  149. (R ^ our_parity), L);
  150. });
  151. run_coroutines(yield, coroutines);
  152. bool parity_bit = our_parity_bit ^ peer_parity_bit;
  153. cfbits |= (size_t(parity_bit)<<level);
  154. DPFnode CWR = CW ^ lsb128_mask[parity_bit];
  155. if (player < 2) {
  156. if (level < depth-1) {
  157. for(size_t i=0;i<curlevel_size;++i) {
  158. bool flag = get_lsb(curlevel[i]);
  159. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  160. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  161. }
  162. } else {
  163. // Recall there are four potentially useful vectors that
  164. // can come out of a DPF:
  165. // - (single-bit) bitwise unit vector
  166. // - additive-shared unit vector
  167. // - XOR-shared scaled unit vector
  168. // - additive-shared scaled unit vector
  169. //
  170. // (No single DPF should be used for both of the first
  171. // two or both of the last two, though, since they're
  172. // correlated; you _can_ use one of the first two and
  173. // one of the last two.)
  174. //
  175. // For each 128-bit leaf, the low bit is the flag bit,
  176. // and we're guaranteed that the flag bits (and indeed
  177. // the whole 128-bit value) for P0 and P1 are the same
  178. // for every leaf except the target, and that the flag
  179. // bits definitely differ for the target (and the other
  180. // 127 bits are independently random on each side).
  181. //
  182. // We divide the 128-bit leaf into a low 64-bit word and
  183. // a high 64-bit word. We use the low word for the unit
  184. // vector and the high word for the scaled vector; this
  185. // choice is not arbitrary: the flag bit in the low word
  186. // means that the sum of all the low words (with P1's
  187. // low words negated) across both P0 and P1 is
  188. // definitely odd, so we can compute that sum's inverse
  189. // mod 2^64, and store it now during precomputation. At
  190. // evaluation time for the additive-shared unit vector,
  191. // we will output this global inverse times the low word
  192. // of each leaf, which will make the sum of all of those
  193. // values 1.
  194. //
  195. // For the scaled vector, we just have to compute shares
  196. // of what the scaled vector is a sharing _of_, but
  197. // that's just XORing or adding all of each party's
  198. // local high words; no communication needed.
  199. value_t low_sum = 0;
  200. value_t high_sum = 0;
  201. value_t high_xor = 0;
  202. for(size_t i=0;i<curlevel_size;++i) {
  203. bool flag = get_lsb(curlevel[i]);
  204. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  205. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  206. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  207. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  208. value_t lefthigh =
  209. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  210. value_t righthigh =
  211. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  212. low_sum += (leftlow + rightlow);
  213. high_sum += (lefthigh + righthigh);
  214. high_xor ^= (lefthigh ^ righthigh);
  215. }
  216. if (player == 1) {
  217. low_sum = -low_sum;
  218. high_sum = -high_sum;
  219. }
  220. scaled_sum.ashare = high_sum;
  221. scaled_xor.xshare = high_xor;
  222. // Exchange low_sum and add them up
  223. tio.queue_peer(&low_sum, sizeof(low_sum));
  224. yield();
  225. value_t peer_low_sum;
  226. tio.recv_peer(&peer_low_sum, sizeof(peer_low_sum));
  227. low_sum += peer_low_sum;
  228. // The low_sum had better be odd
  229. assert(low_sum & 1);
  230. unit_sum_inverse = inverse_value_t(low_sum);
  231. }
  232. cw.push_back(CW);
  233. }
  234. ++level;
  235. }
  236. delete[] curlevel;
  237. delete[] nextlevel;
  238. }
  239. // The number of bytes it will take to store a RDPF of the given depth
  240. size_t RDPF::size(nbits_t depth)
  241. {
  242. return sizeof(seed) + depth*sizeof(DPFnode) + BITBYTES(depth) +
  243. sizeof(unit_sum_inverse) + sizeof(scaled_sum) +
  244. sizeof(scaled_xor);
  245. }
  246. // The number of bytes it will take to store this RDPF
  247. size_t RDPF::size() const
  248. {
  249. uint8_t depth = cw.size();
  250. return size(depth);
  251. }
  252. // Descend from a node at depth parentdepth to one of its children
  253. // whichchild = 0: left child
  254. // whichchild = 1: right child
  255. DPFnode RDPF::descend(const DPFnode parent, nbits_t parentdepth,
  256. bit_t whichchild, size_t &op_counter) const
  257. {
  258. DPFnode prgout;
  259. bool flag = get_lsb(parent);
  260. prg(prgout, parent, whichchild, op_counter);
  261. if (flag) {
  262. DPFnode CW = cw[parentdepth];
  263. bit_t cfbit = !!(cfbits & (value_t(1)<<parentdepth));
  264. DPFnode CWR = CW ^ lsb128_mask[cfbit];
  265. prgout ^= (whichchild ? CWR : CW);
  266. }
  267. return prgout;
  268. }
  269. // Get the leaf node for the given input
  270. DPFnode RDPF::leaf(address_t input, size_t &op_counter) const
  271. {
  272. nbits_t totdepth = depth();
  273. DPFnode node = seed;
  274. for (nbits_t d=0;d<totdepth;++d) {
  275. bit_t dir = !!(input & (address_t(1)<<(totdepth-d-1)));
  276. node = descend(node, d, dir, op_counter);
  277. }
  278. return node;
  279. }
  280. // Construct three RDPFs of the given depth all with the same randomly
  281. // generated target index.
  282. RDPFTriple::RDPFTriple(MPCTIO &tio, yield_t &yield,
  283. nbits_t depth)
  284. {
  285. // Pick a random XOR share of the target
  286. xs_target.randomize(depth);
  287. // Now create three RDPFs with that target, and also convert the XOR
  288. // shares of the target to additive shares
  289. std::vector<coro_t> coroutines;
  290. for (int i=0;i<3;++i) {
  291. coroutines.emplace_back(
  292. [&, i](yield_t &yield) {
  293. dpf[i] = RDPF(tio, yield, xs_target, depth);
  294. });
  295. }
  296. coroutines.emplace_back(
  297. [&](yield_t &yield) {
  298. mpc_xs_to_as(tio, yield, as_target, xs_target, depth);
  299. });
  300. run_coroutines(yield, coroutines);
  301. }