heap.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. #include <functional>
  2. #include "types.hpp"
  3. #include "duoram.hpp"
  4. #include "cell.hpp"
  5. #include "rdpf.hpp"
  6. #include "shapes.hpp"
  7. #include "heap.hpp"
  8. // The optimized insertion protocol works as follows
  9. // Do a binary search from the root of the rightmost child
  10. // Binary search returns the index in the path where the new element should go to
  11. // We next do a trickle down.
  12. // Consdier a path P = [a1, a2, a3, a4, ..., a_n]
  13. // Consider a standard-basis vector [0, ... 0, 1 , 0, ... , 0] (at i)
  14. // we get a new path P = [a1, ... , ai,ai ..., a_n]
  15. // P[i] = newelement
  16. int MinHeap::insert_optimized(MPCTIO tio, yield_t & yield, RegAS val) {
  17. auto HeapArray = oram.flat(tio, yield);
  18. num_items++;
  19. //std::cout << "num_items = " << num_items << std::endl;
  20. uint64_t height = std::ceil(std::log2(num_items)) + 1;
  21. //std::cout << "height = " << height << std::endl;
  22. size_t childindex = num_items;
  23. RegAS zero;
  24. zero.ashare = 0;
  25. HeapArray[childindex] = zero;
  26. typename Duoram<RegAS>::Path P(HeapArray, tio, yield, childindex);
  27. const RegXS foundidx = P.binary_search(val);
  28. uint64_t logheight = std::ceil(double(std::log2(height))) + 1;
  29. // std::cout << "logheight = " << logheight << std::endl;
  30. // RDPF<1> dpf2(tio, yield, foundidx, logheight, false, false);
  31. // RegBS * flags_array = new RegBS[height];
  32. std::vector<RegBS> standard_basis_vector(height+1);
  33. typename Duoram<RegAS>::template OblivIndex<RegXS,1> oidx(tio, yield, foundidx, logheight);
  34. auto flags_array = oidx.unit_vector(tio, yield, 1 << logheight, foundidx);
  35. // for(size_t j = 0; j < height; ++j)
  36. // {
  37. // uint64_t reconstruction = mpc_reconstruct(tio, yield, new_stand[j], 64);
  38. // if(reconstruction != 0) std::cout << j << " --->> reconstruction from OblivIndex [new_stand] = " << reconstruction << std::endl;
  39. // }
  40. for(size_t j = 0; j < height; ++j)
  41. {
  42. if(tio.player() !=2)
  43. {
  44. // RDPF<1>::LeafNode leafval = dpf2.leaf(j, tio.aes_ops());
  45. // flags_array[j] = dpf2.unit_bs(leafval);
  46. standard_basis_vector[j] = flags_array[j];
  47. if(j > 0) flags_array[j] = flags_array[j] ^ flags_array[j-1];
  48. }
  49. }
  50. // // #ifdef VERBOSE
  51. // for(size_t j = 0; j < height; ++j)
  52. // {
  53. // uint64_t reconstruction = mpc_reconstruct(tio, yield, standard_basis_vector[j], 64);
  54. // if(reconstruction != 0) std::cout << j << " --->> reconstruction Explicitly Calling the DPF [standard_basis_vector] = " << reconstruction << std::endl;
  55. // }
  56. // // #endif
  57. RegAS * z_array2 = new RegAS[height];
  58. RegAS * z2_tmp = new RegAS[height];
  59. RegAS * standard_basis_vector_time_value = new RegAS[height];
  60. for(size_t j = 0; j < height; ++j) z_array2[j] = P[j];
  61. //print_heap(tio, yield);
  62. std::vector<coro_t> coroutines;
  63. for(size_t j = 1; j < height; ++j)
  64. {
  65. coroutines.emplace_back(
  66. [&tio, z2_tmp, flags_array, z_array2, j](yield_t &yield) {
  67. mpc_flagmult(tio, yield, z2_tmp[j], flags_array[j-1], (z_array2[j-1]-z_array2[j]), 64);
  68. });
  69. coroutines.emplace_back(
  70. [&tio, standard_basis_vector_time_value, standard_basis_vector, val, z_array2, j](yield_t &yield) {
  71. mpc_flagmult(tio, yield, standard_basis_vector_time_value[j-1], standard_basis_vector[j-1], (val - z_array2[j-1]) , 64);
  72. });
  73. }
  74. run_coroutines(tio, coroutines);
  75. // //#ifdef VERBOSE
  76. // for(size_t j = 0; j < height; ++j)
  77. // {
  78. // int64_t reconstruction = mpc_reconstruct(tio, yield, z2_tmp[j], 64);
  79. // std::cout << j << " --->> reconstruction [z2_tmp] = " << reconstruction << std::endl;
  80. // }
  81. // std::cout << std::endl << " =============== " << std::endl;
  82. // for(size_t j = 0; j < height; ++j)
  83. // {
  84. // int64_t reconstruction = mpc_reconstruct(tio, yield, flags_array[j], 64);
  85. // std::cout << j << " --->> reconstruction [flags_array] = " << reconstruction << std::endl;
  86. // }
  87. // std::cout << std::endl << " =============== " << std::endl;
  88. // for(size_t j = 0; j < height; ++j)
  89. // {
  90. // int64_t reconstruction = mpc_reconstruct(tio, yield, standard_basis_vector[j], 64);
  91. // std::cout << j << " --->> reconstruction [standard_basis_vector] = " << reconstruction << std::endl;
  92. // }
  93. // std::cout << std::endl << " =============== " << std::endl;
  94. // for(size_t j = 0; j < height; ++j)
  95. // {
  96. // int64_t reconstruction = mpc_reconstruct(tio, yield, z_array2[j], 64);
  97. // std::cout << j << " --->> reconstruction [z_array2] = " << reconstruction << std::endl;
  98. // }
  99. // //#endif
  100. // for(size_t j = 0; j < height; ++j) P[j] += (z2_tmp[j] + standard_basis_vector_time_value[j]);
  101. // //#ifdef VERBOSE
  102. // std::cout << std::endl << " =============== " << std::endl;
  103. // for(size_t j = 0; j < height; ++j)
  104. // {
  105. // int64_t reconstruction = mpc_reconstruct(tio, yield, P[j], 64);
  106. // std::cout << j << " --->> reconstruction [P] = " << reconstruction << std::endl;
  107. // }
  108. // print_heap(tio, yield);
  109. // //#endif
  110. // for(size_t j = 1; j < height; ++j) P[j] += z2_tmp[j];
  111. // typename Duoram<RegAS>::template OblivIndex<RegXS,1> oidx(tio, yield, foundidx, height);
  112. //P[oidx] = val;
  113. return 1;
  114. }
  115. // The insert protocol works as follows:
  116. // It adds a new element in the last entry of the array
  117. // From the leaf (the element added), compare with its parent (1 oblivious compare)
  118. // If the child is larger, then we do an OSWAP.
  119. int MinHeap::insert(MPCTIO tio, yield_t & yield, RegAS val) {
  120. auto HeapArray = oram.flat(tio, yield);
  121. num_items++;
  122. //std::cout << "num_items = " << num_items << std::endl;
  123. // uint64_t val_reconstruct = mpc_reconstruct(tio, yield, val);
  124. // std::cout << "val_reconstruct = " << val_reconstruct << std::endl;
  125. size_t childindex = num_items;
  126. size_t parentindex = childindex / 2;
  127. #ifdef VERBOSE
  128. std::cout << "childindex = " << childindex << std::endl;
  129. std::cout << "parentindex = " << parentindex << std::endl;
  130. #endif
  131. HeapArray[num_items] = val;
  132. typename Duoram<RegAS>::Path P(HeapArray, tio, yield, childindex);
  133. //RegXS foundidx = P.binary_search(val);
  134. while (parentindex > 0) {
  135. RegAS sharechild = HeapArray[childindex];
  136. RegAS shareparent = HeapArray[parentindex];
  137. CDPF cdpf = tio.cdpf(yield);
  138. RegAS diff = sharechild - shareparent;
  139. auto[lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
  140. auto lteq = lt ^ eq;
  141. mpc_oswap(tio, yield, sharechild, shareparent, lteq, 64);
  142. HeapArray[childindex] = sharechild;
  143. HeapArray[parentindex] = shareparent;
  144. childindex = parentindex;
  145. parentindex = parentindex / 2;
  146. }
  147. return 1;
  148. }
  149. int MinHeap::verify_heap_property(MPCTIO tio, yield_t & yield) {
  150. std::cout << std::endl << std::endl << "verify_heap_property is being called " << std::endl;
  151. auto HeapArray = oram.flat(tio, yield);
  152. uint64_t heapreconstruction[num_items];
  153. for (size_t j = 0; j <= num_items; ++j) {
  154. heapreconstruction[j] = mpc_reconstruct(tio, yield, HeapArray[j]);
  155. }
  156. for (size_t j = 1; j < num_items / 2; ++j) {
  157. if (heapreconstruction[j] > heapreconstruction[2 * j]) {
  158. std::cout << "heap property failure\n\n";
  159. std::cout << "j = " << j << std::endl;
  160. std::cout << heapreconstruction[j] << std::endl;
  161. std::cout << "2*j = " << 2 * j << std::endl;
  162. std::cout << heapreconstruction[2 * j] << std::endl;
  163. }
  164. if (heapreconstruction[j] > heapreconstruction[2 * j + 1]) {
  165. std::cout << "heap property failure\n\n";
  166. std::cout << "j = " << j << std::endl;
  167. std::cout << heapreconstruction[j] << std::endl;
  168. std::cout << "2*j + 1 = " << 2 * j + 1<< std::endl;
  169. std::cout << heapreconstruction[2 * j + 1] << std::endl;
  170. }
  171. //assert(heapreconstruction[j] <= heapreconstruction[2 * j]);
  172. //assert(heapreconstruction[j] <= heapreconstruction[2 * j + 1]);
  173. }
  174. return 1;
  175. }
  176. void verify_parent_children_heaps(MPCTIO tio, yield_t & yield, RegAS parent, RegAS leftchild, RegAS rightchild) {
  177. std::cout << "calling this ... \n";
  178. uint64_t parent_reconstruction = mpc_reconstruct(tio, yield, parent);
  179. uint64_t leftchild_reconstruction = mpc_reconstruct(tio, yield, leftchild);
  180. uint64_t rightchild_reconstruction = mpc_reconstruct(tio, yield, rightchild);
  181. assert(parent_reconstruction <= leftchild_reconstruction);
  182. assert(parent_reconstruction <= rightchild_reconstruction);
  183. }
  184. RegXS MinHeap::restore_heap_property(MPCIO & mpcio, MPCTIO tio, yield_t & yield, RegXS index) {
  185. RegAS smallest;
  186. auto HeapArray = oram.flat(tio, yield);
  187. mpcio.reset_stats();
  188. tio.reset_lamport();
  189. RegXS leftchildindex = index;
  190. leftchildindex = index << 1;
  191. RegXS rightchildindex;
  192. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  193. RegAS parent; // = HeapArray[index];
  194. RegAS leftchild; // = HeapArray[leftchildindex];
  195. RegAS rightchild; // = HeapArray[rightchildindex];
  196. std::time_t currentTime = std::time(nullptr);
  197. std::string timeString = std::ctime(&currentTime);
  198. std::cout << "Current time (before read): " << timeString;
  199. std::vector<coro_t> coroutines_read;
  200. coroutines_read.emplace_back(
  201. [&tio, &parent, &HeapArray, index](yield_t &yield) {
  202. auto Acoro = HeapArray.context(yield);
  203. parent = Acoro[index]; //inserted_val;
  204. });
  205. coroutines_read.emplace_back(
  206. [&tio, &HeapArray, &leftchild, leftchildindex](yield_t &yield) {
  207. auto Acoro = HeapArray.context(yield);
  208. leftchild = Acoro[leftchildindex]; //inserted_val;
  209. });
  210. coroutines_read.emplace_back(
  211. [&tio, &rightchild, &HeapArray, rightchildindex](yield_t &yield) {
  212. auto Acoro = HeapArray.context(yield);
  213. rightchild = Acoro[rightchildindex];
  214. });
  215. run_coroutines(tio, coroutines_read);
  216. std::cout << "=========== READS DONE =========== \n";
  217. currentTime = std::time(nullptr);
  218. timeString = std::ctime(&currentTime);
  219. std::cout << "Current time (after read): " << timeString;
  220. tio.sync_lamport();
  221. mpcio.dump_stats(std::cout);
  222. //RegAS sum = parent + leftchild + rightchild;
  223. currentTime = std::time(nullptr);
  224. timeString = std::ctime(&currentTime);
  225. std::cout << "Current time (before compare): " << timeString;
  226. CDPF cdpf = tio.cdpf(yield);
  227. auto[lt_c, eq_c, gt_c] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops());
  228. currentTime = std::time(nullptr);
  229. timeString = std::ctime(&currentTime);
  230. std::cout << "Current time (after compare): " << timeString;
  231. auto lteq = lt_c ^ eq_c;
  232. RegXS smallerindex;
  233. RegAS smallerchild;
  234. #ifdef VERBOSE
  235. uint64_t LC_rec = mpc_reconstruct(tio, yield, leftchildindex);
  236. std::cout << "LC_rec = " << LC_rec << std::endl;
  237. #endif
  238. std::cout << "=========== Compare DONE =========== \n";
  239. tio.sync_lamport();
  240. mpcio.dump_stats(std::cout);
  241. // mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);
  242. // mpc_select(tio, yield, smallerchild, lt_c, rightchild, leftchild, 64);
  243. currentTime = std::time(nullptr);
  244. timeString = std::ctime(&currentTime);
  245. std::cout << "Current time (before mpc_select): " << timeString;
  246. run_coroutines(tio, [&tio, &smallerindex, lteq, rightchildindex, leftchildindex](yield_t &yield)
  247. { mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);},
  248. [&tio, &smallerchild, lteq, rightchild, leftchild](yield_t &yield)
  249. { mpc_select(tio, yield, smallerchild, lteq, rightchild, leftchild, 64);});
  250. currentTime = std::time(nullptr);
  251. timeString = std::ctime(&currentTime);
  252. std::cout << "Current time (after mpc_select): " << timeString;
  253. #ifdef VERBOSE
  254. uint64_t smallerindex_rec = mpc_reconstruct(tio, yield, smallerindex);
  255. std::cout << "smallerindex_rec = " << smallerindex_rec << std::endl;
  256. #endif
  257. std::cout << "=========== mpc_select DONE =========== \n";
  258. tio.sync_lamport();
  259. mpcio.dump_stats(std::cout);
  260. currentTime = std::time(nullptr);
  261. timeString = std::ctime(&currentTime);
  262. std::cout << "Current time (before compare): " << timeString;
  263. CDPF cdpf0 = tio.cdpf(yield);
  264. auto[lt_p, eq_p, gt_p] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  265. currentTime = std::time(nullptr);
  266. timeString = std::ctime(&currentTime);
  267. std::cout << "Current time (after compare): " << timeString;
  268. std::cout << "=========== Compare DONE =========== \n";
  269. tio.sync_lamport();
  270. mpcio.dump_stats(std::cout);
  271. auto lt_p_eq_p = lt_p ^ eq_p;
  272. RegBS ltlt1;
  273. currentTime = std::time(nullptr);
  274. timeString = std::ctime(&currentTime);
  275. std::cout << "Current time (before mpc_and): " << timeString;
  276. mpc_and(tio, yield, ltlt1, lteq, lt_p_eq_p);
  277. currentTime = std::time(nullptr);
  278. timeString = std::ctime(&currentTime);
  279. std::cout << "Current time (after mpc_and): " << timeString;
  280. std::cout << "=========== mpc_and DONE =========== \n";
  281. tio.sync_lamport();
  282. mpcio.dump_stats(std::cout);
  283. RegAS update_index_by, update_leftindex_by;
  284. currentTime = std::time(nullptr);
  285. timeString = std::ctime(&currentTime);
  286. std::cout << "Current time (before mpc_flagmult): " << timeString;
  287. run_coroutines(tio, [&tio, &update_leftindex_by, ltlt1, parent, leftchild](yield_t &yield)
  288. { mpc_flagmult(tio, yield, update_leftindex_by, ltlt1, (parent - leftchild), 64);},
  289. [&tio, &update_index_by, lt_p, parent, smallerchild](yield_t &yield)
  290. {mpc_flagmult(tio, yield, update_index_by, lt_p, smallerchild - parent, 64);}
  291. );
  292. std::cout << "=========== flag mults =========== \n";
  293. currentTime = std::time(nullptr);
  294. timeString = std::ctime(&currentTime);
  295. std::cout << "Current time (after mpc_flagmult): " << timeString;
  296. tio.sync_lamport();
  297. mpcio.dump_stats(std::cout);
  298. std::vector<coro_t> coroutines;
  299. // HeapArray[index] += update_index_by;
  300. // HeapArray[leftchildindex] += update_leftindex_by;
  301. // HeapArray[rightchildindex] += -(update_index_by + update_leftindex_by);
  302. currentTime = std::time(nullptr);
  303. timeString = std::ctime(&currentTime);
  304. std::cout << "Current time (before updates): " << timeString;
  305. coroutines.emplace_back(
  306. [&tio, &HeapArray, index, update_index_by](yield_t &yield) {
  307. auto Acoro = HeapArray.context(yield);
  308. Acoro[index] += update_index_by; //inserted_val;
  309. });
  310. coroutines.emplace_back(
  311. [&tio, &HeapArray, leftchildindex, update_leftindex_by](yield_t &yield) {
  312. auto Acoro = HeapArray.context(yield);
  313. Acoro[leftchildindex] += update_leftindex_by; //inserted_val;
  314. });
  315. coroutines.emplace_back(
  316. [&tio, &HeapArray, rightchildindex, update_index_by, update_leftindex_by](yield_t &yield) {
  317. auto Acoro = HeapArray.context(yield);
  318. Acoro[rightchildindex] += -(update_index_by + update_leftindex_by);
  319. });
  320. run_coroutines(tio, coroutines);
  321. currentTime = std::time(nullptr);
  322. timeString = std::ctime(&currentTime);
  323. std::cout << "Current time (after updates): " << timeString;
  324. std::cout << "=========== updates done =========== \n";
  325. tio.sync_lamport();
  326. mpcio.dump_stats(std::cout);
  327. // verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  328. return smallerindex;
  329. }
  330. auto MinHeap::restore_heap_property_optimized(MPCTIO tio, yield_t & yield, RegXS index, size_t layer, size_t depth, typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > (oidx)) {
  331. auto HeapArray = oram.flat(tio, yield);
  332. RegXS leftchildindex = index;
  333. leftchildindex = index << 1;
  334. RegXS rightchildindex;
  335. rightchildindex.xshare = leftchildindex.xshare ^ (tio.player());
  336. typename Duoram < RegAS > ::Flat P(HeapArray, tio, yield, 1 << layer, 1 << layer);
  337. typename Duoram < RegAS > ::Flat C(HeapArray, tio, yield, 2 << layer, 2 << layer);
  338. typename Duoram < RegAS > ::Stride L(C, tio, yield, 0, 2);
  339. typename Duoram < RegAS > ::Stride R(C, tio, yield, 1, 2);
  340. // RegAS parent_tmp = P[oidx];
  341. // RegAS leftchild_tmp = L[oidx];
  342. // RegAS rightchild_tmp = R[oidx];
  343. RegAS parent_tmp; // = HeapArray[index];
  344. RegAS leftchild_tmp; // = HeapArray[leftchildindex];
  345. RegAS rightchild_tmp; // = HeapArray[rightchildindex];
  346. std::vector<coro_t> coroutines_read;
  347. coroutines_read.emplace_back(
  348. [&tio, &parent_tmp, &P, &oidx](yield_t &yield) {
  349. // auto Acoro = P.context(yield);
  350. parent_tmp = P[oidx]; //inserted_val;
  351. });
  352. coroutines_read.emplace_back(
  353. [&tio, &L, &leftchild_tmp, &oidx](yield_t &yield) {
  354. //auto Acoro = L.context(yield);
  355. leftchild_tmp = L[oidx]; //inserted_val;
  356. });
  357. coroutines_read.emplace_back(
  358. [&tio, &R, &rightchild_tmp, &oidx](yield_t &yield) {
  359. // auto Acoro = R.context(yield);
  360. rightchild_tmp = R[oidx];
  361. });
  362. run_coroutines(tio, coroutines_read);
  363. //RegAS sum = parent_tmp + leftchild_tmp + rightchild_tmp;
  364. CDPF cdpf = tio.cdpf(yield);
  365. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild_tmp - rightchild_tmp, tio.aes_ops());
  366. auto lteq = lt ^ eq;
  367. RegXS smallerindex;
  368. RegAS smallerchild;
  369. // mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);
  370. // mpc_select(tio, yield, smallerchild, lt, rightchild_tmp, leftchild_tmp, 64);
  371. run_coroutines(tio, [&tio, &smallerindex, lteq, rightchildindex, leftchildindex](yield_t &yield)
  372. { mpc_select(tio, yield, smallerindex, lteq, rightchildindex, leftchildindex, 64);;},
  373. [&tio, &smallerchild, lt, rightchild_tmp, leftchild_tmp](yield_t &yield)
  374. { mpc_select(tio, yield, smallerchild, lt, rightchild_tmp, leftchild_tmp, 64);;});
  375. CDPF cdpf0 = tio.cdpf(yield);
  376. auto[lt1, eq1, gt1] = cdpf0.compare(tio, yield, smallerchild - parent_tmp, tio.aes_ops());
  377. // RegAS z;
  378. // mpc_flagmult(tio, yield, z, lt1, smallerchild - parent_tmp, 64);
  379. auto lt1eq1 = lt1 ^ eq1;
  380. RegBS ltlt1;
  381. // RegAS zz;
  382. mpc_and(tio, yield, ltlt1, lteq, lt1eq1);
  383. // // mpc_flagmult(tio, yield, zz, ltlt1, (parent_tmp - leftchild_tmp), 64);
  384. // run_coroutines(tio, [&tio, &ltlt1, lteq, lt1eq1](yield_t &yield)
  385. // { mpc_and(tio, yield, ltlt1, lteq, lt1eq1);},
  386. // [&tio, &zz, ltlt1, parent_tmp, leftchild_tmp](yield_t &yield)
  387. // { mpc_flagmult(tio, yield, zz, ltlt1, (parent_tmp - leftchild_tmp), 64);});
  388. RegAS update_index_by, update_leftindex_by;
  389. run_coroutines(tio, [&tio, &update_leftindex_by, ltlt1, parent_tmp, leftchild_tmp](yield_t &yield)
  390. { mpc_flagmult(tio, yield, update_leftindex_by, ltlt1, (parent_tmp - leftchild_tmp), 64);},
  391. [&tio, &update_index_by, lt1eq1, parent_tmp, smallerchild](yield_t &yield)
  392. {mpc_flagmult(tio, yield, update_index_by, lt1eq1, smallerchild - parent_tmp, 64);}
  393. );
  394. // RegAS leftchildplusparent = RegAS(HeapArray[index]) + RegAS(HeapArray[leftchildindex]);
  395. // RegAS tmp = (sum - leftchildplusparent);
  396. std::vector<coro_t> coroutines;
  397. coroutines.emplace_back(
  398. [&tio, &P, &oidx, update_index_by](yield_t &yield) {
  399. auto Acoro = P.context(yield);
  400. Acoro[oidx] += update_index_by; //inserted_val;
  401. });
  402. coroutines.emplace_back(
  403. [&tio, &L, &oidx, update_leftindex_by](yield_t &yield) {
  404. auto Acoro = L.context(yield);
  405. Acoro[oidx] += update_leftindex_by; //inserted_val;
  406. });
  407. coroutines.emplace_back(
  408. [&tio, &R, &oidx, update_leftindex_by, update_index_by](yield_t &yield) {
  409. auto Acoro = R.context(yield);
  410. Acoro[oidx] += -(update_leftindex_by + update_index_by);
  411. });
  412. run_coroutines(tio, coroutines);
  413. // P[oidx] += z;
  414. // L[oidx] += zz;
  415. // R[oidx] += zzz;
  416. return std::make_pair(smallerindex, gt);
  417. }
  418. void MinHeap::initialize(MPCTIO tio, yield_t & yield) {
  419. auto HeapArray = oram.flat(tio, yield);
  420. HeapArray.init(0x7fffffffffffff);
  421. }
  422. void MinHeap::initialize_random(MPCTIO tio, yield_t & yield) {
  423. auto HeapArray = oram.flat(tio, yield);
  424. std::cout << "initialize_random " << num_items << std::endl;
  425. std::vector<coro_t> coroutines;
  426. // RegAS v[num_items+1];
  427. // for(size_t j = 1; j < num_items; ++j) v[j].ashare = j * tio.player();
  428. for(size_t j = 1; j <= num_items; ++j)
  429. {
  430. coroutines.emplace_back(
  431. [&tio, &HeapArray, j](yield_t &yield) {
  432. auto Acoro = HeapArray.context(yield);
  433. RegAS v;
  434. v.ashare = j * tio.player();
  435. Acoro[j] = v; //inserted_val;
  436. });
  437. }
  438. run_coroutines(tio, coroutines);
  439. // for(size_t j = 1; j <= num_items; ++j)
  440. // {
  441. // RegAS v;
  442. // RegAS inserted_val;
  443. // inserted_val.randomize(6);
  444. // v.ashare = j * tio.player();
  445. // HeapArray[j] = v; //inserted_val;
  446. // // HeapArray.init([v] (size_t j) { return v; });
  447. // }
  448. // HeapArray.init(0x7fffffffffffff);
  449. }
  450. void MinHeap::print_heap(MPCTIO tio, yield_t & yield) {
  451. auto HeapArray = oram.flat(tio, yield);
  452. uint64_t * Pjreconstruction = new uint64_t[num_items + 1];
  453. for(size_t j = 0; j <= num_items; ++j) Pjreconstruction[j] = mpc_reconstruct(tio, yield, HeapArray[j]);
  454. for(size_t j = 0; j <= num_items; ++j)
  455. {
  456. if(2 * j < num_items) {
  457. std::cout << j << "-->> HeapArray[" << j << "] = " << std::dec << Pjreconstruction[j] << ", children are: " << Pjreconstruction[2 * j] << " and " << Pjreconstruction[2 * j + 1] << std::endl;
  458. }
  459. else
  460. {
  461. std::cout << j << "-->> HeapArray[" << j << "] = " << std::dec << Pjreconstruction[j] << " is a LEAF " << std::endl;
  462. }
  463. }
  464. }
  465. auto MinHeap::restore_heap_property_at_root(MPCTIO tio, yield_t & yield, size_t index = 1) {
  466. //size_t index = 1;
  467. //std::cout << "index = " << index << std::endl;
  468. auto HeapArray = oram.flat(tio, yield);
  469. RegAS parent = HeapArray[index];
  470. RegAS leftchild = HeapArray[2 * index];
  471. RegAS rightchild = HeapArray[2 * index + 1];
  472. RegAS sum = parent + leftchild + rightchild;
  473. CDPF cdpf = tio.cdpf(yield);
  474. auto[lt, eq, gt] = cdpf.compare(tio, yield, leftchild - rightchild, tio.aes_ops()); // c_1 in the paper
  475. RegAS smallerchild;
  476. mpc_select(tio, yield, smallerchild, lt, rightchild, leftchild, 64); // smallerchild holds smaller of left and right child
  477. auto lteq = lt ^ eq;
  478. RegXS smallerindex(lt);
  479. uint64_t leftchildindex = (2 * index);
  480. uint64_t rightchildindex = (2 * index) + 1;
  481. smallerindex = (RegXS(lteq) & leftchildindex) ^ (RegXS(gt) & rightchildindex);
  482. CDPF cdpf0 = tio.cdpf(yield);
  483. auto[lt1, eq1, gt1] = cdpf0.compare(tio, yield, smallerchild - parent, tio.aes_ops());
  484. auto lt_p_eq_p = lt1 ^ eq1;
  485. RegBS ltlt1;
  486. mpc_and(tio, yield, ltlt1, lteq, lt_p_eq_p);
  487. RegAS update_index_by, update_leftindex_by;
  488. run_coroutines(tio, [&tio, &update_leftindex_by, ltlt1, parent, leftchild](yield_t &yield)
  489. { mpc_flagmult(tio, yield, update_leftindex_by, ltlt1, (parent - leftchild), 64);},
  490. [&tio, &update_index_by, lt1, parent, smallerchild](yield_t &yield)
  491. {mpc_flagmult(tio, yield, update_index_by, lt1, smallerchild - parent, 64);}
  492. );
  493. // HeapArray[index] += update_index_by;
  494. // HeapArray[leftchildindex] += update_leftindex_by;
  495. std::vector<coro_t> coroutines;
  496. coroutines.emplace_back(
  497. [&tio, &HeapArray, index, update_index_by](yield_t &yield) {
  498. auto Acoro = HeapArray.context(yield);
  499. Acoro[index] += update_index_by; //inserted_val;
  500. });
  501. coroutines.emplace_back(
  502. [&tio, &HeapArray, leftchildindex, update_leftindex_by](yield_t &yield) {
  503. auto Acoro = HeapArray.context(yield);
  504. Acoro[leftchildindex] += update_leftindex_by; //inserted_val;
  505. });
  506. coroutines.emplace_back(
  507. [&tio, &HeapArray, rightchildindex, update_index_by, update_leftindex_by](yield_t &yield) {
  508. auto Acoro = HeapArray.context(yield);
  509. Acoro[rightchildindex] += -(update_index_by + update_leftindex_by);
  510. });
  511. run_coroutines(tio, coroutines);
  512. RegAS leftchildplusparent = RegAS(HeapArray[index]) + RegAS(HeapArray[leftchildindex]);
  513. RegAS tmp = (sum - leftchildplusparent);
  514. HeapArray[rightchildindex] += tmp - rightchild;
  515. #ifdef VERBOSE
  516. RegAS new_parent = HeapArray[index];
  517. RegAS new_left = HeapArray[leftchildindex];
  518. RegAS new_right = HeapArray[rightchildindex];
  519. uint64_t parent_R = mpc_reconstruct(tio, yield, new_parent);
  520. uint64_t left_R = mpc_reconstruct(tio, yield, new_left);
  521. uint64_t right_R = mpc_reconstruct(tio, yield, new_right);
  522. std::cout << "parent_R = " << parent_R << std::endl;
  523. std::cout << "left_R = " << left_R << std::endl;
  524. std::cout << "right_R = " << right_R << std::endl;
  525. #endif
  526. //verify_parent_children_heaps(tio, yield, HeapArray[index], HeapArray[leftchildindex] , HeapArray[rightchildindex]);
  527. return std::make_pair(smallerindex, gt);
  528. }
  529. RegAS MinHeap::extract_min(MPCIO & mpcio, MPCTIO tio, yield_t & yield, int is_optimized) {
  530. RegAS minval;
  531. auto HeapArray = oram.flat(tio, yield);
  532. minval = HeapArray[1];
  533. HeapArray[1] = RegAS(HeapArray[num_items]);
  534. num_items--;
  535. auto outroot = restore_heap_property_at_root(tio, yield);
  536. RegXS smaller = outroot.first;
  537. size_t height = std::log2(num_items);
  538. if(is_optimized > 0)
  539. {
  540. typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > oidx(tio, yield, height);
  541. oidx.incr(outroot.second);
  542. for (size_t i = 0; i < height; ++i) {
  543. auto out = restore_heap_property_optimized(tio, yield, smaller, i + 1, height, typename Duoram < RegAS > ::template OblivIndex < RegXS, 3 > (oidx));;
  544. smaller = out.first;
  545. oidx.incr(out.second);
  546. }
  547. }
  548. if(is_optimized == 0)
  549. {
  550. for (size_t i = 0; i < height; ++i) {
  551. smaller = restore_heap_property(mpcio, tio, yield, smaller);
  552. std::cout << "one iter done ... \n \n \n";
  553. }
  554. }
  555. return minval;
  556. }
  557. void MinHeap::heapify2(MPCTIO tio, yield_t & yield, size_t index = 1) {
  558. // auto outroot = restore_heap_property_at_root(tio, yield, index);
  559. // RegXS smaller = outroot.first;
  560. // #ifdef VERBOSE
  561. // uint64_t smaller_rec = mpc_reconstruct(tio, yield, smaller, 64);
  562. // std::cout << "smaller_rec = " << smaller_rec << std::endl;
  563. // std::cout << "num_items = " << num_items << std::endl;
  564. // std::cout << "index = " << index << std::endl;
  565. // #endif
  566. // size_t height = std::log2(num_items) - std::floor(log2(index)) ;
  567. // #ifdef VERBOSE
  568. // std::cout << "height = " << height << std::endl << "===================" << std::endl;
  569. // #endif
  570. // for (size_t i = 0; i < height - 1; ++i) {
  571. // #ifdef VERBOSE
  572. // std::cout << "index = " << index << ", i = " << i << std::endl;
  573. // uint64_t smaller_rec = mpc_reconstruct(tio, yield, smaller, 64);
  574. // std::cout << "[inside loop] smaller_rec = " << smaller_rec << std::endl;
  575. // #endif
  576. // smaller = restore_heap_property(tio, yield, smaller);
  577. // }
  578. }
  579. void MinHeap::heapify(MPCTIO tio, yield_t & yield) {
  580. size_t startIdx = ((num_items + 1) / 2) - 1;
  581. //std::cout << "startIdx " << startIdx << std::endl;
  582. for (size_t i = startIdx; i >= 1; i--) {
  583. heapify2(tio, yield, i);
  584. //print_heap(tio, yield);
  585. }
  586. }
  587. void Heap(MPCIO & mpcio,
  588. const PRACOptions & opts, char ** args) {
  589. nbits_t depth = atoi(args[0]);
  590. nbits_t depth2 = atoi(args[1]);
  591. size_t n_inserts = atoi(args[2]);
  592. size_t n_extracts = atoi(args[3]);
  593. int is_optimized = atoi(args[4]);
  594. std::cout << "print arguements " << std::endl;
  595. std::cout << args[0] << std::endl;
  596. if ( * args) {
  597. depth = atoi( * args);
  598. ++args;
  599. }
  600. size_t items = (size_t(1) << depth) - 1;
  601. if ( * args) {
  602. items = atoi( * args);
  603. ++args;
  604. }
  605. //
  606. std::cout << "items = " << items << std::endl;
  607. MPCTIO tio(mpcio, 0, opts.num_threads);
  608. run_coroutines(tio, [ & tio, depth, depth2, items, n_inserts, n_extracts, is_optimized, &mpcio](yield_t & yield) {
  609. size_t size = size_t(1) << depth;
  610. // std::cout << "size = " << size << std::endl;
  611. MinHeap tree(tio.player(), size);
  612. tree.initialize(tio, yield);
  613. tree.num_items = (size_t(1) << depth2) - 1;
  614. // std::cout << "num_items " << tree.num_items << std::endl;
  615. tree.initialize_random(tio, yield);
  616. std::cout << "\n===== Init Stats =====\n";
  617. tio.sync_lamport();
  618. mpcio.dump_stats(std::cout);
  619. mpcio.reset_stats();
  620. tio.reset_lamport();
  621. // tree.heapify(tio, yield);
  622. // tree.print_heap(tio, yield);
  623. for (size_t j = 0; j < n_inserts; ++j) {
  624. RegAS inserted_val;
  625. inserted_val.randomize(6);
  626. #ifdef VERBOSE
  627. inserted_val.ashare = inserted_val.ashare;
  628. uint64_t inserted_val_rec = mpc_reconstruct(tio, yield, inserted_val, 64);
  629. std::cout << "inserted_val_rec = " << inserted_val_rec << std::endl << std::endl;
  630. #endif
  631. if(is_optimized > 0) tree.insert_optimized(tio, yield, inserted_val);
  632. if(is_optimized == 0) tree.insert(tio, yield, inserted_val);
  633. //tree.print_heap(tio, yield);
  634. }
  635. std::cout << "\n===== Insert Stats =====\n";
  636. tio.sync_lamport();
  637. mpcio.dump_stats(std::cout);
  638. mpcio.reset_stats();
  639. tio.reset_lamport();
  640. // tree.verify_heap_property(tio, yield);
  641. // tree.print_heap(tio, yield);
  642. for (size_t j = 0; j < n_extracts; ++j) {
  643. tree.extract_min(mpcio, tio, yield, is_optimized);
  644. //RegAS minval = tree.extract_min(tio, yield, is_optimized);
  645. // uint64_t minval_reconstruction = mpc_reconstruct(tio, yield, minval, 64);
  646. // std::cout << "minval_reconstruction = " << minval_reconstruction << std::endl;
  647. // tree.verify_heap_property(tio, yield);
  648. // tree.print_heap(tio, yield);
  649. }
  650. std::cout << "\n===== Extract Min Stats =====\n";
  651. tio.sync_lamport();
  652. mpcio.dump_stats(std::cout);
  653. //tree.print_heap(tio, yield);
  654. //tree.verify_heap_property(tio, yield);
  655. });
  656. }