heap.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. #include <functional>
  2. #include "types.hpp"
  3. #include "duoram.hpp"
  4. #include "cell.hpp"
  5. #include "heap.hpp"
  6. RegAS reconstruct_AS(MPCTIO & tio, yield_t & yield, RegAS AS) {
  7. RegAS peer_AS;
  8. RegAS reconstructed_AS = AS;
  9. if (tio.player() == 1) {
  10. tio.queue_peer( & AS, sizeof(AS));
  11. } else {
  12. RegAS peer_AS;
  13. tio.recv_peer( & peer_AS, sizeof(peer_AS));
  14. reconstructed_AS += peer_AS;
  15. }
  16. yield();
  17. if (tio.player() == 0) {
  18. tio.queue_peer( & AS, sizeof(AS));
  19. } else {
  20. RegAS peer_flag;
  21. tio.recv_peer( & peer_AS, sizeof(peer_AS));
  22. reconstructed_AS += peer_AS;
  23. }
  24. return reconstructed_AS;
  25. }
  26. RegXS reconstruct_XS(MPCTIO & tio, yield_t & yield, RegXS XS) {
  27. RegXS peer_XS;
  28. RegXS reconstructed_XS = XS;
  29. if (tio.player() == 1) {
  30. tio.queue_peer( & XS, sizeof(XS));
  31. } else {
  32. RegXS peer_XS;
  33. tio.recv_peer( & peer_XS, sizeof(peer_XS));
  34. reconstructed_XS ^= peer_XS;
  35. }
  36. yield();
  37. if (tio.player() == 0) {
  38. tio.queue_peer( & XS, sizeof(XS));
  39. } else {
  40. RegXS peer_flag;
  41. tio.recv_peer( & peer_XS, sizeof(peer_XS));
  42. reconstructed_XS ^= peer_XS;
  43. }
  44. return reconstructed_XS;
  45. }
  46. bool reconstruct_flag(MPCTIO & tio, yield_t & yield, RegBS flag) {
  47. RegBS peer_flag;
  48. RegBS reconstructed_flag = flag;
  49. if (tio.player() == 1) {
  50. tio.queue_peer( & flag, sizeof(flag));
  51. } else {
  52. RegBS peer_flag;
  53. tio.recv_peer( & peer_flag, sizeof(peer_flag));
  54. reconstructed_flag ^= peer_flag;
  55. }
  56. yield();
  57. if (tio.player() == 0) {
  58. tio.queue_peer( & flag, sizeof(flag));
  59. } else {
  60. RegBS peer_flag;
  61. tio.recv_peer( & peer_flag, sizeof(peer_flag));
  62. reconstructed_flag ^= peer_flag;
  63. }
  64. return reconstructed_flag.bshare;
  65. }
  66. // The insert protocol works as follows:
  67. // It adds a new element in the last entry of the array
  68. // From the leaf (the element added), compare with its parent (1 oblivious compare)
  69. // If the child is larger, then we do an OSWAP.
  70. int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
  71. auto HeapArray = oram.flat(tio, yield);
  72. num_items++;
  73. std::cout << "num_items = " << num_items << std::endl;
  74. std::cout << "we are adding in: " << std::endl;
  75. reconstruct_AS(tio, yield, val);
  76. val.dump();
  77. yield();
  78. size_t childindex = num_items;
  79. size_t parentindex = childindex / 2;
  80. std::cout << "childindex = " << childindex << std::endl;
  81. std::cout << "parentindex = " << parentindex << std::endl;
  82. HeapArray[num_items] = val;
  83. while (parentindex > 0) {
  84. RegAS sharechild = HeapArray[childindex];
  85. RegAS shareparent = HeapArray[parentindex];
  86. CDPF cdpf = tio.cdpf(yield);
  87. RegAS diff = sharechild - shareparent;
  88. auto[lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
  89. auto lteq = lt ^ eq;
  90. mpc_oswap(tio, yield, sharechild, shareparent, lteq, 64);
  91. HeapArray[childindex] = sharechild;
  92. HeapArray[parentindex] = shareparent;
  93. childindex = parentindex;
  94. parentindex = parentindex / 2;
  95. }
  96. return 1;
  97. }
  98. int MinHeap::verify_heap_property(MPCTIO tio, yield_t & yield) {
  99. std::cout << std::endl << std::endl << "verify_heap_property is being called " << std::endl;
  100. auto HeapArray = oram.flat(tio, yield);
  101. RegAS heapreconstruction[num_items];
  102. for (size_t j = 0; j <= num_items; ++j) {
  103. RegAS tmp = HeapArray[j];
  104. heapreconstruction[j] = reconstruct_AS(tio, yield, tmp);
  105. yield();
  106. // heapreconstruction[j].dump();
  107. // std::cout << std::endl;
  108. }
  109. for (size_t j = 1; j < num_items / 2; ++j) {
  110. if (heapreconstruction[j].ashare > heapreconstruction[2 * j].ashare) {
  111. std::cout << "heap property failure\n\n";
  112. std::cout << "j = " << j << std::endl;
  113. heapreconstruction[j].dump();
  114. std::cout << std::endl;
  115. std::cout << "2*j = " << 2 * j << std::endl;
  116. heapreconstruction[2 * j].dump();
  117. std::cout << std::endl;
  118. }
  119. assert(heapreconstruction[j].ashare <= heapreconstruction[2 * j].ashare);
  120. assert(heapreconstruction[j].ashare <= heapreconstruction[2 * j + 1].ashare);
  121. }
  122. return 1;
  123. }
  124. void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, RegAS leftchild, RegAS rightchild) {
  125. RegAS parent_reconstruction = reconstruct_AS(tio, yield, parent);
  126. yield();
  127. RegAS leftchild_reconstruction = reconstruct_AS(tio, yield, leftchild);
  128. yield();
  129. RegAS rightchild_reconstruction = reconstruct_AS(tio, yield, rightchild);
  130. yield();
  131. assert(parent_reconstruction.ashare <= leftchild_reconstruction.ashare);
  132. assert(parent_reconstruction.ashare <= rightchild_reconstruction.ashare);
  133. }
  134. // Let "x" be the root, and let "y" and "z" be the left and right children
  135. // For an array, we have A[i] = x, A[2i] = y, A[2i + 1] = z.
  136. // We want x \le y, and x \le z.
  137. // The steps are as follows:
  138. // Step 1: compare(y,z); (1st call to to MPC Compare)
  139. // Step 2: smaller = min(y,z); This is done with an mpcselect (1st call to mpcselect)
  140. // Step 3: if(smaller == y) then smallerindex = 2i else smalleindex = 2i + 1;
  141. // Step 4: compare(x,smaller); (2nd call to to MPC Compare)
  142. // Step 5: smallest = min(x, smaller); (2nd call to mpcselect)
  143. // Step 6: otherchild = max(x, smaller)
  144. // Step 7: A[i] \gets smallest (1st Duoam Write)
  145. // Step 8: A[smallerindex] \gets otherchild (2nd Duoam Write)
  146. // Overall restore_heap_property takes 2 MPC Comparisons, 2 MPC Selects, and 2 Duoram Writes
  147. RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
  148. RegAS smallest;
  149. auto HeapArray = oram.flat(tio, yield);
  150. RegAS parent = HeapArray[index];
  151. RegXS leftchildindex = index;
  152. leftchildindex = index << 1;
  153. RegXS rightchildindex;
  154. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  155. RegAS leftchild = HeapArray[leftchildindex];
  156. RegAS rightchild = HeapArray[rightchildindex];
  157. RegAS sum = parent + leftchild + rightchild;
  158. CDPF cdpf = tio.cdpf(yield);
  159. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
  160. RegXS smallerindex;
  161. //mpc_select(tio, yield, smallerindex, lt, rightchildindex, leftchildindex, 64);
  162. smallerindex = leftchildindex ^ lt;
  163. //smallerindex stores either the index of the left or child (whichever has the smaller value)
  164. RegAS smallerchild;
  165. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64);
  166. // the value smallerchild holds smaller of left and right child
  167. RegAS largerchild = sum - parent - smallerchild;
  168. CDPF cdpf0 = tio.cdpf(yield);
  169. auto[lt0, eq0, gt0] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  170. //comparison between the smallerchild and the parent
  171. mpc_select(tio, yield, smallest, lt0, parent, smallerchild, 64);
  172. // smallest holds smaller of left/right child and parent
  173. RegAS otherchild;
  174. //mpc_select(tio, yield, otherchild, gt0, parent, smallerchild, 64);
  175. otherchild = sum - smallest - largerchild;
  176. // otherchild holds max(min(leftchild, rightchild), parent)
  177. HeapArray[index] = smallest;
  178. HeapArray[smallerindex] = otherchild;
  179. //verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  180. return smallerindex;
  181. }
  182. /*
  183. */
  184. RegXS MinHeap::restore_heap_property_optimization1(MPCTIO tio, yield_t & yield, RegXS index) {
  185. RegAS smallest;
  186. auto HeapArray = oram.flat(tio, yield);
  187. RegAS parent = HeapArray[index];
  188. RegXS leftchildindex = index;
  189. leftchildindex = index << 1;
  190. RegXS rightchildindex;
  191. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  192. RegAS leftchild = HeapArray[leftchildindex];
  193. RegAS rightchild = HeapArray[rightchildindex];
  194. RegAS sum = parent + leftchild + rightchild;
  195. CDPF cdpf = tio.cdpf(yield);
  196. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
  197. RegXS smallerindex;
  198. //mpc_select(tio, yield, smallerindex, lt, rightchildindex, leftchildindex, 64);
  199. smallerindex = leftchildindex ^ lt;
  200. //smallerindex stores either the index of the left or child (whichever has the smaller value)
  201. RegAS smallerchild;
  202. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64);
  203. // the value smallerchild holds smaller of left and right child
  204. RegAS largerchild = sum - parent - smallerchild;
  205. CDPF cdpf0 = tio.cdpf(yield);
  206. auto[lt0, eq0, gt0] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  207. //comparison between the smallerchild and the parent
  208. RegBS lt0lt;
  209. mpc_and(tio, yield, lt0lt, lt, lt0);
  210. mpc_select(tio, yield, smallest, lt0, parent, smallerchild, 64);
  211. // smallest holds smaller of left/right child and parent
  212. RegAS otherchild;
  213. //mpc_select(tio, yield, otherchild, gt0, parent, smallerchild, 64);
  214. otherchild = sum - smallest - largerchild;
  215. // otherchild holds max(min(leftchild, rightchild), parent)
  216. HeapArray[index] = smallest;
  217. HeapArray[smallerindex] = otherchild;
  218. //verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  219. return smallerindex;
  220. }
  221. RegXS MinHeap::restore_heap_property_at_root(MPCTIO tio, yield_t & yield) {
  222. size_t index = 1;
  223. auto HeapArray = oram.flat(tio, yield);
  224. RegAS parent = HeapArray[index];
  225. RegAS leftchild = HeapArray[2 * index];
  226. RegAS rightchild = HeapArray[2 * index + 1];
  227. CDPF cdpf = tio.cdpf(yield);
  228. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops()); // c_1 in the paper
  229. RegAS smallerchild;
  230. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64); // smallerchild holds smaller of left and right child
  231. CDPF cdpf0 = tio.cdpf(yield);
  232. auto[lt0, eq0, gt0] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops()); //c_0 in the paper
  233. RegAS smallest;
  234. mpc_select(tio, yield, smallest, lt0, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  235. RegAS larger_p;
  236. mpc_select(tio, yield, larger_p, gt0, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  237. parent = smallest;
  238. leftchild = larger_p;
  239. HeapArray[index] = smallest;
  240. //verify_parent_children_heaps(tio, yield, parent, leftchild, rightchild);
  241. //RegAS smallerindex;
  242. RegXS smallerindex(lt);
  243. uint64_t leftchildindex = (2 * index);
  244. uint64_t rightchildindex = (2 * index) + 1;
  245. // smallerindex2 &= (leftchildindex ^ rightchildindex);
  246. // smallerindex2.xshare ^= leftchildindex;
  247. smallerindex = (RegXS(lt) & leftchildindex) ^ (RegXS(gt) & rightchildindex);
  248. // RegXS smallerindex_reconstruction2 = reconstruct_XS(tio, yield, smallerindex2);
  249. // yield();
  250. HeapArray[smallerindex] = larger_p;
  251. // std::cout << "smallerindex XOR (root) == \n";
  252. // smallerindex_reconstruction2.dump();
  253. // std::cout << "\n\n ---? \n";
  254. return smallerindex;
  255. }
  256. RegAS MinHeap::extract_min(MPCTIO tio, yield_t & yield) {
  257. RegAS minval;
  258. auto HeapArray = oram.flat(tio, yield);
  259. minval = HeapArray[1];
  260. HeapArray[1] = RegAS(HeapArray[num_items]);
  261. num_items--;
  262. RegXS smaller = restore_heap_property_at_root(tio, yield);
  263. std::cout << "num_items = " << num_items << std::endl;
  264. size_t height = std::log2(num_items);
  265. std::cout << "height = " << height << std::endl;
  266. for (size_t i = 0; i < height; ++i) {
  267. smaller = restore_heap_property(tio, yield, smaller);
  268. }
  269. return minval;
  270. }
  271. void Heap(MPCIO & mpcio,
  272. const PRACOptions & opts, char ** args) {
  273. nbits_t depth = atoi(args[0]);
  274. size_t n_inserts = atoi(args[1]);
  275. size_t n_extracts = atoi(args[2]);
  276. std::cout << "print arguements " << std::endl;
  277. std::cout << args[0] << std::endl;
  278. if ( * args) {
  279. depth = atoi( * args);
  280. ++args;
  281. }
  282. size_t items = (size_t(1) << depth) - 1;
  283. if ( * args) {
  284. items = atoi( * args);
  285. ++args;
  286. }
  287. std::cout << "items = " << items << std::endl;
  288. MPCTIO tio(mpcio, 0, opts.num_threads);
  289. run_coroutines(tio, [ & tio, depth, items, n_inserts, n_extracts](yield_t & yield) {
  290. size_t size = size_t(1) << depth;
  291. std::cout << "size = " << size << std::endl;
  292. MinHeap tree(tio.player(), size);
  293. for (size_t j = 0; j < n_inserts; ++j) {
  294. RegAS inserted_val;
  295. inserted_val.randomize(62);
  296. inserted_val.ashare = inserted_val.ashare;
  297. tree.insert(tio, yield, inserted_val);
  298. }
  299. std::cout << std::endl << "=============[Insert Done]================" << std::endl << std::endl;
  300. tree.verify_heap_property(tio, yield);
  301. for (size_t j = 0; j < n_extracts; ++j) {
  302. RegAS minval = tree.extract_min(tio, yield);
  303. tree.verify_heap_property(tio, yield);
  304. RegAS minval_reconstruction = reconstruct_AS(tio, yield, minval);
  305. yield();
  306. std::cout << "minval_reconstruction = ";
  307. minval_reconstruction.dump();
  308. std::cout << std::endl;
  309. }
  310. std::cout << std::endl << "=============[Extract Min Done]================" << std::endl << std::endl;
  311. tree.verify_heap_property(tio, yield);
  312. });
  313. }