rdpf.tcc 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. // Templated method implementations for rdpf.hpp
  2. #include "mpcops.hpp"
  3. // Compute the multiplicative inverse of x mod 2^{VALUE_BITS}
  4. // This is the same as computing x to the power of
  5. // 2^{VALUE_BITS-1}-1.
  6. static value_t inverse_value_t(value_t x)
  7. {
  8. int expon = 1;
  9. value_t xe = x;
  10. // Invariant: xe = x^(2^expon - 1) mod 2^{VALUE_BITS}
  11. // Goal: compute x^(2^{VALUE_BITS-1} - 1)
  12. while (expon < VALUE_BITS-1) {
  13. xe = xe * xe * x;
  14. ++expon;
  15. }
  16. return xe;
  17. }
  18. // Create a StreamEval object that will start its output at index start.
  19. // It will wrap around to 0 when it hits 2^depth. If use_expansion
  20. // is true, then if the DPF has been expanded, just output values
  21. // from that. If use_expansion=false or if the DPF has not been
  22. // expanded, compute the values on the fly. If xor_offset is non-zero,
  23. // then the outputs are actually
  24. // DPF(start XOR xor_offset)
  25. // DPF((start+1) XOR xor_offset)
  26. // DPF((start+2) XOR xor_offset)
  27. // etc.
  28. template <typename T>
  29. StreamEval<T>::StreamEval(const T &rdpf, address_t start,
  30. address_t xor_offset, size_t &aes_ops,
  31. bool use_expansion) : rdpf(rdpf), aes_ops(aes_ops),
  32. use_expansion(use_expansion), counter_xor_offset(xor_offset)
  33. {
  34. depth = rdpf.depth();
  35. // Prevent overflow of 1<<depth
  36. if (depth < ADDRESS_MAX_BITS) {
  37. indexmask = (address_t(1)<<depth)-1;
  38. } else {
  39. indexmask = ~0;
  40. }
  41. start &= indexmask;
  42. counter_xor_offset &= indexmask;
  43. // Record that we haven't actually output the leaf for index start
  44. // itself yet
  45. nextindex = start;
  46. if (use_expansion && rdpf.has_expansion()) {
  47. // We just need to keep the counter, not compute anything
  48. return;
  49. }
  50. path.resize(depth);
  51. pathindex = start;
  52. path[0] = rdpf.get_seed();
  53. for (nbits_t i=1;i<depth;++i) {
  54. bool dir = !!(pathindex & (address_t(1)<<(depth-i)));
  55. bool xor_offset_bit =
  56. !!(counter_xor_offset & (address_t(1)<<(depth-i)));
  57. path[i] = rdpf.descend(path[i-1], i-1,
  58. dir ^ xor_offset_bit, aes_ops);
  59. }
  60. }
  61. template <typename T>
  62. typename T::LeafNode StreamEval<T>::next()
  63. {
  64. if (use_expansion && rdpf.has_expansion()) {
  65. // Just use the precomputed values
  66. typename T::LeafNode leaf =
  67. rdpf.get_expansion(nextindex ^ counter_xor_offset);
  68. nextindex = (nextindex + 1) & indexmask;
  69. return leaf;
  70. }
  71. // Invariant: in the first call to next(), nextindex = pathindex.
  72. // Otherwise, nextindex = pathindex+1.
  73. // Get the XOR of nextindex and pathindex, and strip the low bit.
  74. // If nextindex and pathindex are equal, or pathindex is even
  75. // and nextindex is the consecutive odd number, index_xor will be 0,
  76. // indicating that we don't have to update the path, but just
  77. // compute the appropriate leaf given by the low bit of nextindex.
  78. //
  79. // Otherwise, say for example pathindex is 010010111 and nextindex
  80. // is 010011000. Then their XOR is 000001111, and stripping the low
  81. // bit yields 000001110, so how_many_1_bits will be 3.
  82. // That indicates (typically) that path[depth-3] was a left child,
  83. // and now we need to change it to a right child by descending right
  84. // from path[depth-4], and then filling the path after that with
  85. // left children.
  86. //
  87. // When we wrap around, however, index_xor will be 111111110 (after
  88. // we strip the low bit), and how_many_1_bits will be depth-1, but
  89. // the new top child (of the root seed) we have to compute will be a
  90. // left, not a right, child.
  91. uint64_t index_xor = (nextindex ^ pathindex) & ~1;
  92. nbits_t how_many_1_bits = __builtin_popcount(index_xor);
  93. if (how_many_1_bits > 0) {
  94. // This will almost always be 1, unless we've just wrapped
  95. // around from the right subtree back to the left, in which case
  96. // it will be 0.
  97. bool top_changed_bit =
  98. !!(nextindex & (address_t(1) << how_many_1_bits));
  99. bool xor_offset_bit =
  100. !!(counter_xor_offset & (address_t(1) << how_many_1_bits));
  101. path[depth-how_many_1_bits] =
  102. rdpf.descend(path[depth-how_many_1_bits-1],
  103. depth-how_many_1_bits-1,
  104. top_changed_bit ^ xor_offset_bit, aes_ops);
  105. for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
  106. bool xor_offset_bit =
  107. !!(counter_xor_offset & (address_t(1) << (depth-i-1)));
  108. path[i+1] = rdpf.descend(path[i], i, xor_offset_bit, aes_ops);
  109. }
  110. }
  111. bool xor_offset_bit = counter_xor_offset & 1;
  112. typename T::LeafNode leaf = rdpf.descend_to_leaf(path[depth-1], depth-1,
  113. (nextindex & 1) ^ xor_offset_bit, aes_ops);
  114. pathindex = nextindex;
  115. nextindex = (nextindex + 1) & indexmask;
  116. return leaf;
  117. }
  118. // Run the parallel evaluator. The type V is the type of the
  119. // accumulator; init should be the "zero" value of the accumulator.
  120. // The type W (process) is a lambda type with the signature
  121. // (int, address_t, const T::node &) -> V
  122. // which will be called like this for each i from 0 to num_evals-1,
  123. // across num_thread threads:
  124. // value_i = process(t, i, DPF((start+i) XOR xor_offset))
  125. // t is the thread number (0 <= t < num_threads).
  126. // The resulting num_evals values will be combined using V's +=
  127. // operator, first accumulating the values within each thread
  128. // (starting with the init value), and then accumulating the totals
  129. // from each thread together (again starting with the init value):
  130. //
  131. // total = init
  132. // for each thread t:
  133. // accum_t = init
  134. // for each accum_i generated by thread t:
  135. // accum_t += value_i
  136. // total += accum_t
  137. template <typename T> template <typename V, typename W>
  138. inline V ParallelEval<T>::reduce(V init, W process)
  139. {
  140. size_t thread_aes_ops[num_threads];
  141. V accums[num_threads];
  142. boost::asio::thread_pool pool(num_threads);
  143. address_t threadstart = start;
  144. address_t threadchunk = num_evals / num_threads;
  145. address_t threadextra = num_evals % num_threads;
  146. nbits_t depth = rdpf.depth();
  147. address_t indexmask = (depth < ADDRESS_MAX_BITS ?
  148. ((address_t(1)<<depth)-1) : ~0);
  149. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  150. address_t threadsize = threadchunk + (address_t(thread_num) < threadextra);
  151. boost::asio::post(pool,
  152. [this, &init, &thread_aes_ops, &accums, &process,
  153. thread_num, threadstart, threadsize, indexmask] {
  154. size_t local_aes_ops = 0;
  155. auto ev = StreamEval(rdpf, (start+threadstart)&indexmask,
  156. xor_offset, local_aes_ops);
  157. V accum = init;
  158. for (address_t x=0;x<threadsize;++x) {
  159. typename T::LeafNode leaf = ev.next();
  160. accum += process(thread_num,
  161. (threadstart+x)&indexmask, leaf);
  162. }
  163. accums[thread_num] = accum;
  164. thread_aes_ops[thread_num] = local_aes_ops;
  165. });
  166. threadstart = (threadstart + threadsize) & indexmask;
  167. }
  168. pool.join();
  169. V total = init;
  170. for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
  171. total += accums[thread_num];
  172. aes_ops += thread_aes_ops[thread_num];
  173. }
  174. return total;
  175. }
  176. // Descend from a node at depth parentdepth to one of its leaf children
  177. // whichchild = 0: left child
  178. // whichchild = 1: right child
  179. //
  180. // Cost: 1 AES operation
  181. template <nbits_t WIDTH>
  182. inline typename RDPF<WIDTH>::LeafNode RDPF<WIDTH>::descend_to_leaf(
  183. const DPFnode &parent, nbits_t parentdepth, bit_t whichchild,
  184. size_t &aes_ops) const
  185. {
  186. typename RDPF<WIDTH>::LeafNode prgout;
  187. bool flag = get_lsb(parent);
  188. prgleaf(prgout, parent, whichchild, aes_ops);
  189. if (flag) {
  190. LeafNode CW = li[0].leaf_cw;
  191. LeafNode CWR = CW;
  192. for (nbits_t j=0;j<LWIDTH;++j) {
  193. bit_t cfbit = !!(leaf_cfbits[j] &
  194. (value_t(1)<<(maxdepth-parentdepth)));
  195. CWR[j] ^= lsb128_mask[cfbit];
  196. }
  197. prgout ^= (whichchild ? CWR : CW);
  198. }
  199. return prgout;
  200. }
  201. // I/O for RDPFs
  202. template <typename T, nbits_t WIDTH>
  203. T& operator>>(T &is, RDPF<WIDTH> &rdpf)
  204. {
  205. is.read((char *)&rdpf.seed, sizeof(rdpf.seed));
  206. rdpf.whichhalf = get_lsb(rdpf.seed);
  207. uint8_t depth;
  208. // Add 64 to depth to indicate an expanded RDPF
  209. is.read((char *)&depth, sizeof(depth));
  210. bool read_expanded = false;
  211. if (depth > 64) {
  212. read_expanded = true;
  213. depth -= 64;
  214. }
  215. assert(depth <= ADDRESS_MAX_BITS);
  216. rdpf.cw.clear();
  217. for (uint8_t i=0; i<depth; ++i) {
  218. DPFnode cw;
  219. is.read((char *)&cw, sizeof(cw));
  220. rdpf.cw.push_back(cw);
  221. }
  222. if (read_expanded) {
  223. rdpf.expansion.resize(1<<depth);
  224. is.read((char *)rdpf.expansion.data(),
  225. sizeof(rdpf.expansion[0])<<depth);
  226. }
  227. value_t cfbits = 0;
  228. is.read((char *)&cfbits, BITBYTES(depth));
  229. rdpf.cfbits = cfbits;
  230. rdpf.li.resize(1);
  231. is.read((char *)&rdpf.li[0].unit_sum_inverse,
  232. sizeof(rdpf.li[0].unit_sum_inverse));
  233. is.read((char *)&rdpf.li[0].scaled_sum,
  234. sizeof(rdpf.li[0].scaled_sum));
  235. is.read((char *)&rdpf.li[0].scaled_xor,
  236. sizeof(rdpf.li[0].scaled_xor));
  237. return is;
  238. }
  239. // Write the DPF to the output stream. If expanded=true, then include
  240. // the expansion _if_ the DPF is itself already expanded. You can use
  241. // this to write DPFs to files.
  242. template <typename T, nbits_t WIDTH>
  243. T& write_maybe_expanded(T &os, const RDPF<WIDTH> &rdpf,
  244. bool expanded = true)
  245. {
  246. os.write((const char *)&rdpf.seed, sizeof(rdpf.seed));
  247. uint8_t depth = rdpf.cw.size();
  248. assert(depth <= ADDRESS_MAX_BITS);
  249. // If we're writing an expansion, add 64 to depth
  250. uint8_t expanded_depth = depth;
  251. bool write_expansion = false;
  252. if (expanded && rdpf.expansion.size() == (size_t(1)<<depth)) {
  253. write_expansion = true;
  254. expanded_depth += 64;
  255. }
  256. os.write((const char *)&expanded_depth, sizeof(expanded_depth));
  257. for (uint8_t i=0; i<depth; ++i) {
  258. os.write((const char *)&rdpf.cw[i], sizeof(rdpf.cw[i]));
  259. }
  260. if (write_expansion) {
  261. os.write((const char *)rdpf.expansion.data(),
  262. sizeof(rdpf.expansion[0])<<depth);
  263. }
  264. os.write((const char *)&rdpf.cfbits, BITBYTES(depth));
  265. os.write((const char *)&rdpf.li[0].unit_sum_inverse,
  266. sizeof(rdpf.li[0].unit_sum_inverse));
  267. os.write((const char *)&rdpf.li[0].scaled_sum,
  268. sizeof(rdpf.li[0].scaled_sum));
  269. os.write((const char *)&rdpf.li[0].scaled_xor,
  270. sizeof(rdpf.li[0].scaled_xor));
  271. return os;
  272. }
  273. // The ordinary << version never writes the expansion, since this is
  274. // what we use to send DPFs over the network.
  275. template <typename T, nbits_t WIDTH>
  276. T& operator<<(T &os, const RDPF<WIDTH> &rdpf)
  277. {
  278. return write_maybe_expanded(os, rdpf, false);
  279. }
  280. // I/O for RDPF Triples
  281. // We never write RDPFTriples over the network, so always write
  282. // the DPF expansions if they're available.
  283. template <typename T, nbits_t WIDTH>
  284. T& operator<<(T &os, const RDPFTriple<WIDTH> &rdpftrip)
  285. {
  286. write_maybe_expanded(os, rdpftrip.dpf[0], true);
  287. write_maybe_expanded(os, rdpftrip.dpf[1], true);
  288. write_maybe_expanded(os, rdpftrip.dpf[2], true);
  289. nbits_t depth = rdpftrip.dpf[0].depth();
  290. os.write((const char *)&rdpftrip.as_target.ashare, BITBYTES(depth));
  291. os.write((const char *)&rdpftrip.xs_target.xshare, BITBYTES(depth));
  292. return os;
  293. }
  294. template <typename T, nbits_t WIDTH>
  295. T& operator>>(T &is, RDPFTriple<WIDTH> &rdpftrip)
  296. {
  297. is >> rdpftrip.dpf[0] >> rdpftrip.dpf[1] >> rdpftrip.dpf[2];
  298. nbits_t depth = rdpftrip.dpf[0].depth();
  299. rdpftrip.as_target.ashare = 0;
  300. is.read((char *)&rdpftrip.as_target.ashare, BITBYTES(depth));
  301. rdpftrip.xs_target.xshare = 0;
  302. is.read((char *)&rdpftrip.xs_target.xshare, BITBYTES(depth));
  303. return is;
  304. }
  305. // I/O for RDPF Pairs
  306. // We never write RDPFPairs over the network, so always write
  307. // the DPF expansions if they're available.
  308. template <typename T, nbits_t WIDTH>
  309. T& operator<<(T &os, const RDPFPair<WIDTH> &rdpfpair)
  310. {
  311. write_maybe_expanded(os, rdpfpair.dpf[0], true);
  312. write_maybe_expanded(os, rdpfpair.dpf[1], true);
  313. return os;
  314. }
  315. template <typename T, nbits_t WIDTH>
  316. T& operator>>(T &is, RDPFPair<WIDTH> &rdpfpair)
  317. {
  318. is >> rdpfpair.dpf[0] >> rdpfpair.dpf[1];
  319. return is;
  320. }
  321. // Construct a DPF with the given (XOR-shared) target location, and
  322. // of the given depth, to be used for random-access memory reads and
  323. // writes. The DPF is construction collaboratively by P0 and P1,
  324. // with the server P2 helping by providing various kinds of
  325. // correlated randomness, such as MultTriples and AndTriples.
  326. //
  327. // This algorithm is based on Appendix C from the Duoram paper, with a
  328. // small optimization noted below.
  329. template <nbits_t WIDTH>
  330. RDPF<WIDTH>::RDPF(MPCTIO &tio, yield_t &yield,
  331. RegXS target, nbits_t depth, bool save_expansion)
  332. {
  333. int player = tio.player();
  334. size_t &aes_ops = tio.aes_ops();
  335. // Choose a random seed
  336. arc4random_buf(&seed, sizeof(seed));
  337. // Ensure the flag bits (the lsb of each node) are different
  338. seed = set_lsb(seed, !!player);
  339. cfbits = 0;
  340. whichhalf = (player == 1);
  341. maxdepth = depth;
  342. curdepth = depth;
  343. // The root level is just the seed
  344. nbits_t level = 0;
  345. DPFnode *curlevel = NULL;
  346. DPFnode *nextlevel = new DPFnode[1];
  347. nextlevel[0] = seed;
  348. li.resize(1);
  349. // Construct each intermediate level
  350. while(level < depth) {
  351. if (player < 2) {
  352. delete[] curlevel;
  353. curlevel = nextlevel;
  354. if (save_expansion && level == depth-1) {
  355. expansion.resize(1<<depth);
  356. nextlevel = (DPFnode *)expansion.data();
  357. } else {
  358. nextlevel = new DPFnode[1<<(level+1)];
  359. }
  360. }
  361. // Invariant: curlevel has 2^level elements; nextlevel has
  362. // 2^{level+1} elements
  363. // The bit-shared choice bit is bit (depth-level-1) of the
  364. // XOR-shared target index
  365. RegBS bs_choice = target.bit(depth-level-1);
  366. size_t curlevel_size = (size_t(1)<<level);
  367. DPFnode L = _mm_setzero_si128();
  368. DPFnode R = _mm_setzero_si128();
  369. // The server doesn't need to do this computation, but it does
  370. // need to execute mpc_reconstruct_choice so that it sends
  371. // the AndTriples at the appropriate time.
  372. if (player < 2) {
  373. #ifdef RDPF_MTGEN_TIMING_1
  374. if (player == 0) {
  375. mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
  376. nextlevel, aes_ops);
  377. size_t niters = 2048;
  378. if (level > 8) niters = (1<<20)>>level;
  379. for(int t=1;t<=8;++t) {
  380. mtgen_timetest_1(level, t, niters, curlevel,
  381. nextlevel, aes_ops);
  382. }
  383. mtgen_timetest_1(level, 0, (1<<23)>>level, curlevel,
  384. nextlevel, aes_ops);
  385. }
  386. #endif
  387. // Using the timing results gathered above, decide whether
  388. // to multithread, and if so, how many threads to use.
  389. // tio.cpu_nthreads() is the maximum number we have
  390. // available.
  391. int max_nthreads = tio.cpu_nthreads();
  392. if (max_nthreads == 1 || level < 19) {
  393. // No threading
  394. size_t laes_ops = 0;
  395. for(size_t i=0;i<curlevel_size;++i) {
  396. DPFnode lchild, rchild;
  397. prgboth(lchild, rchild, curlevel[i], laes_ops);
  398. L = (L ^ lchild);
  399. R = (R ^ rchild);
  400. nextlevel[2*i] = lchild;
  401. nextlevel[2*i+1] = rchild;
  402. }
  403. aes_ops += laes_ops;
  404. } else {
  405. size_t curlevel_size = size_t(1)<<level;
  406. int nthreads =
  407. int(ceil(sqrt(double(curlevel_size/6000))));
  408. if (nthreads > max_nthreads) {
  409. nthreads = max_nthreads;
  410. }
  411. DPFnode tL[nthreads];
  412. DPFnode tR[nthreads];
  413. size_t taes_ops[nthreads];
  414. size_t threadstart = 0;
  415. size_t threadchunk = curlevel_size / nthreads;
  416. size_t threadextra = curlevel_size % nthreads;
  417. boost::asio::thread_pool pool(nthreads);
  418. for (int t=0;t<nthreads;++t) {
  419. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  420. size_t threadend = threadstart + threadsize;
  421. boost::asio::post(pool,
  422. [t, &tL, &tR, &taes_ops, threadstart, threadend,
  423. &curlevel, &nextlevel] {
  424. DPFnode L = _mm_setzero_si128();
  425. DPFnode R = _mm_setzero_si128();
  426. size_t aes_ops = 0;
  427. for(size_t i=threadstart;i<threadend;++i) {
  428. DPFnode lchild, rchild;
  429. prgboth(lchild, rchild, curlevel[i], aes_ops);
  430. L = (L ^ lchild);
  431. R = (R ^ rchild);
  432. nextlevel[2*i] = lchild;
  433. nextlevel[2*i+1] = rchild;
  434. }
  435. tL[t] = L;
  436. tR[t] = R;
  437. taes_ops[t] = aes_ops;
  438. });
  439. threadstart = threadend;
  440. }
  441. pool.join();
  442. for (int t=0;t<nthreads;++t) {
  443. L ^= tL[t];
  444. R ^= tR[t];
  445. aes_ops += taes_ops[t];
  446. }
  447. }
  448. }
  449. // If we're going left (bs_choice = 0), we want the correction
  450. // word to be the XOR of our right side and our peer's right
  451. // side; if bs_choice = 1, it should be the XOR or our left side
  452. // and our peer's left side.
  453. // We also have to ensure that the flag bits (the lsb) of the
  454. // side that will end up the same be of course the same, but
  455. // also that the flag bits (the lsb) of the side that will end
  456. // up different _must_ be different. That is, it's not enough
  457. // for the nodes of the child selected by choice to be different
  458. // as 128-bit values; they also have to be different in their
  459. // lsb.
  460. // This is where we make a small optimization over Appendix C of
  461. // the Duoram paper: instead of keeping separate correction flag
  462. // bits for the left and right children, we observe that the low
  463. // bit of the overall correction word effectively serves as one
  464. // of those bits, so we just need to store one extra bit per
  465. // level, not two. (We arbitrarily choose the one for the right
  466. // child.)
  467. // Note that the XOR of our left and right child before and
  468. // after applying the correction word won't change, since the
  469. // correction word is applied to either both children or
  470. // neither, depending on the value of the parent's flag. So in
  471. // particular, the XOR of the flag bits won't change, and if our
  472. // children's flag's XOR equals our peer's children's flag's
  473. // XOR, then we won't have different flag bits even for the
  474. // children that have different 128-bit values.
  475. // So we compute our_parity = lsb(L^R)^player, and we XOR that
  476. // into the R value in the correction word computation. At the
  477. // same time, we exchange these parity values to compute the
  478. // combined parity, which we store in the DPF. Then when the
  479. // DPF is evaluated, if the parent's flag is set, not only apply
  480. // the correction work to both children, but also apply the
  481. // (combined) parity bit to just the right child. Then for
  482. // unequal nodes (where the flag bit is different), exactly one
  483. // of the four children (two for P0 and two for P1) will have
  484. // the parity bit applied, which will set the XOR of the lsb of
  485. // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
  486. // = 1 because everything cancels out except player (for which
  487. // one player is 0 and the other is 1).
  488. bool our_parity_bit = get_lsb(L ^ R) ^ !!player;
  489. DPFnode our_parity = lsb128_mask[our_parity_bit];
  490. DPFnode CW;
  491. bool peer_parity_bit;
  492. // Exchange the parities and do mpc_reconstruct_choice at the
  493. // same time (bundled into the same rounds)
  494. run_coroutines(yield,
  495. [this, &tio, &our_parity_bit, &peer_parity_bit](yield_t &yield) {
  496. tio.queue_peer(&our_parity_bit, 1);
  497. yield();
  498. uint8_t peer_parity_byte;
  499. tio.recv_peer(&peer_parity_byte, 1);
  500. peer_parity_bit = peer_parity_byte & 1;
  501. },
  502. [this, &tio, &CW, &L, &R, &bs_choice, &our_parity](yield_t &yield) {
  503. mpc_reconstruct_choice(tio, yield, CW, bs_choice,
  504. (R ^ our_parity), L);
  505. });
  506. bool parity_bit = our_parity_bit ^ peer_parity_bit;
  507. cfbits |= (value_t(parity_bit)<<level);
  508. DPFnode CWR = CW ^ lsb128_mask[parity_bit];
  509. if (player < 2) {
  510. // The timing of each iteration of the inner loop is
  511. // comparable to the above, so just use the same
  512. // computations. All of this could be tuned, of course.
  513. if (level < depth-1) {
  514. // Using the timing results gathered above, decide whether
  515. // to multithread, and if so, how many threads to use.
  516. // tio.cpu_nthreads() is the maximum number we have
  517. // available.
  518. int max_nthreads = tio.cpu_nthreads();
  519. if (max_nthreads == 1 || level < 19) {
  520. // No threading
  521. for(size_t i=0;i<curlevel_size;++i) {
  522. bool flag = get_lsb(curlevel[i]);
  523. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  524. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  525. }
  526. } else {
  527. int nthreads =
  528. int(ceil(sqrt(double(curlevel_size/6000))));
  529. if (nthreads > max_nthreads) {
  530. nthreads = max_nthreads;
  531. }
  532. size_t threadstart = 0;
  533. size_t threadchunk = curlevel_size / nthreads;
  534. size_t threadextra = curlevel_size % nthreads;
  535. boost::asio::thread_pool pool(nthreads);
  536. for (int t=0;t<nthreads;++t) {
  537. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  538. size_t threadend = threadstart + threadsize;
  539. boost::asio::post(pool, [CW, CWR, threadstart, threadend,
  540. &curlevel, &nextlevel] {
  541. for(size_t i=threadstart;i<threadend;++i) {
  542. bool flag = get_lsb(curlevel[i]);
  543. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  544. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  545. }
  546. });
  547. threadstart = threadend;
  548. }
  549. pool.join();
  550. }
  551. } else {
  552. // Recall there are four potentially useful vectors that
  553. // can come out of a DPF:
  554. // - (single-bit) bitwise unit vector
  555. // - additive-shared unit vector
  556. // - XOR-shared scaled unit vector
  557. // - additive-shared scaled unit vector
  558. //
  559. // (No single DPF should be used for both of the first
  560. // two or both of the last two, though, since they're
  561. // correlated; you _can_ use one of the first two and
  562. // one of the last two.)
  563. //
  564. // For each 128-bit leaf, the low bit is the flag bit,
  565. // and we're guaranteed that the flag bits (and indeed
  566. // the whole 128-bit value) for P0 and P1 are the same
  567. // for every leaf except the target, and that the flag
  568. // bits definitely differ for the target (and the other
  569. // 127 bits are independently random on each side).
  570. //
  571. // We divide the 128-bit leaf into a low 64-bit word and
  572. // a high 64-bit word. We use the low word for the unit
  573. // vector and the high word for the scaled vector; this
  574. // choice is not arbitrary: the flag bit in the low word
  575. // means that the sum of all the low words (with P1's
  576. // low words negated) across both P0 and P1 is
  577. // definitely odd, so we can compute that sum's inverse
  578. // mod 2^64, and store it now during precomputation. At
  579. // evaluation time for the additive-shared unit vector,
  580. // we will output this global inverse times the low word
  581. // of each leaf, which will make the sum of all of those
  582. // values 1. (This technique replaces the protocol in
  583. // Appendix D of the Duoram paper.)
  584. //
  585. // For the scaled vector, we just have to compute shares
  586. // of what the scaled vector is a sharing _of_, but
  587. // that's just XORing or adding all of each party's
  588. // local high words; no communication needed.
  589. value_t low_sum = 0;
  590. value_t high_sum = 0;
  591. value_t high_xor = 0;
  592. // Using the timing results gathered above, decide whether
  593. // to multithread, and if so, how many threads to use.
  594. // tio.cpu_nthreads() is the maximum number we have
  595. // available.
  596. int max_nthreads = tio.cpu_nthreads();
  597. if (max_nthreads == 1 || level < 19) {
  598. // No threading
  599. for(size_t i=0;i<curlevel_size;++i) {
  600. bool flag = get_lsb(curlevel[i]);
  601. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  602. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  603. if (save_expansion) {
  604. nextlevel[2*i] = leftchild;
  605. nextlevel[2*i+1] = rightchild;
  606. }
  607. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  608. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  609. value_t lefthigh =
  610. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  611. value_t righthigh =
  612. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  613. low_sum += (leftlow + rightlow);
  614. high_sum += (lefthigh + righthigh);
  615. high_xor ^= (lefthigh ^ righthigh);
  616. }
  617. } else {
  618. int nthreads =
  619. int(ceil(sqrt(double(curlevel_size/6000))));
  620. if (nthreads > max_nthreads) {
  621. nthreads = max_nthreads;
  622. }
  623. value_t tlow_sum[nthreads];
  624. value_t thigh_sum[nthreads];
  625. value_t thigh_xor[nthreads];
  626. size_t threadstart = 0;
  627. size_t threadchunk = curlevel_size / nthreads;
  628. size_t threadextra = curlevel_size % nthreads;
  629. boost::asio::thread_pool pool(nthreads);
  630. for (int t=0;t<nthreads;++t) {
  631. size_t threadsize = threadchunk + (size_t(t) < threadextra);
  632. size_t threadend = threadstart + threadsize;
  633. boost::asio::post(pool,
  634. [t, &tlow_sum, &thigh_sum, &thigh_xor, threadstart, threadend,
  635. &curlevel, &nextlevel, CW, CWR, save_expansion] {
  636. value_t low_sum = 0;
  637. value_t high_sum = 0;
  638. value_t high_xor = 0;
  639. for(size_t i=threadstart;i<threadend;++i) {
  640. bool flag = get_lsb(curlevel[i]);
  641. DPFnode leftchild = xor_if(nextlevel[2*i], CW, flag);
  642. DPFnode rightchild = xor_if(nextlevel[2*i+1], CWR, flag);
  643. if (save_expansion) {
  644. nextlevel[2*i] = leftchild;
  645. nextlevel[2*i+1] = rightchild;
  646. }
  647. value_t leftlow = value_t(_mm_cvtsi128_si64x(leftchild));
  648. value_t rightlow = value_t(_mm_cvtsi128_si64x(rightchild));
  649. value_t lefthigh =
  650. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leftchild,8)));
  651. value_t righthigh =
  652. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(rightchild,8)));
  653. low_sum += (leftlow + rightlow);
  654. high_sum += (lefthigh + righthigh);
  655. high_xor ^= (lefthigh ^ righthigh);
  656. }
  657. tlow_sum[t] = low_sum;
  658. thigh_sum[t] = high_sum;
  659. thigh_xor[t] = high_xor;
  660. });
  661. threadstart = threadend;
  662. }
  663. pool.join();
  664. for (int t=0;t<nthreads;++t) {
  665. low_sum += tlow_sum[t];
  666. high_sum += thigh_sum[t];
  667. high_xor ^= thigh_xor[t];
  668. }
  669. }
  670. if (player == 1) {
  671. low_sum = -low_sum;
  672. high_sum = -high_sum;
  673. }
  674. li[0].scaled_sum[0].ashare = high_sum;
  675. li[0].scaled_xor[0].xshare = high_xor;
  676. // Exchange low_sum and add them up
  677. tio.queue_peer(&low_sum, sizeof(low_sum));
  678. yield();
  679. value_t peer_low_sum;
  680. tio.recv_peer(&peer_low_sum, sizeof(peer_low_sum));
  681. low_sum += peer_low_sum;
  682. // The low_sum had better be odd
  683. assert(low_sum & 1);
  684. li[0].unit_sum_inverse = inverse_value_t(low_sum);
  685. }
  686. cw.push_back(CW);
  687. } else if (level == depth-1) {
  688. yield();
  689. }
  690. ++level;
  691. }
  692. delete[] curlevel;
  693. if (!save_expansion || player == 2) {
  694. delete[] nextlevel;
  695. }
  696. }
  697. // Get the leaf node for the given input
  698. template <nbits_t WIDTH>
  699. typename RDPF<WIDTH>::LeafNode
  700. RDPF<WIDTH>::leaf(address_t input, size_t &aes_ops) const
  701. {
  702. // If we have a precomputed expansion, just use it
  703. if (expansion.size()) {
  704. return expansion[input];
  705. }
  706. DPFnode node = seed;
  707. for (nbits_t d=0;d<curdepth-1;++d) {
  708. bit_t dir = !!(input & (address_t(1)<<(curdepth-d-1)));
  709. node = descend(node, d, dir, aes_ops);
  710. }
  711. bit_t dir = (input & 1);
  712. return descend_to_leaf(node, curdepth, dir, aes_ops);
  713. }
  714. // Expand the DPF if it's not already expanded
  715. //
  716. // This routine is slightly more efficient than repeatedly calling
  717. // StreamEval::next(), but it uses a lot more memory.
  718. template <nbits_t WIDTH>
  719. void RDPF<WIDTH>::expand(size_t &aes_ops)
  720. {
  721. nbits_t depth = this->depth();
  722. size_t num_leaves = size_t(1)<<depth;
  723. if (expansion.size() == num_leaves) return;
  724. expansion.resize(num_leaves);
  725. address_t index = 0;
  726. address_t lastindex = 0;
  727. DPFnode *path = new DPFnode[depth];
  728. path[0] = seed;
  729. for (nbits_t i=1;i<depth;++i) {
  730. path[i] = descend(path[i-1], i-1, 0, aes_ops);
  731. }
  732. expansion[index++][0] = descend(path[depth-1], depth-1, 0, aes_ops);
  733. expansion[index++][0] = descend(path[depth-1], depth-1, 1, aes_ops);
  734. while(index < num_leaves) {
  735. // Invariant: lastindex and index will both be even, and
  736. // index=lastindex+2
  737. uint64_t index_xor = index ^ lastindex;
  738. nbits_t how_many_1_bits = __builtin_popcount(index_xor);
  739. // If lastindex -> index goes for example from (in binary)
  740. // 010010110 -> 010011000, then index_xor will be
  741. // 000001110 and how_many_1_bits will be 3.
  742. // That indicates that path[depth-3] was a left child, and now
  743. // we need to change it to a right child by descending right
  744. // from path[depth-4], and then filling the path after that with
  745. // left children.
  746. path[depth-how_many_1_bits] =
  747. descend(path[depth-how_many_1_bits-1],
  748. depth-how_many_1_bits-1, 1, aes_ops);
  749. for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
  750. path[i+1] = descend(path[i], i, 0, aes_ops);
  751. }
  752. lastindex = index;
  753. expansion[index++][0] = descend(path[depth-1], depth-1, 0, aes_ops);
  754. expansion[index++][0] = descend(path[depth-1], depth-1, 1, aes_ops);
  755. }
  756. delete[] path;
  757. }
  758. // Construct three RDPFs of the given depth all with the same randomly
  759. // generated target index.
  760. template <nbits_t WIDTH>
  761. RDPFTriple<WIDTH>::RDPFTriple(MPCTIO &tio, yield_t &yield,
  762. nbits_t depth, bool save_expansion)
  763. {
  764. // Pick a random XOR share of the target
  765. xs_target.randomize(depth);
  766. // Now create three RDPFs with that target, and also convert the XOR
  767. // shares of the target to additive shares
  768. std::vector<coro_t> coroutines;
  769. for (int i=0;i<3;++i) {
  770. coroutines.emplace_back(
  771. [this, &tio, depth, i, save_expansion](yield_t &yield) {
  772. dpf[i] = RDPF<WIDTH>(tio, yield, xs_target, depth,
  773. save_expansion);
  774. });
  775. }
  776. coroutines.emplace_back(
  777. [this, &tio, depth](yield_t &yield) {
  778. mpc_xs_to_as(tio, yield, as_target, xs_target, depth, false);
  779. });
  780. run_coroutines(yield, coroutines);
  781. }
  782. template <nbits_t WIDTH>
  783. typename RDPFTriple<WIDTH>::node RDPFTriple<WIDTH>::descend(
  784. const RDPFTriple<WIDTH>::node &parent,
  785. nbits_t parentdepth, bit_t whichchild,
  786. size_t &aes_ops) const
  787. {
  788. auto [P0, P1, P2] = parent;
  789. DPFnode C0, C1, C2;
  790. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  791. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  792. C2 = dpf[2].descend(P2, parentdepth, whichchild, aes_ops);
  793. return std::make_tuple(C0,C1,C2);
  794. }
  795. template <nbits_t WIDTH>
  796. typename RDPFTriple<WIDTH>::LeafNode RDPFTriple<WIDTH>::descend_to_leaf(
  797. const RDPFTriple<WIDTH>::node &parent,
  798. nbits_t parentdepth, bit_t whichchild,
  799. size_t &aes_ops) const
  800. {
  801. auto [P0, P1, P2] = parent;
  802. typename RDPF<WIDTH>::LeafNode C0, C1, C2;
  803. C0 = dpf[0].descend_to_leaf(P0, parentdepth, whichchild, aes_ops);
  804. C1 = dpf[1].descend_to_leaf(P1, parentdepth, whichchild, aes_ops);
  805. C2 = dpf[2].descend_to_leaf(P2, parentdepth, whichchild, aes_ops);
  806. return std::make_tuple(C0,C1,C2);
  807. }
  808. template <nbits_t WIDTH>
  809. typename RDPFPair<WIDTH>::node RDPFPair<WIDTH>::descend(
  810. const RDPFPair<WIDTH>::node &parent,
  811. nbits_t parentdepth, bit_t whichchild,
  812. size_t &aes_ops) const
  813. {
  814. auto [P0, P1] = parent;
  815. DPFnode C0, C1;
  816. C0 = dpf[0].descend(P0, parentdepth, whichchild, aes_ops);
  817. C1 = dpf[1].descend(P1, parentdepth, whichchild, aes_ops);
  818. return std::make_tuple(C0,C1);
  819. }
  820. template <nbits_t WIDTH>
  821. typename RDPFPair<WIDTH>::LeafNode RDPFPair<WIDTH>::descend_to_leaf(
  822. const RDPFPair<WIDTH>::node &parent,
  823. nbits_t parentdepth, bit_t whichchild,
  824. size_t &aes_ops) const
  825. {
  826. auto [P0, P1] = parent;
  827. typename RDPF<WIDTH>::LeafNode C0, C1;
  828. C0 = dpf[0].descend_to_leaf(P0, parentdepth, whichchild, aes_ops);
  829. C1 = dpf[1].descend_to_leaf(P1, parentdepth, whichchild, aes_ops);
  830. return std::make_tuple(C0,C1);
  831. }