heapsampler.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #include "heapsampler.hpp"
  2. // In each 64-bit RegAS in the heap, the top bit (of the reconstructed
  3. // value) is 0, the next 42 bits are the random tag, and the low 21 bits
  4. // are the element value.
  5. #define HEAPSAMPLE_TAG_BITS 42
  6. #define HEAPSAMPLE_ELT_BITS 21
  7. // Make the next random tag
  8. void HeapSampler::make_randtag(MPCTIO &tio, yield_t &yield)
  9. {
  10. // Make a uniformly random HEAPSAMPLE_TAG_BITS-bit tag. This needs
  11. // to be RegXS in order for the sum (XOR) of P0 and P1's independent
  12. // values to be uniform.
  13. RegXS tagx;
  14. tagx.randomize(HEAPSAMPLE_TAG_BITS);
  15. mpc_xs_to_as(tio, yield, randtag, tagx);
  16. }
  17. // Compute the heap size (the smallest power of two strictly greater
  18. // than k) needed to store k elements
  19. static size_t heapsize(size_t k)
  20. {
  21. size_t ret = 1;
  22. while (ret <= k) {
  23. ret <<= 1;
  24. }
  25. return ret;
  26. }
  27. // Return a random bit that reconstructs to 1 with probability k/m
  28. static RegBS weighted_coin(MPCTIO &tio, yield_t &yield, size_t k,
  29. size_t m)
  30. {
  31. RegAS limit;
  32. limit.ashare = size_t((__uint128_t(k)<<63)/m) * !tio.player();
  33. RegXS randxs;
  34. randxs.randomize(63);
  35. RegAS randas;
  36. mpc_xs_to_as(tio, yield, randas, randxs);
  37. CDPF cdpf = tio.cdpf(yield);
  38. auto[lt, eq, gt] = cdpf.compare(tio, yield, randas-limit, tio.aes_ops());
  39. return lt;
  40. }
  41. // Constructor for a HeapSampler that samples k items from a stream
  42. // of abritrary and unknown size, using O(k) memory
  43. HeapSampler::HeapSampler(MPCTIO &tio, yield_t &yield, size_t k)
  44. : k(k), m(0), heap(tio.player(), heapsize(k))
  45. {
  46. run_coroutines(tio, [&tio, this](yield_t &yield) {
  47. heap.init(tio, yield);
  48. }, [&tio, this](yield_t &yield) {
  49. make_randtag(tio, yield);
  50. });
  51. }
  52. // An element has arrived
  53. void HeapSampler::ingest(MPCTIO &tio, yield_t &yield, RegAS elt)
  54. {
  55. ++m;
  56. RegAS tagged_elt = (randtag << HEAPSAMPLE_ELT_BITS) + elt;
  57. RegAS elt_to_insert = tagged_elt;
  58. if (m > k) {
  59. RegAS extracted_elt;
  60. RegBS selection_bit;
  61. run_coroutines(tio, [&tio, this, &extracted_elt](yield_t &yield) {
  62. extracted_elt = heap.extract_min(tio, yield);
  63. }, [&tio, this, &selection_bit](yield_t &yield) {
  64. selection_bit = weighted_coin(tio, yield, k, m);
  65. });
  66. mpc_select(tio, yield, elt_to_insert, selection_bit,
  67. extracted_elt, tagged_elt);
  68. }
  69. run_coroutines(tio, [&tio, this, elt_to_insert](yield_t &yield) {
  70. heap.insert_optimized(tio, yield, elt_to_insert);
  71. }, [&tio, this](yield_t &yield) {
  72. make_randtag(tio, yield);
  73. });
  74. }
  75. // The stream has ended; output min(k,m) randomly sampled elements.
  76. // After calling this function, the HeapSampler is reset to its
  77. // initial m=0 state.
  78. std::vector<RegAS> HeapSampler::close(MPCTIO &tio, yield_t &yield)
  79. {
  80. size_t retsize = k;
  81. if (m < k) {
  82. retsize = m;
  83. }
  84. std::vector<RegAS> ret(retsize);
  85. for (size_t i=0; i<retsize; ++i) {
  86. ret[i] = heap.extract_min(tio, yield);
  87. ret[i] &= ((size_t(1)<<HEAPSAMPLE_ELT_BITS)-1);
  88. }
  89. // Compare each output to (size_t(1)<<HEAPSAMPLE_ELT_BITS), since
  90. // there may be a carry; if the output is greater than or equal to
  91. // that value, fix the carry
  92. RegAS limit;
  93. limit.ashare = (size_t(1)<<HEAPSAMPLE_ELT_BITS) * !tio.player();
  94. std::vector<coro_t> coroutines;
  95. for (size_t i=0; i<retsize; ++i) {
  96. coroutines.emplace_back([&tio, &ret, i, limit](yield_t &yield) {
  97. CDPF cdpf = tio.cdpf(yield);
  98. auto[lt, eq, gt] = cdpf.compare(tio, yield,
  99. ret[i]-limit, tio.aes_ops());
  100. RegAS fix, zero;
  101. mpc_select(tio, yield, fix, gt^eq, zero, limit);
  102. ret[i] -= fix;
  103. });
  104. }
  105. run_coroutines(tio, coroutines);
  106. heap.init(tio, yield);
  107. return ret;
  108. }
  109. void heapsampler_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
  110. {
  111. size_t n = 100;
  112. size_t k = 10;
  113. // The number of elements to stream
  114. if (*args) {
  115. n = atoi(*args);
  116. ++args;
  117. }
  118. // The size of the random sample
  119. if (*args) {
  120. k = atoi(*args);
  121. ++args;
  122. }
  123. MPCTIO tio(mpcio, 0, opts.num_cpu_threads);
  124. run_coroutines(tio, [&mpcio, &tio, n, k] (yield_t &yield) {
  125. std::cout << "\n===== STREAMING =====\n";
  126. HeapSampler sampler(tio, yield, k);
  127. for (size_t i=0; i<n; ++i) {
  128. // For ease of checking, just have the elements be in a
  129. // simple sequence
  130. RegAS elt;
  131. elt.ashare = (i+1) * (1 + 0xfff*tio.player());
  132. sampler.ingest(tio, yield, elt);
  133. }
  134. std::vector<RegAS> sample = sampler.close(tio, yield);
  135. tio.sync_lamport();
  136. mpcio.dump_stats(std::cout);
  137. mpcio.reset_stats();
  138. tio.reset_lamport();
  139. std::cout << "\n===== CHECKING =====\n";
  140. size_t expected_size = k;
  141. if (n < k) {
  142. expected_size = n;
  143. }
  144. assert(sample.size() == expected_size);
  145. std::vector<value_t> reconstructed_sample(expected_size);
  146. std::vector<coro_t> coroutines;
  147. for (size_t i=0; i<expected_size; ++i) {
  148. coroutines.emplace_back(
  149. [&tio, &sample, i, &reconstructed_sample](yield_t &yield) {
  150. reconstructed_sample[i] = mpc_reconstruct(
  151. tio, yield, sample[i]);
  152. });
  153. }
  154. run_coroutines(tio, coroutines);
  155. if (tio.player() == 0) {
  156. for (size_t i=0; i<expected_size; ++i) {
  157. printf("%06lx\n", reconstructed_sample[i]);
  158. }
  159. }
  160. });
  161. }
  162. void weighted_coin_test(MPCIO &mpcio, const PRACOptions &opts, char **args)
  163. {
  164. size_t iters = 100;
  165. size_t m = 100;
  166. size_t k = 10;
  167. // The number of iterations
  168. if (*args) {
  169. iters = atoi(*args);
  170. ++args;
  171. }
  172. // The denominator
  173. if (*args) {
  174. m = atoi(*args);
  175. ++args;
  176. }
  177. // The numerator
  178. if (*args) {
  179. k = atoi(*args);
  180. ++args;
  181. }
  182. MPCTIO tio(mpcio, 0, opts.num_cpu_threads);
  183. run_coroutines(tio, [&mpcio, &tio, iters, m, k] (yield_t &yield) {
  184. size_t heads = 0, tails = 0;
  185. for (size_t i=0; i<iters; ++i) {
  186. RegBS coin = weighted_coin(tio, yield, k, m);
  187. bool coin_rec = mpc_reconstruct(tio, yield, coin);
  188. if (coin_rec) {
  189. heads++;
  190. } else {
  191. tails++;
  192. }
  193. printf("%lu flips %lu heads %lu tails\n", i+1, heads, tails);
  194. }
  195. });
  196. }