rdpf.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "rdpf.hpp"
  3. #include "bitutils.hpp"
  4. #include "mpcops.hpp"
  5. #include "aes.hpp"
  6. #include "prg.hpp"
  7. #ifdef DPF_DEBUG
  8. static void dump_node(DPFnode node, const char *label = NULL)
  9. {
  10. if (label) printf("%s: ", label);
  11. for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&node)[15-i]); } printf("\n");
  12. }
  13. static void dump_level(DPFnode *nodes, size_t num, const char *label = NULL)
  14. {
  15. if (label) printf("%s:\n", label);
  16. for (size_t i=0;i<num;++i) {
  17. dump_node(nodes[i]);
  18. }
  19. printf("\n");
  20. }
  21. #endif
  22. // Construct a DPF with the given (XOR-shared) target location, and
  23. // of the given depth, to be used for random-access memory reads and
  24. // writes. The DPF is construction collaboratively by P0 and P1,
  25. // with the server P2 helping by providing various kinds of
  26. // correlated randomness, such as MultTriples and AndTriples.
  27. RDPF::RDPF(MPCTIO &tio, yield_t &yield,
  28. RegXS target, nbits_t depth)
  29. {
  30. int player = tio.player();
  31. size_t &aesops = tio.aes_ops();
  32. // Choose a random seed
  33. arc4random_buf(&seed, sizeof(seed));
  34. // Ensure the flag bits (the lsb of each node) are different
  35. seed = set_lsb(seed, !!player);
  36. cfbits = 0;
  37. // The root level is just the seed
  38. nbits_t level = 0;
  39. DPFnode *curlevel = NULL;
  40. DPFnode *nextlevel = new DPFnode[1];
  41. nextlevel[0] = seed;
  42. // Construct each intermediate level
  43. while(level < depth) {
  44. delete[] curlevel;
  45. curlevel = nextlevel;
  46. // We don't need to store the last level
  47. if (level < depth-1) {
  48. nextlevel = new DPFnode[1<<(level+1)];
  49. } else {
  50. nextlevel = NULL;
  51. }
  52. // Invariant: curlevel has 2^level elements; nextlevel has
  53. // 2^{level+1} elements
  54. // The bit-shared choice bit is bit (depth-level-1) of the
  55. // XOR-shared target index
  56. RegBS bs_choice = target.bit(depth-level-1);
  57. size_t curlevel_size = (size_t(1)<<level);
  58. DPFnode L = _mm_setzero_si128();
  59. DPFnode R = _mm_setzero_si128();
  60. // The server doesn't need to do this computation, but it does
  61. // need to execute mpc_reconstruct_choice so that it sends
  62. // the AndTriples at the appropriate time.
  63. if (player < 2) {
  64. for(size_t i=0;i<curlevel_size;++i) {
  65. DPFnode lchild, rchild;
  66. prgboth(lchild, rchild, curlevel[i], aesops);
  67. L = _mm_xor_si128(L, lchild);
  68. R = _mm_xor_si128(R, rchild);
  69. if (nextlevel) {
  70. nextlevel[2*i] = lchild;
  71. nextlevel[2*i+1] = rchild;
  72. }
  73. }
  74. }
  75. // If we're going left (bs_choice = 0), we want the correction
  76. // work to be the XOR of our right side and our peer's right
  77. // side; if bs_choice = 1, it should be the XOR or our left side
  78. // and our peer's left side.
  79. // We have to ensure that the flag bits (the lsb) of the side
  80. // that will end up the same be of course the same, but also
  81. // that the flag bits (the lsb) of the side that will end up
  82. // different _must_ be different. That is, it's not enough for
  83. // the nodes of the child selected by choice to be different as
  84. // 128-bit values; they also have to be different in their lsb.
  85. // Note that the XOR of our left and right child before and
  86. // after applying the correction word won't change, since the
  87. // correction word is applied to either both children or
  88. // neither, depending on the value of the parent's flag. So in
  89. // particular, the XOR of the flag bits won't change, and if our
  90. // children's flag's XOR equals our peer's children's flag's
  91. // XOR, then we won't have different flag bits even for the
  92. // children that have different 128-bit values.
  93. // So we compute our_parity = lsb(L^R)^player, and we XOR that
  94. // into the R value in the correction word computation. At the
  95. // same time, we exchange these parity values to compute the
  96. // combined parity, which we store in the DPF. Then when the
  97. // DPF is evaluated, if the parent's flag is set, not only apply
  98. // the correction work to both children, but also apply the
  99. // (combined) parity bit to just the right child. Then for
  100. // unequal nodes (where the flag bit is different), exactly one
  101. // of the four children (two for P0 and two for P1) will have
  102. // the parity bit applied, which will set the XOR of the lsb of
  103. // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
  104. // = 1 because everything cancels out except player (for which
  105. // one player is 0 and the other is 1).
  106. bool our_parity_bit = get_lsb(_mm_xor_si128(L,R)) ^ !!player;
  107. DPFnode our_parity = lsb128_mask[our_parity_bit];
  108. DPFnode CW;
  109. bool peer_parity_bit;
  110. // Exchange the parities and do mpc_reconstruct_choice at the
  111. // same time (bundled into the same rounds)
  112. std::vector<coro_t> coroutines;
  113. coroutines.emplace_back(
  114. [&](yield_t &yield) {
  115. tio.queue_peer(&our_parity_bit, 1);
  116. yield();
  117. tio.recv_peer(&peer_parity_bit, 1);
  118. });
  119. coroutines.emplace_back(
  120. [&](yield_t &yield) {
  121. mpc_reconstruct_choice(tio, yield, CW, bs_choice,
  122. _mm_xor_si128(R,our_parity), L);
  123. });
  124. run_coroutines(yield, coroutines);
  125. bool parity_bit = our_parity_bit ^ peer_parity_bit;
  126. cfbits |= (size_t(parity_bit)<<level);
  127. DPFnode CWR = _mm_xor_si128(CW,lsb128_mask[parity_bit]);
  128. if (player < 2) {
  129. if (nextlevel) {
  130. for(size_t i=0;i<curlevel_size;++i) {
  131. bool flag = get_lsb(curlevel[i]);
  132. nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
  133. nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
  134. }
  135. }
  136. cw.push_back(CW);
  137. }
  138. ++level;
  139. }
  140. delete[] curlevel;
  141. delete[] nextlevel;
  142. }