heap.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #include <functional>
  2. #include "types.hpp"
  3. #include "duoram.hpp"
  4. #include "cell.hpp"
  5. #include "heap.hpp"
  6. void HEAP::initialize(int num_players, size_t size) {
  7. this->MAX_SIZE = size;
  8. this->num_items = 0;
  9. oram = new Duoram<RegAS>(num_players, size);
  10. }
  11. RegAS reconstruct_AS(MPCTIO &tio, yield_t &yield, RegAS AS) {
  12. RegAS peer_AS;
  13. RegAS reconstructed_AS = AS;
  14. if (tio.player() == 1) {
  15. tio.queue_peer(&AS, sizeof(AS));
  16. } else {
  17. RegAS peer_AS;
  18. tio.recv_peer(&peer_AS, sizeof(peer_AS));
  19. reconstructed_AS += peer_AS;
  20. }
  21. yield();
  22. if (tio.player() == 0) {
  23. tio.queue_peer(&AS, sizeof(AS));
  24. } else {
  25. RegAS peer_flag;
  26. tio.recv_peer(&peer_AS, sizeof(peer_AS));
  27. reconstructed_AS += peer_AS;
  28. }
  29. return reconstructed_AS;
  30. }
  31. RegXS reconstruct_XS(MPCTIO &tio, yield_t &yield, RegXS XS) {
  32. RegXS peer_XS;
  33. RegXS reconstructed_XS = XS;
  34. if (tio.player() == 1) {
  35. tio.queue_peer(&XS, sizeof(XS));
  36. } else {
  37. RegXS peer_XS;
  38. tio.recv_peer(&peer_XS, sizeof(peer_XS));
  39. reconstructed_XS ^= peer_XS;
  40. }
  41. yield();
  42. if (tio.player() == 0) {
  43. tio.queue_peer(&XS, sizeof(XS));
  44. } else {
  45. RegXS peer_flag;
  46. tio.recv_peer(&peer_XS, sizeof(peer_XS));
  47. reconstructed_XS ^= peer_XS;
  48. }
  49. return reconstructed_XS;
  50. }
  51. bool reconstruct_flag(MPCTIO &tio, yield_t &yield, RegBS flag) {
  52. RegBS peer_flag;
  53. RegBS reconstructed_flag = flag;
  54. if (tio.player() == 1) {
  55. tio.queue_peer(&flag, sizeof(flag));
  56. } else {
  57. RegBS peer_flag;
  58. tio.recv_peer(&peer_flag, sizeof(peer_flag));
  59. reconstructed_flag ^= peer_flag;
  60. }
  61. yield();
  62. if (tio.player() == 0) {
  63. tio.queue_peer(&flag, sizeof(flag));
  64. } else {
  65. RegBS peer_flag;
  66. tio.recv_peer(&peer_flag, sizeof(peer_flag));
  67. reconstructed_flag ^= peer_flag;
  68. }
  69. return reconstructed_flag.bshare;
  70. }
  71. int HEAP::insert(MPCTIO tio, yield_t &yield, RegAS val) {
  72. auto HeapArray = oram->flat(tio, yield);
  73. num_items++;
  74. std::cout << "num_items = " << num_items << std::endl;
  75. std::cout << "we are adding in: " << std::endl;
  76. reconstruct_AS(tio, yield, val);
  77. val.dump();
  78. yield();
  79. size_t child = num_items;
  80. size_t parent = child/2;
  81. std::cout << "child = " << child << std::endl;
  82. std::cout << "parent = " << parent << std::endl;
  83. HeapArray[num_items] = val;
  84. RegAS tmp = HeapArray[num_items];
  85. reconstruct_AS(tio, yield, tmp);
  86. yield();
  87. while(parent != 0) {
  88. //std::cout << "while loop\n";
  89. RegAS sharechild = HeapArray[child];
  90. RegAS shareparent = HeapArray[parent];
  91. // RegAS sharechildrec = reconstruct_AS(tio, yield, sharechild);
  92. // yield();
  93. // RegAS shareparentrec = reconstruct_AS(tio, yield, shareparent);
  94. // yield();
  95. // std::cout << "\nchild reconstruct_AS = \n";
  96. // sharechildrec.dump();
  97. // std::cout << "\nparent reconstruct_AS = \n";
  98. // shareparentrec.dump();
  99. // std::cout << "\n----\n";
  100. CDPF cdpf = tio.cdpf(yield);
  101. RegAS diff = sharechild-shareparent;
  102. // std::cout << "diff = " << std::endl;
  103. // RegAS diff_rec = reconstruct_AS(tio, yield, diff);
  104. // diff_rec.dump();
  105. // std::cout << std::endl << std::endl;
  106. auto [lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
  107. //auto lteq = lt ^ eq;
  108. // bool lteq_rec = reconstruct_flag(tio, yield, lteq);
  109. // yield();
  110. // bool lt_rec = reconstruct_flag(tio, yield, lt);
  111. // yield();
  112. // std::cout <<"lt_rec = " << (int) lt_rec << std::endl;
  113. // std::cout << std::endl;
  114. // bool eq_rec = reconstruct_flag(tio, yield, eq);
  115. // yield();
  116. // std::cout <<"eq_rec = " << (int) eq_rec << std::endl;
  117. // std::cout << std::endl;
  118. // yield();
  119. // bool gt_rec = reconstruct_flag(tio, yield, gt);
  120. // yield();
  121. // std::cout <<"gt_rec = " << (int) gt_rec << std::endl;
  122. // std::cout << std::endl;
  123. // if(lteq_rec) {
  124. // if(sharechildrec.ashare > shareparentrec.ashare) {
  125. // std::cout << "\nchild reconstruct_AS = \n";
  126. // sharechildrec.dump();
  127. // std::cout << "\nchild reconstruct_AS = \n";
  128. // shareparentrec.dump();
  129. // std::cout << "\n----\n";
  130. // }
  131. // assert(sharechildrec.ashare <= shareparentrec.ashare);
  132. // }
  133. // if(gt_rec) {
  134. // if(sharechildrec.ashare < shareparentrec.ashare) {
  135. // std::cout << "\nchild reconstruct_AS = \n";
  136. // sharechildrec.dump();
  137. // std::cout << "\nchild reconstruct_AS = \n";
  138. // shareparentrec.dump();
  139. // std::cout << "\n----\n";
  140. // }
  141. // assert(sharechildrec.ashare > shareparentrec.ashare);
  142. // }
  143. // std::cout << "child = " << child << std::endl;
  144. // sharechildrec = reconstruct_AS(tio, yield, sharechild);
  145. // sharechildrec.dump();
  146. // yield();
  147. // std::cout << "parent = " << parent << std::endl;
  148. // shareparentrec = reconstruct_AS(tio, yield, shareparent);
  149. // shareparentrec.dump();
  150. // yield();
  151. // std::cout << "\n^^^ before mpc_oswap\n";
  152. mpc_oswap(tio, yield, sharechild, shareparent, lt, 64);
  153. HeapArray[child] = sharechild;
  154. HeapArray[parent] = shareparent;
  155. // std::cout << "child = " << child << std::endl;
  156. // sharechildrec = reconstruct_AS(tio, yield, sharechild);
  157. // sharechildrec.dump();
  158. // yield();
  159. // std::cout << "parent = " << parent << std::endl;
  160. // shareparentrec = reconstruct_AS(tio, yield, shareparent);
  161. // shareparentrec.dump();
  162. // yield();
  163. // std::cout << "\n^^^after mpc_oswap\n";
  164. // assert(sharechildrec.ashare >= shareparentrec.ashare);
  165. // std::cout << "we asserted that: \n";
  166. // sharechildrec.dump();
  167. // std::cout << std::endl << " < " << std::endl;
  168. // shareparentrec.dump();
  169. // std::cout << "\n ----- \n";
  170. child = parent;
  171. parent = child/2;
  172. }
  173. return 1;
  174. }
  175. int HEAP::verify_heap_property(MPCTIO tio, yield_t &yield) {
  176. std::cout << std::endl << std::endl << "verify_heap_property is being called " << std::endl;
  177. auto HeapArray = oram->flat(tio, yield);
  178. RegAS heapreconstruction[num_items];
  179. for(size_t j = 0; j <= num_items; ++j)
  180. {
  181. RegAS tmp = HeapArray[j];
  182. heapreconstruction[j] = reconstruct_AS(tio, yield, tmp);
  183. yield();
  184. }
  185. for(size_t j = 1; j < num_items/2; ++j)
  186. {
  187. if(heapreconstruction[j].ashare > heapreconstruction[2*j].ashare)
  188. {
  189. std::cout << "heap property failure\n\n";
  190. std::cout << "j = " << j << std::endl;
  191. heapreconstruction[j].dump();
  192. std::cout << std::endl;
  193. std::cout << "2*j = " << 2*j << std::endl;
  194. heapreconstruction[2*j].dump();
  195. std::cout << std::endl;
  196. }
  197. assert(heapreconstruction[j].ashare <= heapreconstruction[2*j].ashare);
  198. }
  199. return 1;
  200. }
  201. void verify_parent_children_heaps(MPCTIO tio, yield_t &yield, RegAS parent, RegAS leftchild, RegAS rightchild)
  202. {
  203. RegAS parent_reconstruction = reconstruct_AS(tio, yield, parent);
  204. yield();
  205. RegAS leftchild_reconstruction = reconstruct_AS(tio, yield, leftchild);
  206. yield();
  207. RegAS rightchild_reconstruction = reconstruct_AS(tio, yield, rightchild);
  208. yield();
  209. assert(parent_reconstruction.ashare <= leftchild_reconstruction.ashare);
  210. assert(parent_reconstruction.ashare <= rightchild_reconstruction.ashare);
  211. }
  212. //
  213. //
  214. //
  215. //
  216. RegXS HEAP::restore_heap_property(MPCTIO tio, yield_t &yield, RegXS index)
  217. {
  218. RegAS smallest;
  219. auto HeapArray = oram->flat(tio, yield);
  220. RegAS parent = HeapArray[index];
  221. RegXS leftchildindex = index;
  222. leftchildindex = index << 1;
  223. RegXS rightchildindex;
  224. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  225. RegAS leftchild = HeapArray[leftchildindex];
  226. RegAS rightchild = HeapArray[rightchildindex];
  227. CDPF cdpf = tio.cdpf(yield);
  228. auto [lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
  229. RegXS smallerindex;
  230. mpc_select(tio, yield, smallerindex, lt, rightchildindex, leftchildindex, 64);
  231. //smallerindex stores either the index of the left or child (whichever has the smaller value)
  232. RegAS smallerchild;
  233. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64);
  234. // the value smallerchild holds smaller of left and right child
  235. CDPF cdpf0 = tio.cdpf(yield);
  236. auto [lt0, eq0, gt0] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  237. //comparison between the smallerchild and the parent
  238. mpc_select(tio, yield, smallest, lt0, parent, smallerchild, 64);
  239. // smallest holds smaller of left/right child and parent
  240. RegAS otherchild;
  241. mpc_select(tio, yield, otherchild, gt0, parent, smallerchild, 64);
  242. // otherchild holds max(min(leftchild, rightchild), parent)
  243. HeapArray[index] = smallest;
  244. HeapArray[smallerindex] = otherchild;
  245. //verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  246. return smallerindex;
  247. }
  248. /*
  249. */
  250. RegXS HEAP::restore_heap_property_at_root(MPCTIO tio, yield_t &yield, size_t index)
  251. {
  252. auto HeapArray = oram->flat(tio, yield);
  253. RegAS parent = HeapArray[index];
  254. RegAS leftchild = HeapArray[2 * index];
  255. RegAS rightchild = HeapArray[2 * index + 1];
  256. CDPF cdpf = tio.cdpf(yield);
  257. auto [lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops()); // c_1 in the paper
  258. RegAS smallerchild;
  259. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64); // smallerchild holds smaller of left and right child
  260. CDPF cdpf0 = tio.cdpf(yield);
  261. auto [lt0, eq0, gt0] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops()); //c_0 in the paper
  262. RegAS smallest;
  263. mpc_select(tio, yield, smallest, lt0, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  264. RegAS larger_p;
  265. mpc_select(tio, yield, larger_p, gt0, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  266. parent = smallest;
  267. leftchild = larger_p;
  268. HeapArray[index] = smallest;
  269. //verify_parent_children_heaps(tio, yield, parent, leftchild, rightchild);
  270. //RegAS smallerindex;
  271. RegXS smallerindex(lt);
  272. uint64_t leftchildindex = (2 * index);
  273. uint64_t rightchildindex = (2 * index) + 1;
  274. // smallerindex2 &= (leftchildindex ^ rightchildindex);
  275. // smallerindex2.xshare ^= leftchildindex;
  276. smallerindex = (RegXS(lt) & leftchildindex) ^ (RegXS(gt) & rightchildindex);
  277. // RegXS smallerindex_reconstruction2 = reconstruct_XS(tio, yield, smallerindex2);
  278. // yield();
  279. HeapArray[smallerindex] = larger_p;
  280. // std::cout << "smallerindex XOR (root) == \n";
  281. // smallerindex_reconstruction2.dump();
  282. // std::cout << "\n\n ---? \n";
  283. return smallerindex;
  284. }
  285. RegAS HEAP::extract_min(MPCTIO tio, yield_t &yield) {
  286. RegAS minval;
  287. auto HeapArray = oram->flat(tio, yield);
  288. minval = HeapArray[1];
  289. HeapArray[1] = RegAS(HeapArray[num_items]);
  290. return minval;
  291. }
  292. void Heap(MPCIO &mpcio, const PRACOptions &opts, char **args)
  293. {
  294. nbits_t depth=3;
  295. if (*args) {
  296. depth = atoi(*args);
  297. ++args;
  298. }
  299. size_t items = (size_t(1)<<depth)-1;
  300. if (*args) {
  301. items = atoi(*args);
  302. ++args;
  303. }
  304. MPCTIO tio(mpcio, 0, opts.num_threads);
  305. run_coroutines(tio, [&tio, depth, items](yield_t &yield) {
  306. size_t size = size_t(1)<<depth;
  307. std::cout << "size = " << size << std::endl;
  308. HEAP tree(tio.player(), size);
  309. for(size_t j = 0; j < 3; ++j)
  310. {
  311. RegAS inserted_val;
  312. inserted_val.randomize();
  313. inserted_val.ashare = inserted_val.ashare/1000;
  314. tree.insert(tio, yield, inserted_val);
  315. }
  316. std::cout << std::endl << "=============[Insert Done]================" << std::endl << std::endl;
  317. //tree.verify_heap_property(tio, yield);
  318. for(size_t j = 0; j < 2; ++j)
  319. {
  320. tree.extract_min(tio, yield);
  321. RegXS smaller = tree.restore_heap_property_at_root(tio, yield, 1);
  322. size_t height = std::log(tree.num_items);
  323. for(size_t i = 0; i < height ; ++i)
  324. {
  325. smaller = tree.restore_heap_property(tio, yield, smaller);
  326. }
  327. }
  328. std::cout << std::endl << "=============[Extract Min Done]================" << std::endl << std::endl;
  329. tree.verify_heap_property(tio, yield);
  330. });
  331. }