heap.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. #include <functional>
  2. #include "types.hpp"
  3. #include "duoram.hpp"
  4. #include "cell.hpp"
  5. #include "rdpf.hpp"
  6. #include "shapes.hpp"
  7. #include "heap.hpp"
  8. // The optimized insertion protocol works as follows
  9. // Do a binary search from the root of the rightmost child
  10. // Binary search returns the index in the path where the new element should go to
  11. // We next do a trickle down.
  12. // Consdier a path P = [a1, a2, a3, a4, ..., a_n]
  13. // Consider a standard-basis vector [0, ... 0, 1 , 0, ... , 0] (at i)
  14. // we get a new path P = [a1, ... , ai,ai ..., a_n]
  15. // P[i] = newelement
  16. int MinHeap::insert_optimized(MPCTIO tio, yield_t & yield, RegAS val) {
  17. auto HeapArray = oram.flat(tio, yield);
  18. num_items++;
  19. std::cout << "num_items = " << num_items << std::endl;
  20. uint64_t height = std::log2(num_items) + 1;
  21. size_t childindex = num_items;
  22. typename Duoram<RegAS>::Path P(HeapArray, tio, yield, childindex);
  23. RegXS foundidx = P.binary_search(val);
  24. std::cout << "height = " << height << std::endl;
  25. RDPF<1> dpf2(tio, yield, foundidx, height, false, false);
  26. RegAS * ones_array = new RegAS[height];
  27. for(size_t j = 0; j < height; ++j)
  28. {
  29. if(tio.player() !=2)
  30. {
  31. RDPF<1>::LeafNode tmp = dpf2.leaf(j, tio.aes_ops());
  32. RegAS bitshare = dpf2.unit_as(tmp);
  33. ones_array[j] = bitshare;
  34. if(j > 0) ones_array[j] = ones_array[j] + ones_array[j-1];
  35. }
  36. }
  37. RegAS * z_array2 = new RegAS[height];
  38. for(size_t j = 0; j < height; ++j) z_array2[j] = P[j];
  39. for(size_t j = 1; j < height; ++j)
  40. {
  41. RegAS z2;
  42. mpc_mul(tio, yield, z2, ones_array[j-1], (z_array2[j-1]-z_array2[j]), 64);
  43. P[j] += z2;
  44. }
  45. typename Duoram<RegAS>::template OblivIndex<RegXS,1> oidx(tio, yield, foundidx, height);
  46. P[oidx] = val;
  47. return 1;
  48. }
  49. // The insert protocol works as follows:
  50. // It adds a new element in the last entry of the array
  51. // From the leaf (the element added), compare with its parent (1 oblivious compare)
  52. // If the child is larger, then we do an OSWAP.
  53. int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
  54. auto HeapArray = oram.flat(tio, yield);
  55. num_items++;
  56. std::cout << "num_items = " << num_items << std::endl;
  57. // uint64_t val_reconstruct = mpc_reconstruct(tio, yield, val);
  58. // std::cout << "val_reconstruct = " << val_reconstruct << std::endl;
  59. size_t childindex = num_items;
  60. size_t parentindex = childindex / 2;
  61. std::cout << "childindex = " << childindex << std::endl;
  62. std::cout << "parentindex = " << parentindex << std::endl;
  63. HeapArray[num_items] = val;
  64. typename Duoram<RegAS>::Path P(HeapArray, tio, yield, childindex);
  65. //RegXS foundidx = P.binary_search(val);
  66. while (parentindex > 0) {
  67. RegAS sharechild = HeapArray[childindex];
  68. RegAS shareparent = HeapArray[parentindex];
  69. CDPF cdpf = tio.cdpf(yield);
  70. RegAS diff = sharechild - shareparent;
  71. auto[lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
  72. auto lteq = lt ^ eq;
  73. mpc_oswap(tio, yield, sharechild, shareparent, lteq, 64);
  74. HeapArray[childindex] = sharechild;
  75. HeapArray[parentindex] = shareparent;
  76. childindex = parentindex;
  77. parentindex = parentindex / 2;
  78. }
  79. return 1;
  80. }
  81. int MinHeap::verify_heap_property(MPCTIO tio, yield_t & yield) {
  82. std::cout << std::endl << std::endl << "verify_heap_property is being called " << std::endl;
  83. auto HeapArray = oram.flat(tio, yield);
  84. uint64_t heapreconstruction[num_items];
  85. for (size_t j = 0; j <= num_items; ++j) {
  86. heapreconstruction[j] = mpc_reconstruct(tio, yield, HeapArray[j]);
  87. }
  88. for (size_t j = 1; j < num_items / 2; ++j) {
  89. if (heapreconstruction[j] > heapreconstruction[2 * j]) {
  90. std::cout << "heap property failure\n\n";
  91. std::cout << "j = " << j << std::endl;
  92. std::cout << heapreconstruction[j] << std::endl;
  93. std::cout << "2*j = " << 2 * j << std::endl;
  94. std::cout << heapreconstruction[2 * j] << std::endl;
  95. }
  96. assert(heapreconstruction[j] <= heapreconstruction[2 * j]);
  97. assert(heapreconstruction[j] <= heapreconstruction[2 * j + 1]);
  98. }
  99. return 1;
  100. }
  101. void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, RegAS leftchild, RegAS rightchild) {
  102. uint64_t parent_reconstruction = mpc_reconstruct(tio, yield, parent);
  103. uint64_t leftchild_reconstruction = mpc_reconstruct(tio, yield, leftchild);
  104. uint64_t rightchild_reconstruction = mpc_reconstruct(tio, yield, rightchild);
  105. assert(parent_reconstruction <= leftchild_reconstruction);
  106. assert(parent_reconstruction <= rightchild_reconstruction);
  107. }
  108. // Let "x" be the root, and let "y" and "z" be the left and right children
  109. // For an array, we have A[i] = x, A[2i] = y, A[2i + 1] = z.
  110. // We want x \le y, and x \le z.
  111. // The steps are as follows:
  112. // Step 1: compare(y,z); (1st call to to MPC Compare)
  113. // Step 2: smaller = min(y,z); This is done with an mpcselect (1st call to mpcselect)
  114. // Step 3: if(smaller == y) then smallerindex = 2i else smalleindex = 2i + 1;
  115. // Step 4: compare(x,smaller); (2nd call to to MPC Compare)
  116. // Step 5: smallest = min(x, smaller); (2nd call to mpcselect)
  117. // Step 6: otherchild = max(x, smaller)
  118. // Step 7: A[i] \gets smallest (1st Duoam Write)
  119. // Step 8: A[smallerindex] \gets otherchild (2nd Duoam Write)
  120. // Overall restore_heap_property takes 2 MPC Comparisons, 2 MPC Selects, and 2 Duoram Writes
  121. RegXS MinHeap::restore_heap_property(MPCTIO tio, yield_t & yield, RegXS index) {
  122. RegAS smallest;
  123. auto HeapArray = oram.flat(tio, yield);
  124. RegAS parent = HeapArray[index];
  125. RegXS leftchildindex = index;
  126. leftchildindex = index << 1;
  127. RegXS rightchildindex;
  128. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  129. RegAS leftchild = HeapArray[leftchildindex];
  130. RegAS rightchild = HeapArray[rightchildindex];
  131. RegAS sum = parent + leftchild + rightchild;
  132. CDPF cdpf = tio.cdpf(yield);
  133. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
  134. RegXS smallerindex;
  135. //mpc_select(tio, yield, smallerindex, lt, rightchildindex, leftchildindex, 64);
  136. smallerindex = leftchildindex ^ lt;
  137. //smallerindex stores either the index of the left or child (whichever has the smaller value)
  138. RegAS smallerchild;
  139. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64);
  140. // the value smallerchild holds smaller of left and right child
  141. RegAS largerchild = sum - parent - smallerchild;
  142. CDPF cdpf0 = tio.cdpf(yield);
  143. auto[lt1, eq1, gt1] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  144. //comparison between the smallerchild and the parent
  145. mpc_select(tio, yield, smallest, lt1, parent, smallerchild, 64);
  146. // smallest holds smaller of left/right child and parent
  147. RegAS otherchild;
  148. //mpc_select(tio, yield, otherchild, gt1, parent, smallerchild, 64);
  149. otherchild = sum - smallest - largerchild;
  150. // otherchild holds max(min(leftchild, rightchild), parent)
  151. HeapArray[index] = smallest;
  152. HeapArray[smallerindex] = otherchild;
  153. verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  154. return smallerindex;
  155. }
  156. /*
  157. */
  158. auto MinHeap::restore_heap_property_optimized(MPCTIO tio, yield_t & yield, RegXS index, size_t layer, size_t depth, typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > (oidx)) {
  159. auto HeapArray = oram.flat(tio, yield);
  160. RegXS leftchildindex = index;
  161. leftchildindex = index << 1;
  162. RegXS rightchildindex;
  163. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  164. typename Duoram < RegAS > ::Flat P(HeapArray, tio, yield, 1 << layer, 1 << layer);
  165. typename Duoram < RegAS > ::Flat C(HeapArray, tio, yield, 2 << layer, 2 << layer);
  166. typename Duoram < RegAS > ::Stride L(C, tio, yield, 0, 2);
  167. typename Duoram < RegAS > ::Stride R(C, tio, yield, 1, 2);
  168. RegAS parent_tmp = P[oidx];
  169. RegAS leftchild_tmp = L[oidx];
  170. RegAS rightchild_tmp = R[oidx];
  171. RegAS sum = parent_tmp + leftchild_tmp + rightchild_tmp;
  172. CDPF cdpf = tio.cdpf(yield);
  173. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild_tmp - rightchild_tmp, tio.aes_ops());
  174. auto lteq = lt ^ eq;
  175. RegXS smallerindex;
  176. mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);
  177. RegAS smallerchild;
  178. mpc_select(tio, yield, smallerchild, lt, rightchild_tmp, leftchild_tmp, 64);
  179. CDPF cdpf0 = tio.cdpf(yield);
  180. auto[lt1, eq1, gt1] = cdpf0.compare(tio, yield, smallerchild - parent_tmp, tio.aes_ops());
  181. RegAS z;
  182. mpc_flagmult(tio, yield, z, lt1, smallerchild - parent_tmp, 64);
  183. P[oidx] += z;
  184. auto lt1eq1 = lt1 ^ eq1;
  185. RegBS ltlt1;
  186. mpc_and(tio, yield, ltlt1, lteq, lt1eq1);
  187. RegAS zz;
  188. mpc_flagmult(tio, yield, zz, ltlt1, (parent_tmp - leftchild_tmp), 64);
  189. L[oidx] += zz;// - leftchild_tmp;
  190. RegAS leftchildplusparent = RegAS(HeapArray[index]) + RegAS(HeapArray[leftchildindex]);
  191. RegAS tmp = (sum - leftchildplusparent);
  192. R[oidx] += tmp - rightchild_tmp;
  193. return std::make_pair(smallerindex, gt);
  194. }
  195. void MinHeap::initialize(MPCTIO tio, yield_t & yield) {
  196. auto HeapArray = oram.flat(tio, yield);
  197. HeapArray.init(0x7fffffffffffff);
  198. }
  199. void MinHeap::print_heap(MPCTIO tio, yield_t & yield) {
  200. auto HeapArray = oram.flat(tio, yield);
  201. for(size_t j = 0; j < num_items; ++j)
  202. {
  203. uint64_t Pjreconstruction = mpc_reconstruct(tio, yield, HeapArray[j]);
  204. std::cout << j << "-->> HeapArray[" << j << "] = " << std::hex << Pjreconstruction << std::endl;
  205. }
  206. }
  207. auto MinHeap::restore_heap_property_at_root(MPCTIO tio, yield_t & yield) {
  208. size_t index = 1;
  209. auto HeapArray = oram.flat(tio, yield);
  210. RegAS parent = HeapArray[index];
  211. RegAS leftchild = HeapArray[2 * index];
  212. RegAS rightchild = HeapArray[2 * index + 1];
  213. CDPF cdpf = tio.cdpf(yield);
  214. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops()); // c_1 in the paper
  215. RegAS smallerchild;
  216. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64); // smallerchild holds smaller of left and right child
  217. CDPF cdpf0 = tio.cdpf(yield);
  218. auto[lt1, eq1, gt1] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops()); //c_0 in the paper
  219. RegAS smallest;
  220. mpc_select(tio, yield, smallest, lt1, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  221. RegAS larger_p;
  222. mpc_select(tio, yield, larger_p, gt1, parent, smallerchild, 64); // smallest holds smaller of left/right child and parent
  223. parent = smallest;
  224. leftchild = larger_p;
  225. HeapArray[index] = smallest;
  226. RegXS smallerindex(lt);
  227. uint64_t leftchildindex = (2 * index);
  228. uint64_t rightchildindex = (2 * index) + 1;
  229. auto lteq = lt^eq;
  230. smallerindex = (RegXS(lteq) & leftchildindex) ^ (RegXS(gt) & rightchildindex);
  231. HeapArray[smallerindex] = larger_p;
  232. return std::make_pair(smallerindex, gt);
  233. }
  234. RegAS MinHeap::extract_min(MPCTIO tio, yield_t & yield, int is_optimized) {
  235. RegAS minval;
  236. auto HeapArray = oram.flat(tio, yield);
  237. minval = HeapArray[1];
  238. HeapArray[1] = RegAS(HeapArray[num_items]);
  239. num_items--;
  240. auto outroot = restore_heap_property_at_root(tio, yield);
  241. RegXS smaller = outroot.first;
  242. // uint64_t smaller_rec = mpc_reconstruct(tio, yield, smaller, 64);
  243. // std::cout << "smaller_rec [root] = " << smaller_rec << std::endl;
  244. std::cout << "num_items = " << num_items << std::endl;
  245. size_t height = std::log2(num_items);
  246. std::cout << "height = " << height << std::endl << "===================" << std::endl;
  247. typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > oidx(tio, yield, height);
  248. oidx.incr(outroot.second);
  249. for (size_t i = 0; i < height; ++i) {
  250. std::cout << "i = " << i << std::endl;
  251. // typename Duoram<RegAS>::template OblivIndex<RegXS,3> oidx(tio, yield, i+1);
  252. if(is_optimized > 0)
  253. {
  254. auto out = restore_heap_property_optimized(tio, yield, smaller, i + 1, height, typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > (oidx));;
  255. smaller = out.first;
  256. oidx.incr(out.second);
  257. }
  258. if(is_optimized == 0)
  259. {
  260. smaller = restore_heap_property(tio, yield, smaller);
  261. }
  262. std::cout << "\n-------\n\n";
  263. }
  264. return minval;
  265. }
  266. void Heap(MPCIO & mpcio,
  267. const PRACOptions & opts, char ** args) {
  268. nbits_t depth = atoi(args[0]);
  269. size_t n_inserts = atoi(args[1]);
  270. size_t n_extracts = atoi(args[2]);
  271. int is_optimized = atoi(args[3]);
  272. std::cout << "print arguements " << std::endl;
  273. std::cout << args[0] << std::endl;
  274. if ( * args) {
  275. depth = atoi( * args);
  276. ++args;
  277. }
  278. size_t items = (size_t(1) << depth) - 1;
  279. if ( * args) {
  280. items = atoi( * args);
  281. ++args;
  282. }
  283. std::cout << "items = " << items << std::endl;
  284. MPCTIO tio(mpcio, 0, opts.num_threads);
  285. run_coroutines(tio, [ & tio, depth, items, n_inserts, n_extracts, is_optimized](yield_t & yield) {
  286. size_t size = size_t(1) << depth;
  287. std::cout << "size = " << size << std::endl;
  288. MinHeap tree(tio.player(), size);
  289. tree.initialize(tio, yield);
  290. for (size_t j = 0; j < n_inserts; ++j) {
  291. RegAS inserted_val;
  292. inserted_val.randomize(40);
  293. inserted_val.ashare = inserted_val.ashare;
  294. if(is_optimized > 0) tree.insert_optimized(tio, yield, inserted_val);
  295. if(is_optimized == 0) tree.insert(tio, yield, inserted_val);
  296. //tree.print_heap(tio, yield);
  297. }
  298. std::cout << std::endl << "=============[Insert Done]================ " << std::endl << std::endl;
  299. //tree.verify_heap_property(tio, yield);
  300. for (size_t j = 0; j < n_extracts; ++j) {
  301. std::cout << "j = " << j << std::endl;
  302. tree.extract_min(tio, yield, is_optimized);
  303. //RegAS minval = tree.extract_min(tio, yield, is_optimized);
  304. // uint64_t minval_reconstruction = mpc_reconstruct(tio, yield, minval, 64);
  305. // std::cout << "minval_reconstruction = " << minval_reconstruction << std::endl;
  306. // tree.verify_heap_property(tio, yield);
  307. // tree.print_heap(tio, yield);
  308. }
  309. std::cout << std::endl << "=============[Extract Min Done]================" << std::endl << std::endl;
  310. // tree.verify_heap_property(tio, yield);
  311. });
  312. }