rdpf.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. #ifndef __RDPF_HPP__
  2. #define __RDPF_HPP__
  3. #include <vector>
  4. #include <iostream>
  5. #include "mpcio.hpp"
  6. #include "coroutine.hpp"
  7. #include "types.hpp"
  8. #include "bitutils.hpp"
  9. // Streaming evaluation, to avoid taking up enough memory to store
  10. // an entire evaluation. T can be RDPF, RDPFPair, or RDPFTriple.
  11. template <typename T>
  12. class StreamEval {
  13. const T &rdpf;
  14. size_t &op_counter;
  15. bool use_expansion;
  16. nbits_t depth;
  17. address_t indexmask;
  18. address_t pathindex;
  19. address_t nextindex;
  20. std::vector<typename T::node> path;
  21. public:
  22. // Create an Eval object that will start its output at index start.
  23. // It will wrap around to 0 when it hits 2^depth. If use_expansion
  24. // is true, then if the DPF has been expanded, just output values
  25. // from that. If use_expansion=false or if the DPF has not been
  26. // expanded, compute the values on the fly.
  27. StreamEval(const T &rdpf, address_t start, size_t &op_counter,
  28. bool use_expansion = true);
  29. // Get the next value (or tuple of values) from the evaluator
  30. typename T::node next();
  31. };
  32. // Create a StreamEval object that will start its output at index start.
  33. // It will wrap around to 0 when it hits 2^depth. If use_expansion
  34. // is true, then if the DPF has been expanded, just output values
  35. // from that. If use_expansion=false or if the DPF has not been
  36. // expanded, compute the values on the fly.
  37. template <typename T>
  38. StreamEval<T>::StreamEval(const T &rdpf, address_t start,
  39. size_t &op_counter, bool use_expansion) : rdpf(rdpf),
  40. op_counter(op_counter), use_expansion(use_expansion)
  41. {
  42. depth = rdpf.depth();
  43. // Prevent overflow of 1<<depth
  44. if (depth < ADDRESS_MAX_BITS) {
  45. indexmask = (address_t(1)<<depth)-1;
  46. } else {
  47. indexmask = ~0;
  48. }
  49. // Record that we haven't actually output the leaf for index start
  50. // itself yet
  51. nextindex = start;
  52. if (use_expansion && rdpf.has_expansion()) {
  53. // We just need to keep the counter, not compute anything
  54. return;
  55. }
  56. path.resize(depth);
  57. pathindex = start;
  58. path[0] = rdpf.get_seed();
  59. for (nbits_t i=1;i<depth;++i) {
  60. bool dir = !!(pathindex & (address_t(1)<<(depth-i)));
  61. path[i] = rdpf.descend(path[i-1], i-1, dir, op_counter);
  62. }
  63. }
  64. template <typename T>
  65. typename T::node StreamEval<T>::next()
  66. {
  67. if (use_expansion && rdpf.has_expansion()) {
  68. // Just use the precomputed values
  69. typename T::node leaf = rdpf.get_expansion(nextindex);
  70. nextindex = (nextindex + 1) & indexmask;
  71. return leaf;
  72. }
  73. // Invariant: in the first call to next(), nextindex = pathindex.
  74. // Otherwise, nextindex = pathindex+1.
  75. // Get the XOR of nextindex and pathindex, and strip the low bit.
  76. // If nextindex and pathindex are equal, or pathindex is even
  77. // and nextindex is the consecutive odd number, index_xor will be 0,
  78. // indicating that we don't have to update the path, but just
  79. // compute the appropriate leaf given by the low bit of nextindex.
  80. //
  81. // Otherwise, say for example pathindex is 010010111 and nextindex
  82. // is 010011000. Then their XOR is 000001111, and stripping the low
  83. // bit yields 000001110, so how_many_1_bits will be 3.
  84. // That indicates (typically) that path[depth-3] was a left child,
  85. // and now we need to change it to a right child by descending right
  86. // from path[depth-4], and then filling the path after that with
  87. // left children.
  88. //
  89. // When we wrap around, however, index_xor will be 111111110 (after
  90. // we strip the low bit), and how_many_1_bits will be depth-1, but
  91. // the new top child (of the root seed) we have to compute will be a
  92. // left, not a right, child.
  93. uint64_t index_xor = (nextindex ^ pathindex) & ~1;
  94. nbits_t how_many_1_bits = __builtin_popcount(index_xor);
  95. if (how_many_1_bits > 0) {
  96. // This will almost always be 1, unless we've just wrapped
  97. // around from the right subtree back to the left, in which case
  98. // it will be 0.
  99. bool top_changed_bit =
  100. nextindex & (address_t(1) << how_many_1_bits);
  101. path[depth-how_many_1_bits] =
  102. rdpf.descend(path[depth-how_many_1_bits-1],
  103. depth-how_many_1_bits-1, top_changed_bit, op_counter);
  104. for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
  105. path[i+1] = rdpf.descend(path[i], i, 0, op_counter);
  106. }
  107. }
  108. typename T::node leaf = rdpf.descend(path[depth-1], depth-1,
  109. nextindex & 1, op_counter);
  110. pathindex = nextindex;
  111. nextindex = (nextindex + 1) & indexmask;
  112. return leaf;
  113. }
  114. struct RDPF {
  115. // The type of nodes
  116. using node = DPFnode;
  117. // The 128-bit seed
  118. DPFnode seed;
  119. // Which half of the DPF are we?
  120. bit_t whichhalf;
  121. // correction words; the depth of the DPF is the length of this
  122. // vector
  123. std::vector<DPFnode> cw;
  124. // correction flag bits: the one for level i is bit i of this word
  125. value_t cfbits;
  126. // The amount we have to scale the low words of the leaf values by
  127. // to get additive shares of a unit vector
  128. value_t unit_sum_inverse;
  129. // Additive share of the scaling value M_as such that the high words
  130. // of the leaf values for P0 and P1 add to M_as * e_{target}
  131. RegAS scaled_sum;
  132. // XOR share of the scaling value M_xs such that the high words
  133. // of the leaf values for P0 and P1 XOR to M_xs * e_{target}
  134. RegXS scaled_xor;
  135. // If we're saving the expansion, put it here
  136. std::vector<DPFnode> expansion;
  137. RDPF() {}
  138. // Construct a DPF with the given (XOR-shared) target location, and
  139. // of the given depth, to be used for random-access memory reads and
  140. // writes. The DPF is constructed collaboratively by P0 and P1,
  141. // with the server P2 helping by providing correlated randomness,
  142. // such as SelectTriples.
  143. //
  144. // Cost:
  145. // (2 DPFnode + 2 bytes)*depth + 1 word communication in
  146. // 2*depth + 1 messages
  147. // (2 DPFnode + 1 byte)*depth communication from P2 to each party
  148. // 2^{depth+1}-2 local AES operations for P0,P1
  149. // 0 local AES operations for P2
  150. RDPF(MPCTIO &tio, yield_t &yield,
  151. RegXS target, nbits_t depth, bool save_expansion = false);
  152. // The number of bytes it will take to store this RDPF
  153. size_t size() const;
  154. // The number of bytes it will take to store a RDPF of the given
  155. // depth
  156. static size_t size(nbits_t depth);
  157. // The depth
  158. inline nbits_t depth() const { return cw.size(); }
  159. // The seed
  160. inline node get_seed() const { return seed; }
  161. // Do we have a precomputed expansion?
  162. inline bool has_expansion() const { return expansion.size() > 0; }
  163. // Get an element of the expansion
  164. inline node get_expansion(address_t index) const {
  165. return expansion[index];
  166. }
  167. // Descend from a node at depth parentdepth to one of its children
  168. // whichchild = 0: left child
  169. // whichchild = 1: right child
  170. //
  171. // Cost: 1 AES operation
  172. DPFnode descend(const DPFnode &parent, nbits_t parentdepth,
  173. bit_t whichchild, size_t &op_counter) const;
  174. // Get the leaf node for the given input
  175. //
  176. // Cost: depth AES operations
  177. DPFnode leaf(address_t input, size_t &op_counter) const;
  178. // Expand the DPF if it's not already expanded
  179. void expand(size_t &op_counter);
  180. #if 0
  181. // Streaming evaluation, to avoid taking up enough memory to store
  182. // an entire evaluation
  183. class Eval {
  184. friend class RDPF; // So eval() can call the Eval constructor
  185. const RDPF &rdpf;
  186. size_t &op_counter;
  187. bool use_expansion;
  188. nbits_t depth;
  189. address_t indexmask;
  190. address_t pathindex;
  191. address_t nextindex;
  192. std::vector<DPFnode> path;
  193. Eval(const RDPF &rdpf, size_t &op_counter, address_t start,
  194. bool use_expansion);
  195. public:
  196. DPFnode next();
  197. };
  198. // Create an Eval object that will start its output at index start.
  199. // It will wrap around to 0 when it hits 2^depth. If use_expansion
  200. // is true, then if the DPF has been expanded, just output values
  201. // from that. If use_expansion=false or if the DPF has not been
  202. // expanded, compute the values on the fly.
  203. StreamEval<RDPF> eval(address_t start, size_t &op_counter,
  204. bool use_expansion=true) const;
  205. #endif
  206. // Get the bit-shared unit vector entry from the leaf node
  207. inline RegBS unit_bs(DPFnode leaf) const {
  208. RegBS b;
  209. b.bshare = get_lsb(leaf);
  210. return b;
  211. }
  212. // Get the additive-shared unit vector entry from the leaf node
  213. inline RegAS unit_as(DPFnode leaf) const {
  214. RegAS a;
  215. value_t lowword = value_t(_mm_cvtsi128_si64x(leaf));
  216. if (whichhalf == 1) {
  217. lowword = -lowword;
  218. }
  219. a.ashare = lowword * unit_sum_inverse;
  220. return a;
  221. }
  222. // Get the XOR-shared scaled vector entry from the leaf ndoe
  223. inline RegXS scaled_xs(DPFnode leaf) const {
  224. RegXS x;
  225. value_t highword =
  226. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leaf,8)));
  227. x.xshare = highword;
  228. return x;
  229. }
  230. // Get the additive-shared scaled vector entry from the leaf ndoe
  231. inline RegAS scaled_as(DPFnode leaf) const {
  232. RegAS a;
  233. value_t highword =
  234. value_t(_mm_cvtsi128_si64x(_mm_srli_si128(leaf,8)));
  235. if (whichhalf == 1) {
  236. highword = -highword;
  237. }
  238. a.ashare = highword;
  239. return a;
  240. }
  241. };
  242. // I/O for RDPFs
  243. template <typename T>
  244. T& operator>>(T &is, RDPF &rdpf)
  245. {
  246. is.read((char *)&rdpf.seed, sizeof(rdpf.seed));
  247. uint8_t depth;
  248. // The whichhalf bit is the high bit of depth
  249. is.read((char *)&depth, sizeof(depth));
  250. rdpf.whichhalf = !!(depth & 0x80);
  251. depth &= 0x7f;
  252. bool read_expanded = false;
  253. if (depth > 64) {
  254. read_expanded = true;
  255. depth -= 64;
  256. }
  257. assert(depth <= ADDRESS_MAX_BITS);
  258. rdpf.cw.clear();
  259. for (uint8_t i=0; i<depth; ++i) {
  260. DPFnode cw;
  261. is.read((char *)&cw, sizeof(cw));
  262. rdpf.cw.push_back(cw);
  263. }
  264. if (read_expanded) {
  265. rdpf.expansion.resize(1<<depth);
  266. is.read((char *)rdpf.expansion.data(),
  267. sizeof(rdpf.expansion[0])<<depth);
  268. }
  269. value_t cfbits = 0;
  270. is.read((char *)&cfbits, BITBYTES(depth));
  271. rdpf.cfbits = cfbits;
  272. is.read((char *)&rdpf.unit_sum_inverse, sizeof(rdpf.unit_sum_inverse));
  273. is.read((char *)&rdpf.scaled_sum, sizeof(rdpf.scaled_sum));
  274. is.read((char *)&rdpf.scaled_xor, sizeof(rdpf.scaled_xor));
  275. return is;
  276. }
  277. // Write the DPF to the output stream. If expanded=true, then include
  278. // the expansion _if_ the DPF is itself already expanded. You can use
  279. // this to write DPFs to files.
  280. template <typename T>
  281. T& write_maybe_expanded(T &os, const RDPF &rdpf,
  282. bool expanded = true)
  283. {
  284. os.write((const char *)&rdpf.seed, sizeof(rdpf.seed));
  285. uint8_t depth = rdpf.cw.size();
  286. assert(depth <= ADDRESS_MAX_BITS);
  287. // The whichhalf bit is the high bit of depth
  288. // If we're writing an expansion, add 64 to depth as well
  289. uint8_t whichhalf_and_depth = depth |
  290. (uint8_t(rdpf.whichhalf)<<7);
  291. bool write_expansion = false;
  292. if (expanded && rdpf.expansion.size() == (size_t(1)<<depth)) {
  293. write_expansion = true;
  294. whichhalf_and_depth += 64;
  295. }
  296. os.write((const char *)&whichhalf_and_depth,
  297. sizeof(whichhalf_and_depth));
  298. for (uint8_t i=0; i<depth; ++i) {
  299. os.write((const char *)&rdpf.cw[i], sizeof(rdpf.cw[i]));
  300. }
  301. if (write_expansion) {
  302. os.write((const char *)rdpf.expansion.data(),
  303. sizeof(rdpf.expansion[0])<<depth);
  304. }
  305. os.write((const char *)&rdpf.cfbits, BITBYTES(depth));
  306. os.write((const char *)&rdpf.unit_sum_inverse, sizeof(rdpf.unit_sum_inverse));
  307. os.write((const char *)&rdpf.scaled_sum, sizeof(rdpf.scaled_sum));
  308. os.write((const char *)&rdpf.scaled_xor, sizeof(rdpf.scaled_xor));
  309. return os;
  310. }
  311. // The ordinary << version never writes the expansion, since this is
  312. // what we use to send DPFs over the network.
  313. template <typename T>
  314. T& operator<<(T &os, const RDPF &rdpf)
  315. {
  316. return write_maybe_expanded(os, rdpf, false);
  317. }
  318. // Computational peers will generate triples of RDPFs with the _same_
  319. // random target for use in Duoram. They will each hold a share of the
  320. // target (neither knowing the complete target index). They will each
  321. // give one of the DPFs (not a matching pair) to the server, but not the
  322. // shares of the target index. So computational peers will hold a
  323. // RDPFTriple (which includes both an additive and an XOR share of the
  324. // target index), while the server will hold a RDPFPair (which does
  325. // not).
  326. struct RDPFTriple {
  327. // The type of node triples
  328. using node = std::tuple<DPFnode, DPFnode, DPFnode>;
  329. RegAS as_target;
  330. RegXS xs_target;
  331. RDPF dpf[3];
  332. // The depth
  333. inline nbits_t depth() const { return dpf[0].depth(); }
  334. // The seed
  335. inline node get_seed() const {
  336. return std::make_tuple(dpf[0].get_seed(), dpf[1].get_seed(),
  337. dpf[2].get_seed());
  338. }
  339. // Do we have a precomputed expansion?
  340. inline bool has_expansion() const {
  341. return dpf[0].expansion.size() > 0;
  342. }
  343. // Get an element of the expansion
  344. inline node get_expansion(address_t index) const {
  345. return std::make_tuple(dpf[0].get_expansion(index),
  346. dpf[1].get_expansion(index), dpf[2].get_expansion(index));
  347. }
  348. RDPFTriple() {}
  349. // Construct three RDPFs of the given depth all with the same
  350. // randomly generated target index.
  351. RDPFTriple(MPCTIO &tio, yield_t &yield,
  352. nbits_t depth, bool save_expansion = false);
  353. // Descend the three RDPFs in lock step
  354. node descend(const node &parent, nbits_t parentdepth,
  355. bit_t whichchild, size_t &op_counter) const;
  356. };
  357. // I/O for RDPF Triples
  358. // We never write RDPFTriples over the network, so always write
  359. // the DPF expansions if they're available.
  360. template <typename T>
  361. T& operator<<(T &os, const RDPFTriple &rdpftrip)
  362. {
  363. write_maybe_expanded(os, rdpftrip.dpf[0], true);
  364. write_maybe_expanded(os, rdpftrip.dpf[1], true);
  365. write_maybe_expanded(os, rdpftrip.dpf[2], true);
  366. nbits_t depth = rdpftrip.dpf[0].depth();
  367. os.write((const char *)&rdpftrip.as_target.ashare, BITBYTES(depth));
  368. os.write((const char *)&rdpftrip.xs_target.xshare, BITBYTES(depth));
  369. return os;
  370. }
  371. template <typename T>
  372. T& operator>>(T &is, RDPFTriple &rdpftrip)
  373. {
  374. is >> rdpftrip.dpf[0] >> rdpftrip.dpf[1] >> rdpftrip.dpf[2];
  375. nbits_t depth = rdpftrip.dpf[0].depth();
  376. rdpftrip.as_target.ashare = 0;
  377. is.read((char *)&rdpftrip.as_target.ashare, BITBYTES(depth));
  378. rdpftrip.xs_target.xshare = 0;
  379. is.read((char *)&rdpftrip.xs_target.xshare, BITBYTES(depth));
  380. return is;
  381. }
  382. struct RDPFPair {
  383. // The type of node pairs
  384. using node = std::tuple<DPFnode, DPFnode>;
  385. RDPF dpf[2];
  386. // The depth
  387. inline nbits_t depth() const { return dpf[0].depth(); }
  388. // The seed
  389. inline node get_seed() const {
  390. return std::make_tuple(dpf[0].get_seed(), dpf[1].get_seed());
  391. }
  392. // Do we have a precomputed expansion?
  393. inline bool has_expansion() const {
  394. return dpf[0].expansion.size() > 0;
  395. }
  396. // Get an element of the expansion
  397. inline node get_expansion(address_t index) const {
  398. return std::make_tuple(dpf[0].get_expansion(index),
  399. dpf[1].get_expansion(index));
  400. }
  401. // Descend the two RDPFs in lock step
  402. node descend(const node &parent, nbits_t parentdepth,
  403. bit_t whichchild, size_t &op_counter) const;
  404. };
  405. // I/O for RDPF Pairs
  406. // We never write RDPFPairs over the network, so always write
  407. // the DPF expansions if they're available.
  408. template <typename T>
  409. T& operator<<(T &os, const RDPFPair &rdpfpair)
  410. {
  411. write_maybe_expanded(os, rdpfpair.dpf[0], true);
  412. write_maybe_expanded(os, rdpfpair.dpf[1], true);
  413. return os;
  414. }
  415. template <typename T>
  416. T& operator>>(T &is, RDPFPair &rdpfpair)
  417. {
  418. is >> rdpfpair.dpf[0] >> rdpfpair.dpf[1];
  419. return is;
  420. }
  421. #endif