rdpf.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "rdpf.hpp"
  3. #include "bitutils.hpp"
  4. #include "mpcops.hpp"
  5. // Compute the multiplicative inverse of x mod 2^{VALUE_BITS}
  6. // This is the same as computing x to the power of
  7. // 2^{VALUE_BITS-1}-1.
  8. static value_t inverse_value_t(value_t x)
  9. {
  10. int expon = 1;
  11. value_t xe = x;
  12. // Invariant: xe = x^(2^expon - 1) mod 2^{VALUE_BITS}
  13. // Goal: compute x^(2^{VALUE_BITS-1} - 1)
  14. while (expon < VALUE_BITS-1) {
  15. xe = xe * xe * x;
  16. ++expon;
  17. }
  18. return xe;
  19. }
  20. #undef RDPF_MTGEN_TIMING_1
  21. #ifdef RDPF_MTGEN_TIMING_1
  22. // Timing tests for multithreaded generation of RDPFs
  23. // nthreads = 0 to not launch threads at all
  24. // run for num_iters iterations, output the number of millisections
  25. // total for all of the iterations
  26. //
  27. // Results: roughly 50 µs to launch the thread pool with 1 thread, and
  28. // roughly 30 additional µs for each additional thread. Each iteration
  29. // of the inner loop takes about 4 to 5 ns. This works out to around
  30. // level 19 where it starts being worth it to multithread, and you
  31. // should use at most sqrt(2^{level}/6000) threads.
  32. static void mtgen_timetest_1(nbits_t level, int nthreads,
  33. size_t num_iters, const DPFnode *curlevel,
  34. DPFnode *nextlevel, size_t &aes_ops)
  35. {
  36. if (num_iters == 0) {
  37. num_iters = 1;
  38. }
  39. size_t prev_aes_ops = aes_ops;
  40. DPFnode L = _mm_setzero_si128();
  41. DPFnode R = _mm_setzero_si128();
  42. // The tweak causes us to compute something slightly different every
  43. // iteration of the loop, so that the compiler doesn't notice we're
  44. // doing the same thing num_iters times and optimize it away
  45. DPFnode tweak = _mm_setzero_si128();
  46. auto start = boost::chrono::steady_clock::now();
  47. for(size_t iter=0;iter<num_iters;++iter) {
  48. tweak += 1; // This actually adds the 128-bit value whose high
  49. // and low 64-bits words are both 1, but that's
  50. // fine.
  51. size_t curlevel_size = size_t(1)<<level;
  52. if (nthreads == 0) {
  53. size_t laes_ops = 0;
  54. for(size_t i=0;i<curlevel_size;++i) {
  55. DPFnode lchild, rchild;
  56. prgboth(lchild, rchild, curlevel[i]^tweak, laes_ops);
  57. L = (L ^ lchild);
  58. R = (R ^ rchild);
  59. nextlevel[2*i] = lchild;
  60. nextlevel[2*i+1] = rchild;
  61. }
  62. aes_ops += laes_ops;
  63. } else {
  64. DPFnode tL[nthreads];
  65. DPFnode tR[nthreads];
  66. size_t taes_ops[nthreads];
  67. size_t threadstart = 0;
  68. size_t threadchunk = curlevel_size / nthreads;
  69. size_t threadextra = curlevel_size % nthreads;
  70. boost::asio::thread_pool pool(nthreads);
  71. for (int t=0;t<nthreads;++t) {
  72. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  73. size_t threadend = threadstart + threadsize;
  74. boost::asio::post(pool,
  75. [t, &tL, &tR, &taes_ops, threadstart, threadend,
  76. &curlevel, &nextlevel, tweak] {
  77. DPFnode L = _mm_setzero_si128();
  78. DPFnode R = _mm_setzero_si128();
  79. size_t aes_ops = 0;
  80. for(size_t i=threadstart;i<threadend;++i) {
  81. DPFnode lchild, rchild;
  82. prgboth(lchild, rchild, curlevel[i]^tweak, aes_ops);
  83. L = (L ^ lchild);
  84. R = (R ^ rchild);
  85. nextlevel[2*i] = lchild;
  86. nextlevel[2*i+1] = rchild;
  87. }
  88. tL[t] = L;
  89. tR[t] = R;
  90. taes_ops[t] = aes_ops;
  91. });
  92. threadstart = threadend;
  93. }
  94. pool.join();
  95. for (int t=0;t<nthreads;++t) {
  96. L ^= tL[t];
  97. R ^= tR[t];
  98. aes_ops += taes_ops[t];
  99. }
  100. }
  101. }
  102. auto elapsed =
  103. boost::chrono::steady_clock::now() - start;
  104. std::cout << "timetest_1 " << int(level) << " " << nthreads << " "
  105. << num_iters << " " << boost::chrono::duration_cast
  106. <boost::chrono::milliseconds>(elapsed) << " " <<
  107. (aes_ops-prev_aes_ops) << " AES\n";
  108. dump_node(L);
  109. dump_node(R);
  110. }
  111. #endif
  112. // Construct a DPF with the given (XOR-shared) target location, and
  113. // of the given depth, to be used for random-access memory reads and
  114. // writes. The DPF is construction collaboratively by P0 and P1,
  115. // with the server P2 helping by providing various kinds of
  116. // correlated randomness, such as MultTriples and AndTriples.
  117. //
  118. // This algorithm is based on Appendix C from the Duoram paper, with a
  119. // small optimization noted below.
  120. RDPF::RDPF(MPCTIO &tio, yield_t &yield,
  121. RegXS target, nbits_t depth, bool save_expansion)
  122. {
  123. int player = tio.player();
  124. size_t &aes_ops = tio.aes_ops();
  125. // Choose a random seed
  126. arc4random_buf(&seed, sizeof(seed));
  127. // Ensure the flag bits (the lsb of each node) are different
  128. seed = set_lsb(seed, !!player);
  129. cfbits = 0;
  130. whichhalf = (player == 1);
  131. // The root level is just the seed
  132. nbits_t level = 0;
  133. DPFnode *curlevel = NULL;
  134. DPFnode *nextlevel = new DPFnode[1];
  135. nextlevel[0] = seed;
  136. // Construct each intermediate level
  137. while(level < depth) {
  138. if (player < 2) {
  139. delete[] curlevel;
  140. curlevel = nextlevel;
  141. if (save_expansion && level == depth-1) {
  142. expansion.resize(1<<depth);
  143. nextlevel = expansion.data();
  144. } else {
  145. nextlevel = new DPFnode[1<<(level+1)];
  146. }
  147. }
  148. // Invariant: curlevel has 2^level elements; nextlevel has
  149. // 2^{level+1} elements
  150. // The bit-shared choice bit is bit (depth-level-1) of the
  151. // XOR-shared target index
  152. RegBS bs_choice = target.bit(depth-level-1);
  153. size_t curlevel_size = (size_t(1)<<level);
  154. DPFnode L = _mm_setzero_si128();
  155. DPFnode R = _mm_setzero_si128();
  156. // The server doesn't need to do this computation, but it does
  157. // need to execute mpc_reconstruct_choice so that it sends
  158. // the AndTriples at the appropriate time.
  159. if (player < 2) {
  160. #ifdef RDPF_MTGEN_TIMING_1
  161. if (player == 0) {
  162. mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
  163. nextlevel, aes_ops);
  164. size_t niters = 2048;
  165. if (level > 8) niters = (1<<20)>>level;
  166. for(int t=1;t<=8;++t) {
  167. mtgen_timetest_1(level, t, niters, curlevel,
  168. nextlevel, aes_ops);
  169. }
  170. mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
  171. nextlevel, aes_ops);
  172. }
  173. #endif
  174. // Using the timing results gathered above, decide whether
  175. // to multithread, and if so, how many threads to use.
  176. // tio.cpu_nthreads() is the maximum number we have
  177. // available.
  178. int max_nthreads = tio.cpu_nthreads();
  179. if (max_nthreads == 1 || level < 19) {
  180. // No threading
  181. size_t laes_ops = 0;
  182. for(size_t i=0;i<curlevel_size;++i) {
  183. DPFnode lchild, rchild;
  184. prgboth(lchild, rchild, curlevel[i], laes_ops);
  185. L = (L ^ lchild);
  186. R = (R ^ rchild);
  187. nextlevel[2*i] = lchild;
  188. nextlevel[2*i+1] = rchild;
  189. }
  190. aes_ops += laes_ops;
  191. } else {
  192. size_t curlevel_size = size_t(1)<<level;
  193. int nthreads =
  194. int(ceil(sqrt(double(curlevel_size/6000))));
  195. if (nthreads > max_nthreads) {
  196. nthreads = max_nthreads;
  197. }
  198. DPFnode tL[nthreads];
  199. DPFnode tR[nthreads];
  200. size_t taes_ops[nthreads];
  201. size_t threadstart = 0;
  202. size_t threadchunk = curlevel_size / nthreads;
  203. size_t threadextra = curlevel_size % nthreads;
  204. boost::asio::thread_pool pool(nthreads);
  205. for (int t=0;t<nthreads;++t) {
  206. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  207. size_t threadend = threadstart + threadsize;
  208. boost::asio::post(pool,
  209. [t, &tL, &tR, &taes_ops, threadstart, threadend,
  210. &curlevel, &nextlevel] {
  211. DPFnode L = _mm_setzero_si128();
  212. DPFnode R = _mm_setzero_si128();
  213. size_t aes_ops = 0;
  214. for(size_t i=threadstart;i<threadend;++i) {
  215. DPFnode lchild, rchild;
  216. prgboth(lchild, rchild, curlevel[i], aes_ops);
  217. L = (L ^ lchild);
  218. R = (R ^ rchild);
  219. nextlevel[2*i] = lchild;
  220. nextlevel[2*i+1] = rchild;
  221. }
  222. tL[t] = L;
  223. tR[t] = R;
  224. taes_ops[t] = aes_ops;
  225. });
  226. threadstart = threadend;
  227. }
  228. pool.join();
  229. for (int t=0;t<nthreads;++t) {
  230. L ^= tL[t];
  231. R ^= tR[t];
  232. aes_ops += taes_ops[t];
  233. }
  234. }
  235. }
  236. // If we're going left (bs_choice = 0), we want the correction
  237. // word to be the XOR of our right side and our peer's right
  238. // side; if bs_choice = 1, it should be the XOR or our left side
  239. // and our peer's left side.
  240. // We also have to ensure that the flag bits (the lsb) of the
  241. // side that will end up the same be of course the same, but
  242. // also that the flag bits (the lsb) of the side that will end
  243. // up different _must_ be different. That is, it's not enough
  244. // for the nodes of the child selected by choice to be different
  245. // as 128-bit values; they also have to be different in their
  246. // lsb.
  247. // This is where we make a small optimization over Appendix C of
  248. // the Duoram paper: instead of keeping separate correction flag
  249. // bits for the left and right children, we observe that the low
  250. // bit of the overall correction word effectively serves as one
  251. // of those bits, so we just need to store one extra bit per
  252. // level, not two. (We arbitrarily choose the one for the right
  253. // child.)
  254. // Note that the XOR of our left and right child before and
  255. // after applying the correction word won't change, since the
  256. // correction word is applied to either both children or
  257. // neither, depending on the value of the parent's flag. So in
  258. // particular, the XOR of the flag bits won't change, and if our
  259. // children's flag's XOR equals our peer's children's flag's
  260. // XOR, then we won't have different flag bits even for the
  261. // children that have different 128-bit values.
  262. // So we compute our_parity = lsb(L^R)^player, and we XOR that
  263. // into the R value in the correction word computation. At the
  264. // same time, we exchange these parity values to compute the
  265. // combined parity, which we store in the DPF. Then when the
  266. // DPF is evaluated, if the parent's flag is set, not only apply
  267. // the correction work to both children, but also apply the
  268. // (combined) parity bit to just the right child. Then for
  269. // unequal nodes (where the flag bit is different), exactly one
  270. // of the four children (two for P0 and two for P1) will have
  271. // the parity bit applied, which will set the XOR of the lsb of
  272. // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
  273. // = 1 because everything cancels out except player (for which
  274. // one player is 0 and the other is 1).
  275. bool our_parity_bit = get_lsb(L ^ R) ^ !!player;
  276. DPFnode our_parity = lsb128_mask[our_parity_bit];
  277. DPFnode CW;
  278. bool peer_parity_bit;
  279. // Exchange the parities and do mpc_reconstruct_choice at the
  280. // same time (bundled into the same rounds)
  281. run_coroutines(yield,
  282. [this, &tio, &our_parity_bit, &peer_parity_bit](yield_t &yield) {
  283. tio.queue_peer(&our_parity_bit, 1);
  284. yield();
  285. uint8_t peer_parity_byte;
  286. tio.recv_peer(&peer_parity_byte, 1);
  287. peer_parity_bit = peer_parity_byte & 1;
  288. },
  289. [this, &tio, &CW, &L, &R, &bs_choice, &our_parity](yield_t &yield) {
  290. mpc_reconstruct_choice(tio, yield, CW, bs_choice,
  291. (R ^ our_parity), L);
  292. });
  293. bool parity_bit = our_parity_bit ^ peer_parity_bit;
  294. cfbits |= (value_t(parity_bit)<<level);
  295. DPFnode CWR = CW ^ lsb128_mask[parity_bit];
  296. if (player < 2) {
  297. // The timing of each iteration of the inner loop is
  298. // comparable to the above, so just use the same
  299. // computations. All of this could be tuned, of course.
  300. if (level < depth-1) {
  301. // Using the timing results gathered above, decide whether
  302. // to multithread, and if so, how many threads to use.
  303. // tio.cpu_nthreads() is the maximum number we have
  304. // available.
  305. int max_nthreads = tio.cpu_nthreads();
  306. if (max_nthreads == 1 || level < 19) {
  307. // No threading
  308. for(size_t i=0;i<curlevel_size;++i) {
  309. bool flag = get_lsb(curlevel[i]);
  310. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  311. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  312. }
  313. } else {
  314. int nthreads =
  315. int(ceil(sqrt(double(curlevel_size/6000))));
  316. if (nthreads > max_nthreads) {
  317. nthreads = max_nthreads;
  318. }
  319. size_t threadstart = 0;
  320. size_t threadchunk = curlevel_size / nthreads;
  321. size_t threadextra = curlevel_size % nthreads;
  322. boost::asio::thread_pool pool(nthreads);
  323. for (int t=0;t<nthreads;++t) {
  324. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  325. size_t threadend = threadstart + threadsize;
  326. boost::asio::post(pool, [CW, CWR, threadstart, threadend,
  327. &curlevel, &nextlevel] {
  328. for(size_t i=threadstart;i<threadend;++i) {
  329. bool flag = get_lsb(curlevel[i]);
  330. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  331. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  332. }
  333. });
  334. threadstart = threadend;
  335. }
  336. pool.join();
  337. }
  338. } else {
  339. // Recall there are four potentially useful vectors that
  340. // can come out of a DPF:
  341. // - (single-bit) bitwise unit vector
  342. // - additive-shared unit vector
  343. // - XOR-shared scaled unit vector
  344. // - additive-shared scaled unit vector
  345. //
  346. // (No single DPF should be used for both of the first
  347. // two or both of the last two, though, since they're
  348. // correlated; you _can_ use one of the first two and
  349. // one of the last two.)
  350. //
  351. // For each 128-bit leaf, the low bit is the flag bit,
  352. // and we're guaranteed that the flag bits (and indeed
  353. // the whole 128-bit value) for P0 and P1 are the same
  354. // for every leaf except the target, and that the flag
  355. // bits definitely differ for the target (and the other
  356. // 127 bits are independently random on each side).
  357. //
  358. // We divide the 128-bit leaf into a low 64-bit word and
  359. // a high 64-bit word. We use the low word for the unit
  360. // vector and the high word for the scaled vector; this
  361. // choice is not arbitrary: the flag bit in the low word
  362. // means that the sum of all the low words (with P1's
  363. // low words negated) across both P0 and P1 is
  364. // definitely odd, so we can compute that sum's inverse
  365. // mod 2^64, and store it now during precomputation. At
  366. // evaluation time for the additive-shared unit vector,
  367. // we will output this global inverse times the low word
  368. // of each leaf, which will make the sum of all of those
  369. // values 1. (This technique replaces the protocol in
  370. // Appendix D of the Duoram paper.)
  371. //
  372. // For the scaled vector, we just have to compute shares
  373. // of what the scaled vector is a sharing _of_, but
  374. // that's just XORing or adding all of each party's
  375. // local high words; no communication needed.
  376. value_t low_sum = 0;
  377. value_t high_sum = 0;
  378. value_t high_xor = 0;
  379. // Using the timing results gathered above, decide whether
  380. // to multithread, and if so, how many threads to use.
  381. // tio.cpu_nthreads() is the maximum number we have
  382. // available.
  383. int max_nthreads = tio.cpu_nthreads();
  384. if (max_nthreads == 1 || level < 19) {
  385. // No threading
  386. for(size_t i=0;i<curlevel_size;++i) {
  387. bool flag = get_lsb(curlevel[i]);
  388. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  389. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  390. if (save_expansion) {
  391. nextlevel[2*i] = leftchild;
  392. nextlevel[2*i+1] = rightchild;
  393. }
  394. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  395. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  396. value_t lefthigh =
  397. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  398. value_t righthigh =
  399. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  400. low_sum += (leftlow + rightlow);
  401. high_sum += (lefthigh + righthigh);
  402. high_xor ^= (lefthigh ^ righthigh);
  403. }
  404. } else {
  405. int nthreads =
  406. int(ceil(sqrt(double(curlevel_size/6000))));
  407. if (nthreads > max_nthreads) {
  408. nthreads = max_nthreads;
  409. }
  410. value_t tlow_sum[nthreads];
  411. value_t thigh_sum[nthreads];
  412. value_t thigh_xor[nthreads];
  413. size_t threadstart = 0;
  414. size_t threadchunk = curlevel_size / nthreads;
  415. size_t threadextra = curlevel_size % nthreads;
  416. boost::asio::thread_pool pool(nthreads);
  417. for (int t=0;t<nthreads;++t) {
  418. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  419. size_t threadend = threadstart + threadsize;
  420. boost::asio::post(pool,
  421. [t, &tlow_sum, &thigh_sum, &thigh_xor, threadstart, threadend,
  422. &curlevel, &nextlevel, CW, CWR, save_expansion] {
  423. value_t low_sum = 0;
  424. value_t high_sum = 0;
  425. value_t high_xor = 0;
  426. for(size_t i=threadstart;i<threadend;++i) {
  427. bool flag = get_lsb(curlevel[i]);
  428. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  429. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  430. if (save_expansion) {
  431. nextlevel[2*i] = leftchild;
  432. nextlevel[2*i+1] = rightchild;
  433. }
  434. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  435. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  436. value_t lefthigh =
  437. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  438. value_t righthigh =
  439. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  440. low_sum += (leftlow + rightlow);
  441. high_sum += (lefthigh + righthigh);
  442. high_xor ^= (lefthigh ^ righthigh);
  443. }
  444. tlow_sum[t] = low_sum;
  445. thigh_sum[t] = high_sum;
  446. thigh_xor[t] = high_xor;
  447. });
  448. threadstart = threadend;
  449. }
  450. pool.join();
  451. for (int t=0;t<nthreads;++t) {
  452. low_sum += tlow_sum[t];
  453. high_sum += thigh_sum[t];
  454. high_xor ^= thigh_xor[t];
  455. }
  456. }
  457. if (player == 1) {
  458. low_sum = -low_sum;
  459. high_sum = -high_sum;
  460. }
  461. scaled_sum.ashare = high_sum;
  462. scaled_xor.xshare = high_xor;
  463. // Exchange low_sum and add them up
  464. tio.queue_peer(&low_sum, sizeof(low_sum));
  465. yield();
  466. value_t peer_low_sum;
  467. tio.recv_peer(&peer_low_sum, sizeof(peer_low_sum));
  468. low_sum += peer_low_sum;
  469. // The low_sum had better be odd
  470. assert(low_sum & 1);
  471. unit_sum_inverse = inverse_value_t(low_sum);
  472. }
  473. cw.push_back(CW);
  474. } else if (level == depth-1) {
  475. yield();
  476. }
  477. ++level;
  478. }
  479. delete[] curlevel;
  480. if (!save_expansion || player == 2) {
  481. delete[] nextlevel;
  482. }
  483. }
  484. // Get the leaf node for the given input
  485. DPFnode RDPF::leaf(address_t input, size_t &aes_ops) const
  486. {
  487. // If we have a precomputed expansion, just use it
  488. if (expansion.size()) {
  489. return expansion[input];
  490. }
  491. nbits_t totdepth = depth();
  492. DPFnode node = seed;
  493. for (nbits_t d=0;d<totdepth;++d) {
  494. bit_t dir = !!(input & (address_t(1)<<(totdepth-d-1)));
  495. node = descend(node, d, dir, aes_ops);
  496. }
  497. return node;
  498. }
  499. // Expand the DPF if it's not already expanded
  500. //
  501. // This routine is slightly more efficient than repeatedly calling
  502. // StreamEval::next(), but it uses a lot more memory.
  503. void RDPF::expand(size_t &aes_ops)
  504. {
  505. nbits_t depth = this->depth();
  506. size_t num_leaves = size_t(1)<<depth;
  507. if (expansion.size() == num_leaves) return;
  508. expansion.resize(num_leaves);
  509. address_t index = 0;
  510. address_t lastindex = 0;
  511. DPFnode *path = new DPFnode[depth];
  512. path[0] = seed;
  513. for (nbits_t i=1;i<depth;++i) {
  514. path[i] = descend(path[i-1], i-1, 0, aes_ops);
  515. }
  516. expansion[index++] = descend(path[depth-1], depth-1, 0, aes_ops);
  517. expansion[index++] = descend(path[depth-1], depth-1, 1, aes_ops);
  518. while(index < num_leaves) {
  519. // Invariant: lastindex and index will both be even, and
  520. // index=lastindex+2
  521. uint64_t index_xor = index ^ lastindex;
  522. nbits_t how_many_1_bits = __builtin_popcount(index_xor);
  523. // If lastindex -> index goes for example from (in binary)
  524. // 010010110 -> 010011000, then index_xor will be
  525. // 000001110 and how_many_1_bits will be 3.
  526. // That indicates that path[depth-3] was a left child, and now
  527. // we need to change it to a right child by descending right
  528. // from path[depth-4], and then filling the path after that with
  529. // left children.
  530. path[depth-how_many_1_bits] =
  531. descend(path[depth-how_many_1_bits-1],
  532. depth-how_many_1_bits-1, 1, aes_ops);
  533. for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
  534. path[i+1] = descend(path[i], i, 0, aes_ops);
  535. }
  536. lastindex = index;
  537. expansion[index++] = descend(path[depth-1], depth-1, 0, aes_ops);
  538. expansion[index++] = descend(path[depth-1], depth-1, 1, aes_ops);
  539. }
  540. delete[] path;
  541. }
  542. // Construct three RDPFs of the given depth all with the same randomly
  543. // generated target index.
  544. RDPFTriple::RDPFTriple(MPCTIO &tio, yield_t &yield,
  545. nbits_t depth, bool save_expansion)
  546. {
  547. // Pick a random XOR share of the target
  548. xs_target.randomize(depth);
  549. // Now create three RDPFs with that target, and also convert the XOR
  550. // shares of the target to additive shares
  551. std::vector<coro_t> coroutines;
  552. for (int i=0;i<3;++i) {
  553. coroutines.emplace_back(
  554. [this, &tio, depth, i, save_expansion](yield_t &yield) {
  555. dpf[i] = RDPF(tio, yield, xs_target, depth,
  556. save_expansion);
  557. });
  558. }
  559. coroutines.emplace_back(
  560. [this, &tio, depth](yield_t &yield) {
  561. mpc_xs_to_as(tio, yield, as_target, xs_target, depth, false);
  562. });
  563. run_coroutines(yield, coroutines);
  564. }
  565. RDPFTriple::node RDPFTriple::descend(const RDPFTriple::node &parent,
  566. nbits_t parentdepth, bit_t whichchild,
  567. size_t &aes_ops) const
  568. {
  569. auto [P0, P1, P2] = parent;
  570. DPFnode C0, C1, C2;
  571. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  572. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  573. C2 = dpf[2].descend(P2, parentdepth, whichchild, aes_ops);
  574. return std::make_tuple(C0,C1,C2);
  575. }
  576. RDPFPair::node RDPFPair::descend(const RDPFPair::node &parent,
  577. nbits_t parentdepth, bit_t whichchild,
  578. size_t &aes_ops) const
  579. {
  580. auto [P0, P1] = parent;
  581. DPFnode C0, C1;
  582. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  583. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  584. return std::make_tuple(C0,C1);
  585. }