cdpf.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #include <bsd/stdlib.h> // arc4random_buf
  2. #include "bitutils.hpp"
  3. #include "cdpf.hpp"
  4. // Generate a pair of CDPFs with the given target value
  5. //
  6. // Cost:
  7. // 4*VALUE_BITS - 28 = 228 local AES operations
  8. std::tuple<CDPF,CDPF> CDPF::generate(value_t target, size_t &aes_ops)
  9. {
  10. CDPF dpf0, dpf1;
  11. const nbits_t depth = VALUE_BITS - 7;
  12. // Pick two random seeds
  13. arc4random_buf(&dpf0.seed, sizeof(dpf0.seed));
  14. arc4random_buf(&dpf1.seed, sizeof(dpf1.seed));
  15. // Ensure the flag bits (the lsb of each node) are different
  16. dpf0.seed = set_lsb(dpf0.seed, 0);
  17. dpf1.seed = set_lsb(dpf1.seed, 1);
  18. dpf0.whichhalf = 0;
  19. dpf1.whichhalf = 1;
  20. dpf0.cfbits = 0;
  21. dpf1.cfbits = 0;
  22. dpf0.as_target.randomize();
  23. dpf1.as_target.ashare = target - dpf0.as_target.ashare;
  24. dpf0.xs_target.randomize();
  25. dpf1.xs_target.xshare = target ^ dpf0.xs_target.xshare;
  26. // The current node in each CDPF as we descend the tree. The
  27. // invariant is that cur0 and cur1 are the nodes on the path to the
  28. // target at level curlevel. They will necessarily be different,
  29. // and indeed must have different flag (low) bits.
  30. DPFnode cur0 = dpf0.seed;
  31. DPFnode cur1 = dpf1.seed;
  32. nbits_t curlevel = 0;
  33. while(curlevel < depth) {
  34. // Construct the two (uncorrected) children of each node
  35. DPFnode left0, right0, left1, right1;
  36. prgboth(left0, right0, cur0, aes_ops);
  37. prgboth(left1, right1, cur1, aes_ops);
  38. // Which way lies the target?
  39. bool targetdir = !!(target & (value_t(1)<<((depth+7)-curlevel-1)));
  40. DPFnode CW;
  41. bool cfbit = !get_lsb(left0 ^ left1 ^ right0 ^ right1);
  42. bool flag0 = get_lsb(cur0);
  43. bool flag1 = get_lsb(cur1);
  44. // The last level is special
  45. if (curlevel < depth-1) {
  46. if (targetdir == 0) {
  47. // The target is to the left, so make the correction word
  48. // and bit make the right children the same and the left
  49. // children have different flag bits.
  50. // Recall that descend will apply (only for the party whose
  51. // current node (cur0 or cur1) has the flag bit set, for
  52. // which exactly one of the two will) CW to both children,
  53. // and cfbit to the flag bit of the right child.
  54. CW = right0 ^ right1 ^ lsb128_mask[cfbit];
  55. // Compute the current nodes for the next level
  56. // Exactly one of these two XORs will fire, so afterwards,
  57. // cur0 ^ cur1 = left0 ^ left1 ^ CW, which will have low bit
  58. // 1 by the definition of cfbit.
  59. cur0 = xor_if(left0, CW, flag0);
  60. cur1 = xor_if(left1, CW, flag1);
  61. } else {
  62. // The target is to the right, so make the correction word
  63. // and bit make the left children the same and the right
  64. // children have different flag bits.
  65. CW = left0 ^ left1;
  66. // Compute the current nodes for the next level
  67. // Exactly one of these two XORs will fire, so similar to
  68. // the above, afterwards, cur0 ^ cur1 = right0 ^ right1 ^ CWR,
  69. // which will have low bit 1.
  70. DPFnode CWR = CW ^ lsb128_mask[cfbit];
  71. cur0 = xor_if(right0, CWR, flag0);
  72. cur1 = xor_if(right1, CWR, flag1);
  73. }
  74. } else {
  75. // We're at the last level before the leaves. We still want
  76. // the children not in the direction of targetdir to end up
  77. // the same, but now we want the child in the direction of
  78. // targetdir to also end up the same, except for the single
  79. // target bit. Importantly, the low bit (the flag bit in
  80. // all other nodes) is not special, and will in fact usually
  81. // end up the same for the two DPFs (unless the target bit
  82. // happens to be the low bit of the word; i.e., the low 7
  83. // bits of target are all 0).
  84. // This will be a 128-bit word with a single bit set, in
  85. // position (target & 0x7f).
  86. uint8_t loc = (target & 0x7f);
  87. DPFnode target_set_bit = _mm_set_epi64x(
  88. loc >= 64 ? (uint64_t(1)<<(loc-64)) : 0,
  89. loc >= 64 ? 0 : (uint64_t(1)<<loc));
  90. if (targetdir == 0) {
  91. // We want the right children to be the same, and the
  92. // left children to be the same except for the target
  93. // bit.
  94. // Remember for exactly one of the two parties, CW will
  95. // be applied to the left and CWR will be applied to the
  96. // right.
  97. CW = left0 ^ left1 ^ target_set_bit;
  98. DPFnode CWR = right0 ^ right1;
  99. dpf0.leaf_cwr = CWR;
  100. dpf1.leaf_cwr = CWR;
  101. } else {
  102. // We want the left children to be the same, and the
  103. // right children to be the same except for the target
  104. // bit.
  105. // Remember for exactly one of the two parties, CW will
  106. // be applied to the left and CWR will be applied to the
  107. // right.
  108. CW = left0 ^ left1;
  109. DPFnode CWR = right0 ^ right1 ^ target_set_bit;
  110. dpf0.leaf_cwr = CWR;
  111. dpf1.leaf_cwr = CWR;
  112. }
  113. }
  114. dpf0.cw.push_back(CW);
  115. dpf1.cw.push_back(CW);
  116. dpf0.cfbits |= (value_t(cfbit)<<curlevel);
  117. dpf1.cfbits |= (value_t(cfbit)<<curlevel);
  118. ++curlevel;
  119. }
  120. return std::make_tuple(dpf0, dpf1);
  121. }
  122. // Generate a pair of CDPFs with a random target value
  123. //
  124. // Cost:
  125. // 4*VALUE_BITS - 28 = 228 local AES operations
  126. std::tuple<CDPF,CDPF> CDPF::generate(size_t &aes_ops)
  127. {
  128. value_t target;
  129. arc4random_buf(&target, sizeof(target));
  130. return generate(target, aes_ops);
  131. }
  132. // Get the leaf node for the given input. We don't actually use
  133. // this in the protocol, but it's useful for testing.
  134. DPFnode CDPF::leaf(value_t input, size_t &aes_ops) const
  135. {
  136. nbits_t depth = cw.size();
  137. DPFnode node = seed;
  138. input >>= 7;
  139. for (nbits_t d=0;d<depth-1;++d) {
  140. bit_t dir = !!(input & (value_t(1)<<(depth-d-1)));
  141. node = descend(node, d, dir, aes_ops);
  142. }
  143. // The last layer is special
  144. bit_t dir = input & 1;
  145. node = descend_to_leaf(node, dir, aes_ops);
  146. return node;
  147. }
  148. // Compare the given additively shared element to 0. The output is
  149. // a triple of bit shares; the first is a share of 1 iff the
  150. // reconstruction of the element is negative; the second iff it is
  151. // 0; the third iff it is positive. (All as two's-complement
  152. // VALUE_BIT-bit integers.) Note in particular that exactly one of
  153. // the outputs will be a share of 1, so you can do "greater than or
  154. // equal to" just by adding the greater and equal outputs together.
  155. // Note also that you can compare two RegAS values A and B by
  156. // passing A-B here.
  157. //
  158. // Cost:
  159. // 1 word sent in 1 message
  160. // 2*VALUE_BITS - 14 = 114 local AES operations
  161. std::tuple<RegBS,RegBS,RegBS> CDPF::compare(MPCTIO &tio, yield_t &yield,
  162. RegAS x, size_t &aes_ops)
  163. {
  164. // Reconstruct S = target-x
  165. // The server does nothing in this protocol
  166. if (tio.player() < 2) {
  167. RegAS S_share = as_target - x;
  168. tio.iostream_peer() << S_share;
  169. yield();
  170. RegAS peer_S_share;
  171. tio.iostream_peer() >> peer_S_share;
  172. value_t S = S_share.ashare + peer_S_share.ashare;
  173. // After that one single-word exchange, the rest of this
  174. // algorithm is entirely a local computation.
  175. return compare(S, aes_ops);
  176. } else {
  177. yield();
  178. }
  179. // The server gets three shares of 0 (which is not a valid output
  180. // for the computational players)
  181. RegBS lt, eq, gt;
  182. return std::make_tuple(lt, eq, gt);
  183. }
  184. // You can call this version directly if you already have S = target-x
  185. // reconstructed. This routine is entirely local; no communication
  186. // is needed.
  187. //
  188. // Cost:
  189. // 2*VALUE_BITS - 14 = 114 local AES operations
  190. std::tuple<RegBS,RegBS,RegBS> CDPF::compare(value_t S, size_t &aes_ops)
  191. {
  192. RegBS gt, eq;
  193. // Now we're going to simultaneously descend the DPF tree for the
  194. // values S and T = S + 2^63. Note that the 1 values of V (see the
  195. // explanation of the algorithm in cdpf.hpp) are those values
  196. // _strictly_ larger than S and smaller than T (noting they can
  197. // "wrap around" 2^64). In level 1 of the tree, the paths to S and
  198. // T will necessarily be at the two different children of the root
  199. // seed, but they could be in either order. From then on, they will
  200. // proceed in lockstep, either both going left, or both going right.
  201. // If they both go left, we also compute the flag for the right
  202. // sibling on the S path (which will be the XOR of the left sibling
  203. // and the parent), and add it to the gt flag. If they both go
  204. // right, we also include the left sibling on the T path (which will
  205. // be the XOR of the right sibling and the parent), and add it to
  206. // the gt flag. When we hit the leaves, the gt flag will account
  207. // for all of the complete leaf nodes strictly greater than S and
  208. // strictly less than T. Then we just have to pull out the parity
  209. // of the appropriate bits in the two leaf nodes containing S and T
  210. // respectively to complete the computation of gt, and also to get
  211. // the single bit eq.
  212. nbits_t curlevel = 0;
  213. const nbits_t depth = VALUE_BITS - 7;
  214. DPFnode Sparent = seed;
  215. DPFnode Tparent = seed;
  216. // The top level is the only place where the paths to S and T go
  217. // in different directions.
  218. bool Sdir = !!(S & (value_t(1)<<63));
  219. DPFnode Snode = descend(Sparent, curlevel, Sdir, aes_ops);
  220. DPFnode Tnode = descend(Tparent, curlevel, !Sdir, aes_ops);
  221. curlevel = 1;
  222. // Invariant: Snode is the node on level curlevel on the path to
  223. // S, and Tnode is the node on level curlevel on the path to
  224. // T = S + 2^63. Sparent and Tparent are the parents of Snode and
  225. // Tnode respectively.
  226. // The last level is special, so terminate the loop 1 before the end
  227. while(curlevel < depth-1) {
  228. Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
  229. Sparent = Snode;
  230. Tparent = Tnode;
  231. Snode = descend(Sparent, curlevel, Sdir, aes_ops);
  232. Tnode = descend(Tparent, curlevel, Sdir, aes_ops);
  233. ++curlevel;
  234. if (Sdir == 0) {
  235. // They both went left; include the right child of
  236. // Sparent in the gt computation
  237. gt ^= get_lsb(Sparent) ^ get_lsb(Snode);
  238. } else {
  239. // Theye both went right; include the left child of
  240. // Tparent in the gt computation
  241. gt ^= get_lsb(Tparent) ^ get_lsb(Tnode);
  242. }
  243. }
  244. // Now we're at the level just above the leaves. If we go left,
  245. // include *all* the bits (not just the low bit) of the right child
  246. // of Sparent (which will be the flag bit of Sparent XORed with the
  247. // parity of all the bits of Snode), and if we go right, include all
  248. // the bits of the left child of Tnode (which will be the flag bit
  249. // of Tparent XORed with the parity of all the bits of Tnode).
  250. Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
  251. Sparent = Snode;
  252. Tparent = Tnode;
  253. Snode = descend_to_leaf(Sparent, Sdir, aes_ops);
  254. Tnode = descend_to_leaf(Tparent, Sdir, aes_ops);
  255. ++curlevel;
  256. if (Sdir == 0) {
  257. // They both went left; include the right child of
  258. // Snode in the gt computation
  259. gt ^= get_lsb(Sparent) ^ parity(Snode);
  260. } else {
  261. // They're both going right; include the left child of
  262. // Tnode in the gt computation
  263. gt ^= get_lsb(Tparent) ^ parity(Tnode);
  264. }
  265. // Now Snode and Tnode are the leaves containing S and T
  266. // respectively. Pull out the bit in Snode for S itself into eq,
  267. // and all the higher bits into gt. Also pull out the bits strictly
  268. // below that for T in Tnode into gt.
  269. nbits_t Spos = S & 0x7f;
  270. eq.bshare = bit_at(Snode, Spos);
  271. gt ^= parity_above(Snode, Spos);
  272. gt ^= parity_below(Tnode, Spos);
  273. // Once we have gt and eq (which cannot both be 1), lt is just 1
  274. // exactly if they're both 0.
  275. RegBS lt;
  276. lt.bshare = whichhalf ^ eq.bshare ^ gt.bshare;
  277. return std::make_tuple(lt, eq, gt);
  278. }
  279. // You can call this version directly if you already have S = target-x
  280. // reconstructed. This routine is entirely local; no communication
  281. // is needed. This function is identical to compare, above, except that
  282. // it only computes what's needed for the eq output.
  283. //
  284. // Cost:
  285. // VALUE_BITS - 7 = 57 local AES operations
  286. RegBS CDPF::is_zero(value_t S, size_t &aes_ops)
  287. {
  288. RegBS eq;
  289. // We' descend the DPF tree for the values S.
  290. // Invariant: Snode is the node on level curlevel on the path to
  291. // S.
  292. nbits_t curlevel = 0;
  293. const nbits_t depth = VALUE_BITS - 7;
  294. DPFnode Snode = seed;
  295. bool Sdir = !!(S & (value_t(1)<<63));
  296. Snode = descend(Snode, curlevel, Sdir, aes_ops);
  297. curlevel = 1;
  298. // The last level is special
  299. while(curlevel < depth-1) {
  300. Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
  301. Snode = descend(Snode, curlevel, Sdir, aes_ops);
  302. ++curlevel;
  303. }
  304. // Now we're at the level just above the leaves. If we go left,
  305. // include *all* the bits (not just the low bit) of the right
  306. // child of Snode, and if we go right, include all the bits of
  307. // the left child of Tnode.
  308. Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
  309. Snode = descend_to_leaf(Snode, Sdir, aes_ops);
  310. ++curlevel;
  311. // Now Snode is the leaf containing S. Pull out the bit in Snode
  312. // for S itself into eq.
  313. nbits_t Spos = S & 0x7f;
  314. eq.bshare = bit_at(Snode, Spos);
  315. return eq;
  316. }