avl.cpp 127 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378
  1. #include <functional>
  2. #include "avl.hpp"
  3. #define KNRM "\x1B[0m"
  4. #define KRED "\x1B[31m"
  5. #define KGRN "\x1B[32m"
  6. #define KYEL "\x1B[33m"
  7. #define KBLU "\x1B[34m"
  8. #define KMAG "\x1B[35m"
  9. #define KCYN "\x1B[36m"
  10. #define KWHT "\x1B[37m"
  11. static void randomize_node(Node &a) {
  12. a.key.randomize(8);
  13. a.pointers.set(0);
  14. a.value.randomize();
  15. }
  16. void print_green(std::string line) {
  17. printf("%s%s%s", KGRN, line.c_str(), KNRM);
  18. }
  19. void print_red(std::string line) {
  20. printf("%s%s%s", KRED, line.c_str(), KNRM);
  21. }
  22. /*
  23. Pretty-print a reconstructed BST, rooted at node. is_left_child and
  24. is_right_child indicate whether node is a left or right child of its
  25. parent. They cannot both be true, but the root of the tree has both
  26. of them false.
  27. */
  28. void AVL::pretty_print(const std::vector<Node> &R, value_t node,
  29. const std::string &prefix = "", bool is_left_child = false,
  30. bool is_right_child = false)
  31. {
  32. if (node == 0) {
  33. // NULL pointer
  34. if (is_left_child) {
  35. printf("%s\xE2\x95\xA7\n", prefix.c_str()); // ╧
  36. } else if (is_right_child) {
  37. printf("%s\xE2\x95\xA4\n", prefix.c_str()); // ╤
  38. } else {
  39. printf("%s\xE2\x95\xA2\n", prefix.c_str()); // ╢
  40. }
  41. return;
  42. }
  43. const Node &n = R[node];
  44. value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
  45. value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
  46. std::string rightprefix(prefix), leftprefix(prefix),
  47. nodeprefix(prefix);
  48. if (is_left_child) {
  49. rightprefix.append("\xE2\x94\x82"); // │
  50. leftprefix.append(" ");
  51. nodeprefix.append("\xE2\x94\x94"); // └
  52. } else if (is_right_child) {
  53. rightprefix.append(" ");
  54. leftprefix.append("\xE2\x94\x82"); // │
  55. nodeprefix.append("\xE2\x94\x8C"); // ┌
  56. } else {
  57. rightprefix.append(" ");
  58. leftprefix.append(" ");
  59. nodeprefix.append("\xE2\x94\x80"); // ─
  60. }
  61. pretty_print(R, right_ptr, rightprefix, false, true);
  62. printf("%s\xE2\x94\xA4", nodeprefix.c_str()); // ┤
  63. dumpAVL(n);
  64. printf("\n");
  65. pretty_print(R, left_ptr, leftprefix, true, false);
  66. }
  67. void AVL::print_oram(MPCTIO &tio, yield_t &yield) {
  68. auto A = oram.flat(tio, yield);
  69. auto R = A.reconstruct();
  70. for(size_t i=0;i<R.size();++i) {
  71. printf("\n%04lx ", i);
  72. R[i].dump();
  73. }
  74. printf("\n");
  75. }
  76. void AVL::pretty_print(MPCTIO &tio, yield_t &yield) {
  77. RegXS peer_root;
  78. RegXS reconstructed_root = root;
  79. if (tio.player() == 1) {
  80. tio.queue_peer(&root, sizeof(root));
  81. yield();
  82. } else {
  83. RegXS peer_root;
  84. yield();
  85. tio.recv_peer(&peer_root, sizeof(peer_root));
  86. reconstructed_root += peer_root;
  87. }
  88. auto A = oram.flat(tio, yield);
  89. auto R = A.reconstruct();
  90. if(tio.player()==0) {
  91. pretty_print(R, reconstructed_root.xshare);
  92. }
  93. }
  94. /*
  95. Check the AVL invariants of the tree are recursively true:
  96. (i) all keys to the left are less than or equal to parent (BST invariant)
  97. (ii) all keys to the right are strictly greater than parent (BST invariant)
  98. (iii) difference in height between sibling subtrees <= 1
  99. (iv) the balance bits of each node are correct
  100. Returns a tuple<bool a, bool b, bool c, address_t height>, where:
  101. - bool a indicates if BST invariants are true
  102. - bool b indicates if (iii) is true
  103. - bool c indicates if (iv) is true
  104. - height returns the height of the current subtree
  105. */
  106. std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
  107. value_t node, value_t min_key = 0, value_t max_key = ~0)
  108. {
  109. if (node == 0) {
  110. return { true, true, true, 0};
  111. }
  112. const Node &n = R[node];
  113. value_t key = n.key.ashare;
  114. value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
  115. value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
  116. auto [leftok, leftavlok, leftbbok, leftheight ] = check_avl(R, left_ptr, min_key, key);
  117. auto [rightok, rightavlok, rightbbok, rightheight ] = check_avl(R, right_ptr, key, max_key);
  118. address_t height = leftheight;
  119. if (rightheight > height) {
  120. height = rightheight;
  121. }
  122. height += 1;
  123. int heightgap = leftheight - rightheight;
  124. bool leftbal = (getLeftBal(n.pointers)).bshare;
  125. bool rightbal = (getRightBal(n.pointers)).bshare;
  126. bool avlok = (abs(heightgap)<2);
  127. bool bb_ok = false;
  128. if(heightgap==-1) {
  129. if(rightbal==1 && leftbal==0){
  130. bb_ok = true;
  131. }
  132. } else if(heightgap==1){
  133. if(leftbal==1 && rightbal==0){
  134. bb_ok = true;
  135. }
  136. } else if(heightgap==0){
  137. if(rightbal==0 && leftbal==0) {
  138. bb_ok = true;
  139. }
  140. }
  141. #ifdef AVL_DEBUG_BB
  142. if(bb_ok == false){
  143. printf("BB check failed at node with key = %ld\n", key);
  144. }
  145. #endif
  146. //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
  147. return { leftok && rightok && key >= min_key && key <= max_key,
  148. avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
  149. }
  150. // Note only P0 gets the correct result of check_AVL.
  151. // That's fine since P0 outputs all the correctness outputs for the test suite.
  152. bool AVL::check_avl(MPCTIO &tio, yield_t &yield) {
  153. auto A = oram.flat(tio, yield);
  154. auto R = A.reconstruct();
  155. RegXS rec_root = this->root;
  156. if (tio.player() == 1) {
  157. tio.queue_peer(&(this->root), sizeof(this->root));
  158. yield();
  159. } else {
  160. RegXS peer_root;
  161. yield();
  162. tio.recv_peer(&peer_root, sizeof(peer_root));
  163. rec_root+= peer_root;
  164. }
  165. if (tio.player() == 0) {
  166. auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
  167. printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
  168. bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
  169. return (bst_ok && avl_ok && bb_ok);
  170. }
  171. else {
  172. return false;
  173. }
  174. }
  175. /*
  176. Rotate: (gp = grandparent (if exists), p = parent, c = child)
  177. This rotates the p -> c link.
  178. gp gp
  179. \ \
  180. p --- Left rotate ---> c
  181. \ /
  182. c p
  183. gp gp
  184. \ \
  185. p --- Right rotate ---> c
  186. / \
  187. c p
  188. */
  189. void AVL::rotate(MPCTIO &tio, yield_t &yield, RegXS &gp_pointers, RegXS p_ptr,
  190. RegXS &p_pointers, RegXS c_ptr, RegXS &c_pointers, RegBS dir_gpp,
  191. RegBS dir_pc, RegBS isReal, RegBS F_gp) {
  192. bool player0 = tio.player()==0;
  193. RegXS gp_left = getAVLLeftPtr(gp_pointers);
  194. RegXS gp_right = getAVLRightPtr(gp_pointers);
  195. RegXS p_left = getAVLLeftPtr(p_pointers);
  196. RegXS p_right = getAVLRightPtr(p_pointers);
  197. RegXS c_left = getAVLLeftPtr(c_pointers);
  198. RegXS c_right = getAVLRightPtr(c_pointers);
  199. RegXS ptr_upd;
  200. // F_gpp: Flag to update gp -> p link, F_pc: Flag to update p -> c link
  201. // F_pc_l/F_pc_r: indicates whether p -> c link is in the l/r direction
  202. // F_gpp_l/F_gpp_r: indicates whether gp -> p link is in the l/r direction
  203. RegBS F_gpp, F_pc_l, F_pc_r, F_gppr, F_gppl;
  204. // We care about !F_gp. If !F_gp, then we do the gp->p link updates.
  205. // Otherwise, we do NOT do any updates to gp-> p link;
  206. // since F_gp==1, implies gp does not exist and parent is root.
  207. if(player0) {
  208. F_gp^=1;
  209. }
  210. mpc_and(tio, yield, F_gpp, F_gp, isReal);
  211. // i) gp[dir_gpp] <-- c_ptr
  212. RegBS not_dir_gpp = dir_gpp;
  213. if(player0) {
  214. not_dir_gpp^=1;
  215. }
  216. mpc_select(tio, yield, ptr_upd, F_gpp, p_ptr, c_ptr);
  217. RegBS not_dir_pc_l = dir_pc, not_dir_pc_r = dir_pc;
  218. if(player0) {
  219. not_dir_pc_r^=1;
  220. }
  221. RegXS c_not_dir_pc; //c[!dir_pc]
  222. // ndpc_right: if not_dir_pc is right
  223. // ndpc_left: if not_dir_pc is left
  224. RegBS F_ndpc_right, F_ndpc_left;
  225. RegBS nt_dir_pc = dir_pc;
  226. if(player0) {
  227. nt_dir_pc^=1;
  228. }
  229. std::vector<coro_t> coroutines;
  230. coroutines.emplace_back(
  231. [&tio, &F_gppr, F_gpp, dir_gpp](yield_t &yield) {
  232. mpc_and(tio, yield, F_gppr, F_gpp, dir_gpp);
  233. });
  234. coroutines.emplace_back(
  235. [&tio, &F_gppl, F_gpp, not_dir_gpp](yield_t &yield) {
  236. mpc_and(tio, yield, F_gppl, F_gpp, not_dir_gpp);
  237. });
  238. // ii) p[dir_pc] <-- c[!dir_pc] and iii) c[!dir_pc] <-- p_ptr
  239. coroutines.emplace_back(
  240. [&tio, &F_ndpc_right, isReal, not_dir_pc_r](yield_t &yield) {
  241. mpc_and(tio, yield, F_ndpc_right, isReal, not_dir_pc_r);
  242. });
  243. coroutines.emplace_back(
  244. [&tio, &F_ndpc_left, isReal, not_dir_pc_l](yield_t &yield) {
  245. mpc_and(tio, yield, F_ndpc_left, isReal, not_dir_pc_l);
  246. });
  247. coroutines.emplace_back(
  248. [&tio, &F_pc_l, dir_pc, isReal](yield_t &yield) {
  249. mpc_and(tio, yield, F_pc_l, dir_pc, isReal);
  250. });
  251. coroutines.emplace_back(
  252. [&tio, &F_pc_r, nt_dir_pc, isReal](yield_t &yield) {
  253. mpc_and(tio, yield, F_pc_r, nt_dir_pc, isReal);
  254. });
  255. run_coroutines(tio, coroutines);
  256. run_coroutines(tio, [&tio, &gp_right, F_gppr, ptr_upd](yield_t &yield)
  257. { mpc_select(tio, yield, gp_right, F_gppr, gp_right, ptr_upd);},
  258. [&tio, &gp_left, F_gppl, ptr_upd](yield_t &yield)
  259. { mpc_select(tio, yield, gp_left, F_gppl, gp_left, ptr_upd);},
  260. [&tio, &c_not_dir_pc, F_ndpc_right, c_right](yield_t &yield)
  261. { mpc_select(tio, yield, c_not_dir_pc, F_ndpc_right, c_not_dir_pc, c_right, AVL_PTR_SIZE);});
  262. //[&tio, &c_not_dir_pc, F_ndpc_left, c_left](yield_t &yield)
  263. mpc_select(tio, yield, c_not_dir_pc, F_ndpc_left, c_not_dir_pc, c_left, AVL_PTR_SIZE);
  264. // ii) p[dir_pc] <-- c[!dir_pc]
  265. // iii): c[!dir_pc] <-- p_ptr
  266. run_coroutines(tio, [&tio, &p_left, F_ndpc_right, c_not_dir_pc](yield_t &yield)
  267. { mpc_select(tio, yield, p_left, F_ndpc_right, p_left, c_not_dir_pc, AVL_PTR_SIZE);},
  268. [&tio, &p_right, F_ndpc_left, c_not_dir_pc](yield_t &yield)
  269. { mpc_select(tio, yield, p_right, F_ndpc_left, p_right, c_not_dir_pc, AVL_PTR_SIZE);},
  270. [&tio, &ptr_upd, isReal, c_not_dir_pc, p_ptr](yield_t &yield)
  271. { mpc_select(tio, yield, ptr_upd, isReal, c_not_dir_pc, p_ptr, AVL_PTR_SIZE);});
  272. run_coroutines(tio, [&tio, &c_left, F_pc_l, ptr_upd](yield_t &yield)
  273. { mpc_select(tio, yield, c_left, F_pc_l, c_left, ptr_upd, AVL_PTR_SIZE);},
  274. [&tio, &c_right, F_pc_r, ptr_upd](yield_t &yield)
  275. { mpc_select(tio, yield, c_right, F_pc_r, c_right, ptr_upd, AVL_PTR_SIZE);});
  276. setAVLLeftPtr(gp_pointers, gp_left);
  277. setAVLRightPtr(gp_pointers, gp_right);
  278. setAVLLeftPtr(p_pointers, p_left);
  279. setAVLRightPtr(p_pointers, p_right);
  280. setAVLLeftPtr(c_pointers, c_left);
  281. setAVLRightPtr(c_pointers, c_right);
  282. }
  283. /*
  284. If F_rs: (bal_upd & right_child)
  285. bal_l, balanced, bal_r, imbalance
  286. And then right shift to get imbalance bit, and new bal_l, bal_r bits
  287. else if F_ls: (bal_upd & left_child)
  288. imbalance, bal_l, balanced, bal_r
  289. And then left shift to get imbalance bit, and new bal_l, bal_r bits
  290. */
  291. std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceIns(MPCTIO &tio, yield_t &yield,
  292. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
  293. bool player0 = tio.player()==0;
  294. RegBS s0;
  295. RegBS F_rs, F_ls, balanced, imbalance, nt_child_dir;
  296. // balanced = is the node currently balanced
  297. balanced = bal_l ^ bal_r;
  298. nt_child_dir = child_dir;
  299. if(player0){
  300. nt_child_dir^=1;
  301. }
  302. if(player0) {
  303. balanced^=1;
  304. }
  305. run_coroutines(tio, [&tio, &F_rs, child_dir, bal_upd](yield_t &yield)
  306. { //F_rs (Flag right shift) <- child_dir & bal_upd
  307. mpc_and(tio, yield, F_rs, child_dir, bal_upd);},
  308. [&tio, &F_ls, nt_child_dir, bal_upd](yield_t &yield)
  309. { //F_ls (Flag left shift) <- !child_dir & bal_upd
  310. mpc_and(tio, yield, F_ls, nt_child_dir, bal_upd);});
  311. std::vector<coro_t> coroutines;
  312. // Right shift if child_dir = 1 & bal_upd = 1
  313. coroutines.emplace_back(
  314. [&tio, &imbalance, F_rs, bal_r, balanced](yield_t &yield) {
  315. mpc_select(tio, yield, imbalance, F_rs, imbalance, bal_r);
  316. });
  317. coroutines.emplace_back(
  318. [&tio, &bal_r, F_rs, balanced](yield_t &yield) {
  319. mpc_select(tio, yield, bal_r, F_rs, bal_r, balanced);
  320. });
  321. coroutines.emplace_back(
  322. [&tio, &balanced, F_rs, bal_l](yield_t &yield) {
  323. mpc_select(tio, yield, balanced, F_rs, balanced, bal_l);
  324. });
  325. coroutines.emplace_back(
  326. [&tio, &bal_l, F_rs, s0](yield_t &yield) {
  327. mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);
  328. });
  329. run_coroutines(tio, coroutines);
  330. coroutines.clear();
  331. // Left shift if child_dir = 0 & bal_upd = 1
  332. coroutines.emplace_back(
  333. [&tio, &imbalance, F_ls, bal_l] (yield_t &yield) {
  334. mpc_select(tio, yield, imbalance, F_ls, imbalance, bal_l);
  335. });
  336. coroutines.emplace_back(
  337. [&tio, &bal_l, F_ls, balanced] (yield_t &yield) {
  338. mpc_select(tio, yield, bal_l, F_ls, bal_l, balanced);
  339. });
  340. coroutines.emplace_back(
  341. [&tio, &balanced, F_ls, bal_r] (yield_t &yield) {
  342. mpc_select(tio, yield, balanced, F_ls, balanced, bal_r);
  343. });
  344. coroutines.emplace_back(
  345. [&tio, &bal_r, F_ls, s0](yield_t &yield) {
  346. mpc_select(tio, yield, bal_r, F_ls, bal_r, s0);
  347. });
  348. run_coroutines(tio, coroutines);
  349. // bal_upd <- bal_upd ^ imbalance
  350. RegBS F_bu0;
  351. mpc_and(tio, yield, F_bu0, bal_upd, balanced);
  352. mpc_select(tio, yield, bal_upd, F_bu0, bal_upd, s0);
  353. mpc_select(tio, yield, bal_upd, imbalance, bal_upd, s0);
  354. return {bal_l, bal_r, bal_upd, imbalance};
  355. }
  356. /*
  357. In updateBalanceDel, the position of imbalance, and shift direction for both
  358. cases are inverted (from that of updateBalanceIns()), since a bal_upd on a child
  359. implies it reduced height.
  360. If F_rs: (bal_upd & right_child)
  361. imbalance, bal_l, balanced, bal_r
  362. And then left shift to get imbalance bit, and new bal_l, bal_r bits
  363. else if F_ls: (bal_upd & left_child)
  364. bal_l, balanced, bal_r, imbalance
  365. And then right shift to get imbalance bit, and new bal_l, bal_r bits
  366. */
  367. std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_t &yield,
  368. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
  369. bool player0 = tio.player()==0;
  370. RegBS s0;
  371. RegBS F_rs, F_ls, balanced, imbalance, not_imbalance;
  372. RegBS nt_child_dir = child_dir;
  373. if(player0) {
  374. nt_child_dir^=1;
  375. }
  376. // balanced = is the node currently balanced
  377. balanced = bal_l ^ bal_r;
  378. if(player0) {
  379. balanced^=1;
  380. }
  381. //F_ls (Flag left shift) <- child_dir & bal_upd
  382. //F_rs (Flag right shift) <- !child_dir & bal_upd
  383. run_coroutines(tio, [&tio, &F_ls, child_dir, bal_upd](yield_t &yield)
  384. { mpc_and(tio, yield, F_ls, child_dir, bal_upd);},
  385. [&tio, &F_rs, nt_child_dir, bal_upd](yield_t &yield)
  386. { mpc_and(tio, yield, F_rs, nt_child_dir, bal_upd);});
  387. // Left shift if F_ls
  388. run_coroutines(tio, [&tio, &imbalance, F_ls, bal_l](yield_t &yield)
  389. { mpc_select(tio, yield, imbalance, F_ls, imbalance, bal_l);},
  390. [&tio, &bal_l, F_ls, balanced](yield_t &yield)
  391. { mpc_select(tio, yield, bal_l, F_ls, bal_l, balanced);},
  392. [&tio, &balanced, F_ls, bal_r](yield_t &yield)
  393. { mpc_select(tio, yield, balanced, F_ls, balanced, bal_r);},
  394. [&tio, &bal_r, F_ls, s0](yield_t &yield)
  395. { mpc_select(tio, yield, bal_r, F_ls, bal_r, s0);});
  396. // Right shift if F_rs
  397. run_coroutines(tio, [&tio, &imbalance, F_rs, bal_r](yield_t &yield)
  398. { mpc_select(tio, yield, imbalance, F_rs, imbalance, bal_r);},
  399. [&tio, &bal_r, F_rs, balanced](yield_t &yield)
  400. { mpc_select(tio, yield, bal_r, F_rs, bal_r, balanced);},
  401. [&tio, &balanced, F_rs, bal_l](yield_t &yield)
  402. { mpc_select(tio, yield, balanced, F_rs, balanced, bal_l);},
  403. [&tio, &bal_l, F_rs, s0](yield_t &yield)
  404. { mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);});
  405. /*
  406. if(bal_upd) and this node:
  407. (i) becomes balanced: the height has decreased, so continue propogating bal_upd.
  408. (ii) becomes imbalanced: fixImbalance will update bal_upd correctly.
  409. (iii) updates from balanced to left/right heavy: the height of this subtree has not changed,
  410. so don't propogate bal_upd.
  411. We handle (iii) below.
  412. */
  413. RegBS LR_heavy, bu0;
  414. LR_heavy = bal_l ^ bal_r;
  415. mpc_and(tio, yield, bu0, bal_upd, LR_heavy);
  416. mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
  417. return {bal_l, bal_r, bal_upd, imbalance};
  418. }
  419. /*
  420. The recurisve AVL insert function.
  421. Takes as input: the current node pointer of the tree traversal (ptr),
  422. the address of the newly inserted node (ins_addr), the insertion key
  423. (insert_key), the underlying DORAM as a flat (A), the time-to-live (TTL),
  424. a flag indicating if this is a dummy operation (isDummy), and a return
  425. structure (ret) that gets stores the imbalance state (if any) during
  426. the insertion, to resolve it with a one time imbalance fix operation.
  427. Returns a tuple<RegBS a, RegBS b, RegBS c, RegBS d>
  428. where (a) is the balance update bit from the subtree, (b) is a flag to
  429. indicate when to store the grandparent node in the return structure,
  430. (c) is the pointer to the recursive subtree, and (d) is the direction of
  431. the subtree from the parent.
  432. */
  433. std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield, RegXS ptr, RegXS ins_addr,
  434. RegAS insert_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy, avl_insert_return &ret) {
  435. if(TTL==0) {
  436. RegBS z;
  437. return {z, z, z, z};
  438. }
  439. RegBS isReal = isDummy ^ (!tio.player());
  440. Node cnode;
  441. std::optional<Duoram<Node>::OblivIndex<RegXS, 1>> oidx;
  442. RegXS old_pointers;
  443. nbits_t width = ceil(log2(cur_max_index+1));
  444. if(OPTIMIZED) {
  445. oidx.emplace(tio, yield, ptr, width);
  446. cnode = A[oidx.value()];
  447. old_pointers = cnode.pointers;
  448. } else {
  449. cnode = A[ptr];
  450. }
  451. // Compare key
  452. auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insert_key);
  453. // Depending on [lteq, gt] select the next_ptr
  454. RegXS next_ptr;
  455. RegXS left = getAVLLeftPtr(cnode.pointers);
  456. RegXS right = getAVLRightPtr(cnode.pointers);
  457. RegBS bal_l = getLeftBal(cnode.pointers);
  458. RegBS bal_r = getRightBal(cnode.pointers);
  459. /*
  460. size_t rec_left = mpc_reconstruct(tio, yield, left, AVL_PTR_SIZE);
  461. size_t rec_right = mpc_reconstruct(tio, yield, right, AVL_PTR_SIZE);
  462. size_t rec_key = mpc_reconstruct(tio, yield, cnode.key);
  463. printf("\n\n(Before recursing) Key = %ld\n", rec_key);
  464. printf("rec_left = %ld, rec_right = %ld\n", rec_left, rec_right);
  465. */
  466. mpc_select(tio, yield, next_ptr, gt, left, right, AVL_PTR_SIZE);
  467. /*
  468. size_t rec_next_ptr = mpc_reconstruct(tio, yield, next_ptr, AVL_PTR_SIZE);
  469. printf("rec_next_ptr = %ld\n", rec_next_ptr);
  470. */
  471. CDPF dpf = tio.cdpf(yield);
  472. size_t &aes_ops = tio.aes_ops();
  473. // F_z: Check if this is last node on path
  474. RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
  475. RegBS F_i;
  476. // F_i: If this was last node on path (F_z), and isReal insert.
  477. mpc_and(tio, yield, F_i, (isReal), F_z);
  478. isDummy^=F_i;
  479. auto [bal_upd, F_gp, prev_node, prev_dir] = insert(tio, yield,
  480. next_ptr, ins_addr, insert_key, A, TTL-1, isDummy, ret);
  481. /*
  482. rec_bal_upd = mpc_reconstruct(tio, yield, bal_upd);
  483. rec_F_gp = mpc_reconstruct(tio, yield, F_gp);
  484. printf("Insert returns: rec_bal_upd = %d, rec_F_gp = %d\n",
  485. rec_bal_upd, rec_F_gp);
  486. size_t rec_ptr = mpc_reconstruct(tio, yield, pt);
  487. printf("\nrec_ptr = %ld\n", rec_ptr);
  488. */
  489. // Update balance
  490. // If we inserted at this level (F_i), bal_upd is set to 1
  491. mpc_or(tio, yield, bal_upd, bal_upd, F_i);
  492. auto [new_bal_l, new_bal_r, new_bal_upd, imbalance] = updateBalanceIns(tio, yield, bal_l, bal_r, bal_upd, gt);
  493. // Store if this insert triggers an imbalance
  494. ret.imbalance ^= imbalance;
  495. std::vector<coro_t> coroutines;
  496. // Save grandparent pointer
  497. coroutines.emplace_back(
  498. [&tio, &ret, F_gp, ptr](yield_t &yield) {
  499. mpc_select(tio, yield, ret.gp_node, F_gp, ret.gp_node, ptr, AVL_PTR_SIZE);
  500. });
  501. coroutines.emplace_back(
  502. [&tio, &ret, F_gp, gt](yield_t &yield) {
  503. mpc_select(tio, yield, ret.dir_gpp, F_gp, ret.dir_gpp, gt);
  504. });
  505. // Save parent pointer
  506. coroutines.emplace_back(
  507. [&tio, &ret, imbalance, ptr](yield_t &yield) {
  508. mpc_select(tio, yield, ret.p_node, imbalance, ret.p_node, ptr, AVL_PTR_SIZE);
  509. });
  510. coroutines.emplace_back(
  511. [&tio, &ret, imbalance, gt](yield_t &yield) {
  512. mpc_select(tio, yield, ret.dir_pc, imbalance, ret.dir_pc, gt);
  513. });
  514. // Save child pointer
  515. coroutines.emplace_back(
  516. [&tio, &ret, imbalance, prev_node](yield_t &yield) {
  517. mpc_select(tio, yield, ret.c_node, imbalance, ret.c_node, prev_node, AVL_PTR_SIZE);
  518. });
  519. coroutines.emplace_back(
  520. [&tio, &ret, imbalance, prev_dir](yield_t &yield) {
  521. mpc_select(tio, yield, ret.dir_cn, imbalance, ret.dir_cn, prev_dir);
  522. });
  523. run_coroutines(tio, coroutines);
  524. // Store new_bal_l and new_bal_r for this node
  525. setLeftBal(cnode.pointers, new_bal_l);
  526. setRightBal(cnode.pointers, new_bal_r);
  527. // We have to write the node pointers anyway to handle balance updates,
  528. // so we perform insertion along with it by modifying pointers appropriately.
  529. RegBS F_ir, F_il;
  530. run_coroutines(tio, [&tio, &F_ir, F_i, gt](yield_t &yield)
  531. { mpc_and(tio, yield, F_ir, F_i, gt); },
  532. [&tio, &F_il, F_i, lteq](yield_t &yield)
  533. { mpc_and(tio, yield, F_il, F_i, lteq); });
  534. run_coroutines(tio, [&tio, &left, F_il, ins_addr](yield_t &yield)
  535. { mpc_select(tio, yield, left, F_il, left, ins_addr);},
  536. [&tio, &right, F_ir, ins_addr](yield_t &yield)
  537. { mpc_select(tio, yield, right, F_ir, right, ins_addr);});
  538. setAVLLeftPtr(cnode.pointers, left);
  539. setAVLRightPtr(cnode.pointers, right);
  540. /*
  541. bool rec_F_ir, rec_F_il;
  542. rec_F_ir = mpc_reconstruct(tio, yield, F_ir);
  543. rec_F_il = mpc_reconstruct(tio, yield, F_il);
  544. rec_left = mpc_reconstruct(tio, yield, left, AVL_PTR_SIZE);
  545. rec_right = mpc_reconstruct(tio, yield, right, AVL_PTR_SIZE);
  546. printf("(After recursing) F_il = %d, left = %ld, F_ir = %d, right = %ld\n",
  547. rec_F_il, rec_left, rec_F_ir, rec_right);
  548. */
  549. if(OPTIMIZED) {
  550. A[oidx.value()].NODE_POINTERS+=(cnode.pointers - old_pointers);
  551. } else {
  552. A[ptr].NODE_POINTERS = cnode.pointers;
  553. }
  554. // s0 = shares of 0
  555. RegBS s0;
  556. // If there was an imbalance then we need to store the grandparent node
  557. // (node in the level above) into the ret_struct. So we return imbalance.
  558. return {new_bal_upd, imbalance, ptr, gt};
  559. }
  560. /*
  561. Main AVL insert function.
  562. Takes as input the new node to insert.
  563. */
  564. // Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
  565. void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
  566. bool player0 = tio.player()==0;
  567. // If there are no items in tree. Make this new item the root.
  568. if(num_items==0) {
  569. auto A = oram.flat(tio, yield);
  570. Node zero;
  571. A[0] = zero;
  572. A[1] = node;
  573. // Set root to a secret sharing of the constant value 1
  574. root.set(1*tio.player());
  575. num_items++;
  576. cur_max_index++;
  577. return;
  578. } else {
  579. // Insert node into next free slot in the ORAM
  580. int new_id;
  581. RegXS insert_address;
  582. num_items++;
  583. int TTL = AVL_TTL(num_items);
  584. bool insertAtEmptyLocation = (empty_locations.size() > 0);
  585. if(!insertAtEmptyLocation) {
  586. cur_max_index++;
  587. }
  588. auto A = oram.flat(tio, yield, 0, cur_max_index+1);
  589. if(insertAtEmptyLocation) {
  590. insert_address = empty_locations.back();
  591. empty_locations.pop_back();
  592. A[insert_address] = node;
  593. } else {
  594. new_id = num_items;
  595. A[new_id] = node;
  596. insert_address.set(new_id * tio.player());
  597. }
  598. RegBS isDummy;
  599. avl_insert_return ret;
  600. RegAS insert_key = node.key;
  601. // Recursive insert function
  602. auto [bal_upd, F_gp, prev_node, prev_dir] = insert(tio, yield, root,
  603. insert_address, insert_key, A, TTL, isDummy, ret);
  604. /*
  605. // Debug code
  606. bool rec_bal_upd, rec_F_gp, ret_dir_pc, ret_dir_cn;
  607. rec_bal_upd = mpc_reconstruct(tio, yield, bal_upd);
  608. rec_F_gp = mpc_reconstruct(tio, yield, F_gp);
  609. ret_dir_pc = mpc_reconstruct(tio, yield, ret.dir_pc);
  610. ret_dir_cn = mpc_reconstruct(tio, yield, ret.dir_cn);
  611. printf("(Top level) Insert returns: rec_bal_upd = %d, rec_F_gp = %d\n",
  612. rec_bal_upd, rec_F_gp);
  613. printf("(Top level) Insert returns: ret.dir_pc = %d, rt.dir_cn = %d\n",
  614. ret_dir_pc, ret_dir_cn);
  615. */
  616. // Perform balance procedure
  617. RegXS gp_pointers, parent_pointers, child_pointers;
  618. std::vector<coro_t> coroutines;
  619. std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_gp;
  620. std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_p;
  621. std::optional<Duoram<Node>::template OblivIndex<RegXS, 1>> oidx_c;
  622. nbits_t width = ceil(log2(cur_max_index+1));
  623. if(OPTIMIZED) {
  624. oidx_gp.emplace(tio, yield, ret.gp_node, width);
  625. oidx_p.emplace(tio, yield, ret.p_node, width);
  626. oidx_c.emplace(tio, yield, ret.c_node, width);
  627. coroutines.emplace_back(
  628. [&tio, &A, &oidx_gp, &gp_pointers](yield_t &yield) {
  629. auto acont = A.context(yield);
  630. gp_pointers = acont[oidx_gp.value()].NODE_POINTERS;});
  631. coroutines.emplace_back(
  632. [&tio, &A, &oidx_p, &parent_pointers](yield_t &yield) {
  633. auto acont = A.context(yield);
  634. parent_pointers = acont[oidx_p.value()].NODE_POINTERS;});
  635. coroutines.emplace_back(
  636. [&tio, &A, &oidx_c, &child_pointers](yield_t &yield) {
  637. auto acont = A.context(yield);
  638. child_pointers = acont[oidx_c.value()].NODE_POINTERS;});
  639. run_coroutines(tio, coroutines);
  640. coroutines.clear();
  641. /*
  642. gp_pointers = A[oidx_gp].NODE_POINTERS;
  643. parent_pointers = A[oidx_p].NODE_POINTERS;
  644. child_pointers = A[oidx_c].NODE_POINTERS;
  645. */
  646. /*
  647. size_t rec_gp_key = mpc_reconstruct(tio, yield, A[oidx_gp].NODE_KEY);
  648. size_t rec_p_key = mpc_reconstruct(tio, yield, A[oidx_p].NODE_KEY);
  649. size_t rec_c_key = mpc_reconstruct(tio, yield, A[oidx_c].NODE_KEY);
  650. size_t rec_gp_lptr = mpc_reconstruct(tio, iyield, getAVLLeftPtr(A[oidx_gp].NODE_POINTERS), AVL_PTR_SIZE);
  651. size_t rec_gp_rptr = mpc_reconstruct(tio, yield, getAVLRightPtr(A[oidx_gp].NODE_POINTERS), AVL_PTR_SIZE);
  652. size_t rec_p_lptr = mpc_reconstruct(tio, yield, getAVLLeftPtr(A[oidx_p].NODE_POINTERS), AVL_PTR_SIZE);
  653. size_t rec_p_rptr = mpc_reconstruct(tio, yield, getAVLRightPtr(A[oidx_p].NODE_POINTERS), AVL_PTR_SIZE);
  654. size_t rec_c_lptr = mpc_reconstruct(tio, yield, getAVLLeftPtr(A[oidx_c].NODE_POINTERS), AVL_PTR_SIZE);
  655. size_t rec_c_rptr = mpc_reconstruct(tio, yield, getAVLRightPtr(A[oidx_c].NODE_POINTERS), AVL_PTR_SIZE);
  656. printf("Reconstructed:\ngp_key = %ld, gp_left_ptr = %ld, gp_right_ptr = %ld\n",
  657. rec_gp_key, rec_gp_lptr, rec_gp_rptr);
  658. printf("p_key = %ld, p_left_ptr = %ld, p_right_ptr = %ld\n",
  659. rec_p_key, rec_p_lptr, rec_p_rptr);
  660. printf("c_key = %ld, c_left_ptr = %ld, c_right_ptr = %ld\n",
  661. rec_c_key, rec_c_lptr, rec_c_rptr);
  662. */
  663. } else {
  664. gp_pointers = A[ret.gp_node].NODE_POINTERS;
  665. parent_pointers = A[ret.p_node].NODE_POINTERS;
  666. child_pointers = A[ret.c_node].NODE_POINTERS;
  667. }
  668. // n_node (child's next node)
  669. RegXS child_left = getAVLLeftPtr(child_pointers);
  670. RegXS child_right = getAVLRightPtr(child_pointers);
  671. RegXS n_node, n_pointers;
  672. mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
  673. std::optional <Duoram<Node>::template OblivIndex<RegXS,1>> oidx_n;
  674. if(OPTIMIZED) {
  675. oidx_n.emplace(tio, yield, n_node, width);
  676. n_pointers = A[oidx_n.value()].NODE_POINTERS;
  677. } else {
  678. n_pointers = A[n_node].NODE_POINTERS;
  679. }
  680. RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
  681. if(OPTIMIZED) {
  682. old_gp_pointers = gp_pointers;
  683. old_parent_pointers = parent_pointers;
  684. old_child_pointers = child_pointers;
  685. old_n_pointers = n_pointers;
  686. }
  687. // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
  688. // (parent->child) and (child->new_node) are not in the same direction
  689. RegBS F_dr = (ret.dir_pc) ^ (ret.dir_cn);
  690. /* Flags: F_cn_rot = child->node rotate
  691. F_ur = update root.
  692. In case of an imbalance we have to always rotate p->c link. (L or R case)
  693. In case of an imbalance where p->c and c->n are in different directions, we have
  694. to perform a double rotation (LR or RL case). In such cases, first rotate
  695. c->n link, and then p->c link
  696. (Note: in the second rotation c is actually n, since the the first rotation
  697. swaps their positions)
  698. */
  699. RegBS F_cn_rot, F_ur, s0;
  700. run_coroutines(tio, [&tio, &F_ur, F_gp, ret](yield_t &yield)
  701. {mpc_and(tio, yield, F_ur, F_gp, ret.imbalance);},
  702. [&tio, &F_cn_rot, ret, F_dr](yield_t &yield)
  703. {mpc_and(tio, yield, F_cn_rot, ret.imbalance, F_dr);});
  704. // Get the n children information for 2nd rotate fix before rotations happen.
  705. RegBS n_bal_l, n_bal_r;
  706. RegXS n_l = getAVLLeftPtr(n_pointers);
  707. RegXS n_r = getAVLRightPtr(n_pointers);
  708. n_bal_l = getLeftBal(n_pointers);
  709. n_bal_r = getRightBal(n_pointers);
  710. // First rotation: c->n link
  711. rotate(tio, yield, parent_pointers, ret.c_node, child_pointers, n_node,
  712. n_pointers, ret.dir_pc, ret.dir_cn, F_cn_rot, s0);
  713. // If F_cn_rot, i.e. we did first rotation. Then c and n need to swap before the second rotate.
  714. RegXS new_child_pointers, new_child;
  715. run_coroutines(tio, [&tio, &new_child_pointers, F_cn_rot, child_pointers, n_pointers] (yield_t &yield)
  716. {mpc_select(tio, yield, new_child_pointers, F_cn_rot, child_pointers, n_pointers);},
  717. [&tio, &new_child, F_cn_rot, ret, n_node](yield_t &yield)
  718. {mpc_select(tio, yield, new_child, F_cn_rot, ret.c_node, n_node, AVL_PTR_SIZE);});
  719. // Second rotation: p->c link
  720. rotate(tio, yield, gp_pointers, ret.p_node, parent_pointers, new_child,
  721. new_child_pointers, ret.dir_gpp, ret.dir_pc, ret.imbalance, F_gp);
  722. // Set parent and child balances to 0 if there was an imbalance.
  723. // parent balances are already set to 0 from updateBalanceIns
  724. RegBS temp_bal, p_bal_l, p_bal_r, p_bal_ndpc;
  725. RegBS c_bal_l, c_bal_r, c_bal_dpc, n_bal_dpc, n_bal_ndpc;
  726. p_bal_l = getLeftBal(parent_pointers);
  727. p_bal_r = getRightBal(parent_pointers);
  728. run_coroutines(tio, [&tio, &child_pointers, F_cn_rot, new_child_pointers] (yield_t &yield)
  729. {mpc_select(tio, yield, child_pointers, F_cn_rot, new_child_pointers, child_pointers);},
  730. [&tio, &n_pointers, F_cn_rot, new_child_pointers] (yield_t &yield)
  731. {mpc_select(tio, yield, n_pointers, F_cn_rot, n_pointers, new_child_pointers);});
  732. c_bal_l = getLeftBal(child_pointers);
  733. c_bal_r = getRightBal(child_pointers);
  734. run_coroutines(tio, [&tio, &c_bal_l, ret, s0] (yield_t &yield)
  735. {mpc_select(tio, yield, c_bal_l, ret.imbalance, c_bal_l, s0);},
  736. [&tio, &c_bal_r, ret, s0] (yield_t &yield)
  737. {mpc_select(tio, yield, c_bal_r, ret.imbalance, c_bal_r, s0);});
  738. /* In the double rotation case: balance of c and p have a tweak
  739. p_bal_ndpc <- !(n_bal_ndpc)
  740. c_bal_dpc <- !(n_bal_dpc) */
  741. size_t &aes_ops = tio.aes_ops();
  742. RegBS n_l0, n_r0;
  743. run_coroutines(tio, [&tio, &n_l0, n_l, &aes_ops] (yield_t &yield)
  744. { CDPF cdpf = tio.cdpf(yield);
  745. n_l0 = cdpf.is_zero(tio, yield, n_l, aes_ops);},
  746. [&tio, &n_r0, n_r, &aes_ops] (yield_t &yield)
  747. { CDPF cdpf = tio.cdpf(yield);
  748. n_r0 = cdpf.is_zero(tio, yield, n_r, aes_ops);});
  749. RegBS p_c_update, n_has_children;
  750. // n_has_children = !(n_l0 & n_r0)
  751. mpc_and(tio, yield, n_has_children, n_l0, n_r0);
  752. if(player0) {
  753. n_has_children^=1;
  754. }
  755. run_coroutines(tio, [&tio, &p_c_update, F_cn_rot, n_has_children] (yield_t &yield)
  756. {mpc_and(tio, yield, p_c_update, F_cn_rot, n_has_children);},
  757. [&tio, &n_bal_ndpc, ret, n_bal_l, n_bal_r] (yield_t &yield)
  758. {mpc_select(tio, yield, n_bal_ndpc, ret.dir_pc, n_bal_r, n_bal_l);},
  759. [&tio, &n_bal_dpc, ret, n_bal_l, n_bal_r] (yield_t &yield)
  760. {mpc_select(tio, yield, n_bal_dpc, ret.dir_pc, n_bal_l, n_bal_r);},
  761. [&tio, &p_bal_ndpc, ret, p_bal_r, p_bal_l] (yield_t &yield)
  762. {mpc_select(tio, yield, p_bal_ndpc, ret.dir_pc, p_bal_r, p_bal_l);});
  763. // !n_bal_ndpc, !n_bal_dpc
  764. if(player0) {
  765. n_bal_ndpc^=1;
  766. n_bal_dpc^=1;
  767. }
  768. run_coroutines(tio, [&tio, &p_bal_ndpc, p_c_update, n_bal_ndpc] (yield_t &yield)
  769. {mpc_select(tio, yield, p_bal_ndpc, p_c_update, p_bal_ndpc, n_bal_ndpc);},
  770. [&tio, &c_bal_dpc, p_c_update, n_bal_dpc] (yield_t &yield)
  771. {mpc_select(tio, yield, c_bal_dpc, p_c_update, c_bal_dpc, n_bal_dpc);});
  772. coroutines.emplace_back([&tio, &p_bal_r, ret, p_bal_ndpc] (yield_t &yield)
  773. {mpc_select(tio, yield, p_bal_r, ret.dir_pc, p_bal_ndpc, p_bal_r);});
  774. coroutines.emplace_back([&tio, &p_bal_l, ret, p_bal_ndpc] (yield_t &yield)
  775. {mpc_select(tio, yield, p_bal_l, ret.dir_pc, p_bal_l, p_bal_ndpc);});
  776. coroutines.emplace_back([&tio, &c_bal_r, ret, c_bal_dpc] (yield_t &yield)
  777. {mpc_select(tio, yield, c_bal_r, ret.dir_pc, c_bal_r, c_bal_dpc);});
  778. coroutines.emplace_back([&tio, &c_bal_l, ret, c_bal_dpc] (yield_t &yield)
  779. {mpc_select(tio, yield, c_bal_l, ret.dir_pc, c_bal_dpc, c_bal_l);});
  780. // If double rotation (LR/RL) case, n ends up with 0 balance.
  781. // In all other cases, n's balance remains unaffected by rotation during insertion.
  782. coroutines.emplace_back([&tio, &n_bal_l, F_cn_rot, s0] (yield_t &yield)
  783. {mpc_select(tio, yield, n_bal_l, F_cn_rot, n_bal_l, s0);});
  784. coroutines.emplace_back([&tio, &n_bal_r, F_cn_rot, s0] (yield_t &yield)
  785. {mpc_select(tio, yield, n_bal_r, F_cn_rot, n_bal_r, s0);});
  786. run_coroutines(tio, coroutines);
  787. setLeftBal(parent_pointers, p_bal_l);
  788. setRightBal(parent_pointers, p_bal_r);
  789. setLeftBal(child_pointers, c_bal_l);
  790. setRightBal(child_pointers, c_bal_r);
  791. setLeftBal(n_pointers, n_bal_l);
  792. setRightBal(n_pointers, n_bal_r);
  793. // Write back update pointers and balances into gp, p, c, and n
  794. if(OPTIMIZED) {
  795. run_coroutines(tio,
  796. [&tio, &A, &oidx_n, n_pointers, old_n_pointers]
  797. (yield_t &yield) {
  798. auto Acont = A.context(yield);
  799. Acont[oidx_n.value()].NODE_POINTERS+=(n_pointers - old_n_pointers);
  800. },
  801. [&tio, &A, &oidx_c, child_pointers, old_child_pointers]
  802. (yield_t &yield) {
  803. auto Acont = A.context(yield);
  804. Acont[oidx_c.value()].NODE_POINTERS+=(child_pointers - old_child_pointers);
  805. },
  806. [&tio, &A, &oidx_p, parent_pointers, old_parent_pointers]
  807. (yield_t &yield) {
  808. auto Acont = A.context(yield);
  809. Acont[oidx_p.value()].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
  810. },
  811. [&tio, &A, &oidx_gp, gp_pointers, old_gp_pointers]
  812. (yield_t &yield) {
  813. auto Acont = A.context(yield);
  814. Acont[oidx_gp.value()].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
  815. });
  816. } else {
  817. A[ret.c_node].NODE_POINTERS = child_pointers;
  818. A[ret.p_node].NODE_POINTERS = parent_pointers;
  819. A[ret.gp_node].NODE_POINTERS = gp_pointers;
  820. A[n_node].NODE_POINTERS = n_pointers;
  821. }
  822. // Handle root pointer update (if F_ur is true)
  823. // If F_ur and we did a double rotation: root <-- new node
  824. // If F_ur and we did a single rotation: root <-- child node
  825. RegXS temp_root = root;
  826. run_coroutines(tio, [&tio, &temp_root, F_ur, ret] (yield_t &yield)
  827. {mpc_select(tio, yield, temp_root, F_ur, temp_root, ret.c_node, AVL_PTR_SIZE);},
  828. [&tio, &F_ur, F_gp, F_dr] (yield_t &yield)
  829. {mpc_and(tio, yield, F_ur, F_gp, F_dr);});
  830. mpc_select(tio, yield, temp_root, F_ur, temp_root, n_node, AVL_PTR_SIZE);
  831. root = temp_root;
  832. }
  833. }
  834. bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>::Flat &A,
  835. int TTL, RegBS isDummy, Node *ret_node) {
  836. if(TTL==0) {
  837. // Reconstruct and return isDummy
  838. // If we found the key, then isDummy will be true
  839. bool found = reconstruct_RegBS(tio, yield, isDummy);
  840. return found;
  841. }
  842. RegBS isNotDummy = isDummy ^ (!tio.player());
  843. Node cnode = A[ptr];
  844. // Compare key
  845. CDPF cdpf = tio.cdpf(yield);
  846. auto [lt, eq, gt] = cdpf.compare(tio, yield, key - cnode.key, tio.aes_ops());
  847. // Depending on [lteq, gt] select the next ptr/index as
  848. // first AVL_PTR_SIZE bits of cnode.pointers if lteq
  849. // next AVL_PTR_SIZE bits of cnode.pointers if gt
  850. // (the last 2 bits are balance bits)
  851. RegXS left = getAVLLeftPtr(cnode.pointers);
  852. RegXS right = getAVLRightPtr(cnode.pointers);
  853. RegXS next_ptr;
  854. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  855. RegBS F_found;
  856. // If we haven't found the key yet, and the lookup matches the current node key,
  857. // then we found the node to return
  858. // If multiple keys in the tree match the lookup key, this returns the last match.
  859. // Extracting the first match would add an extra round here, since the
  860. // F_found flag will have to be computed first, then the next two based on F_found
  861. // instead of eq
  862. run_coroutines(tio,
  863. [&tio, &F_found, isNotDummy, eq](yield_t &yield)
  864. { mpc_and(tio, yield, F_found, isNotDummy, eq);},
  865. [&tio, &ret_node, F_found, &cnode](yield_t &yield)
  866. { mpc_select(tio, yield, ret_node->key, F_found, ret_node->key, cnode.key);},
  867. [&tio, &ret_node, F_found, &cnode](yield_t &yield)
  868. { mpc_select(tio, yield, ret_node->value, F_found, ret_node->value, cnode.value);});
  869. isDummy^=F_found;
  870. bool found = lookup(tio, yield, next_ptr, key, A, TTL-1, isDummy, ret_node);
  871. return found;
  872. }
  873. bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
  874. auto A = oram.flat(tio, yield);
  875. RegBS isDummy;
  876. bool found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
  877. return found;
  878. }
  879. void AVL::updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
  880. RegBS c_prime, const avl_del_return &ret_struct) {
  881. bool player0 = tio.player()==0;
  882. RegBS F_rr; // Flag to resolve F_r by updating right child ptr
  883. RegBS F_rl; // Flag to resolve F_r by updating left child ptr
  884. RegBS nt_c_prime = c_prime;
  885. if(player0) {
  886. nt_c_prime^=1;
  887. }
  888. run_coroutines(tio, [&tio, &F_rr, c_prime, ret_struct](yield_t &yield)
  889. { mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);},
  890. [&tio, &F_rl, nt_c_prime, ret_struct](yield_t &yield)
  891. { mpc_and(tio, yield, F_rl, nt_c_prime, ret_struct.F_r);});
  892. run_coroutines(tio, [&tio, &right, F_rr, ret_struct](yield_t &yield)
  893. { mpc_select(tio, yield, right, F_rr, right, ret_struct.ret_ptr);},
  894. [&tio, &left, F_rl, ret_struct](yield_t &yield)
  895. { mpc_select(tio, yield, left, F_rl, left, ret_struct.ret_ptr);});
  896. }
  897. // Perform rotations if imbalance (else dummy rotations)
  898. /*
  899. For capturing both the symmetric L and R cases of rotations, we'll capture directions with
  900. dpc = dir_pc = direction from parent to child, and
  901. ndpc = not(dir_pc)
  902. When we travelled down the stack, we went from p->c. But in deletions to handle any imbalance
  903. we look at c's sibling cs (child's sibling). And the rotation is between p and cs if there
  904. was an imbalance at p, and perhaps even cs and its child (the child in dir_pc, as that's the
  905. only case that results in a double rotation when deleting).
  906. In case of an imbalance we have to always rotate p->cs link. (L or R case)
  907. If cs.bal_(dir_pc), then we have a double rotation (LR or RL) case.
  908. In such cases, first rotate cs->gcs link, and then p->cs link. gcs = grandchild on cs path
  909. Layout: In the R (or LR) case:
  910. p
  911. / \
  912. cs c
  913. / \
  914. a gcs
  915. / \
  916. x y
  917. - One of x or y must exist for it to be an LR case,
  918. since then cs.bal_(dir_pc) = cs.bal_r = 1
  919. Layout: In the L (or RL) case:
  920. p
  921. / \
  922. c cs
  923. / \
  924. gcs a
  925. / \
  926. x y
  927. - One of x or y must exist for it to be an RL case,
  928. since then cs.bal_(dir_pc) = cs.bal_l = 1
  929. (Note: if double rotation case, in the second rotation cs is actually gcs,
  930. since the the first rotation swaps their positions)
  931. */
  932. void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
  933. Duoram<Node>::OblivIndex<RegXS, 1> oidx, RegXS oidx_oldptrs, RegXS ptr,
  934. RegXS nodeptrs, RegBS new_p_bal_l, RegBS new_p_bal_r, RegBS &bal_upd,
  935. RegBS c_prime, RegXS cs_ptr, RegBS imb, RegBS &F_ri,
  936. avl_del_return &ret_struct) {
  937. bool player0 = tio.player()==0;
  938. RegBS s0, s1;
  939. s1.set(tio.player()==1);
  940. Node cs_node, gcs_node;
  941. std::optional<Duoram<Node>::OblivIndex<RegXS,1>> oidx_cs;
  942. RegXS old_cs_ptr, old_gcs_ptr;
  943. nbits_t width = ceil(log2(cur_max_index+1));
  944. if(OPTIMIZED) {
  945. oidx_cs.emplace(tio, yield, cs_ptr, width);
  946. cs_node = A[oidx_cs.value()];
  947. old_cs_ptr = cs_node.pointers;
  948. } else {
  949. cs_node = A[cs_ptr];
  950. }
  951. //dirpc = dir_pc = dpc = c_prime
  952. RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, p_bal_ndpc, p_bal_dpc;
  953. RegBS F_dr, not_c_prime;
  954. RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
  955. not_c_prime = c_prime;
  956. if(player0) {
  957. not_c_prime^=1;
  958. }
  959. // child's sibling node's balances in dir_pc (dpc), and not_dir_pc (ndpc)
  960. cs_bal_l = getLeftBal(cs_node.pointers);
  961. cs_bal_r = getRightBal(cs_node.pointers);
  962. cs_left = getAVLLeftPtr(cs_node.pointers);
  963. cs_right = getAVLRightPtr(cs_node.pointers);
  964. std::vector<coro_t> coroutines;
  965. RegBS gcs_balanced, gcs_bal_dpc, gcs_bal_ndpc;
  966. RegBS ndpc_is_l, ndpc_is_r, dpc_is_l, dpc_is_r;
  967. // First flags to check dpc = L/R, and similarly ndpc = L/R
  968. // If it's not an imbalance all of these are zeroes, resulting in no updates
  969. // to the pointers and balances in the end when we write back post imbalance
  970. // fix pointers and balances.
  971. coroutines.emplace_back([&tio, &ndpc_is_l, c_prime, imb] (yield_t &yield)
  972. { mpc_and(tio, yield, ndpc_is_l, imb, c_prime);});
  973. coroutines.emplace_back([&tio, &ndpc_is_r, imb, not_c_prime](yield_t &yield)
  974. { mpc_and(tio, yield, ndpc_is_r, imb, not_c_prime);});
  975. coroutines.emplace_back([&tio, &dpc_is_l, imb, not_c_prime](yield_t &yield)
  976. { mpc_and(tio, yield, dpc_is_l, imb, not_c_prime);});
  977. coroutines.emplace_back([&tio, &dpc_is_r, imb, c_prime](yield_t &yield)
  978. { mpc_and(tio, yield, dpc_is_r, imb, c_prime);});
  979. run_coroutines(tio, coroutines);
  980. coroutines.clear();
  981. coroutines.emplace_back(
  982. [&tio, &cs_bal_dpc, dpc_is_r, cs_bal_l, cs_bal_r] (yield_t &yield)
  983. { mpc_select(tio, yield, cs_bal_dpc, dpc_is_r, cs_bal_l, cs_bal_r);});
  984. coroutines.emplace_back(
  985. [&tio, &cs_bal_ndpc, ndpc_is_l, cs_bal_r, cs_bal_l](yield_t &yield)
  986. { mpc_select(tio, yield, cs_bal_ndpc, ndpc_is_l, cs_bal_r, cs_bal_l);});
  987. coroutines.emplace_back(
  988. [&tio, &cs_dpc, dpc_is_r, cs_left, cs_right](yield_t &yield)
  989. { mpc_select(tio, yield, cs_dpc, dpc_is_r, cs_left, cs_right);});
  990. coroutines.emplace_back(
  991. [&tio, &cs_ndpc, ndpc_is_l, cs_right, cs_left](yield_t &yield)
  992. { mpc_select(tio, yield, cs_ndpc, ndpc_is_l, cs_right, cs_left);});
  993. coroutines.emplace_back(
  994. [&tio, &p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r](yield_t &yield)
  995. { mpc_select(tio, yield, p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r);});
  996. coroutines.emplace_back(
  997. [&tio, &p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r] (yield_t &yield)
  998. { mpc_select(tio, yield, p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r);});
  999. run_coroutines(tio, coroutines);
  1000. coroutines.clear();
  1001. // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
  1002. run_coroutines(tio, [&tio, &F_dr, imb, cs_bal_dpc] (yield_t &yield)
  1003. { mpc_and(tio, yield, F_dr, imb, cs_bal_dpc);},
  1004. [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
  1005. { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
  1006. std::optional<Duoram<Node>::template OblivIndex<RegXS,1>> oidx_gcs;
  1007. if(OPTIMIZED) {
  1008. oidx_gcs.emplace(tio, yield, gcs_ptr, width);
  1009. gcs_node = A[oidx_gcs.value()];
  1010. old_gcs_ptr = gcs_node.pointers;
  1011. } else {
  1012. gcs_node = A[gcs_ptr];
  1013. }
  1014. RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
  1015. RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
  1016. run_coroutines(tio, [&tio, &gcs_bal_dpc, dpc_is_r, gcs_bal_l, gcs_bal_r](yield_t &yield)
  1017. { mpc_select(tio, yield, gcs_bal_dpc, dpc_is_r, gcs_bal_l, gcs_bal_r);},
  1018. [&tio, &gcs_bal_ndpc, ndpc_is_r, gcs_bal_l, gcs_bal_r](yield_t &yield)
  1019. { mpc_select(tio, yield, gcs_bal_ndpc, ndpc_is_r, gcs_bal_l, gcs_bal_r);});
  1020. // First rotation: cs->gcs link
  1021. rotate(tio, yield, nodeptrs, cs_ptr, cs_node.pointers, gcs_ptr,
  1022. gcs_node.pointers, not_c_prime, c_prime, F_dr, s0);
  1023. // If F_dr, we did first rotation. Then cs and gcs need to swap before the second rotate.
  1024. RegXS new_cs_pointers, new_cs, new_ptr;
  1025. run_coroutines(tio, [&tio, &new_cs_pointers, F_dr, cs_node, gcs_node](yield_t &yield)
  1026. { mpc_select(tio, yield, new_cs_pointers, F_dr, cs_node.pointers, gcs_node.pointers);},
  1027. [&tio, &new_cs, F_dr, cs_ptr, gcs_ptr](yield_t &yield)
  1028. { mpc_select(tio, yield, new_cs, F_dr, cs_ptr, gcs_ptr, AVL_PTR_SIZE);},
  1029. [&tio, &new_ptr, F_dr, cs_ptr, gcs_ptr](yield_t &yield)
  1030. { mpc_select(tio, yield, new_ptr, F_dr, cs_ptr, gcs_ptr);});
  1031. // Second rotation: p->cs link
  1032. // Since we don't have access to gp node here we just send a null and s0
  1033. // for gp_pointers and dir_gpp. Instead this pointer fix is handled by F_r
  1034. // and ret_struct.ret_ptr.
  1035. rotate(tio, yield, null, ptr, nodeptrs, new_cs,
  1036. new_cs_pointers, s0, not_c_prime, imb, s1);
  1037. // If imb (we do some rotation), then update F_r, and ret_ptr, to
  1038. // fix the gp->p link (There are F_r clauses later, but they are mutually
  1039. // exclusive events. They will never trigger together.)
  1040. F_ri = imb;
  1041. coroutines.emplace_back([&tio, &ret_struct, imb, new_ptr](yield_t &yield) {
  1042. mpc_select(tio, yield, ret_struct.ret_ptr, imb, ret_struct.ret_ptr, new_ptr);
  1043. });
  1044. // Write back new_cs_pointers correctly to (cs_node/gcs_node).pointers
  1045. // and then balance the nodes
  1046. coroutines.emplace_back([&tio, &cs_node, F_dr, new_cs_pointers](yield_t &yield) {
  1047. mpc_select(tio, yield, cs_node.pointers, F_dr, new_cs_pointers, cs_node.pointers);
  1048. });
  1049. coroutines.emplace_back([&tio, &gcs_node, F_dr, new_cs_pointers](yield_t &yield) {
  1050. mpc_select(tio, yield, gcs_node.pointers, F_dr, gcs_node.pointers, new_cs_pointers);
  1051. });
  1052. run_coroutines(tio, coroutines);
  1053. coroutines.clear();
  1054. /*
  1055. Update balances based on imbalance and type of rotations that happen.
  1056. In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
  1057. */
  1058. RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
  1059. RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
  1060. if(player0) {
  1061. cs_zero_bal^=1;
  1062. }
  1063. run_coroutines(tio, [&tio, &IC1, imb, cs_bal_ndpc] (yield_t &yield) {
  1064. // IC1 = Single rotation (L/R). L/R = dpc
  1065. mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
  1066. },
  1067. // IC2 = Single rotation (L/R). L/R = dpc
  1068. [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
  1069. mpc_and(tio, yield, IC2, imb, cs_zero_bal);
  1070. },
  1071. [&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
  1072. // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
  1073. mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
  1074. });
  1075. /* IC3 has 3 subcases:
  1076. IC3_S1: gcs_bal_dpc = 0, gcs_bal_ndpc = 1
  1077. IC3_S2: gcs_bal_dpc = 1, gc_bal_ndpc = 0
  1078. IC3_S3: gcs_bal_dpc = 0, gcs_bal_ndpc = 0
  1079. IC3_S1: p_dpc <- 1
  1080. cs_dpc <- 0
  1081. (gcs_bal stays same)
  1082. IC3_S2: Swap cs_dpc and cs_ndpc (1 0 -> - 1).
  1083. cs_dpc <- 0, cs_ndpc <- 1
  1084. gcs_bal_dpc <- 0
  1085. IC3_S3: cs_dpc <- 0
  1086. gcs_bal stays same
  1087. */
  1088. RegBS IC3_S1, IC3_S2, IC3_S3;
  1089. gcs_balanced = gcs_bal_dpc ^ gcs_bal_ndpc;
  1090. if(player0) {
  1091. gcs_balanced^=1;
  1092. }
  1093. // Updating balance bits of p, cs, and gcs.
  1094. // Parallel Ops 1
  1095. coroutines.emplace_back([&tio, &cs_bal_ndpc, IC1, s0](yield_t &yield)
  1096. { mpc_select(tio, yield, cs_bal_ndpc, IC1, cs_bal_ndpc, s0);});
  1097. coroutines.emplace_back([&tio, &cs_bal_dpc, IC2, s1](yield_t &yield)
  1098. { mpc_select(tio, yield, cs_bal_dpc, IC2, cs_bal_dpc, s1);});
  1099. coroutines.emplace_back([&tio, &p_bal_ndpc, IC2, s1](yield_t &yield)
  1100. { mpc_select(tio, yield, p_bal_ndpc, IC2, p_bal_ndpc, s1);});
  1101. coroutines.emplace_back([&tio, &IC3_S1, IC3, gcs_bal_ndpc](yield_t &yield)
  1102. { mpc_and(tio, yield, IC3_S1, IC3, gcs_bal_ndpc);});
  1103. coroutines.emplace_back([&tio, &IC3_S2, IC3, gcs_bal_dpc](yield_t &yield)
  1104. { mpc_and(tio, yield, IC3_S2, IC3, gcs_bal_dpc);});
  1105. coroutines.emplace_back([&tio, &IC3_S3, IC3, gcs_balanced](yield_t &yield)
  1106. { mpc_and(tio, yield, IC3_S3, IC3, gcs_balanced);});
  1107. // In the IC2 case bal_upd = 0 (The rotation doesn't end up
  1108. // decreasing height of this subtree.)
  1109. coroutines.emplace_back([&tio, &bal_upd, IC2, s0](yield_t &yield)
  1110. { mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
  1111. run_coroutines(tio, coroutines);
  1112. coroutines.clear();
  1113. // Parallel Ops 2
  1114. coroutines.emplace_back([&tio, &cs_bal_dpc, IC3, s0](yield_t &yield)
  1115. { mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);});
  1116. coroutines.emplace_back([&tio, &p_bal_dpc, IC3_S1, s1](yield_t &yield)
  1117. { mpc_select(tio, yield, p_bal_dpc, IC3_S1, p_bal_dpc, s1);});
  1118. coroutines.emplace_back([&tio, &cs_bal_ndpc, IC3_S2, s1](yield_t &yield)
  1119. { mpc_select(tio, yield, cs_bal_ndpc, IC3_S2, cs_bal_ndpc, s1);});
  1120. coroutines.emplace_back([&tio, &gcs_bal_dpc, IC3_S2, s0](yield_t &yield)
  1121. { mpc_select(tio, yield, gcs_bal_dpc, IC3_S2, gcs_bal_dpc, s0);});
  1122. run_coroutines(tio, coroutines);
  1123. coroutines.clear();
  1124. // Write back updated balance bits (Parallel batch 1)
  1125. // Updating gcs_bal_l/r
  1126. coroutines.emplace_back([&tio, &gcs_bal_r, dpc_is_r, gcs_bal_dpc](yield_t &yield)
  1127. { mpc_select(tio, yield, gcs_bal_r, dpc_is_r, gcs_bal_r, gcs_bal_dpc);});
  1128. coroutines.emplace_back([&tio, &gcs_bal_l, dpc_is_l, gcs_bal_dpc](yield_t &yield)
  1129. { mpc_select(tio, yield, gcs_bal_l, dpc_is_l, gcs_bal_l, gcs_bal_dpc);});
  1130. // Updating cs_bal_l/r (cs_bal_dpc effected by IC3, cs_bal_ndpc effected by IC1,2)
  1131. coroutines.emplace_back([&tio, &cs_bal_r, dpc_is_r, cs_bal_dpc](yield_t &yield)
  1132. { mpc_select(tio, yield, cs_bal_r, dpc_is_r, cs_bal_r, cs_bal_dpc);});
  1133. coroutines.emplace_back([&tio, &cs_bal_l, dpc_is_l, cs_bal_dpc](yield_t &yield)
  1134. { mpc_select(tio, yield, cs_bal_l, dpc_is_l, cs_bal_l, cs_bal_dpc);});
  1135. // Updating new_p_bal_l/r (p_bal_ndpc effected by IC2)
  1136. coroutines.emplace_back([&tio, &new_p_bal_r, ndpc_is_r, p_bal_ndpc] (yield_t &yield)
  1137. { mpc_select(tio, yield, new_p_bal_r, ndpc_is_r, new_p_bal_r, p_bal_ndpc);});
  1138. coroutines.emplace_back([&tio, &new_p_bal_l, ndpc_is_l, p_bal_ndpc](yield_t &yield)
  1139. { mpc_select(tio, yield, new_p_bal_l, ndpc_is_l, new_p_bal_l, p_bal_ndpc);});
  1140. run_coroutines(tio, coroutines);
  1141. coroutines.clear();
  1142. // Write back updated balance bits (Parallel batch 2)
  1143. coroutines.emplace_back([&tio, &cs_bal_r, ndpc_is_r, cs_bal_ndpc] (yield_t &yield)
  1144. { mpc_select(tio, yield, cs_bal_r, ndpc_is_r, cs_bal_r, cs_bal_ndpc);});
  1145. coroutines.emplace_back([&tio, &cs_bal_l, ndpc_is_l, cs_bal_ndpc](yield_t &yield)
  1146. { mpc_select(tio, yield, cs_bal_l, ndpc_is_l, cs_bal_l, cs_bal_ndpc);});
  1147. run_coroutines(tio, coroutines);
  1148. coroutines.clear();
  1149. // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
  1150. setLeftBal(gcs_node.pointers, gcs_bal_l);
  1151. setRightBal(gcs_node.pointers, gcs_bal_r);
  1152. setLeftBal(cs_node.pointers, cs_bal_l);
  1153. setRightBal(cs_node.pointers, cs_bal_r);
  1154. setLeftBal(nodeptrs, new_p_bal_l);
  1155. setRightBal(nodeptrs, new_p_bal_r);
  1156. // Write back updated pointers correctly accounting for rotations
  1157. if(OPTIMIZED) {
  1158. coroutines.emplace_back(
  1159. [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
  1160. auto acont = A.context(yield);
  1161. (acont[oidx_cs.value()].NODE_POINTERS)+= (cs_node.pointers - old_cs_ptr);});
  1162. coroutines.emplace_back(
  1163. [&tio, &A, &oidx_gcs, &gcs_node, old_gcs_ptr] (yield_t &yield) {
  1164. auto acont = A.context(yield);
  1165. (acont[oidx_gcs.value()].NODE_POINTERS)+= (gcs_node.pointers - old_gcs_ptr);});
  1166. coroutines.emplace_back(
  1167. [&tio, &A, &oidx, nodeptrs, oidx_oldptrs] (yield_t &yield) {
  1168. auto acont = A.context(yield);
  1169. (acont[oidx].NODE_POINTERS)+=(nodeptrs - oidx_oldptrs);});
  1170. run_coroutines(tio, coroutines);
  1171. coroutines.clear();
  1172. } else {
  1173. A[cs_ptr].NODE_POINTERS = cs_node.pointers;
  1174. A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
  1175. A[ptr].NODE_POINTERS = nodeptrs;
  1176. }
  1177. }
  1178. /*
  1179. Update the return structure
  1180. F_dh = Delete Here flag,
  1181. F_sf = successor found (no more left children while trying to find successor)
  1182. F_r = Flag for updating with ret_struct.ret_ptr. F_r happens in 3 cases.
  1183. It's subflag F_rs, handles cases (i) and (ii).
  1184. F_rs = Subflag of F_r. F_rs indicates if we need to update a child pointer
  1185. at this level by skipping the current child in the direction of
  1186. traversal. We do this in two cases (i) and (ii).
  1187. F_r cases:
  1188. (i) F_d & (!F_2) : If we delete here, and this node does not have
  1189. 2 children (;i.e., we are not in the finding successor case)
  1190. (ii) F_sf: Found the successor (no more left children while
  1191. traversing to find successor)
  1192. In cases i and ii we skip the next node, and make the current node
  1193. point to the node after the next node on the path.
  1194. (iii) We did rotation(s) at the lower level, changing the child in
  1195. that position. So we update it to the correct node in that
  1196. position now.
  1197. Whether skip happens or just update happens is handled by F_r and
  1198. the ret_struct.ret_ptr that is set.
  1199. */
  1200. void AVL::updateRetStruct(MPCTIO &tio, yield_t &yield, RegXS ptr, RegBS F_rs, RegBS F_dh,
  1201. RegBS F_ri, RegBS &bal_upd, avl_del_return &ret_struct) {
  1202. bool player0 = tio.player()==0;
  1203. RegBS s0, s1;
  1204. s1.set(tio.player()==1);
  1205. // F_rs and F_ri will never trigger together. So the line below
  1206. // set ret_ptr to the correct pointer to handle either case
  1207. // If neither F_rs nor F_ri, we set the ret_ptr to current ptr.
  1208. RegBS F_nr;
  1209. mpc_or(tio, yield, F_nr, F_rs, F_ri);
  1210. // F_nr = F_rs || F_ri
  1211. ret_struct.F_r = F_nr;
  1212. if(player0) {
  1213. F_nr^=1;
  1214. }
  1215. // F_nr = !(F_rs || F_ri)
  1216. run_coroutines(tio, [&tio, &ret_struct, F_nr, ptr](yield_t &yield)
  1217. { mpc_select(tio, yield, ret_struct.ret_ptr, F_nr, ret_struct.ret_ptr, ptr);},
  1218. [&tio, &bal_upd, F_rs, s1](yield_t &yield)
  1219. { // If F_rs, we skipped a node, so update bal_upd to 1
  1220. mpc_select(tio, yield, bal_upd, F_rs, bal_upd, s1);});
  1221. }
  1222. std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  1223. Duoram<Node>::Flat &A, RegBS found, RegBS find_successor, int TTL,
  1224. avl_del_return &ret_struct) {
  1225. bool player0 = tio.player()==0;
  1226. if(TTL==0) {
  1227. //Reconstruct and return found
  1228. bool success = reconstruct_RegBS(tio, yield, found);
  1229. RegBS zero;
  1230. return {success, zero};
  1231. } else {
  1232. Node node;
  1233. RegXS oldptrs;
  1234. // This OblivIndex creation is not required if we are not running optimized version,
  1235. // but for convenience we leave it in, so that fixImbalance has an oidx to be supplied
  1236. // when we are in the non-optimized setting.
  1237. nbits_t width = ceil(log2(cur_max_index+1));
  1238. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
  1239. if(OPTIMIZED) {
  1240. node = A[oidx];
  1241. oldptrs = node.pointers;
  1242. } else {
  1243. node = A[ptr];
  1244. }
  1245. RegXS left = getAVLLeftPtr(node.pointers);
  1246. RegXS right = getAVLRightPtr(node.pointers);
  1247. size_t &aes_ops = tio.aes_ops();
  1248. RegBS l0, r0, lt, eq, gt;
  1249. // Check if left and right children are 0
  1250. // l0: Is left child 0
  1251. // r0: Is right child 0
  1252. run_coroutines(tio, [&tio, &l0, left, &aes_ops](yield_t &yield)
  1253. { CDPF cdpf = tio.cdpf(yield);
  1254. l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
  1255. [&tio, &r0, right, &aes_ops](yield_t &yield)
  1256. { CDPF cdpf = tio.cdpf(yield);
  1257. r0 = cdpf.is_zero(tio, yield, right, aes_ops);},
  1258. // Compare key
  1259. [&tio, &lt, &eq, &gt, del_key, node, aes_ops](yield_t &yield)
  1260. { CDPF cdpf = tio.cdpf(yield);
  1261. auto [a, b, c] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  1262. lt = a; eq = b; gt = c;});
  1263. // c is the direction bit for next_ptr
  1264. // (c=0: go left or c=1: go right)
  1265. RegBS c = gt;
  1266. // lf = local found. We found the key to delete in this level.
  1267. RegBS lf = eq;
  1268. // F_{X}: Flags that indicate the number of children this node has
  1269. // F_0: no children, F_1: one child, F_2: both children
  1270. // F_n2: either F_0 or F_1
  1271. RegBS F_0, F_1, F_2, F_n2;
  1272. RegBS F_c1, F_c2, F_c3, F_c4, c_prime, F_dh, F_rs;
  1273. RegXS next_ptr, cs_ptr;
  1274. RegBS not_found = found;
  1275. if(player0) {
  1276. not_found^=1;
  1277. }
  1278. // F_1 = l0 \xor r0
  1279. F_1 = l0 ^ r0;
  1280. // F_0 = l0 & r0
  1281. // Case 1: lf & F_1
  1282. run_coroutines(tio, [&tio, &F_0, l0, r0](yield_t &yield)
  1283. { mpc_and(tio, yield, F_0, l0, r0);},
  1284. [&tio, &F_c1, lf, F_1](yield_t &yield)
  1285. { mpc_and(tio, yield, F_c1, lf, F_1);},
  1286. // Premptively computing flags for updateRetStruct in parallel
  1287. // with above operations.
  1288. [&tio, &F_dh, not_found, lf](yield_t &yield)
  1289. { mpc_and(tio, yield, F_dh, not_found, lf);});
  1290. // F_2 = !(F_0 ^ F_1) (Exactly 1 of F_0, F_1, and F_2 is true)
  1291. F_n2 = F_0 ^ F_1;
  1292. F_2 = F_n2;
  1293. if(player0) {
  1294. F_2^=1;
  1295. }
  1296. // s1: shares of 1 bit, s0: shares of 0 bit
  1297. RegBS s1, s0;
  1298. s1.set(tio.player()==1);
  1299. // We set next ptr based on c, but we need to handle three
  1300. // edge cases where we do not pick next_ptr by just the comparison result
  1301. // Case 1: found the node here (lf), and node has only one child.
  1302. // Then we iterate down the only child.
  1303. // Set c_prime for Case 1
  1304. run_coroutines(tio, [&tio, &c_prime, F_c1, c, l0](yield_t &yield)
  1305. { mpc_select(tio, yield, c_prime, F_c1, c, l0);},
  1306. [&tio, &F_c2, lf, F_2](yield_t &yield)
  1307. { mpc_and(tio, yield, F_c2, lf, F_2);},
  1308. // Premptively computing flags for updateRetStruct in parallel
  1309. // with above operations.
  1310. // If we have to i) delete here, and it doesn't have two children
  1311. // we have to update child pointer in parent with the returned pointer
  1312. [&tio, &F_rs, F_dh, F_n2](yield_t &yield)
  1313. { mpc_and(tio, yield, F_rs, F_dh, F_n2);});
  1314. // Case 2: found the node here (lf) and node has both children (F_2)
  1315. // In find successor case, so we find inorder successor for node to be deleted
  1316. // (inorder successor = go right and then find leftmost child.)
  1317. // Case 3: finding successor (find_successor) and node has both children (F_2)
  1318. // Go left.
  1319. run_coroutines(tio, [&tio, &c_prime, F_c2, s1](yield_t &yield)
  1320. { mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);},
  1321. [&tio, &F_c3, find_successor, F_2](yield_t &yield)
  1322. { mpc_and(tio, yield, F_c3, find_successor, F_2);});
  1323. // Case 4: finding successor (find_successor) and node has no more left children (l0)
  1324. // This is the successor node then.
  1325. // Go right (since no more left)
  1326. run_coroutines(tio, [&tio, &c_prime, F_c3, s0](yield_t &yield)
  1327. { mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);},
  1328. [&tio, &F_c4, find_successor, l0](yield_t &yield)
  1329. { mpc_and(tio, yield, F_c4, find_successor, l0);},
  1330. // Premptively computing flags for updateRetStruct in parallel
  1331. // with above operations.
  1332. [&tio, &ret_struct, F_c2](yield_t &yield)
  1333. { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
  1334. [&tio, &ret_struct, F_dh, ptr](yield_t &yield)
  1335. { mpc_select(tio, yield, ret_struct.N_d, F_dh, ret_struct.N_d, ptr);});
  1336. RegBS found_prime, find_successor_prime;
  1337. // F_sf = Flag for successor found.
  1338. RegBS F_sf = F_c4;
  1339. run_coroutines(tio, [&tio, &c_prime, F_c4, l0](yield_t &yield)
  1340. { mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);},
  1341. [&tio, &F_rs, F_sf](yield_t &yield)
  1342. { mpc_or(tio, yield, F_rs, F_rs, F_sf);},
  1343. [&tio, &ret_struct, F_sf, ptr](yield_t &yield)
  1344. { mpc_select(tio, yield, ret_struct.N_s, F_sf, ret_struct.N_s, ptr);});
  1345. // Set next_ptr
  1346. mpc_select(tio, yield, next_ptr, c_prime, left, right, AVL_PTR_SIZE);
  1347. // cs_ptr: child's sibling pointer
  1348. run_coroutines(tio, [&tio, &cs_ptr, c_prime, right, left](yield_t &yield)
  1349. { mpc_select(tio, yield, cs_ptr, c_prime, right, left, AVL_PTR_SIZE);},
  1350. [&tio, &found_prime, found, lf](yield_t &yield)
  1351. { mpc_or(tio, yield, found_prime, found, lf);},
  1352. // If in Case 2, set find_successor. We are now finding successor
  1353. [&tio, &find_successor_prime, find_successor, F_c2](yield_t &yield)
  1354. { mpc_or(tio, yield, find_successor_prime, find_successor, F_c2);});
  1355. // If in Case 4. Successor found here already. Toggle find_successor off
  1356. find_successor_prime=find_successor_prime^F_c4;
  1357. TTL-=1;
  1358. auto [key_found, bal_upd] = del(tio, yield, next_ptr, del_key, A, found_prime, find_successor_prime, TTL, ret_struct);
  1359. // If we didn't find the key, we can end here.
  1360. if(!key_found) {
  1361. return {false, s0};
  1362. }
  1363. updateChildPointers(tio, yield, left, right, c_prime, ret_struct);
  1364. setAVLLeftPtr(node.pointers, left);
  1365. setAVLRightPtr(node.pointers, right);
  1366. // Delay storing pointers back until balance updates are done as well.
  1367. // Since we resolved the F_r flag returned with updateChildPointers(),
  1368. // we set it back to 0.
  1369. ret_struct.F_r = s0;
  1370. RegBS p_bal_l, p_bal_r;
  1371. p_bal_l = getLeftBal(node.pointers);
  1372. p_bal_r = getRightBal(node.pointers);
  1373. #ifdef AVL_DEBUG
  1374. size_t rec_key = mpc_reconstruct(tio, yield, node.key);
  1375. bool rec_bal_upd = mpc_reconstruct(tio, yield, bal_upd);
  1376. printf("current_key = %ld, bal_upd (before updateBalanceDel) = %d\n", rec_key, rec_bal_upd);
  1377. #endif
  1378. auto [new_p_bal_l, new_p_bal_r, new_bal_upd, imb] =
  1379. updateBalanceDel(tio, yield, p_bal_l, p_bal_r, bal_upd, c_prime);
  1380. bal_upd = new_bal_upd;
  1381. #ifdef AVL_DEBUG
  1382. bool rec_imb = mpc_reconstruct(tio, yield, imb);
  1383. bool rec_new_bal_upd = mpc_reconstruct(tio, yield, new_bal_upd);
  1384. printf("new_bal_upd (after updateBalanceDel) = %d, imb = %d\n", rec_new_bal_upd, rec_imb);
  1385. #endif
  1386. // F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
  1387. RegBS F_ri;
  1388. fixImbalance(tio, yield, A, oidx, oldptrs, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd,
  1389. c_prime, cs_ptr, imb, F_ri, ret_struct);
  1390. #ifdef AVL_DEBUG
  1391. rec_imb = mpc_reconstruct(tio, yield, imb);
  1392. rec_bal_upd = mpc_reconstruct(tio, yield, bal_upd);
  1393. printf("imb (after fixImbalance) = %d, bal_upd = %d\n", rec_imb, rec_bal_upd);
  1394. #endif
  1395. updateRetStruct(tio, yield, ptr, F_rs, F_dh, F_ri, bal_upd, ret_struct);
  1396. #ifdef AVL_DEBUG
  1397. rec_bal_upd = mpc_reconstruct(tio, yield, bal_upd);
  1398. printf("bal_upd (after updateRetStruct) = %d\n", rec_bal_upd);
  1399. #endif
  1400. return {key_found, bal_upd};
  1401. }
  1402. }
  1403. /*
  1404. The main AVL delete function.
  1405. Trying to delete an item that does not exist in the tree will result in
  1406. an explicit (non-oblivious) failure.
  1407. */
  1408. bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
  1409. if(num_items==0) {
  1410. return false;
  1411. }
  1412. auto A = oram.flat(tio, yield, 0, cur_max_index+1);
  1413. if(num_items==1) {
  1414. //Delete root if root's key = del_key
  1415. Node zero;
  1416. nbits_t width = ceil(log2(cur_max_index+1));
  1417. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, root, width);
  1418. Node node = A[oidx];
  1419. // Compare key
  1420. CDPF cdpf = tio.cdpf(yield);
  1421. auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  1422. bool success = reconstruct_RegBS(tio, yield, eq);
  1423. if(success) {
  1424. empty_locations.emplace_back(root);
  1425. A[oidx] = zero;
  1426. num_items--;
  1427. return true;
  1428. } else {
  1429. return false;
  1430. }
  1431. } else {
  1432. int TTL = AVL_TTL(num_items);
  1433. // Flags for already found (found) item to delete and find successor (find_successor)
  1434. // if this deletion requires a successor swap
  1435. RegBS found, find_successor;
  1436. avl_del_return ret_struct;
  1437. auto [success, bal_upd] = del(tio, yield, root, del_key, A, found, find_successor, TTL, ret_struct);
  1438. //printf ("Success = %d\n", success);
  1439. if(!success){
  1440. return false;
  1441. }
  1442. else{
  1443. num_items--;
  1444. /*
  1445. printf("In delete's swap portion\n");
  1446. Node rec_del_node = A.reconstruct(A[ret_struct.N_d]);
  1447. Node rec_suc_node = A.reconstruct(A[ret_struct.N_s]);
  1448. printf("del_node key = %ld, suc_node key = %ld\n",
  1449. rec_del_node.key.ashare, rec_suc_node.key.ashare);
  1450. printf("flag_s = %d\n", ret_struct.F_ss.bshare);
  1451. */
  1452. Node del_node, suc_node;
  1453. nbits_t width = ceil(log2(cur_max_index+1));
  1454. std::optional<Duoram<Node>::template OblivIndex<RegXS,2>> oidx_nd;
  1455. std::optional<Duoram<Node>::template OblivIndex<RegXS,2>> oidx_ns;
  1456. std::vector<coro_t> coroutines;
  1457. if(OPTIMIZED) {
  1458. oidx_nd.emplace(tio, yield, ret_struct.N_d, width);
  1459. oidx_ns.emplace(tio, yield, ret_struct.N_s, width);
  1460. coroutines.emplace_back(
  1461. [&tio, &A, &oidx_nd, &del_node](yield_t &yield) {
  1462. auto acont = A.context(yield);
  1463. del_node = acont[oidx_nd.value()];});
  1464. coroutines.emplace_back(
  1465. [&tio, &A, &oidx_ns, &suc_node](yield_t &yield) {
  1466. auto acont = A.context(yield);
  1467. suc_node = acont[oidx_ns.value()];});
  1468. run_coroutines(tio, coroutines);
  1469. coroutines.clear();
  1470. } else{
  1471. del_node = A[ret_struct.N_d];
  1472. suc_node = A[ret_struct.N_s];
  1473. }
  1474. RegAS zero_as; RegXS zero_xs;
  1475. // Update root if needed
  1476. mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
  1477. /*
  1478. bool rec_F_ss = mpc_reconstruct(tio, yield, ret_struct.F_ss);
  1479. size_t rec_del_key = mpc_reconstruct(tio, yield, del_node.key);
  1480. size_t rec_suc_key = mpc_reconstruct(tio, yield, suc_node.key);
  1481. printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
  1482. rec_F_ss, rec_del_key, rec_suc_key);
  1483. */
  1484. RegXS old_del_value;
  1485. RegAS old_del_key;
  1486. RegXS empty_loc;
  1487. if(OPTIMIZED) {
  1488. old_del_value = del_node.value;
  1489. old_del_key = del_node.key;
  1490. }
  1491. run_coroutines(tio, [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
  1492. { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
  1493. [&tio, &del_node, ret_struct, suc_node] (yield_t &yield)
  1494. { mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);},
  1495. [&tio, &empty_loc, ret_struct](yield_t &yield)
  1496. { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
  1497. if(OPTIMIZED) {
  1498. coroutines.emplace_back(
  1499. [&tio, &A, &oidx_nd, &del_node, old_del_key] (yield_t &yield) {
  1500. auto acont = A.context(yield);
  1501. acont[oidx_nd.value()].NODE_KEY+=(del_node.key - old_del_key);
  1502. });
  1503. coroutines.emplace_back(
  1504. [&tio, &A, &oidx_nd, &del_node, old_del_value] (yield_t &yield) {
  1505. auto acont = A.context(yield);
  1506. acont[oidx_nd.value()].NODE_VALUE+=(del_node.value - old_del_value);
  1507. });
  1508. coroutines.emplace_back(
  1509. [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
  1510. auto acont = A.context(yield);
  1511. acont[oidx_ns.value()].NODE_KEY+=(-suc_node.key);
  1512. });
  1513. coroutines.emplace_back(
  1514. [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
  1515. auto acont = A.context(yield);
  1516. acont[oidx_ns.value()].NODE_VALUE+=(-suc_node.value);
  1517. });
  1518. run_coroutines(tio, coroutines);
  1519. coroutines.clear();
  1520. } else {
  1521. A[ret_struct.N_d].NODE_KEY = del_node.key;
  1522. A[ret_struct.N_d].NODE_VALUE = del_node.value;
  1523. A[ret_struct.N_s].NODE_KEY = zero_as;
  1524. A[ret_struct.N_s].NODE_VALUE = zero_xs;
  1525. }
  1526. //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
  1527. empty_locations.emplace_back(empty_loc);
  1528. }
  1529. return true;
  1530. }
  1531. }
  1532. /*
  1533. Initializes a complete tree of size 2^depth
  1534. */
  1535. void AVL::initialize(MPCTIO &tio, yield_t &yield, size_t depth) {
  1536. size_t init_size = (size_t(1)<<depth) - 1;
  1537. auto A = oram.flat(tio, yield);
  1538. A.explicitonly(true);
  1539. for(size_t i=1; i<=depth; i++) {
  1540. size_t start = size_t(1)<<(i-1);
  1541. size_t gap = size_t(1)<<i;
  1542. size_t current = start;
  1543. for(size_t j=1; j<=(size_t(1)<<(depth-i)); j++) {
  1544. //printf("current = %ld ", current);
  1545. Node node;
  1546. node.key.set(current * tio.player());
  1547. if(i!=1) {
  1548. //Set left and right child pointers and balance bits
  1549. size_t ptr_gap = start/2;
  1550. RegXS lptr, rptr;
  1551. lptr.set(tio.player() * (current-(ptr_gap)));
  1552. rptr.set(tio.player() * (current+(ptr_gap)));
  1553. setAVLLeftPtr(node.pointers, lptr);
  1554. setAVLRightPtr(node.pointers, rptr);
  1555. }
  1556. A[current] = node;
  1557. current+=gap;
  1558. }
  1559. }
  1560. A.explicitonly(false);
  1561. // Set num_items to init_size after they have been initialized;
  1562. num_items = init_size;
  1563. cur_max_index = num_items;
  1564. // Set root correctly
  1565. root.set(tio.player() * size_t(1)<<(depth-1));
  1566. }
  1567. // Now we use the AVL class in various ways. This function is called by
  1568. // online.cpp.
  1569. void avl(MPCIO &mpcio,
  1570. const PRACOptions &opts, char **args)
  1571. {
  1572. int nargs = 0;
  1573. while(args[nargs]!=nullptr) {
  1574. ++nargs;
  1575. }
  1576. int depth = 0; // Initialization depth
  1577. size_t n_inserts = 0; // Max ORAM_SIZE = 2^depth + n_inserts
  1578. size_t n_deletes = 0;
  1579. bool run_sanity = 0;
  1580. bool optimized = false;
  1581. // Process command line arguments
  1582. for (int i = 0; i < nargs; i += 2) {
  1583. std::string option = args[i];
  1584. if (option == "-m" && i + 1 < nargs) {
  1585. depth = std::atoi(args[i + 1]);
  1586. } else if (option == "-i" && i + 1 < nargs) {
  1587. n_inserts = std::atoi(args[i + 1]);
  1588. } else if (option == "-e" && i + 1 < nargs) {
  1589. n_deletes = std::atoi(args[i + 1]);
  1590. } else if (option == "-opt" && i + 1 < nargs) {
  1591. optimized = std::atoi(args[i + 1]);
  1592. } else if (option == "-s" && i + 1 < nargs) {
  1593. run_sanity = std::atoi(args[i + 1]);
  1594. }
  1595. }
  1596. /* The ORAM will be initialized with 2^depth-1 items, but the 0 slot is reserved.
  1597. So we initialize (initial inserts) with 2^depth-2 items.
  1598. The ORAM size is set to 2^depth-1 + n_insert.
  1599. */
  1600. size_t init_size = (size_t(1)<<(depth));
  1601. size_t oram_size = init_size + n_inserts;
  1602. MPCTIO tio(mpcio, 0, opts.num_cpu_threads);
  1603. run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes, run_sanity, optimized] (yield_t &yield) {
  1604. //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
  1605. std::cout << "\n===== SETUP =====\n";
  1606. AVL tree(tio.player(), oram_size, optimized);
  1607. tree.initialize(tio, yield, depth);
  1608. //tree.pretty_print(tio, yield);
  1609. tio.sync_lamport();
  1610. Node node;
  1611. mpcio.dump_stats(std::cout);
  1612. std::cout << "\n===== INSERTS =====\n";
  1613. mpcio.reset_stats();
  1614. tio.reset_lamport();
  1615. for(size_t i = 1; i<=n_inserts; i++) {
  1616. randomize_node(node);
  1617. size_t ikey;
  1618. #ifdef AVL_RANDOMIZE_INSERTS
  1619. ikey = (1+(rand()%oram_size));
  1620. #else
  1621. ikey = (i+init_size);
  1622. #endif
  1623. printf("Insert key = %ld\n", ikey);
  1624. node.key.set(ikey * tio.player());
  1625. tree.insert(tio, yield, node);
  1626. if(run_sanity) {
  1627. tree.pretty_print(tio, yield);
  1628. if(tio.player()==0) {
  1629. assert(tree.check_avl(tio, yield));
  1630. } else {
  1631. tree.check_avl(tio, yield);
  1632. }
  1633. }
  1634. //tree.print_oram(tio, yield);
  1635. }
  1636. tio.sync_lamport();
  1637. mpcio.dump_stats(std::cout);
  1638. std::cout << "\n===== DELETES =====\n";
  1639. mpcio.reset_stats();
  1640. tio.reset_lamport();
  1641. for(size_t i = 1; i<=n_deletes; i++) {
  1642. RegAS del_key;
  1643. size_t dkey;
  1644. #ifdef AVL_RANDOMIZE_INSERTS
  1645. dkey = 1 + (rand()%init_size);
  1646. #else
  1647. dkey = i + 0;
  1648. #endif
  1649. del_key.set(dkey * tio.player());
  1650. printf("Deletion key = %ld\n", dkey);
  1651. tree.del(tio, yield, del_key);
  1652. if(run_sanity) {
  1653. tree.pretty_print(tio, yield);
  1654. if(tio.player()==0) {
  1655. assert(tree.check_avl(tio, yield));
  1656. } else {
  1657. tree.check_avl(tio, yield);
  1658. }
  1659. }
  1660. }
  1661. });
  1662. }
  1663. /*
  1664. AVL tests by default run the optimized AVL tree protocols.
  1665. */
  1666. void avl_tests(MPCIO &mpcio,
  1667. const PRACOptions &opts, char **args)
  1668. {
  1669. // Not taking arguments for tests
  1670. nbits_t depth=4;
  1671. size_t items = (size_t(1)<<depth)-1;
  1672. MPCTIO tio(mpcio, 0, opts.num_cpu_threads);
  1673. run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
  1674. size_t size = size_t(1)<<depth;
  1675. bool player0 = tio.player()==0;
  1676. AVL tree(tio.player(), size);
  1677. // (T1) : Test 1 : L rotation (root modified)
  1678. /*
  1679. Operation:
  1680. 5 7
  1681. \ / \
  1682. 7 ---> 5 9
  1683. \
  1684. 9
  1685. T1 checks:
  1686. - root is 7
  1687. - 5,7,9 in correct positions
  1688. - 5 and 9 have no children and 0 balances
  1689. */
  1690. {
  1691. bool success = true;
  1692. int insert_array[] = {5, 7, 9};
  1693. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  1694. Node node;
  1695. for(size_t i = 0; i<insert_array_size; i++) {
  1696. randomize_node(node);
  1697. node.key.set(insert_array[i] * tio.player());
  1698. tree.insert(tio, yield, node);
  1699. success &= tree.check_avl(tio, yield);
  1700. }
  1701. Duoram<Node>* oram = tree.get_oram();
  1702. RegXS root_xs = tree.get_root();
  1703. size_t root = mpc_reconstruct(tio, yield, root_xs);
  1704. auto A = oram->flat(tio, yield);
  1705. auto R = A.reconstruct();
  1706. Node root_node, left_node, right_node;
  1707. size_t left_index, right_index;
  1708. root_node = R[root];
  1709. if((root_node.key).share()!=7) {
  1710. success = false;
  1711. }
  1712. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  1713. right_index = (getAVLRightPtr(root_node.pointers)).share();
  1714. left_node = R[left_index];
  1715. right_node = R[right_index];
  1716. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  1717. success = false;
  1718. }
  1719. //To check that left and right have no children and 0 balances
  1720. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  1721. if(sum!=0) {
  1722. success = false;
  1723. }
  1724. if(player0) {
  1725. if(success) {
  1726. print_green("T1 : SUCCESS\n");
  1727. } else {
  1728. print_red("T1 : FAIL\n");
  1729. }
  1730. }
  1731. A.init();
  1732. tree.init();
  1733. }
  1734. // (T2) : Test 2 : L rotation (root unmodified)
  1735. /*
  1736. Operation:
  1737. 5 5
  1738. / \ / \
  1739. 3 7 3 9
  1740. \ ---> / \
  1741. 9 7 7 12
  1742. \
  1743. 12
  1744. T2 checks:
  1745. - root is 5
  1746. - 3, 7, 9, 12 in expected positions
  1747. - Nodes 3, 7, 12 have 0 balance and no children
  1748. - 5's bal = 0 1
  1749. */
  1750. {
  1751. bool success = true;
  1752. int insert_array[] = {5, 3, 7, 9, 12};
  1753. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  1754. Node node;
  1755. for(size_t i = 0; i<insert_array_size; i++) {
  1756. randomize_node(node);
  1757. node.key.set(insert_array[i] * tio.player());
  1758. tree.insert(tio, yield, node);
  1759. success &= tree.check_avl(tio, yield);
  1760. }
  1761. Duoram<Node>* oram = tree.get_oram();
  1762. RegXS root_xs = tree.get_root();
  1763. size_t root = mpc_reconstruct(tio, yield, root_xs);
  1764. auto A = oram->flat(tio, yield);
  1765. auto R = A.reconstruct();
  1766. Node root_node, n3, n7, n9, n12;
  1767. size_t n3_index, n7_index, n9_index, n12_index;
  1768. root_node = R[root];
  1769. if((root_node.key).share()!=5) {
  1770. success = false;
  1771. }
  1772. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  1773. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  1774. n3 = R[n3_index];
  1775. n9 = R[n9_index];
  1776. n7_index = getAVLLeftPtr(n9.pointers).share();
  1777. n12_index = getAVLRightPtr(n9.pointers).share();
  1778. n7 = R[n7_index];
  1779. n12 = R[n12_index];
  1780. // Node value checks
  1781. if(n3.key.share()!=3 || n9.key.share()!=9) {
  1782. success = false;
  1783. }
  1784. if(n7.key.share()!=7 || n12.key.share()!=12) {
  1785. success = false;
  1786. }
  1787. // Node children and balance checks
  1788. size_t zero = 0;
  1789. zero+=(n3.pointers.share());
  1790. zero+=(n7.pointers.share());
  1791. zero+=(n12.pointers.share());
  1792. zero+=(getLeftBal(root_node.pointers).share());
  1793. zero+=(getLeftBal(n9.pointers).share());
  1794. zero+=(getRightBal(n9.pointers).share());
  1795. if(zero!=0) {
  1796. success = false;
  1797. }
  1798. int one = (getRightBal(root_node.pointers).share());
  1799. if(one!=1) {
  1800. success = false;
  1801. }
  1802. if(player0) {
  1803. if(success) {
  1804. print_green("T2 : SUCCESS\n");
  1805. } else {
  1806. print_red("T2 : FAIL\n");
  1807. }
  1808. }
  1809. A.init();
  1810. tree.init();
  1811. }
  1812. // (T3) : Test 3 : R rotation (root modified)
  1813. /*
  1814. Operation:
  1815. 9 7
  1816. / / \
  1817. 7 ---> 5 9
  1818. /
  1819. 5
  1820. T3 checks:
  1821. - root is 7
  1822. - 5,7,9 in correct positions
  1823. - 5 and 9 have no children
  1824. */
  1825. {
  1826. bool success = true;
  1827. int insert_array[] = {9, 7, 5};
  1828. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  1829. Node node;
  1830. for(size_t i = 0; i<insert_array_size; i++) {
  1831. randomize_node(node);
  1832. node.key.set(insert_array[i] * tio.player());
  1833. tree.insert(tio, yield, node);
  1834. success &= tree.check_avl(tio, yield);
  1835. }
  1836. Duoram<Node>* oram = tree.get_oram();
  1837. RegXS root_xs = tree.get_root();
  1838. size_t root = mpc_reconstruct(tio, yield, root_xs);
  1839. auto A = oram->flat(tio, yield);
  1840. auto R = A.reconstruct();
  1841. Node root_node, left_node, right_node;
  1842. size_t left_index, right_index;
  1843. root_node = R[root];
  1844. if((root_node.key).share()!=7) {
  1845. success = false;
  1846. }
  1847. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  1848. right_index = (getAVLRightPtr(root_node.pointers)).share();
  1849. left_node = R[left_index];
  1850. right_node = R[right_index];
  1851. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  1852. success = false;
  1853. }
  1854. //To check that left and right have no children and 0 balances
  1855. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  1856. if(sum!=0) {
  1857. success = false;
  1858. }
  1859. if(player0) {
  1860. if(success) {
  1861. print_green("T3 : SUCCESS\n");
  1862. } else{
  1863. print_red("T3 : FAIL\n");
  1864. }
  1865. }
  1866. A.init();
  1867. tree.init();
  1868. }
  1869. // (T4) : Test 4 : R rotation (root unmodified)
  1870. /*
  1871. Operation:
  1872. 9 9
  1873. / \ / \
  1874. 7 12 5 12
  1875. / ---> / \
  1876. 5 7 3 7
  1877. /
  1878. 3
  1879. T4 checks:
  1880. - root is 9
  1881. - 3,5,7,12 are in correct positions
  1882. - Nodes 3,7,12 have 0 balance
  1883. - Nodes 3,7,12 have no children
  1884. - 9's bal = 1 0
  1885. */
  1886. {
  1887. bool success = true;
  1888. int insert_array[] = {9, 12, 7, 5, 3};
  1889. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  1890. Node node;
  1891. for(size_t i = 0; i<insert_array_size; i++) {
  1892. randomize_node(node);
  1893. node.key.set(insert_array[i] * tio.player());
  1894. tree.insert(tio, yield, node);
  1895. success &= tree.check_avl(tio, yield);
  1896. }
  1897. Duoram<Node>* oram = tree.get_oram();
  1898. RegXS root_xs = tree.get_root();
  1899. size_t root = mpc_reconstruct(tio, yield, root_xs);
  1900. auto A = oram->flat(tio, yield);
  1901. auto R = A.reconstruct();
  1902. Node root_node, n3, n7, n5, n12;
  1903. size_t n3_index, n7_index, n5_index, n12_index;
  1904. root_node = R[root];
  1905. if((root_node.key).share()!=9) {
  1906. success = false;
  1907. }
  1908. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  1909. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  1910. n5 = R[n5_index];
  1911. n12 = R[n12_index];
  1912. n3_index = getAVLLeftPtr(n5.pointers).share();
  1913. n7_index = getAVLRightPtr(n5.pointers).share();
  1914. n7 = R[n7_index];
  1915. n3 = R[n3_index];
  1916. // Node value checks
  1917. if(n12.key.share()!=12 || n5.key.share()!=5) {
  1918. success = false;
  1919. }
  1920. if(n3.key.share()!=3 || n7.key.share()!=7) {
  1921. success = false;
  1922. }
  1923. // Node balance checks
  1924. size_t zero = 0;
  1925. zero+=(n3.pointers.share());
  1926. zero+=(n7.pointers.share());
  1927. zero+=(n12.pointers.share());
  1928. zero+=(getRightBal(root_node.pointers).share());
  1929. zero+=(getLeftBal(n5.pointers).share());
  1930. zero+=(getRightBal(n5.pointers).share());
  1931. if(zero!=0) {
  1932. success = false;
  1933. }
  1934. int one = (getLeftBal(root_node.pointers).share());
  1935. if(one!=1) {
  1936. success = false;
  1937. }
  1938. if(player0) {
  1939. if(success) {
  1940. print_green("T4 : SUCCESS\n");
  1941. } else {
  1942. print_red("T4 : FAIL\n");
  1943. }
  1944. }
  1945. A.init();
  1946. tree.init();
  1947. }
  1948. // (T5) : Test 5 : LR rotation (root modified)
  1949. /*
  1950. Operation:
  1951. 9 9 7
  1952. / / / \
  1953. 5 --> 7 --> 5 9
  1954. \ /
  1955. 7 5
  1956. T5 checks:
  1957. - root is 7
  1958. - 9,5,7 are in correct positions
  1959. - Nodes 5,7,9 have 0 balance
  1960. - Nodes 5,9 have no children
  1961. */
  1962. {
  1963. bool success = true;
  1964. int insert_array[] = {9, 5, 7};
  1965. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  1966. Node node;
  1967. for(size_t i = 0; i<insert_array_size; i++) {
  1968. randomize_node(node);
  1969. node.key.set(insert_array[i] * tio.player());
  1970. tree.insert(tio, yield, node);
  1971. success &= tree.check_avl(tio, yield);
  1972. }
  1973. Duoram<Node>* oram = tree.get_oram();
  1974. RegXS root_xs = tree.get_root();
  1975. size_t root = mpc_reconstruct(tio, yield, root_xs);
  1976. auto A = oram->flat(tio, yield);
  1977. auto R = A.reconstruct();
  1978. Node root_node, n9, n5;
  1979. size_t n9_index, n5_index;
  1980. root_node = R[root];
  1981. if((root_node.key).share()!=7) {
  1982. success = false;
  1983. }
  1984. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  1985. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  1986. n5 = R[n5_index];
  1987. n9 = R[n9_index];
  1988. // Node value checks
  1989. if(n9.key.share()!=9 || n5.key.share()!=5) {
  1990. success = false;
  1991. }
  1992. // Node balance checks
  1993. size_t zero = 0;
  1994. zero+=(n5.pointers.share());
  1995. zero+=(n9.pointers.share());
  1996. zero+=(getRightBal(root_node.pointers).share());
  1997. zero+=(getLeftBal(n5.pointers).share());
  1998. zero+=(getRightBal(n5.pointers).share());
  1999. zero+=(getLeftBal(n5.pointers).share());
  2000. zero+=(getRightBal(n9.pointers).share());
  2001. zero+=(getLeftBal(n9.pointers).share());
  2002. if(zero!=0) {
  2003. success = false;
  2004. }
  2005. if(player0) {
  2006. if(success) {
  2007. print_green("T5 : SUCCESS\n");
  2008. } else {
  2009. print_red("T5 : FAIL\n");
  2010. }
  2011. }
  2012. A.init();
  2013. tree.init();
  2014. }
  2015. // (T6) : Test 6 : LR rotation (root unmodified)
  2016. /*
  2017. Operation:
  2018. 9 9 9
  2019. / \ / \ / \
  2020. 7 12 7 12 5 12
  2021. / ---> / ---> / \
  2022. 3 5 3 7
  2023. \ /
  2024. 5 3
  2025. T6 checks:
  2026. - root is 9
  2027. - 3,5,7,12 are in correct positions
  2028. - Nodes 3,7,12 have 0 balance
  2029. - Nodes 3,7,12 have no children
  2030. - 9's bal = 1 0
  2031. */
  2032. {
  2033. bool success = true;
  2034. int insert_array[] = {9, 12, 7, 3, 5};
  2035. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2036. Node node;
  2037. for(size_t i = 0; i<insert_array_size; i++) {
  2038. randomize_node(node);
  2039. node.key.set(insert_array[i] * tio.player());
  2040. tree.insert(tio, yield, node);
  2041. success &= tree.check_avl(tio, yield);
  2042. }
  2043. Duoram<Node>* oram = tree.get_oram();
  2044. RegXS root_xs = tree.get_root();
  2045. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2046. auto A = oram->flat(tio, yield);
  2047. auto R = A.reconstruct();
  2048. Node root_node, n3, n7, n5, n12;
  2049. size_t n3_index, n7_index, n5_index, n12_index;
  2050. root_node = R[root];
  2051. if((root_node.key).share()!=9) {
  2052. success = false;
  2053. }
  2054. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2055. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  2056. n5 = R[n5_index];
  2057. n12 = R[n12_index];
  2058. n3_index = getAVLLeftPtr(n5.pointers).share();
  2059. n7_index = getAVLRightPtr(n5.pointers).share();
  2060. n7 = R[n7_index];
  2061. n3 = R[n3_index];
  2062. // Node value checks
  2063. if(n5.key.share()!=5 || n12.key.share()!=12) {
  2064. success = false;
  2065. }
  2066. if(n3.key.share()!=3 || n7.key.share()!=7) {
  2067. success = false;
  2068. }
  2069. // Node balance checks
  2070. size_t zero = 0;
  2071. zero+=(n3.pointers.share());
  2072. zero+=(n7.pointers.share());
  2073. zero+=(n12.pointers.share());
  2074. zero+=(getRightBal(root_node.pointers).share());
  2075. zero+=(getLeftBal(n5.pointers).share());
  2076. zero+=(getRightBal(n5.pointers).share());
  2077. if(zero!=0) {
  2078. success = false;
  2079. }
  2080. int one = (getLeftBal(root_node.pointers).share());
  2081. if(one!=1) {
  2082. success = false;
  2083. }
  2084. if(player0) {
  2085. if(success) {
  2086. print_green("T6 : SUCCESS\n");
  2087. } else {
  2088. print_red("T6 : FAIL\n");
  2089. }
  2090. }
  2091. A.init();
  2092. tree.init();
  2093. }
  2094. // (T7) : Test 7 : RL rotation (root modified)
  2095. /*
  2096. Operation:
  2097. 5 5 7
  2098. \ \ / \
  2099. 9 --> 7 --> 5 9
  2100. / \
  2101. 7 9
  2102. T7 checks:
  2103. - root is 7
  2104. - 9,5,7 are in correct positions
  2105. - Nodes 5,7,9 have 0 balance
  2106. - Nodes 5,9 have no children
  2107. */
  2108. {
  2109. bool success = true;
  2110. int insert_array[] = {5, 9, 7};
  2111. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2112. Node node;
  2113. for(size_t i = 0; i<insert_array_size; i++) {
  2114. randomize_node(node);
  2115. node.key.set(insert_array[i] * tio.player());
  2116. tree.insert(tio, yield, node);
  2117. success &= tree.check_avl(tio, yield);
  2118. }
  2119. Duoram<Node>* oram = tree.get_oram();
  2120. RegXS root_xs = tree.get_root();
  2121. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2122. auto A = oram->flat(tio, yield);
  2123. auto R = A.reconstruct();
  2124. Node root_node, n9, n5;
  2125. size_t n9_index, n5_index;
  2126. root_node = R[root];
  2127. if((root_node.key).share()!=7) {
  2128. success = false;
  2129. }
  2130. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2131. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2132. n5 = R[n5_index];
  2133. n9 = R[n9_index];
  2134. // Node value checks
  2135. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2136. success = false;
  2137. }
  2138. // Node balance checks
  2139. size_t zero = 0;
  2140. zero+=(n5.pointers.share());
  2141. zero+=(n9.pointers.share());
  2142. zero+=(getRightBal(root_node.pointers).share());
  2143. zero+=(getLeftBal(n5.pointers).share());
  2144. zero+=(getRightBal(n5.pointers).share());
  2145. zero+=(getLeftBal(n5.pointers).share());
  2146. zero+=(getRightBal(n9.pointers).share());
  2147. zero+=(getLeftBal(n9.pointers).share());
  2148. if(zero!=0) {
  2149. success = false;
  2150. }
  2151. if(player0) {
  2152. if(success) {
  2153. print_green("T7 : SUCCESS\n");
  2154. } else {
  2155. print_red("T7 : FAIL\n");
  2156. }
  2157. }
  2158. A.init();
  2159. tree.init();
  2160. }
  2161. // (T8) : Test 8 : RL rotation (root unmodified)
  2162. /*
  2163. Operation:
  2164. 5 5 5
  2165. / \ / \ / \
  2166. 3 12 3 12 3 9
  2167. / ---> / ---> / \
  2168. 7 9 7 12
  2169. \ /
  2170. 9 7
  2171. T8 checks:
  2172. - root is 5
  2173. - 3,9,7,12 are in correct positions
  2174. - Nodes 3,7,12 have 0 balance
  2175. - Nodes 3,7,12 have no children
  2176. - 5's bal = 0 1
  2177. */
  2178. {
  2179. bool success = true;
  2180. int insert_array[] = {5, 3, 12, 7, 9};
  2181. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2182. Node node;
  2183. for(size_t i = 0; i<insert_array_size; i++) {
  2184. randomize_node(node);
  2185. node.key.set(insert_array[i] * tio.player());
  2186. tree.insert(tio, yield, node);
  2187. success &= tree.check_avl(tio, yield);
  2188. }
  2189. Duoram<Node>* oram = tree.get_oram();
  2190. RegXS root_xs = tree.get_root();
  2191. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2192. auto A = oram->flat(tio, yield);
  2193. auto R = A.reconstruct();
  2194. Node root_node, n3, n7, n9, n12;
  2195. size_t n3_index, n7_index, n9_index, n12_index;
  2196. root_node = R[root];
  2197. if((root_node.key).share()!=5) {
  2198. success = false;
  2199. }
  2200. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2201. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2202. n3 = R[n3_index];
  2203. n9 = R[n9_index];
  2204. n7_index = getAVLLeftPtr(n9.pointers).share();
  2205. n12_index = getAVLRightPtr(n9.pointers).share();
  2206. n7 = R[n7_index];
  2207. n12 = R[n12_index];
  2208. // Node value checks
  2209. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2210. success = false;
  2211. }
  2212. if(n7.key.share()!=7 || n12.key.share()!=12) {
  2213. success = false;
  2214. }
  2215. // Node balance checks
  2216. size_t zero = 0;
  2217. zero+=(n3.pointers.share());
  2218. zero+=(n7.pointers.share());
  2219. zero+=(n12.pointers.share());
  2220. zero+=(getLeftBal(root_node.pointers).share());
  2221. zero+=(getLeftBal(n9.pointers).share());
  2222. zero+=(getRightBal(n9.pointers).share());
  2223. if(zero!=0) {
  2224. success = false;
  2225. }
  2226. int one = (getRightBal(root_node.pointers).share());
  2227. if(one!=1) {
  2228. success = false;
  2229. }
  2230. if(player0) {
  2231. if(success) {
  2232. print_green("T8 : SUCCESS\n");
  2233. } else {
  2234. print_red("T8 : FAIL\n");
  2235. }
  2236. }
  2237. A.init();
  2238. tree.init();
  2239. }
  2240. // Deletion Tests:
  2241. // (T9) : Test 9 : L rotation (root modified)
  2242. /*
  2243. Operation:
  2244. 5 7
  2245. / \ Del 3 / \
  2246. 3 7 ------> 5 9
  2247. \
  2248. 9
  2249. T9 checks:
  2250. - root is 7
  2251. - 5,7,9 in correct positions
  2252. - 5 and 9 have no children and 0 balances
  2253. - 7 has 0 balances
  2254. */
  2255. {
  2256. bool success = true;
  2257. int insert_array[] = {5, 3, 7, 9};
  2258. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2259. Node node;
  2260. for(size_t i = 0; i<insert_array_size; i++) {
  2261. randomize_node(node);
  2262. node.key.set(insert_array[i] * tio.player());
  2263. tree.insert(tio, yield, node);
  2264. success &= tree.check_avl(tio, yield);
  2265. }
  2266. RegAS del_key;
  2267. del_key.set(3 * tio.player());
  2268. bool del_ret;
  2269. del_ret = tree.del(tio, yield, del_key);
  2270. success &= tree.check_avl(tio, yield);
  2271. Duoram<Node>* oram = tree.get_oram();
  2272. RegXS root_xs = tree.get_root();
  2273. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2274. auto A = oram->flat(tio, yield);
  2275. auto R = A.reconstruct();
  2276. Node root_node, left_node, right_node;
  2277. size_t left_index, right_index;
  2278. root_node = R[root];
  2279. if((root_node.key).share()!=7) {
  2280. success = false;
  2281. }
  2282. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  2283. right_index = (getAVLRightPtr(root_node.pointers)).share();
  2284. left_node = R[left_index];
  2285. right_node = R[right_index];
  2286. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  2287. success = false;
  2288. }
  2289. //To check that left and right have no children and 0 balances
  2290. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  2291. if(sum!=0) {
  2292. success = false;
  2293. }
  2294. success &= del_ret;
  2295. if(player0) {
  2296. if(success) {
  2297. print_green("T9 : SUCCESS\n");
  2298. } else {
  2299. print_red("T9 : FAIL\n");
  2300. }
  2301. }
  2302. A.init();
  2303. tree.init();
  2304. }
  2305. // (T10) : Test 10 : L rotation (root unmodified)
  2306. /*
  2307. Operation:
  2308. 5 5
  2309. / \ / \
  2310. 3 7 Del 6 3 9
  2311. / / \ ------> / / \
  2312. 1 6 9 1 7 12
  2313. \
  2314. 12
  2315. T10 checks:
  2316. - root is 5
  2317. - 3, 7, 9, 12 in expected positions
  2318. - Nodes 5, 7, 12 have 0 balance and no children
  2319. - 3's bal = 1 0
  2320. */
  2321. {
  2322. bool success = true;
  2323. int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
  2324. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2325. Node node;
  2326. for(size_t i = 0; i<insert_array_size; i++) {
  2327. randomize_node(node);
  2328. node.key.set(insert_array[i] * tio.player());
  2329. tree.insert(tio, yield, node);
  2330. success &= tree.check_avl(tio, yield);
  2331. }
  2332. RegAS del_key;
  2333. del_key.set(6 * tio.player());
  2334. bool del_ret;
  2335. del_ret = tree.del(tio, yield, del_key);
  2336. success &= tree.check_avl(tio, yield);
  2337. Duoram<Node>* oram = tree.get_oram();
  2338. RegXS root_xs = tree.get_root();
  2339. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2340. auto A = oram->flat(tio, yield);
  2341. auto R = A.reconstruct();
  2342. Node root_node, n1, n3, n7, n9, n12;
  2343. size_t n1_index, n3_index, n7_index, n9_index, n12_index;
  2344. root_node = R[root];
  2345. if((root_node.key).share()!=5) {
  2346. success = false;
  2347. }
  2348. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2349. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2350. n3 = R[n3_index];
  2351. n9 = R[n9_index];
  2352. n7_index = getAVLLeftPtr(n9.pointers).share();
  2353. n12_index = getAVLRightPtr(n9.pointers).share();
  2354. n7 = R[n7_index];
  2355. n12 = R[n12_index];
  2356. n1_index = getAVLLeftPtr(n3.pointers).share();
  2357. n1 = R[n1_index];
  2358. // Node value checks
  2359. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2360. success = false;
  2361. }
  2362. if(n7.key.share()!=7 || n12.key.share()!=12 || n1.key.share()!=1) {
  2363. success = false;
  2364. }
  2365. // Node children and balance checks
  2366. size_t zero = 0;
  2367. zero+=(n1.pointers.share());
  2368. zero+=(n7.pointers.share());
  2369. zero+=(n12.pointers.share());
  2370. zero+=(getLeftBal(root_node.pointers).share());
  2371. zero+=(getRightBal(root_node.pointers).share());
  2372. zero+=(getLeftBal(n9.pointers).share());
  2373. zero+=(getRightBal(n9.pointers).share());
  2374. zero+=(getRightBal(n3.pointers).share());
  2375. if(zero!=0) {
  2376. success = false;
  2377. }
  2378. int one = (getLeftBal(n3.pointers).share());
  2379. if(one!=1) {
  2380. success = false;
  2381. }
  2382. success &= del_ret;
  2383. if(player0) {
  2384. if(success) {
  2385. print_green("T10 : SUCCESS\n");
  2386. } else {
  2387. print_red("T10 : FAIL\n");
  2388. }
  2389. }
  2390. A.init();
  2391. tree.init();
  2392. }
  2393. // (T11) : Test 11 : R rotation (root modified)
  2394. /*
  2395. Operation:
  2396. 9 7
  2397. / \ Del 12 / \
  2398. 7 12 -------> 5 9
  2399. /
  2400. 5
  2401. T11 checks:
  2402. - root is 7
  2403. - 5,7,9 in correct positions and balances to 0
  2404. - 5 and 9 have no children
  2405. */
  2406. {
  2407. bool success = true;
  2408. int insert_array[] = {9, 7, 12, 5};
  2409. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2410. Node node;
  2411. for(size_t i = 0; i<insert_array_size; i++) {
  2412. randomize_node(node);
  2413. node.key.set(insert_array[i] * tio.player());
  2414. tree.insert(tio, yield, node);
  2415. success &= tree.check_avl(tio, yield);
  2416. }
  2417. RegAS del_key;
  2418. del_key.set(12 * tio.player());
  2419. bool del_ret;
  2420. del_ret = tree.del(tio, yield, del_key);
  2421. success &= tree.check_avl(tio, yield);
  2422. Duoram<Node>* oram = tree.get_oram();
  2423. RegXS root_xs = tree.get_root();
  2424. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2425. auto A = oram->flat(tio, yield);
  2426. auto R = A.reconstruct();
  2427. Node root_node, left_node, right_node;
  2428. size_t left_index, right_index;
  2429. root_node = R[root];
  2430. if((root_node.key).share()!=7) {
  2431. success = false;
  2432. }
  2433. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  2434. right_index = (getAVLRightPtr(root_node.pointers)).share();
  2435. left_node = R[left_index];
  2436. right_node = R[right_index];
  2437. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  2438. success = false;
  2439. }
  2440. //To check that left and right have no children and 0 balances
  2441. size_t zero = left_node.pointers.share() + right_node.pointers.share();
  2442. zero+=(getLeftBal(left_node.pointers).share());
  2443. zero+=(getRightBal(left_node.pointers).share());
  2444. zero+=(getLeftBal(right_node.pointers).share());
  2445. zero+=(getRightBal(right_node.pointers).share());
  2446. if(zero!=0) {
  2447. success = false;
  2448. }
  2449. success &= del_ret;
  2450. if(player0) {
  2451. if(success) {
  2452. print_green("T11 : SUCCESS\n");
  2453. } else{
  2454. print_red("T11 : FAIL\n");
  2455. }
  2456. }
  2457. A.init();
  2458. tree.init();
  2459. }
  2460. // (T12) : Test 12 : R rotation (root unmodified)
  2461. /*
  2462. Operation:
  2463. 9 9
  2464. / \ / \
  2465. 7 12 Del 8 5 12
  2466. / \ \ ------> / \ \
  2467. 5 8 15 3 7 15
  2468. /
  2469. 3
  2470. T12 checks:
  2471. - root is 9
  2472. - 3,5,7,12,15 are in correct positions
  2473. - Nodes 3,7,15 have 0 balance
  2474. - Nodes 3,7,15 have no children
  2475. - 9,5 bal = 0 0
  2476. - 12 bal = 0 1
  2477. */
  2478. {
  2479. bool success = true;
  2480. int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
  2481. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2482. Node node;
  2483. for(size_t i = 0; i<insert_array_size; i++) {
  2484. randomize_node(node);
  2485. node.key.set(insert_array[i] * tio.player());
  2486. tree.insert(tio, yield, node);
  2487. success &= tree.check_avl(tio, yield);
  2488. }
  2489. RegAS del_key;
  2490. del_key.set(8 * tio.player());
  2491. bool del_ret;
  2492. del_ret = tree.del(tio, yield, del_key);
  2493. success &= tree.check_avl(tio, yield);
  2494. Duoram<Node>* oram = tree.get_oram();
  2495. RegXS root_xs = tree.get_root();
  2496. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2497. auto A = oram->flat(tio, yield);
  2498. auto R = A.reconstruct();
  2499. Node root_node, n3, n7, n5, n12, n15;
  2500. size_t n3_index, n7_index, n5_index, n12_index, n15_index;
  2501. root_node = R[root];
  2502. if((root_node.key).share()!=9) {
  2503. success = false;
  2504. }
  2505. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2506. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  2507. n5 = R[n5_index];
  2508. n12 = R[n12_index];
  2509. n3_index = getAVLLeftPtr(n5.pointers).share();
  2510. n7_index = getAVLRightPtr(n5.pointers).share();
  2511. n7 = R[n7_index];
  2512. n3 = R[n3_index];
  2513. n15_index = getAVLRightPtr(n12.pointers).share();
  2514. n15 = R[n15_index];
  2515. // Node value checks
  2516. if(n12.key.share()!=12 || n5.key.share()!=5) {
  2517. success = false;
  2518. }
  2519. if(n3.key.share()!=3 || n7.key.share()!=7 || n15.key.share()!=15) {
  2520. success = false;
  2521. }
  2522. // Node balance checks
  2523. size_t zero = 0;
  2524. zero+=(n3.pointers.share());
  2525. zero+=(n7.pointers.share());
  2526. zero+=(n15.pointers.share());
  2527. zero+=(getRightBal(root_node.pointers).share());
  2528. zero+=(getLeftBal(root_node.pointers).share());
  2529. zero+=(getLeftBal(n5.pointers).share());
  2530. zero+=(getRightBal(n5.pointers).share());
  2531. if(zero!=0) {
  2532. success = false;
  2533. }
  2534. int one = (getRightBal(n12.pointers).share());
  2535. if(one!=1) {
  2536. success = false;
  2537. }
  2538. success &= del_ret;
  2539. if(player0) {
  2540. if(success) {
  2541. print_green("T12 : SUCCESS\n");
  2542. } else {
  2543. print_red("T12 : FAIL\n");
  2544. }
  2545. }
  2546. A.init();
  2547. tree.init();
  2548. }
  2549. // (T13) : Test 13 : LR rotation (root modified)
  2550. /*
  2551. Operation:
  2552. 9 9 7
  2553. / \ Del 12 / / \
  2554. 5 12 -------> 7 --> 5 9
  2555. \ /
  2556. 7 5
  2557. T13 checks:
  2558. - root is 7
  2559. - 9,5,7 are in correct positions
  2560. - Nodes 5,7,9 have 0 balance
  2561. - Nodes 5,9 have no children
  2562. */
  2563. {
  2564. bool success = true;
  2565. int insert_array[] = {9, 5, 12, 7};
  2566. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2567. Node node;
  2568. for(size_t i = 0; i<insert_array_size; i++) {
  2569. randomize_node(node);
  2570. node.key.set(insert_array[i] * tio.player());
  2571. tree.insert(tio, yield, node);
  2572. success &= tree.check_avl(tio, yield);
  2573. }
  2574. RegAS del_key;
  2575. del_key.set(12 * tio.player());
  2576. bool del_ret;
  2577. del_ret = tree.del(tio, yield, del_key);
  2578. success &= tree.check_avl(tio, yield);
  2579. Duoram<Node>* oram = tree.get_oram();
  2580. RegXS root_xs = tree.get_root();
  2581. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2582. auto A = oram->flat(tio, yield);
  2583. auto R = A.reconstruct();
  2584. Node root_node, n9, n5;
  2585. size_t n9_index, n5_index;
  2586. root_node = R[root];
  2587. if((root_node.key).share()!=7) {
  2588. success = false;
  2589. }
  2590. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2591. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2592. n5 = R[n5_index];
  2593. n9 = R[n9_index];
  2594. // Node value checks
  2595. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2596. success = false;
  2597. }
  2598. // Node balance checks
  2599. size_t zero = 0;
  2600. zero+=(n5.pointers.share());
  2601. zero+=(n9.pointers.share());
  2602. zero+=(getRightBal(root_node.pointers).share());
  2603. zero+=(getLeftBal(n5.pointers).share());
  2604. zero+=(getRightBal(n5.pointers).share());
  2605. zero+=(getLeftBal(n5.pointers).share());
  2606. zero+=(getRightBal(n9.pointers).share());
  2607. zero+=(getLeftBal(n9.pointers).share());
  2608. if(zero!=0) {
  2609. success = false;
  2610. }
  2611. success &= del_ret;
  2612. if(player0) {
  2613. if(success) {
  2614. print_green("T13 : SUCCESS\n");
  2615. } else {
  2616. print_red("T13 : FAIL\n");
  2617. }
  2618. }
  2619. A.init();
  2620. tree.init();
  2621. }
  2622. // (T14) : Test 14 : LR rotation (root unmodified)
  2623. /*
  2624. Operation:
  2625. 9 9
  2626. / \ / \
  2627. 5 12 Del 8 5 12
  2628. / \ ------> / \
  2629. 3 7 (No-op) 3 7
  2630. T14 checks:
  2631. - root is 9
  2632. - 3,5,7,12 are in correct positions
  2633. - Nodes 3,7,12 have 0 balance
  2634. - Nodes 3,7,12 have no children
  2635. - 9's bal = 1 0
  2636. */
  2637. {
  2638. bool success = true;
  2639. int insert_array[] = {9, 12, 7, 3, 5};
  2640. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2641. Node node;
  2642. for(size_t i = 0; i<insert_array_size; i++) {
  2643. randomize_node(node);
  2644. node.key.set(insert_array[i] * tio.player());
  2645. tree.insert(tio, yield, node);
  2646. success &= tree.check_avl(tio, yield);
  2647. }
  2648. RegAS del_key;
  2649. del_key.set(8 * tio.player());
  2650. bool del_ret;
  2651. del_ret = tree.del(tio, yield, del_key);
  2652. success &= tree.check_avl(tio, yield);
  2653. Duoram<Node>* oram = tree.get_oram();
  2654. RegXS root_xs = tree.get_root();
  2655. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2656. auto A = oram->flat(tio, yield);
  2657. auto R = A.reconstruct();
  2658. Node root_node, n3, n7, n5, n12;
  2659. size_t n3_index, n7_index, n5_index, n12_index;
  2660. root_node = R[root];
  2661. if((root_node.key).share()!=9) {
  2662. success = false;
  2663. }
  2664. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2665. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  2666. n5 = R[n5_index];
  2667. n12 = R[n12_index];
  2668. n3_index = getAVLLeftPtr(n5.pointers).share();
  2669. n7_index = getAVLRightPtr(n5.pointers).share();
  2670. n7 = R[n7_index];
  2671. n3 = R[n3_index];
  2672. // Node value checks
  2673. if(n5.key.share()!=5 || n12.key.share()!=12) {
  2674. success = false;
  2675. }
  2676. if(n3.key.share()!=3 || n7.key.share()!=7) {
  2677. success = false;
  2678. }
  2679. // Node balance checks
  2680. size_t zero = 0;
  2681. zero+=(n3.pointers.share());
  2682. zero+=(n7.pointers.share());
  2683. zero+=(n12.pointers.share());
  2684. zero+=(getRightBal(root_node.pointers).share());
  2685. zero+=(getLeftBal(n5.pointers).share());
  2686. zero+=(getRightBal(n5.pointers).share());
  2687. if(zero!=0) {
  2688. success = false;
  2689. }
  2690. int one = (getLeftBal(root_node.pointers).share());
  2691. if(one!=1) {
  2692. success = false;
  2693. }
  2694. success &=(!del_ret);
  2695. if(player0) {
  2696. if(success) {
  2697. print_green("T14 : SUCCESS\n");
  2698. } else {
  2699. print_red("T14 : FAIL\n");
  2700. }
  2701. }
  2702. A.init();
  2703. tree.init();
  2704. }
  2705. // (T15) : Test 15 : RL rotation (root modified)
  2706. /*
  2707. Operation:
  2708. 5 5 7
  2709. / \ Del 3 \ / \
  2710. 3 9 -------> 7 --> 5 9
  2711. / \
  2712. 7 9
  2713. T15 checks:
  2714. - root is 7
  2715. - 9,5,7 are in correct positions
  2716. - Nodes 5,7,9 have 0 balance
  2717. - Nodes 5,9 have no children
  2718. */
  2719. {
  2720. bool success = true;
  2721. int insert_array[] = {5, 9, 3, 7};
  2722. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2723. Node node;
  2724. for(size_t i = 0; i<insert_array_size; i++) {
  2725. randomize_node(node);
  2726. node.key.set(insert_array[i] * tio.player());
  2727. tree.insert(tio, yield, node);
  2728. success &= tree.check_avl(tio, yield);
  2729. }
  2730. RegAS del_key;
  2731. del_key.set(3 * tio.player());
  2732. bool del_ret;
  2733. del_ret = tree.del(tio, yield, del_key);
  2734. success &= tree.check_avl(tio, yield);
  2735. Duoram<Node>* oram = tree.get_oram();
  2736. RegXS root_xs = tree.get_root();
  2737. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2738. auto A = oram->flat(tio, yield);
  2739. auto R = A.reconstruct();
  2740. Node root_node, n9, n5;
  2741. size_t n9_index, n5_index;
  2742. root_node = R[root];
  2743. if((root_node.key).share()!=7) {
  2744. success = false;
  2745. }
  2746. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2747. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2748. n5 = R[n5_index];
  2749. n9 = R[n9_index];
  2750. // Node value checks
  2751. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2752. success = false;
  2753. }
  2754. // Node balance checks
  2755. size_t zero = 0;
  2756. zero+=(n5.pointers.share());
  2757. zero+=(n9.pointers.share());
  2758. zero+=(getRightBal(root_node.pointers).share());
  2759. zero+=(getLeftBal(n5.pointers).share());
  2760. zero+=(getRightBal(n5.pointers).share());
  2761. zero+=(getLeftBal(n5.pointers).share());
  2762. zero+=(getRightBal(n9.pointers).share());
  2763. zero+=(getLeftBal(n9.pointers).share());
  2764. if(zero!=0) {
  2765. success = false;
  2766. }
  2767. success &= del_ret;
  2768. if(player0) {
  2769. if(success) {
  2770. print_green("T15 : SUCCESS\n");
  2771. } else {
  2772. print_red("T15 : FAIL\n");
  2773. }
  2774. }
  2775. A.init();
  2776. tree.init();
  2777. }
  2778. // (T16) : Test 16 : RL rotation (root unmodified)
  2779. /*
  2780. Operation:
  2781. 5 5 5
  2782. / \ / \ / \
  2783. 3 8 Del 7 3 8 3 9
  2784. / / \ ------> / \ ---> / / \
  2785. 1 7 12 1 9 1 8 12
  2786. / \
  2787. 9 12
  2788. T16 checks:
  2789. - root is 5
  2790. - 3,9,8,12 are in correct positions
  2791. - Nodes 1,5,8,9,12 have 0 balance
  2792. - Nodes 1,5,8,9,12 have no children
  2793. - Node 3 has 1 0 balance
  2794. */
  2795. {
  2796. bool success = true;
  2797. int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
  2798. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2799. Node node;
  2800. for(size_t i = 0; i<insert_array_size; i++) {
  2801. randomize_node(node);
  2802. node.key.set(insert_array[i] * tio.player());
  2803. tree.insert(tio, yield, node);
  2804. success &= tree.check_avl(tio, yield);
  2805. }
  2806. RegAS del_key;
  2807. del_key.set(7 * tio.player());
  2808. bool del_ret;
  2809. del_ret = tree.del(tio, yield, del_key);
  2810. success &= tree.check_avl(tio, yield);
  2811. Duoram<Node>* oram = tree.get_oram();
  2812. RegXS root_xs = tree.get_root();
  2813. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2814. auto A = oram->flat(tio, yield);
  2815. auto R = A.reconstruct();
  2816. Node root_node, n1, n3, n8, n9, n12;
  2817. size_t n1_index, n3_index, n8_index, n9_index, n12_index;
  2818. root_node = R[root];
  2819. if((root_node.key).share()!=5) {
  2820. success = false;
  2821. }
  2822. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2823. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2824. n3 = R[n3_index];
  2825. n9 = R[n9_index];
  2826. n1_index = getAVLLeftPtr(n3.pointers).share();
  2827. n8_index = getAVLLeftPtr(n9.pointers).share();
  2828. n12_index = getAVLRightPtr(n9.pointers).share();
  2829. n1 = R[n1_index];
  2830. n8 = R[n8_index];
  2831. n12 = R[n12_index];
  2832. // Node value checks
  2833. if(n1.key.share()!=1) {
  2834. success = false;
  2835. }
  2836. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2837. success = false;
  2838. }
  2839. if(n8.key.share()!=8 || n12.key.share()!=12) {
  2840. success = false;
  2841. }
  2842. // Node balance checks
  2843. size_t zero = 0;
  2844. zero+=(n1.pointers.share());
  2845. zero+=(getRightBal(n3.pointers).share());
  2846. zero+=(n8.pointers.share());
  2847. zero+=(n12.pointers.share());
  2848. zero+=(getLeftBal(root_node.pointers).share());
  2849. zero+=(getRightBal(root_node.pointers).share());
  2850. zero+=(getLeftBal(n9.pointers).share());
  2851. zero+=(getRightBal(n9.pointers).share());
  2852. if(zero!=0) {
  2853. success = false;
  2854. }
  2855. success &= del_ret;
  2856. if(player0) {
  2857. if(success) {
  2858. print_green("T16 : SUCCESS\n");
  2859. } else {
  2860. print_red("T16 : FAIL\n");
  2861. }
  2862. }
  2863. A.init();
  2864. tree.init();
  2865. }
  2866. // (T17) : Test 17 : Double imbalance (root modified)
  2867. /*
  2868. Operation:
  2869. 9 9
  2870. / \ / \
  2871. 5 12 Del 10 5 15
  2872. / \ / \ --------> / \ / \
  2873. 3 7 10 15 3 7 12 20
  2874. / \ / \ \ / \ / \
  2875. 2 4 6 8 20 2 4 6 8
  2876. / /
  2877. 1 1
  2878. 5
  2879. / \
  2880. 3 9
  2881. -----> / \ / \
  2882. 2 4 7 15
  2883. / / \ / \
  2884. 1 6 8 10 20
  2885. T17 checks:
  2886. - root is 5
  2887. - all other nodes are in correct positions
  2888. - balances and children are correct
  2889. */
  2890. {
  2891. bool success = true;
  2892. int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
  2893. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  2894. Node node;
  2895. for(size_t i = 0; i<insert_array_size; i++) {
  2896. randomize_node(node);
  2897. node.key.set(insert_array[i] * tio.player());
  2898. tree.insert(tio, yield, node);
  2899. success &= tree.check_avl(tio, yield);
  2900. }
  2901. RegAS del_key;
  2902. del_key.set(10 * tio.player());
  2903. bool del_ret;
  2904. del_ret = tree.del(tio, yield, del_key);
  2905. success &= tree.check_avl(tio, yield);
  2906. Duoram<Node>* oram = tree.get_oram();
  2907. RegXS root_xs = tree.get_root();
  2908. size_t root = mpc_reconstruct(tio, yield, root_xs);
  2909. auto A = oram->flat(tio, yield);
  2910. auto R = A.reconstruct();
  2911. Node root_node, n3, n7, n9;
  2912. Node n1, n2, n4, n6, n8, n12, n15, n20;
  2913. size_t n3_index, n7_index, n9_index;
  2914. size_t n1_index, n2_index, n4_index, n6_index;
  2915. size_t n8_index, n12_index, n15_index, n20_index;
  2916. root_node = R[root];
  2917. if((root_node.key).share()!=5) {
  2918. success = false;
  2919. }
  2920. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2921. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2922. n3 = R[n3_index];
  2923. n9 = R[n9_index];
  2924. n2_index = getAVLLeftPtr(n3.pointers).share();
  2925. n4_index = getAVLRightPtr(n3.pointers).share();
  2926. n7_index = getAVLLeftPtr(n9.pointers).share();
  2927. n15_index = getAVLRightPtr(n9.pointers).share();
  2928. n2 = R[n2_index];
  2929. n4 = R[n4_index];
  2930. n7 = R[n7_index];
  2931. n15 = R[n15_index];
  2932. n1_index = getAVLLeftPtr(n2.pointers).share();
  2933. n6_index = getAVLLeftPtr(n7.pointers).share();
  2934. n8_index = getAVLRightPtr(n7.pointers).share();
  2935. n12_index = getAVLLeftPtr(n15.pointers).share();
  2936. n20_index = getAVLRightPtr(n15.pointers).share();
  2937. n1 = R[n1_index];
  2938. n6 = R[n6_index];
  2939. n8 = R[n8_index];
  2940. n12 = R[n12_index];
  2941. n20 = R[n20_index];
  2942. // Node value checks
  2943. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2944. success = false;
  2945. }
  2946. if(n2.key.share()!=2 || n4.key.share()!=4) {
  2947. success = false;
  2948. }
  2949. if(n7.key.share()!=7 || n15.key.share()!=15) {
  2950. success = false;
  2951. }
  2952. if(n1.key.share()!=1 || n6.key.share()!=6 || n8.key.share()!=8) {
  2953. success = false;
  2954. }
  2955. if(n12.key.share()!=12 || n20.key.share()!=20) {
  2956. success = false;
  2957. }
  2958. // Node balance checks
  2959. size_t zero = 0;
  2960. zero+=(n1.pointers.share());
  2961. zero+=(n4.pointers.share());
  2962. zero+=(n6.pointers.share());
  2963. zero+=(n8.pointers.share());
  2964. zero+=(n12.pointers.share());
  2965. zero+=(n20.pointers.share());
  2966. zero+=(getLeftBal(n7.pointers).share());
  2967. zero+=(getRightBal(n7.pointers).share());
  2968. zero+=(getLeftBal(n9.pointers).share());
  2969. zero+=(getRightBal(n9.pointers).share());
  2970. zero+=(getLeftBal(n15.pointers).share());
  2971. zero+=(getRightBal(n15.pointers).share());
  2972. zero+=(getRightBal(n3.pointers).share());
  2973. zero+=(getLeftBal(root_node.pointers).share());
  2974. zero+=(getRightBal(root_node.pointers).share());
  2975. if(zero!=0) {
  2976. success = false;
  2977. }
  2978. int one = (getLeftBal(n3.pointers).share());
  2979. if(one!=1) {
  2980. success = false;
  2981. }
  2982. success &= del_ret;
  2983. if(player0) {
  2984. if(success) {
  2985. print_green("T17 : SUCCESS\n");
  2986. } else {
  2987. print_red("T17 : FAIL\n");
  2988. }
  2989. }
  2990. A.init();
  2991. tree.init();
  2992. }
  2993. });
  2994. }