avl.cpp 124 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290
  1. #include <functional>
  2. #include "avl.hpp"
  3. void print_green(std::string line) {
  4. printf("%s%s%s", KGRN, line.c_str(), KNRM);
  5. }
  6. void print_red(std::string line) {
  7. printf("%s%s%s", KRED, line.c_str(), KNRM);
  8. }
  9. // Pretty-print a reconstructed BST, rooted at node. is_left_child and
  10. // is_right_child indicate whether node is a left or right child of its
  11. // parent. They cannot both be true, but the root of the tree has both
  12. // of them false.
  13. void AVL::pretty_print(const std::vector<Node> &R, value_t node,
  14. const std::string &prefix = "", bool is_left_child = false,
  15. bool is_right_child = false)
  16. {
  17. if (node == 0) {
  18. // NULL pointer
  19. if (is_left_child) {
  20. printf("%s\xE2\x95\xA7\n", prefix.c_str()); // ╧
  21. } else if (is_right_child) {
  22. printf("%s\xE2\x95\xA4\n", prefix.c_str()); // ╤
  23. } else {
  24. printf("%s\xE2\x95\xA2\n", prefix.c_str()); // ╢
  25. }
  26. return;
  27. }
  28. const Node &n = R[node];
  29. value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
  30. value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
  31. std::string rightprefix(prefix), leftprefix(prefix),
  32. nodeprefix(prefix);
  33. if (is_left_child) {
  34. rightprefix.append("\xE2\x94\x82"); // │
  35. leftprefix.append(" ");
  36. nodeprefix.append("\xE2\x94\x94"); // └
  37. } else if (is_right_child) {
  38. rightprefix.append(" ");
  39. leftprefix.append("\xE2\x94\x82"); // │
  40. nodeprefix.append("\xE2\x94\x8C"); // ┌
  41. } else {
  42. rightprefix.append(" ");
  43. leftprefix.append(" ");
  44. nodeprefix.append("\xE2\x94\x80"); // ─
  45. }
  46. pretty_print(R, right_ptr, rightprefix, false, true);
  47. printf("%s\xE2\x94\xA4", nodeprefix.c_str()); // ┤
  48. dumpAVL(n);
  49. printf("\n");
  50. pretty_print(R, left_ptr, leftprefix, true, false);
  51. }
  52. void AVL::print_oram(MPCTIO &tio, yield_t &yield) {
  53. auto A = oram.flat(tio, yield);
  54. auto R = A.reconstruct();
  55. for(size_t i=0;i<R.size();++i) {
  56. printf("\n%04lx ", i);
  57. R[i].dump();
  58. }
  59. printf("\n");
  60. }
  61. void AVL::pretty_print(MPCTIO &tio, yield_t &yield) {
  62. RegXS peer_root;
  63. RegXS reconstructed_root = root;
  64. if (tio.player() == 1) {
  65. tio.queue_peer(&root, sizeof(root));
  66. } else {
  67. RegXS peer_root;
  68. tio.recv_peer(&peer_root, sizeof(peer_root));
  69. reconstructed_root += peer_root;
  70. }
  71. auto A = oram.flat(tio, yield);
  72. auto R = A.reconstruct();
  73. if(tio.player()==0) {
  74. pretty_print(R, reconstructed_root.xshare);
  75. }
  76. }
  77. // Check the BST invariant of the tree (that all keys to the left are
  78. // less than or equal to this key, all keys to the right are strictly
  79. // greater, and this is true recursively). Returns a
  80. // tuple<bool,address_t>, where the bool says whether the BST invariant
  81. // holds, and the address_t is the height of the tree (which will be
  82. // useful later when we check AVL trees).
  83. std::tuple<bool, bool, bool, address_t> AVL::check_avl(const std::vector<Node> &R,
  84. value_t node, value_t min_key = 0, value_t max_key = ~0)
  85. {
  86. if (node == 0) {
  87. return { true, true, true, 0};
  88. }
  89. const Node &n = R[node];
  90. value_t key = n.key.ashare;
  91. value_t left_ptr = getAVLLeftPtr(n.pointers).xshare;
  92. value_t right_ptr = getAVLRightPtr(n.pointers).xshare;
  93. auto [leftok, leftavlok, leftbbok, leftheight ] = check_avl(R, left_ptr, min_key, key);
  94. auto [rightok, rightavlok, rightbbok, rightheight ] = check_avl(R, right_ptr, key, max_key);
  95. address_t height = leftheight;
  96. if (rightheight > height) {
  97. height = rightheight;
  98. }
  99. height += 1;
  100. int heightgap = leftheight - rightheight;
  101. bool leftbal = (getLeftBal(n.pointers)).bshare;
  102. bool rightbal = (getRightBal(n.pointers)).bshare;
  103. bool avlok = (abs(heightgap)<2);
  104. bool bb_ok = false;
  105. if(heightgap==-1) {
  106. if(rightbal==1 && leftbal==0){
  107. bb_ok = true;
  108. }
  109. } else if(heightgap==1){
  110. if(leftbal==1 && rightbal==0){
  111. bb_ok = true;
  112. }
  113. } else if(heightgap==0){
  114. if(rightbal==0 && leftbal==0) {
  115. bb_ok = true;
  116. }
  117. }
  118. #ifdef DEBUG_BB
  119. if(bb_ok == false){
  120. printf("BB check failed at node with key = %ld\n", key);
  121. }
  122. #endif
  123. //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
  124. return { leftok && rightok && key >= min_key && key <= max_key,
  125. avlok && leftavlok && rightavlok, bb_ok && leftbbok && rightbbok, height};
  126. }
  127. void AVL::check_avl(MPCTIO &tio, yield_t &yield) {
  128. auto A = oram.flat(tio, yield);
  129. auto R = A.reconstruct();
  130. RegXS rec_root = this->root;
  131. if (tio.player() == 1) {
  132. tio.queue_peer(&(this->root), sizeof(this->root));
  133. } else {
  134. RegXS peer_root;
  135. tio.recv_peer(&peer_root, sizeof(peer_root));
  136. rec_root+= peer_root;
  137. }
  138. if (tio.player() == 0) {
  139. auto [ bst_ok, avl_ok, bb_ok, height ] = check_avl(R, rec_root.xshare);
  140. printf("BST structure %s\nAVL structure %s\nBalance Bits %s\nTree height = %u\n",
  141. bst_ok ? "ok" : "NOT OK", avl_ok ? "ok" : "NOT OK", bb_ok? "ok" : "NOT OK", height);
  142. }
  143. }
  144. /*
  145. Rotate: (gp = grandparent (if exists), p = parent, c = child)
  146. This rotates the p -> c link.
  147. gp gp
  148. \ \
  149. p --- Left rotate ---> c
  150. \ /
  151. c p
  152. gp gp
  153. \ \
  154. p --- Right rotate ---> c
  155. / \
  156. c p
  157. */
  158. void AVL::rotate(MPCTIO &tio, yield_t &yield, RegXS &gp_pointers, RegXS p_ptr,
  159. RegXS &p_pointers, RegXS c_ptr, RegXS &c_pointers, RegBS dir_gpp,
  160. RegBS dir_pc, RegBS isReal, RegBS F_gp) {
  161. bool player0 = tio.player()==0;
  162. RegXS gp_left = getAVLLeftPtr(gp_pointers);
  163. RegXS gp_right = getAVLRightPtr(gp_pointers);
  164. RegXS p_left = getAVLLeftPtr(p_pointers);
  165. RegXS p_right = getAVLRightPtr(p_pointers);
  166. RegXS c_left = getAVLLeftPtr(c_pointers);
  167. RegXS c_right = getAVLRightPtr(c_pointers);
  168. RegXS ptr_upd;
  169. // F_gpp: Flag to update gp -> p link, F_pc: Flag to update p -> c link
  170. RegBS F_gpp, F_pc_l, F_pc_r, F_gppr, F_gppl;
  171. // We care about !F_gp. If !F_gp, then we do the gp->p link updates.
  172. // Otherwise, we do NOT do any updates to gp-> p link;
  173. // since F_gp==1, implies gp does not exist and parent is root.
  174. if(player0)
  175. F_gp^=1;
  176. mpc_and(tio, yield, F_gpp, F_gp, isReal);
  177. // i) gp[dir_gpp] <-- c_ptr
  178. RegBS nt_dir_gpp = dir_gpp;
  179. if(player0)
  180. nt_dir_gpp^=1;
  181. mpc_select(tio, yield, ptr_upd, F_gpp, p_ptr, c_ptr);
  182. RegBS not_dir_pc_l = dir_pc, not_dir_pc_r = dir_pc;
  183. if(player0)
  184. not_dir_pc_r^=1;
  185. RegXS c_not_dir_pc; //c[!dir_pc]
  186. // ndpc_right: if not_dir_pc is right
  187. // ndpc_left: if not_dir_pc is left
  188. RegBS F_ndpc_right, F_ndpc_left;
  189. RegBS nt_dir_pc = dir_pc;
  190. if(player0)
  191. nt_dir_pc^=1;
  192. std::vector<coro_t> coroutines;
  193. coroutines.emplace_back(
  194. [&tio, &F_gppr, F_gpp, dir_gpp](yield_t &yield) {
  195. mpc_and(tio, yield, F_gppr, F_gpp, dir_gpp);
  196. });
  197. coroutines.emplace_back(
  198. [&tio, &F_gppl, F_gpp, nt_dir_gpp](yield_t &yield) {
  199. mpc_and(tio, yield, F_gppl, F_gpp, nt_dir_gpp);
  200. });
  201. // ii) p[dir_pc] <-- c[!dir_pc] and iii) c[!dir_pc] <-- p_ptr
  202. coroutines.emplace_back(
  203. [&tio, &F_ndpc_right, isReal, not_dir_pc_r](yield_t &yield) {
  204. mpc_and(tio, yield, F_ndpc_right, isReal, not_dir_pc_r);
  205. });
  206. coroutines.emplace_back(
  207. [&tio, &F_ndpc_left, isReal, not_dir_pc_l](yield_t &yield) {
  208. mpc_and(tio, yield, F_ndpc_left, isReal, not_dir_pc_l);
  209. });
  210. coroutines.emplace_back(
  211. [&tio, &F_pc_l, dir_pc, isReal](yield_t &yield) {
  212. mpc_and(tio, yield, F_pc_l, dir_pc, isReal);
  213. });
  214. coroutines.emplace_back(
  215. [&tio, &F_pc_r, nt_dir_pc, isReal](yield_t &yield) {
  216. mpc_and(tio, yield, F_pc_r, nt_dir_pc, isReal);
  217. });
  218. run_coroutines(tio, coroutines);
  219. run_coroutines(tio, [&tio, &gp_right, F_gppr, ptr_upd](yield_t &yield)
  220. { mpc_select(tio, yield, gp_right, F_gppr, gp_right, ptr_upd);},
  221. [&tio, &gp_left, F_gppl, ptr_upd](yield_t &yield)
  222. { mpc_select(tio, yield, gp_left, F_gppl, gp_left, ptr_upd);},
  223. [&tio, &c_not_dir_pc, F_ndpc_right, c_right](yield_t &yield)
  224. { mpc_select(tio, yield, c_not_dir_pc, F_ndpc_right, c_not_dir_pc, c_right, AVL_PTR_SIZE);});
  225. //[&tio, &c_not_dir_pc, F_ndpc_left, c_left](yield_t &yield)
  226. mpc_select(tio, yield, c_not_dir_pc, F_ndpc_left, c_not_dir_pc, c_left, AVL_PTR_SIZE);
  227. // ii) p[dir_pc] <-- c[!dir_pc]
  228. // iii): c[!dir_pc] <-- p_ptr
  229. run_coroutines(tio, [&tio, &p_left, F_ndpc_right, c_not_dir_pc](yield_t &yield)
  230. { mpc_select(tio, yield, p_left, F_ndpc_right, p_left, c_not_dir_pc, AVL_PTR_SIZE);},
  231. [&tio, &p_right, F_ndpc_left, c_not_dir_pc](yield_t &yield)
  232. { mpc_select(tio, yield, p_right, F_ndpc_left, p_right, c_not_dir_pc, AVL_PTR_SIZE);},
  233. [&tio, &ptr_upd, isReal, c_not_dir_pc, p_ptr](yield_t &yield)
  234. { mpc_select(tio, yield, ptr_upd, isReal, c_not_dir_pc, p_ptr, AVL_PTR_SIZE);});
  235. run_coroutines(tio, [&tio, &c_left, F_pc_l, ptr_upd](yield_t &yield)
  236. { mpc_select(tio, yield, c_left, F_pc_l, c_left, ptr_upd, AVL_PTR_SIZE);},
  237. [&tio, &c_right, F_pc_r, ptr_upd](yield_t &yield)
  238. { mpc_select(tio, yield, c_right, F_pc_r, c_right, ptr_upd, AVL_PTR_SIZE);});
  239. setAVLLeftPtr(gp_pointers, gp_left);
  240. setAVLRightPtr(gp_pointers, gp_right);
  241. setAVLLeftPtr(p_pointers, p_left);
  242. setAVLRightPtr(p_pointers, p_right);
  243. setAVLLeftPtr(c_pointers, c_left);
  244. setAVLRightPtr(c_pointers, c_right);
  245. }
  246. /*
  247. In updateBalanceDel, the position of imbalance, and shift direction for both
  248. cases are inverted, since a bal_upd on a child implies it reduced height.
  249. If F_rs: (bal_upd & right_child)
  250. imbalance, bal_l, balanced, bal_r
  251. And then left shift to get imbalance bit, and new bal_l, bal_r bits
  252. else if F_ls: (bal_upd & left_child)
  253. bal_l, balanced, bal_r, imbalance
  254. And then right shift to get imbalance bit, and new bal_l, bal_r bits
  255. */
  256. std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceDel(MPCTIO &tio, yield_t &yield,
  257. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
  258. bool player0 = tio.player()==0;
  259. RegBS s0;
  260. RegBS F_rs, F_ls, balanced, imbalance, not_imbalance;
  261. RegBS nt_child_dir = child_dir;
  262. if(player0) {
  263. nt_child_dir^=1;
  264. }
  265. // balanced = is the node currently balanced
  266. balanced = bal_l ^ bal_r;
  267. if(player0) {
  268. balanced^=1;
  269. }
  270. //F_ls (Flag left shift) <- child_dir & bal_upd
  271. //F_rs (Flag right shift) <- !child_dir & bal_upd
  272. run_coroutines(tio, [&tio, &F_ls, child_dir, bal_upd](yield_t &yield)
  273. { mpc_and(tio, yield, F_ls, child_dir, bal_upd);},
  274. [&tio, &F_rs, nt_child_dir, bal_upd](yield_t &yield)
  275. { mpc_and(tio, yield, F_rs, nt_child_dir, bal_upd);});
  276. // Left shift if F_ls
  277. run_coroutines(tio, [&tio, &imbalance, F_ls, bal_l](yield_t &yield)
  278. { mpc_select(tio, yield, imbalance, F_ls, imbalance, bal_l);},
  279. [&tio, &bal_l, F_ls, balanced](yield_t &yield)
  280. { mpc_select(tio, yield, bal_l, F_ls, bal_l, balanced);},
  281. [&tio, &balanced, F_ls, bal_r](yield_t &yield)
  282. { mpc_select(tio, yield, balanced, F_ls, balanced, bal_r);},
  283. [&tio, &bal_r, F_ls, s0](yield_t &yield)
  284. { mpc_select(tio, yield, bal_r, F_ls, bal_r, s0);});
  285. // Right shift if F_rs
  286. run_coroutines(tio, [&tio, &imbalance, F_rs, bal_r](yield_t &yield)
  287. { mpc_select(tio, yield, imbalance, F_rs, imbalance, bal_r);},
  288. [&tio, &bal_r, F_rs, balanced](yield_t &yield)
  289. { mpc_select(tio, yield, bal_r, F_rs, bal_r, balanced);},
  290. [&tio, &balanced, F_rs, bal_l](yield_t &yield)
  291. { mpc_select(tio, yield, balanced, F_rs, balanced, bal_l);},
  292. [&tio, &bal_l, F_rs, s0](yield_t &yield)
  293. { mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);});
  294. // if(bal_upd) and not imbalance bal_upd<-0
  295. /*
  296. RegBS bu0;
  297. not_imbalance = imbalance;
  298. if(player0){
  299. not_imbalance^=1;
  300. }
  301. mpc_and(tio, yield, bu0, bal_upd, not_imbalance);
  302. mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
  303. */
  304. // if(bal_upd) and this node turns balanced, the height has decreased, so continue propogating bal_upd.
  305. // if(bal_upd) and node turns imbalanced, fixImbalance will update bal_upd correctly.
  306. // if(bal_upd) and node moves out of balanced to left/right heavy, the height of this subtree has not changed,
  307. // so don't propogate bal_upd.
  308. // if(bal_upd && bal_l ^ bal_r
  309. RegBS LR_heavy, bu0;
  310. LR_heavy = bal_l ^ bal_r;
  311. mpc_and(tio, yield, bu0, bal_upd, LR_heavy);
  312. mpc_select(tio, yield, bal_upd, bu0, bal_upd, s0);
  313. return {bal_l, bal_r, bal_upd, imbalance};
  314. }
  315. /*
  316. If F_rs: (bal_upd & right_child)
  317. bal_l, balanced, bal_r, imbalance
  318. And then right shift to get imbalance bit, and new bal_l, bal_r bits
  319. else if F_ls: (bal_upd & left_child)
  320. imbalance, bal_l, balanced, bal_r
  321. And then left shift to get imbalance bit, and new bal_l, bal_r bits
  322. */
  323. std::tuple<RegBS, RegBS, RegBS, RegBS> AVL::updateBalanceIns(MPCTIO &tio, yield_t &yield,
  324. RegBS bal_l, RegBS bal_r, RegBS bal_upd, RegBS child_dir) {
  325. bool player0 = tio.player()==0;
  326. RegBS s1, s0;
  327. s1.set(tio.player()==1);
  328. RegBS F_rs, F_ls, balanced, imbalance, nt_child_dir;
  329. // balanced = is the node currently balanced
  330. balanced = bal_l ^ bal_r;
  331. nt_child_dir = child_dir;
  332. if(player0){
  333. nt_child_dir^=1;
  334. }
  335. if(player0) {
  336. balanced^=1;
  337. }
  338. run_coroutines(tio, [&tio, &F_rs, child_dir, bal_upd](yield_t &yield)
  339. { //F_rs (Flag right shift) <- child_dir & bal_upd
  340. mpc_and(tio, yield, F_rs, child_dir, bal_upd);},
  341. [&tio, &F_ls, nt_child_dir, bal_upd](yield_t &yield)
  342. { //F_ls (Flag left shift) <- !child_dir & bal_upd
  343. mpc_and(tio, yield, F_ls, nt_child_dir, bal_upd);});
  344. std::vector<coro_t> coroutines;
  345. // Right shift if child_dir = 1 & bal_upd = 1
  346. coroutines.emplace_back(
  347. [&tio, &imbalance, F_rs, bal_r, balanced](yield_t &yield) {
  348. mpc_select(tio, yield, imbalance, F_rs, imbalance, bal_r);
  349. });
  350. coroutines.emplace_back(
  351. [&tio, &bal_r, F_rs, balanced](yield_t &yield) {
  352. mpc_select(tio, yield, bal_r, F_rs, bal_r, balanced);
  353. });
  354. coroutines.emplace_back(
  355. [&tio, &balanced, F_rs, bal_l](yield_t &yield) {
  356. mpc_select(tio, yield, balanced, F_rs, balanced, bal_l);
  357. });
  358. coroutines.emplace_back(
  359. [&tio, &bal_l, F_rs, s0](yield_t &yield) {
  360. mpc_select(tio, yield, bal_l, F_rs, bal_l, s0);
  361. });
  362. run_coroutines(tio, coroutines);
  363. coroutines.clear();
  364. // Left shift if child_dir = 0 & bal_upd = 1
  365. coroutines.emplace_back(
  366. [&tio, &imbalance, F_ls, bal_l] (yield_t &yield) {
  367. mpc_select(tio, yield, imbalance, F_ls, imbalance, bal_l);
  368. });
  369. coroutines.emplace_back(
  370. [&tio, &bal_l, F_ls, balanced] (yield_t &yield) {
  371. mpc_select(tio, yield, bal_l, F_ls, bal_l, balanced);
  372. });
  373. coroutines.emplace_back(
  374. [&tio, &balanced, F_ls, bal_r] (yield_t &yield) {
  375. mpc_select(tio, yield, balanced, F_ls, balanced, bal_r);
  376. });
  377. coroutines.emplace_back(
  378. [&tio, &bal_r, F_ls, s0](yield_t &yield) {
  379. mpc_select(tio, yield, bal_r, F_ls, bal_r, s0);
  380. });
  381. run_coroutines(tio, coroutines);
  382. // bal_upd' <- bal_upd ^ imbalance
  383. RegBS F_bu0;
  384. mpc_and(tio, yield, F_bu0, bal_upd, balanced);
  385. mpc_select(tio, yield, bal_upd, F_bu0, bal_upd, s0);
  386. mpc_select(tio, yield, bal_upd, imbalance, bal_upd, s0);
  387. return {bal_l, bal_r, bal_upd, imbalance};
  388. }
  389. std::tuple<RegBS, RegBS, RegXS, RegBS> AVL::insert(MPCTIO &tio, yield_t &yield, RegXS ptr, RegXS ins_addr,
  390. RegAS insert_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy, avl_insert_return *ret) {
  391. if(TTL==0) {
  392. RegBS z;
  393. return {z, z, z, z};
  394. }
  395. RegBS isReal = isDummy ^ (tio.player());
  396. Node cnode;
  397. #ifdef OPT_ON
  398. nbits_t width = ceil(log2(cur_max_index+1));
  399. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
  400. cnode = A[oidx];
  401. #else
  402. cnode = A[ptr];
  403. #endif
  404. RegXS old_pointers = cnode.pointers;
  405. // Compare key
  406. auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insert_key);
  407. // Depending on [lteq, gt] select the next_ptr
  408. RegXS next_ptr;
  409. RegXS left = getAVLLeftPtr(cnode.pointers);
  410. RegXS right = getAVLRightPtr(cnode.pointers);
  411. RegBS bal_l = getLeftBal(cnode.pointers);
  412. RegBS bal_r = getRightBal(cnode.pointers);
  413. /*
  414. size_t rec_left = reconstruct_RegXS(tio, yield, left);
  415. size_t rec_right = reconstruct_RegXS(tio, yield, right);
  416. size_t rec_key = reconstruct_RegAS(tio, yield, cnode.key);
  417. printf("\n\n(Before recursing) Key = %ld\n", rec_key);
  418. printf("rec_left = %ld, rec_right = %ld\n", rec_left, rec_right);
  419. */
  420. mpc_select(tio, yield, next_ptr, gt, left, right, AVL_PTR_SIZE);
  421. /*
  422. size_t rec_next_ptr = reconstruct_RegXS(tio, yield, next_ptr);
  423. printf("rec_next_ptr = %ld\n", rec_next_ptr);
  424. */
  425. CDPF dpf = tio.cdpf(yield);
  426. size_t &aes_ops = tio.aes_ops();
  427. // F_z: Check if this is last node on path
  428. RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
  429. RegBS F_i;
  430. // F_i: If this was last node on path (F_z), and isReal insert.
  431. mpc_and(tio, yield, F_i, (isReal), F_z);
  432. isDummy^=F_i;
  433. auto [bal_upd, F_gp, prev_node, prev_dir] = insert(tio, yield,
  434. next_ptr, ins_addr, insert_key, A, TTL-1, isDummy, ret);
  435. /*
  436. rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
  437. rec_F_gp = reconstruct_RegBS(tio, yield, F_gp);
  438. printf("Insert returns: rec_bal_upd = %d, rec_F_gp = %d\n",
  439. rec_bal_upd, rec_F_gp);
  440. size_t rec_ptr = reconstruct_RegXS(tio, yield, ptr);
  441. printf("\nrec_ptr = %ld\n", rec_ptr);
  442. */
  443. // Update balance
  444. // If we inserted at this level (F_i), bal_upd is set to 1
  445. mpc_or(tio, yield, bal_upd, bal_upd, F_i);
  446. auto [new_bal_l, new_bal_r, new_bal_upd, imbalance] = updateBalanceIns(tio, yield, bal_l, bal_r, bal_upd, gt);
  447. // Store if this insert triggers an imbalance
  448. ret->imbalance ^= imbalance;
  449. std::vector<coro_t> coroutines;
  450. // Save grandparent pointer
  451. coroutines.emplace_back(
  452. [&tio, &ret, F_gp, ptr](yield_t &yield) {
  453. mpc_select(tio, yield, ret->gp_node, F_gp, ret->gp_node, ptr, AVL_PTR_SIZE);
  454. });
  455. coroutines.emplace_back(
  456. [&tio, &ret, F_gp, gt](yield_t &yield) {
  457. mpc_select(tio, yield, ret->dir_gpp, F_gp, ret->dir_gpp, gt);
  458. });
  459. // Save parent pointer
  460. coroutines.emplace_back(
  461. [&tio, &ret, imbalance, ptr](yield_t &yield) {
  462. mpc_select(tio, yield, ret->p_node, imbalance, ret->p_node, ptr, AVL_PTR_SIZE);
  463. });
  464. coroutines.emplace_back(
  465. [&tio, &ret, imbalance, gt](yield_t &yield) {
  466. mpc_select(tio, yield, ret->dir_pc, imbalance, ret->dir_pc, gt);
  467. });
  468. // Save child pointer
  469. coroutines.emplace_back(
  470. [&tio, &ret, imbalance, prev_node](yield_t &yield) {
  471. mpc_select(tio, yield, ret->c_node, imbalance, ret->c_node, prev_node, AVL_PTR_SIZE);
  472. });
  473. coroutines.emplace_back(
  474. [&tio, &ret, imbalance, prev_dir](yield_t &yield) {
  475. mpc_select(tio, yield, ret->dir_cn, imbalance, ret->dir_cn, prev_dir);
  476. });
  477. run_coroutines(tio, coroutines);
  478. // Store new_bal_l and new_bal_r for this node
  479. setLeftBal(cnode.pointers, new_bal_l);
  480. setRightBal(cnode.pointers, new_bal_r);
  481. // We have to write the node pointers anyway to handle balance updates,
  482. // so we perform insertion along with it by modifying pointers appropriately.
  483. RegBS F_ir, F_il;
  484. run_coroutines(tio, [&tio, &F_ir, F_i, gt](yield_t &yield)
  485. { mpc_and(tio, yield, F_ir, F_i, gt); },
  486. [&tio, &F_il, F_i, lteq](yield_t &yield)
  487. { mpc_and(tio, yield, F_il, F_i, lteq); });
  488. run_coroutines(tio, [&tio, &left, F_il, ins_addr](yield_t &yield)
  489. { mpc_select(tio, yield, left, F_il, left, ins_addr);},
  490. [&tio, &right, F_ir, ins_addr](yield_t &yield)
  491. { mpc_select(tio, yield, right, F_ir, right, ins_addr);});
  492. setAVLLeftPtr(cnode.pointers, left);
  493. setAVLRightPtr(cnode.pointers, right);
  494. /*
  495. bool rec_F_ir, rec_F_il;
  496. rec_F_ir = reconstruct_RegBS(tio, yield, F_ir);
  497. rec_F_il = reconstruct_RegBS(tio, yield, F_il);
  498. rec_left = reconstruct_RegXS(tio, yield, left);
  499. rec_right = reconstruct_RegXS(tio, yield, right);
  500. printf("(After recursing) F_il = %d, left = %ld, F_ir = %d, right = %ld\n",
  501. rec_F_il, rec_left, rec_F_ir, rec_right);
  502. */
  503. #ifdef OPT_ON
  504. A[oidx].NODE_POINTERS+=(cnode.pointers - old_pointers);
  505. #else
  506. A[ptr].NODE_POINTERS = cnode.pointers;
  507. #endif
  508. // s0 = shares of 0
  509. RegBS s0;
  510. // Update F_gp flag: If there was an imbalance then we set this to store
  511. // the grandparent node (node in the level above) into the ret_struct
  512. mpc_select(tio, yield, F_gp, imbalance, s0, imbalance);
  513. return {new_bal_upd, F_gp, ptr, gt};
  514. }
  515. // Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
  516. void AVL::insert(MPCTIO &tio, yield_t &yield, const Node &node) {
  517. bool player0 = tio.player()==0;
  518. // If there are no items in tree. Make this new item the root.
  519. if(num_items==0) {
  520. auto A = oram.flat(tio, yield);
  521. Node zero;
  522. A[0] = zero;
  523. A[1] = node;
  524. (root).set(1*tio.player());
  525. num_items++;
  526. cur_max_index++;
  527. return;
  528. } else {
  529. // Insert node into next free slot in the ORAM
  530. int new_id;
  531. RegXS insert_address;
  532. num_items++;
  533. int TTL = AVL_TTL(num_items);
  534. bool insertAtEmptyLocation = (numEmptyLocations() > 0);
  535. if(!insertAtEmptyLocation) {
  536. cur_max_index++;
  537. }
  538. auto A = oram.flat(tio, yield, 0, cur_max_index+1);
  539. if(insertAtEmptyLocation) {
  540. insert_address = empty_locations.back();
  541. empty_locations.pop_back();
  542. A[insert_address] = node;
  543. } else {
  544. new_id = num_items;
  545. A[new_id] = node;
  546. insert_address.set(new_id * tio.player());
  547. }
  548. RegBS isDummy;
  549. avl_insert_return ret;
  550. RegAS insert_key = node.key;
  551. // Recursive insert function
  552. auto [bal_upd, F_gp, prev_node, prev_dir] = insert(tio, yield, root,
  553. insert_address, insert_key, A, TTL, isDummy, &ret);
  554. /*
  555. // Debug code
  556. bool rec_bal_upd, rec_F_gp, ret_dir_pc, ret_dir_cn;
  557. rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
  558. rec_F_gp = reconstruct_RegBS(tio, yield, F_gp);
  559. ret_dir_pc = reconstruct_RegBS(tio, yield, ret.dir_pc);
  560. ret_dir_cn = reconstruct_RegBS(tio, yield, ret.dir_cn);
  561. printf("(Top level) Insert returns: rec_bal_upd = %d, rec_F_gp = %d\n",
  562. rec_bal_upd, rec_F_gp);
  563. printf("(Top level) Insert returns: ret.dir_pc = %d, rt.dir_cn = %d\n",
  564. ret_dir_pc, ret_dir_cn);
  565. */
  566. // Perform balance procedure
  567. RegXS gp_pointers, parent_pointers, child_pointers;
  568. std::vector<coro_t> coroutines;
  569. #ifdef OPT_ON
  570. nbits_t width = ceil(log2(cur_max_index+1));
  571. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gp(tio, yield, ret.gp_node, width);
  572. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_p(tio, yield, ret.p_node, width);
  573. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_c(tio, yield, ret.c_node, width);
  574. coroutines.emplace_back(
  575. [&tio, &A, &oidx_gp, &gp_pointers](yield_t &yield) {
  576. auto acont = A.context(yield);
  577. gp_pointers = acont[oidx_gp].NODE_POINTERS;});
  578. coroutines.emplace_back(
  579. [&tio, &A, &oidx_p, &parent_pointers](yield_t &yield) {
  580. auto acont = A.context(yield);
  581. parent_pointers = acont[oidx_p].NODE_POINTERS;});
  582. coroutines.emplace_back(
  583. [&tio, &A, &oidx_c, &child_pointers](yield_t &yield) {
  584. auto acont = A.context(yield);
  585. child_pointers = acont[oidx_c].NODE_POINTERS;});
  586. run_coroutines(tio, coroutines);
  587. coroutines.clear();
  588. /*
  589. gp_pointers = A[oidx_gp].NODE_POINTERS;
  590. parent_pointers = A[oidx_p].NODE_POINTERS;
  591. child_pointers = A[oidx_c].NODE_POINTERS;
  592. */
  593. /*
  594. size_t rec_gp_key = reconstruct_RegAS(tio, yield, A[oidx_gp].NODE_KEY);
  595. size_t rec_p_key = reconstruct_RegAS(tio, yield, A[oidx_p].NODE_KEY);
  596. size_t rec_c_key = reconstruct_RegAS(tio, yield, A[oidx_c].NODE_KEY);
  597. size_t rec_gp_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_gp].NODE_POINTERS));
  598. size_t rec_gp_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_gp].NODE_POINTERS));
  599. size_t rec_p_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_p].NODE_POINTERS));
  600. size_t rec_p_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_p].NODE_POINTERS));
  601. size_t rec_c_lptr = reconstruct_RegXS(tio, yield, getAVLLeftPtr(A[oidx_c].NODE_POINTERS));
  602. size_t rec_c_rptr = reconstruct_RegXS(tio, yield, getAVLRightPtr(A[oidx_c].NODE_POINTERS));
  603. printf("Reconstructed:\ngp_key = %ld, gp_left_ptr = %ld, gp_right_ptr = %ld\n",
  604. rec_gp_key, rec_gp_lptr, rec_gp_rptr);
  605. printf("p_key = %ld, p_left_ptr = %ld, p_right_ptr = %ld\n",
  606. rec_p_key, rec_p_lptr, rec_p_rptr);
  607. printf("c_key = %ld, c_left_ptr = %ld, c_right_ptr = %ld\n",
  608. rec_c_key, rec_c_lptr, rec_c_rptr);
  609. */
  610. #else
  611. gp_pointers = A[ret.gp_node].NODE_POINTERS;
  612. parent_pointers = A[ret.p_node].NODE_POINTERS;
  613. child_pointers = A[ret.c_node].NODE_POINTERS;
  614. #endif
  615. // n_node (child's next node)
  616. RegXS child_left = getAVLLeftPtr(child_pointers);
  617. RegXS child_right = getAVLRightPtr(child_pointers);
  618. RegXS n_node, n_pointers;
  619. mpc_select(tio, yield, n_node, ret.dir_cn, child_left, child_right, AVL_PTR_SIZE);
  620. #ifdef OPT_ON
  621. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_n(tio, yield, n_node, width);
  622. n_pointers = A[oidx_n].NODE_POINTERS;
  623. #else
  624. n_pointers = A[n_node].NODE_POINTERS;
  625. #endif
  626. RegXS old_gp_pointers, old_parent_pointers, old_child_pointers, old_n_pointers;
  627. #ifdef OPT_ON
  628. old_gp_pointers = gp_pointers;
  629. old_parent_pointers = parent_pointers;
  630. old_child_pointers = child_pointers;
  631. old_n_pointers = n_pointers;
  632. #endif
  633. // F_dr = (dir_pc != dir_cn) : i.e., double rotation case if
  634. // (parent->child) and (child->new_node) are not in the same direction
  635. RegBS F_dr = (ret.dir_pc) ^ (ret.dir_cn);
  636. /* Flags: F_cn_rot = child->node rotate
  637. F_ur = update root.
  638. In case of an imbalance we have to always rotate p->c link. (L or R case)
  639. In case of an imbalance where p->c and c->n are in different directions, we have
  640. to perform a double rotation (LR or RL case). In such cases, first rotate
  641. c->n link, and then p->c link
  642. (Note: in the second rotation c is actually n, since the the first rotation
  643. swaps their positions)
  644. */
  645. RegBS F_cn_rot, F_ur, s0;
  646. run_coroutines(tio, [&tio, &F_ur, F_gp, ret](yield_t &yield)
  647. {mpc_and(tio, yield, F_ur, F_gp, ret.imbalance);},
  648. [&tio, &F_cn_rot, ret, F_dr](yield_t &yield)
  649. {mpc_and(tio, yield, F_cn_rot, ret.imbalance, F_dr);});
  650. // Get the n children information for 2nd rotate fix before rotations happen.
  651. RegBS n_bal_l, n_bal_r;
  652. RegXS n_l = getAVLLeftPtr(n_pointers);
  653. RegXS n_r = getAVLRightPtr(n_pointers);
  654. n_bal_l = getLeftBal(n_pointers);
  655. n_bal_r = getRightBal(n_pointers);
  656. // First rotation: c->n link
  657. rotate(tio, yield, parent_pointers, ret.c_node, child_pointers, n_node,
  658. n_pointers, ret.dir_pc, ret.dir_cn, F_cn_rot, s0);
  659. // If F_cn_rot, i.e. we did first rotation. Then c and n need to swap before the second rotate.
  660. RegXS new_child_pointers, new_child;
  661. run_coroutines(tio, [&tio, &new_child_pointers, F_cn_rot, child_pointers, n_pointers] (yield_t &yield)
  662. {mpc_select(tio, yield, new_child_pointers, F_cn_rot, child_pointers, n_pointers);},
  663. [&tio, &new_child, F_cn_rot, ret, n_node](yield_t &yield)
  664. {mpc_select(tio, yield, new_child, F_cn_rot, ret.c_node, n_node, AVL_PTR_SIZE);});
  665. // Second rotation: p->c link
  666. rotate(tio, yield, gp_pointers, ret.p_node, parent_pointers, new_child,
  667. new_child_pointers, ret.dir_gpp, ret.dir_pc, ret.imbalance, F_gp);
  668. // Set parent and child balances to 0 if there was an imbalance.
  669. // parent balances are already set to 0 from updateBalanceIns
  670. RegBS temp_bal, p_bal_l, p_bal_r, p_bal_ndpc;
  671. RegBS c_bal_l, c_bal_r, c_bal_dpc, n_bal_dpc, n_bal_ndpc;
  672. p_bal_l = getLeftBal(parent_pointers);
  673. p_bal_r = getRightBal(parent_pointers);
  674. run_coroutines(tio, [&tio, &child_pointers, F_cn_rot, new_child_pointers] (yield_t &yield)
  675. {mpc_select(tio, yield, child_pointers, F_cn_rot, new_child_pointers, child_pointers);},
  676. [&tio, &n_pointers, F_cn_rot, new_child_pointers] (yield_t &yield)
  677. {mpc_select(tio, yield, n_pointers, F_cn_rot, n_pointers, new_child_pointers);});
  678. c_bal_l = getLeftBal(child_pointers);
  679. c_bal_r = getRightBal(child_pointers);
  680. run_coroutines(tio, [&tio, &c_bal_l, ret, s0] (yield_t &yield)
  681. {mpc_select(tio, yield, c_bal_l, ret.imbalance, c_bal_l, s0);},
  682. [&tio, &c_bal_r, ret, s0] (yield_t &yield)
  683. {mpc_select(tio, yield, c_bal_r, ret.imbalance, c_bal_r, s0);});
  684. /* In the double rotation case: balance of c and p have a tweak
  685. p_bal_ndpc <- !(n_bal_ndpc)
  686. c_bal_dpc <- !(n_bal_dpc) */
  687. CDPF cdpf = tio.cdpf(yield);
  688. size_t &aes_ops = tio.aes_ops();
  689. RegBS n_l0, n_r0;
  690. run_coroutines(tio, [&tio, &n_l0, n_l, &cdpf, &aes_ops] (yield_t &yield)
  691. {n_l0 = cdpf.is_zero(tio, yield, n_l, aes_ops);},
  692. [&tio, &n_r0, n_r, &cdpf, &aes_ops] (yield_t &yield)
  693. {n_r0 = cdpf.is_zero(tio, yield, n_r, aes_ops);});
  694. RegBS p_c_update, n_has_children;
  695. // n_has_children = !(n_l0 & n_r0)
  696. mpc_and(tio, yield, n_has_children, n_l0, n_r0);
  697. if(player0) {
  698. n_has_children^=1;
  699. }
  700. run_coroutines(tio, [&tio, &p_c_update, F_cn_rot, n_has_children] (yield_t &yield)
  701. {mpc_and(tio, yield, p_c_update, F_cn_rot, n_has_children);},
  702. [&tio, &n_bal_ndpc, ret, n_bal_l, n_bal_r] (yield_t &yield)
  703. {mpc_select(tio, yield, n_bal_ndpc, ret.dir_pc, n_bal_r, n_bal_l);},
  704. [&tio, &n_bal_dpc, ret, n_bal_l, n_bal_r] (yield_t &yield)
  705. {mpc_select(tio, yield, n_bal_dpc, ret.dir_pc, n_bal_l, n_bal_r);},
  706. [&tio, &p_bal_ndpc, ret, p_bal_r, p_bal_l] (yield_t &yield)
  707. {mpc_select(tio, yield, p_bal_ndpc, ret.dir_pc, p_bal_r, p_bal_l);});
  708. // !n_bal_ndpc, !n_bal_dpc
  709. if(player0) {
  710. n_bal_ndpc^=1;
  711. n_bal_dpc^=1;
  712. }
  713. run_coroutines(tio, [&tio, &p_bal_ndpc, p_c_update, n_bal_ndpc] (yield_t &yield)
  714. {mpc_select(tio, yield, p_bal_ndpc, p_c_update, p_bal_ndpc, n_bal_ndpc);},
  715. [&tio, &c_bal_dpc, p_c_update, n_bal_dpc] (yield_t &yield)
  716. {mpc_select(tio, yield, c_bal_dpc, p_c_update, c_bal_dpc, n_bal_dpc);});
  717. coroutines.emplace_back([&tio, &p_bal_r, ret, p_bal_ndpc] (yield_t &yield)
  718. {mpc_select(tio, yield, p_bal_r, ret.dir_pc, p_bal_ndpc, p_bal_r);});
  719. coroutines.emplace_back([&tio, &p_bal_l, ret, p_bal_ndpc] (yield_t &yield)
  720. {mpc_select(tio, yield, p_bal_l, ret.dir_pc, p_bal_l, p_bal_ndpc);});
  721. coroutines.emplace_back([&tio, &c_bal_r, ret, c_bal_dpc] (yield_t &yield)
  722. {mpc_select(tio, yield, c_bal_r, ret.dir_pc, c_bal_r, c_bal_dpc);});
  723. coroutines.emplace_back([&tio, &c_bal_l, ret, c_bal_dpc] (yield_t &yield)
  724. {mpc_select(tio, yield, c_bal_l, ret.dir_pc, c_bal_dpc, c_bal_l);});
  725. // If double rotation (LR/RL) case, n ends up with 0 balance.
  726. // In all other cases, n's balance remains unaffected by rotation during insertion.
  727. coroutines.emplace_back([&tio, &n_bal_l, F_cn_rot, s0] (yield_t &yield)
  728. {mpc_select(tio, yield, n_bal_l, F_cn_rot, n_bal_l, s0);});
  729. coroutines.emplace_back([&tio, &n_bal_r, F_cn_rot, s0] (yield_t &yield)
  730. {mpc_select(tio, yield, n_bal_r, F_cn_rot, n_bal_r, s0);});
  731. run_coroutines(tio, coroutines);
  732. setLeftBal(parent_pointers, p_bal_l);
  733. setRightBal(parent_pointers, p_bal_r);
  734. setLeftBal(child_pointers, c_bal_l);
  735. setRightBal(child_pointers, c_bal_r);
  736. setLeftBal(n_pointers, n_bal_l);
  737. setRightBal(n_pointers, n_bal_r);
  738. // Write back update pointers and balances into gp, p, c, and n
  739. #ifdef OPT_ON
  740. A[oidx_n].NODE_POINTERS+=(n_pointers - old_n_pointers);
  741. A[oidx_c].NODE_POINTERS+=(child_pointers - old_child_pointers);
  742. A[oidx_p].NODE_POINTERS+=(parent_pointers - old_parent_pointers);
  743. A[oidx_gp].NODE_POINTERS+=(gp_pointers - old_gp_pointers);
  744. #else
  745. A[ret.c_node].NODE_POINTERS = child_pointers;
  746. A[ret.p_node].NODE_POINTERS = parent_pointers;
  747. A[ret.gp_node].NODE_POINTERS = gp_pointers;
  748. A[n_node].NODE_POINTERS = n_pointers;
  749. #endif
  750. // Handle root pointer update (if F_ur is true)
  751. // If F_ur and we did a double rotation: root <-- new node
  752. // If F_ur and we did a single rotation: root <-- child node
  753. RegXS temp_root = root;
  754. run_coroutines(tio, [&tio, &temp_root, F_ur, ret] (yield_t &yield)
  755. {mpc_select(tio, yield, temp_root, F_ur, temp_root, ret.c_node, AVL_PTR_SIZE);},
  756. [&tio, &F_ur, F_gp, F_dr] (yield_t &yield)
  757. {mpc_and(tio, yield, F_ur, F_gp, F_dr);});
  758. mpc_select(tio, yield, temp_root, F_ur, temp_root, n_node, AVL_PTR_SIZE);
  759. root = temp_root;
  760. }
  761. }
  762. bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>::Flat &A,
  763. int TTL, RegBS isDummy, Node *ret_node) {
  764. if(TTL==0) {
  765. // Reconstruct and return isDummy
  766. // If we found the key, then isDummy will be true
  767. bool found = reconstruct_RegBS(tio, yield, isDummy);
  768. return found;
  769. }
  770. RegBS isNotDummy = isDummy ^ (tio.player());
  771. Node cnode = A[ptr];
  772. // Compare key
  773. CDPF cdpf = tio.cdpf(yield);
  774. auto [lt, eq, gt] = cdpf.compare(tio, yield, key - cnode.key, tio.aes_ops());
  775. // Depending on [lteq, gt] select the next ptr/index as
  776. // upper 32 bits of cnode.pointers if lteq
  777. // lower 32 bits of cnode.pointers if gt
  778. RegXS left = getAVLLeftPtr(cnode.pointers);
  779. RegXS right = getAVLRightPtr(cnode.pointers);
  780. RegXS next_ptr;
  781. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  782. RegBS F_found;
  783. // If we haven't found the key yet, and the lookup matches the current node key,
  784. // then we found the node to return
  785. mpc_and(tio, yield, F_found, isNotDummy, eq);
  786. mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key);
  787. mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value);
  788. isDummy^=F_found;
  789. bool found = lookup(tio, yield, next_ptr, key, A, TTL-1, isDummy, ret_node);
  790. return found;
  791. }
  792. bool AVL::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
  793. auto A = oram.flat(tio, yield);
  794. RegBS isDummy;
  795. bool found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
  796. return found;
  797. }
  798. void AVL::updateChildPointers(MPCTIO &tio, yield_t &yield, RegXS &left, RegXS &right,
  799. RegBS c_prime, avl_del_return ret_struct) {
  800. bool player0 = tio.player()==0;
  801. RegBS F_rr; // Flag to resolve F_r by updating right child ptr
  802. RegBS F_rl; // Flag to resolve F_r by updating left child ptr
  803. RegBS nt_c_prime = c_prime;
  804. if(player0)
  805. nt_c_prime^=1;
  806. run_coroutines(tio, [&tio, &F_rr, c_prime, ret_struct](yield_t &yield)
  807. { mpc_and(tio, yield, F_rr, c_prime, ret_struct.F_r);},
  808. [&tio, &F_rl, nt_c_prime, ret_struct](yield_t &yield)
  809. { mpc_and(tio, yield, F_rl, nt_c_prime, ret_struct.F_r);});
  810. run_coroutines(tio, [&tio, &right, F_rr, ret_struct](yield_t &yield)
  811. { mpc_select(tio, yield, right, F_rr, right, ret_struct.ret_ptr);},
  812. [&tio, &left, F_rl, ret_struct](yield_t &yield)
  813. { mpc_select(tio, yield, left, F_rl, left, ret_struct.ret_ptr);});
  814. }
  815. // Perform rotations if imbalance (else dummy rotations)
  816. /*
  817. For capturing both the symmetric L and R cases of rotations, we'll capture directions with
  818. dpc = dir_pc = direction from parent to child, and
  819. ndpc = not(dir_pc)
  820. When we travelled down the stack, we went from p->c. But in deletions to handle any imbalance
  821. we look at c's sibling cs (child's sibling). And the rotation is between p and cs if there
  822. was an imbalance at p, and perhaps even cs and it's child (the child in dir_pc, as that's the
  823. only case that results in a double rotation when deleting).
  824. In case of an imbalance we have to always rotate p->cs link. (L or R case)
  825. If cs.bal_(dir_pc), then we have a double rotation (LR or RL) case.
  826. In such cases, first rotate cs->gcs link, and then p->cs link. gcs = grandchild on cs path
  827. Layout: In the R (or LR) case:
  828. p
  829. / \
  830. cs c
  831. / \
  832. a gcs
  833. / \
  834. x y
  835. - One of x or y must exist for it to be an LR case,
  836. since then cs.bal_(dir_pc) = cs.bal_r = 1
  837. Layout: In the L (or RL) case:
  838. p
  839. / \
  840. c cs
  841. / \
  842. gcs a
  843. / \
  844. x y
  845. - One of x or y must exist for it to be an RL case,
  846. since then cs.bal_(dir_pc) = cs.bal_l = 1
  847. (Note: if double rotation case, in the second rotation cs is actually gcs,
  848. since the the first rotation swaps their positions)
  849. */
  850. void AVL::fixImbalance(MPCTIO &tio, yield_t &yield, Duoram<Node>::Flat &A,
  851. Duoram<Node>::OblivIndex<RegXS, 1> oidx, RegXS oidx_oldptrs, RegXS ptr,
  852. RegXS nodeptrs, RegBS new_p_bal_l, RegBS new_p_bal_r, RegBS &bal_upd,
  853. RegBS c_prime, RegXS cs_ptr, RegBS imb, RegBS &F_ri,
  854. avl_del_return &ret_struct) {
  855. bool player0 = tio.player()==0;
  856. RegBS s0, s1;
  857. s1.set(tio.player()==1);
  858. RegXS old_cs_ptr;
  859. Node cs_node;
  860. #ifdef OPT_ON
  861. nbits_t width = ceil(log2(cur_max_index+1));
  862. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_cs(tio, yield, cs_ptr, width);
  863. cs_node = A[oidx_cs];
  864. old_cs_ptr = cs_node.pointers;
  865. #else
  866. cs_node = A[cs_ptr];
  867. #endif
  868. //dirpc = dir_pc = dpc = c_prime
  869. RegBS cs_bal_l, cs_bal_r, cs_bal_dpc, cs_bal_ndpc, F_dr, not_c_prime;
  870. RegXS gcs_ptr, cs_left, cs_right, cs_dpc, cs_ndpc, null;
  871. // child's sibling node's balances in dir_pc (dpc), and not_dir_pc (ndpc)
  872. cs_bal_l = getLeftBal(cs_node.pointers);
  873. cs_bal_r = getRightBal(cs_node.pointers);
  874. cs_left = getAVLLeftPtr(cs_node.pointers);
  875. cs_right = getAVLRightPtr(cs_node.pointers);
  876. run_coroutines(tio, [&tio, &cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r](yield_t &yield)
  877. { mpc_select(tio, yield, cs_bal_dpc, c_prime, cs_bal_l, cs_bal_r);},
  878. [&tio, &cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l](yield_t &yield)
  879. { mpc_select(tio, yield, cs_bal_ndpc, c_prime, cs_bal_r, cs_bal_l);},
  880. [&tio, &cs_dpc, c_prime, cs_left, cs_right](yield_t &yield)
  881. { mpc_select(tio, yield, cs_dpc, c_prime, cs_left, cs_right);},
  882. [&tio, &cs_ndpc, c_prime, cs_right, cs_left](yield_t &yield)
  883. { mpc_select(tio, yield, cs_ndpc, c_prime, cs_right, cs_left);});
  884. // We need to double rotate (LR or RL case) if cs_bal_dpc is 1
  885. run_coroutines(tio, [&tio, &F_dr, imb, cs_bal_dpc] (yield_t &yield)
  886. { mpc_and(tio, yield, F_dr, imb, cs_bal_dpc);},
  887. [&tio, &gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc](yield_t &yield)
  888. { mpc_select(tio, yield, gcs_ptr, cs_bal_dpc, cs_ndpc, cs_dpc, AVL_PTR_SIZE);});
  889. Node gcs_node;
  890. RegXS old_gcs_ptr;
  891. #ifdef OPT_ON
  892. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx_gcs(tio, yield, gcs_ptr, width);
  893. gcs_node = A[oidx_gcs];
  894. old_gcs_ptr = gcs_node.pointers;
  895. #else
  896. gcs_node = A[gcs_ptr];
  897. #endif
  898. RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
  899. RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
  900. not_c_prime = c_prime;
  901. if(player0) {
  902. not_c_prime^=1;
  903. }
  904. // First rotation: cs->gcs link
  905. rotate(tio, yield, nodeptrs, cs_ptr, cs_node.pointers, gcs_ptr,
  906. gcs_node.pointers, not_c_prime, c_prime, F_dr, s0);
  907. // If F_dr, we did first rotation. Then cs and gcs need to swap before the second rotate.
  908. RegXS new_cs_pointers, new_cs, new_ptr;
  909. run_coroutines(tio, [&tio, &new_cs_pointers, F_dr, cs_node, gcs_node](yield_t &yield)
  910. { mpc_select(tio, yield, new_cs_pointers, F_dr, cs_node.pointers, gcs_node.pointers);},
  911. [&tio, &new_cs, F_dr, cs_ptr, gcs_ptr](yield_t &yield)
  912. { mpc_select(tio, yield, new_cs, F_dr, cs_ptr, gcs_ptr, AVL_PTR_SIZE);},
  913. [&tio, &new_ptr, F_dr, cs_ptr, gcs_ptr](yield_t &yield)
  914. { mpc_select(tio, yield, new_ptr, F_dr, cs_ptr, gcs_ptr);});
  915. // Second rotation: p->cs link
  916. // Since we don't have access to gp node here we just send a null and s0
  917. // for gp_pointers and dir_gpp. Instead this pointer fix is handled by F_r
  918. // and ret_struct.ret_ptr.
  919. rotate(tio, yield, null, ptr, nodeptrs, new_cs,
  920. new_cs_pointers, s0, not_c_prime, imb, s1);
  921. // If imb (we do some rotation), then update F_r, and ret_ptr, to
  922. // fix the gp->p link (There are F_r clauses later, but they are mutually
  923. // exclusive events. They will never trigger together.)
  924. std::vector<coro_t> coroutines;
  925. coroutines.emplace_back([&tio, &F_ri, imb, s0, s1](yield_t &yield) {
  926. mpc_select(tio, yield, F_ri, imb, s0, s1);
  927. });
  928. coroutines.emplace_back([&tio, &ret_struct, imb, new_ptr](yield_t &yield) {
  929. mpc_select(tio, yield, ret_struct.ret_ptr, imb, ret_struct.ret_ptr, new_ptr);
  930. });
  931. // Write back new_cs_pointers correctly to (cs_node/gcs_node).pointers
  932. // and then balance the nodes
  933. coroutines.emplace_back([&tio, &cs_node, F_dr, new_cs_pointers](yield_t &yield) {
  934. mpc_select(tio, yield, cs_node.pointers, F_dr, new_cs_pointers, cs_node.pointers);
  935. });
  936. coroutines.emplace_back([&tio, &gcs_node, F_dr, new_cs_pointers](yield_t &yield) {
  937. mpc_select(tio, yield, gcs_node.pointers, F_dr, gcs_node.pointers, new_cs_pointers);
  938. });
  939. run_coroutines(tio, coroutines);
  940. coroutines.clear();
  941. /*
  942. Update balances based on imbalance and type of rotations that happen.
  943. In the case of an imbalance, updateBalance() sets bal_l and bal_r of p to 0.
  944. */
  945. RegBS IC1, IC2, IC3; // Imbalance Case 1, 2 or 3
  946. RegBS cs_zero_bal = cs_bal_dpc ^ cs_bal_ndpc;
  947. if(player0) {
  948. cs_zero_bal^=1;
  949. }
  950. run_coroutines(tio, [&tio, &IC1, imb, cs_bal_ndpc] (yield_t &yield) {
  951. // IC1 = Single rotation (L/R). L/R = dpc
  952. mpc_and(tio, yield, IC1, imb, cs_bal_ndpc);
  953. },
  954. // IC2 = Single rotation (L/R). L/R = dpc
  955. [&tio, &IC2, imb, cs_zero_bal](yield_t &yield) {
  956. mpc_and(tio, yield, IC2, imb, cs_zero_bal);
  957. },
  958. [&tio, &IC3, imb, cs_bal_dpc](yield_t &yield) {
  959. // IC3 = Double rotation (LR/RL). 1st rotate direction = ndpc, 2nd direction = dpc
  960. mpc_and(tio, yield, IC3, imb, cs_bal_dpc);
  961. });
  962. RegBS p_bal_ndpc, p_bal_dpc;
  963. RegBS ndpc_is_l, ndpc_is_r, dpc_is_l, dpc_is_r;
  964. RegBS IC2_n_cprime;
  965. // First flags to check dpc = L/R, and similarly ndpc = L/R
  966. coroutines.emplace_back([&tio, &ndpc_is_l, c_prime, imb] (yield_t &yield)
  967. { mpc_and(tio, yield, ndpc_is_l, imb, c_prime);});
  968. coroutines.emplace_back([&tio, &ndpc_is_r, imb, not_c_prime](yield_t &yield)
  969. { mpc_and(tio, yield, ndpc_is_r, imb, not_c_prime);});
  970. coroutines.emplace_back([&tio, &dpc_is_l, imb, not_c_prime](yield_t &yield)
  971. { mpc_and(tio, yield, dpc_is_l, imb, not_c_prime);});
  972. coroutines.emplace_back([&tio, &dpc_is_r, imb, c_prime](yield_t &yield)
  973. { mpc_and(tio, yield, dpc_is_r, imb, c_prime);});
  974. run_coroutines(tio, coroutines);
  975. coroutines.clear();
  976. mpc_and(tio, yield, IC2_n_cprime, IC2, c_prime);
  977. mpc_select(tio, yield, p_bal_ndpc, ndpc_is_r, new_p_bal_l, new_p_bal_r);
  978. mpc_select(tio, yield, p_bal_dpc, dpc_is_r, new_p_bal_l, new_p_bal_r);
  979. /*
  980. run_coroutines(tio, [&tio, &IC2, imb] (yield_t &yield)
  981. { mpc_and(tio, yield, IC2, imb, IC2);},
  982. [&tio, &cs_bal_dpc, imb, s0](yield_t &yield)
  983. { // IC1, IC2, IC3: CS.bal = 0 0
  984. mpc_select(tio, yield, cs_bal_dpc, imb, cs_bal_dpc, s0);},
  985. [&tio, &cs_bal_ndpc, c_prime, imb, s0](yield_t &yield) {
  986. mpc_select(tio, yield, cs_bal_ndpc, imb, cs_bal_ndpc, s0);});
  987. */
  988. mpc_select(tio, yield, cs_bal_ndpc, IC1, cs_bal_ndpc, s0);
  989. mpc_select(tio, yield, cs_bal_dpc, IC2, cs_bal_dpc, s1);
  990. mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
  991. mpc_select(tio, yield, p_bal_ndpc, IC2, p_bal_ndpc, s1);
  992. /* IC3 has 3 subcases:
  993. IC3_S1: gcs_bal_dpc = 0, gcs_bal_ndpc = 1
  994. IC3_S2: gcs_bal_dpc = 1, gc_bal_ndpc = 0
  995. IC3_S3: gcs_bal_dpc = 0, gcs_bal_ndpc = 0
  996. IC3_S1: p_dpc <- 1
  997. cs_dpc <- 0
  998. (gcs_bal stays same)
  999. IC3_S2: Swap cs_dpc and cs_ndpc (1 0 -> - 1).
  1000. cs_dpc <- 0, cs_ndpc <- 1
  1001. gcs_bal_dpc <- 0
  1002. IC3_S3: cs_dpc <- 0
  1003. gcs_bal stays same
  1004. */
  1005. RegBS IC3_S1, IC3_S2, IC3_S3, gcs_balanced, gcs_bal_dpc, gcs_bal_ndpc;
  1006. mpc_select(tio, yield, gcs_bal_dpc, c_prime, gcs_bal_l, gcs_bal_r);
  1007. mpc_select(tio, yield, gcs_bal_ndpc, c_prime, gcs_bal_r, gcs_bal_l);
  1008. mpc_and(tio, yield, IC3_S1, IC3, gcs_bal_ndpc);
  1009. mpc_and(tio, yield, IC3_S2, IC3, gcs_bal_dpc);
  1010. gcs_balanced = gcs_bal_dpc ^ gcs_bal_ndpc;
  1011. if(player0) {
  1012. gcs_balanced^=1;
  1013. }
  1014. mpc_and(tio, yield, IC3_S3, IC3, gcs_balanced);
  1015. //mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);
  1016. //mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);
  1017. RegBS imb_n_cprime;
  1018. mpc_and(tio, yield, imb_n_cprime, imb, c_prime);
  1019. mpc_select(tio, yield, p_bal_dpc, IC3_S1, p_bal_dpc, s1);
  1020. mpc_select(tio, yield, cs_bal_dpc, IC3, cs_bal_dpc, s0);
  1021. mpc_select(tio, yield, cs_bal_ndpc, IC3_S2, cs_bal_ndpc, s1);
  1022. mpc_select(tio, yield, gcs_bal_dpc, IC3_S2, gcs_bal_dpc, s0);
  1023. // Updating balance bits of p and cs.
  1024. // Write back updated balance bits
  1025. // Updating gcs_bal_l/r
  1026. run_coroutines(tio, [&tio, &gcs_bal_r, dpc_is_r, gcs_bal_dpc] (yield_t &yield)
  1027. { mpc_select(tio, yield, gcs_bal_r, dpc_is_r, gcs_bal_r, gcs_bal_dpc);},
  1028. [&tio, &gcs_bal_l, dpc_is_l, gcs_bal_dpc](yield_t &yield)
  1029. { mpc_select(tio, yield, gcs_bal_l, dpc_is_l, gcs_bal_l, gcs_bal_dpc);});
  1030. // Updating cs_bal_l/r (cs_bal_dpc effected by IC3, cs_bal_ndpc effected by IC1,2)
  1031. run_coroutines(tio, [&tio, &cs_bal_r, dpc_is_r, cs_bal_dpc] (yield_t &yield)
  1032. { mpc_select(tio, yield, cs_bal_r, dpc_is_r, cs_bal_r, cs_bal_dpc);},
  1033. [&tio, &cs_bal_l, dpc_is_l, cs_bal_dpc](yield_t &yield)
  1034. { mpc_select(tio, yield, cs_bal_l, dpc_is_l, cs_bal_l, cs_bal_dpc);});
  1035. run_coroutines(tio, [&tio, &cs_bal_r, ndpc_is_r, cs_bal_ndpc] (yield_t &yield)
  1036. { mpc_select(tio, yield, cs_bal_r, ndpc_is_r, cs_bal_r, cs_bal_ndpc);},
  1037. [&tio, &cs_bal_l, ndpc_is_l, cs_bal_ndpc](yield_t &yield)
  1038. { mpc_select(tio, yield, cs_bal_l, ndpc_is_l, cs_bal_l, cs_bal_ndpc);});
  1039. // Updating new_p_bal_l/r (p_bal_ndpc effected by IC2)
  1040. run_coroutines(tio, [&tio, &new_p_bal_r, ndpc_is_r, p_bal_ndpc] (yield_t &yield)
  1041. { mpc_select(tio, yield, new_p_bal_r, ndpc_is_r, new_p_bal_r, p_bal_ndpc);},
  1042. [&tio, &new_p_bal_l, ndpc_is_l, p_bal_ndpc](yield_t &yield)
  1043. { mpc_select(tio, yield, new_p_bal_l, ndpc_is_l, new_p_bal_l, p_bal_ndpc);});
  1044. // IC2: p.bal_ndpc = 1, cs.bal_dpc = 1
  1045. // (IC2 & not_c_prime)
  1046. /*
  1047. cs_bal_dpc^=IC2;
  1048. p_bal_ndpc^=IC2;
  1049. coroutines.emplace_back([&tio, &new_p_bal_l, IC2_ndpc_l, p_bal_ndpc](yield_t &yield)
  1050. { mpc_select(tio, yield, new_p_bal_l, IC2_ndpc_l, new_p_bal_l, p_bal_ndpc);});
  1051. coroutines.emplace_back([&tio, &new_p_bal_r, IC2_ndpc_r, p_bal_ndpc](yield_t &yield)
  1052. { mpc_select(tio, yield, new_p_bal_r, IC2_ndpc_r, new_p_bal_r, p_bal_ndpc);});
  1053. coroutines.emplace_back([&tio, &cs_bal_l, IC2_dpc_l, cs_bal_dpc](yield_t &yield)
  1054. { mpc_select(tio, yield, cs_bal_l, IC2_dpc_l, cs_bal_l, cs_bal_dpc);});
  1055. coroutines.emplace_back([&tio, &cs_bal_r, IC2_dpc_r, cs_bal_dpc](yield_t &yield)
  1056. { mpc_select(tio, yield, cs_bal_r, IC2_dpc_r, cs_bal_r, cs_bal_dpc);});
  1057. coroutines.emplace_back([&tio, &bal_upd, IC2, s0](yield_t &yield)
  1058. {
  1059. // In the IC2 case bal_upd = 0 (The rotation doesn't end up
  1060. // decreasing height of this subtree.
  1061. mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);});
  1062. run_coroutines(tio, coroutines);
  1063. coroutines.clear();
  1064. */
  1065. // In the IC2 case bal_upd = 0 (The rotation doesn't end up
  1066. // decreasing height of this subtree.
  1067. mpc_select(tio, yield, bal_upd, IC2, bal_upd, s0);
  1068. /*
  1069. // IC3:
  1070. // To set balance in this case we need to know if gcs.dpc child exists
  1071. // and similarly if gcs.ndpc child exitst.
  1072. // if(gcs.ndpc child exists): cs.bal_ndpc = 1
  1073. // if(gcs.dpc child exists): p.bal_dpc = 1
  1074. RegBS gcs_dpc_exists, gcs_ndpc_exists;
  1075. RegXS gcs_l = getAVLLeftPtr(gcs_node.pointers);
  1076. RegXS gcs_r = getAVLRightPtr(gcs_node.pointers);
  1077. RegBS gcs_bal_l = getLeftBal(gcs_node.pointers);
  1078. RegBS gcs_bal_r = getRightBal(gcs_node.pointers);
  1079. RegXS gcs_dpc, gcs_ndpc;
  1080. run_coroutines(tio, [&tio, &gcs_dpc, c_prime, gcs_l, gcs_r] (yield_t &yield)
  1081. { mpc_select(tio, yield, gcs_dpc, c_prime, gcs_l, gcs_r);},
  1082. [&tio, &gcs_ndpc, not_c_prime, gcs_l, gcs_r] (yield_t &yield)
  1083. { mpc_select(tio, yield, gcs_ndpc, not_c_prime, gcs_l, gcs_r);});
  1084. CDPF cdpf = tio.cdpf(yield);
  1085. run_coroutines(tio, [&tio, &gcs_dpc_exists, gcs_dpc, &cdpf](yield_t &yield)
  1086. { gcs_dpc_exists = cdpf.is_zero(tio, yield, gcs_dpc, tio.aes_ops());},
  1087. [&tio, &gcs_ndpc_exists, gcs_ndpc, &cdpf](yield_t &yield)
  1088. { gcs_ndpc_exists = cdpf.is_zero(tio, yield, gcs_ndpc, tio.aes_ops());});
  1089. cs_bal_ndpc^=IC3;
  1090. RegBS IC3_ndpc_l, IC3_ndpc_r, IC3_dpc_l, IC3_dpc_r;
  1091. run_coroutines(tio, [&tio, &IC3_ndpc_l, IC3, c_prime](yield_t &yield)
  1092. { mpc_and(tio, yield, IC3_ndpc_l, IC3, c_prime);},
  1093. [&tio, &IC3_ndpc_r, IC3, not_c_prime](yield_t &yield)
  1094. { mpc_and(tio, yield, IC3_ndpc_r, IC3, not_c_prime);},
  1095. [&tio, &IC3_dpc_l, IC3, not_c_prime](yield_t &yield)
  1096. { mpc_and(tio, yield, IC3_dpc_l, IC3, not_c_prime);},
  1097. [&tio, &IC3_dpc_r, IC3, c_prime](yield_t &yield)
  1098. { mpc_and(tio, yield, IC3_dpc_r, IC3, c_prime);});
  1099. RegBS f0, f1, f2, f3;
  1100. run_coroutines(tio, [&tio, &f0, IC3_dpc_l, gcs_dpc_exists] (yield_t &yield)
  1101. { mpc_and(tio, yield, f0, IC3_dpc_l, gcs_dpc_exists);},
  1102. [&tio, &f1, IC3_dpc_r, gcs_dpc_exists] (yield_t &yield)
  1103. { mpc_and(tio, yield, f1, IC3_dpc_r, gcs_dpc_exists);},
  1104. [&tio, &f2, IC3_ndpc_l, gcs_ndpc_exists] (yield_t &yield)
  1105. { mpc_and(tio, yield, f2, IC3_ndpc_l, gcs_ndpc_exists);},
  1106. [&tio, &f3, IC3_ndpc_r, gcs_ndpc_exists] (yield_t &yield)
  1107. { mpc_and(tio, yield, f3, IC3_ndpc_r, gcs_ndpc_exists);});
  1108. coroutines.emplace_back([&tio, &new_p_bal_l, f0, IC3](yield_t &yield) {
  1109. mpc_select(tio, yield, new_p_bal_l, f0, new_p_bal_l, IC3);});
  1110. coroutines.emplace_back([&tio, &new_p_bal_r, f1, IC3](yield_t &yield) {
  1111. mpc_select(tio, yield, new_p_bal_r, f1, new_p_bal_r, IC3);});
  1112. coroutines.emplace_back([&tio, &cs_bal_l, f2, IC3](yield_t &yield) {
  1113. mpc_select(tio, yield, cs_bal_l, f2, cs_bal_l, IC3);});
  1114. coroutines.emplace_back([&tio, &cs_bal_r, f3, IC3](yield_t &yield) {
  1115. mpc_select(tio, yield, cs_bal_r, f3, cs_bal_r, IC3);});
  1116. // In IC3 gcs.bal = 0 0
  1117. coroutines.emplace_back([&tio, &gcs_bal_l, IC3, s0](yield_t &yield) {
  1118. mpc_select(tio, yield, gcs_bal_l, IC3, gcs_bal_l, s0);});
  1119. coroutines.emplace_back([&tio, &gcs_bal_r, IC3, s0](yield_t &yield) {
  1120. mpc_select(tio, yield, gcs_bal_r, IC3, gcs_bal_r, s0);});
  1121. run_coroutines(tio, coroutines);
  1122. */
  1123. // Write back <cs_bal_dpc, cs_bal_ndpc> and <gcs_bal_l, gcs_bal_r>
  1124. setLeftBal(gcs_node.pointers, gcs_bal_l);
  1125. setRightBal(gcs_node.pointers, gcs_bal_r);
  1126. setLeftBal(cs_node.pointers, cs_bal_l);
  1127. setRightBal(cs_node.pointers, cs_bal_r);
  1128. setLeftBal(nodeptrs, new_p_bal_l);
  1129. setRightBal(nodeptrs, new_p_bal_r);
  1130. // Write back updated pointers correctly accounting for rotations
  1131. #ifdef OPT_ON
  1132. coroutines.emplace_back(
  1133. [&tio, &A, &oidx_cs, &cs_node, old_cs_ptr] (yield_t &yield) {
  1134. auto acont = A.context(yield);
  1135. (acont[oidx_cs].NODE_POINTERS)+= (cs_node.pointers - old_cs_ptr);});
  1136. coroutines.emplace_back(
  1137. [&tio, &A, &oidx_gcs, &gcs_node, old_gcs_ptr] (yield_t &yield) {
  1138. auto acont = A.context(yield);
  1139. (acont[oidx_gcs].NODE_POINTERS)+= (gcs_node.pointers - old_gcs_ptr);});
  1140. coroutines.emplace_back(
  1141. [&tio, &A, &oidx, nodeptrs, oidx_oldptrs] (yield_t &yield) {
  1142. auto acont = A.context(yield);
  1143. (acont[oidx].NODE_POINTERS)+=(nodeptrs - oidx_oldptrs);});
  1144. run_coroutines(tio, coroutines);
  1145. coroutines.clear();
  1146. #else
  1147. A[cs_ptr].NODE_POINTERS = cs_node.pointers;
  1148. A[gcs_ptr].NODE_POINTERS = gcs_node.pointers;
  1149. A[ptr].NODE_POINTERS = nodeptrs;
  1150. #endif
  1151. }
  1152. /* Update the return structure
  1153. F_dh = Delete Here flag,
  1154. F_sf = successor found (no more left children while trying to find successor)
  1155. F_rs is a subflag for F_r (update children pointers with ret ptr)
  1156. F_rs: Flag for updating the correct child pointer of this node
  1157. This happens if F_r is set in ret_struct. F_r indicates if we need
  1158. to update a child pointer at this level by skipping the current
  1159. child in the direction of traversal. We do this in two cases:
  1160. i) F_d & (!F_2) : If we delete here, and this node does not have
  1161. 2 children (;i.e., we are not in the finding successor case)
  1162. ii) F_sf: Found the successor (no more left children while
  1163. traversing to find successor)
  1164. In cases i and ii we skip the next node, and make the current node
  1165. point to the node after the next node.
  1166. The third case for F_r:
  1167. iii) We did rotation(s) at the lower level, changing the child in
  1168. that position. So we update it to the correct node in that
  1169. position now.
  1170. Whether skip happens or just update happens is handled by how
  1171. ret_struct.ret_ptr is set.
  1172. */
  1173. void AVL::updateRetStruct(MPCTIO &tio, yield_t &yield, RegXS ptr, RegBS F_2, RegBS F_c2,
  1174. RegBS F_c4, RegBS lf, RegBS F_ri, RegBS &found, RegBS &bal_upd,
  1175. avl_del_return &ret_struct) {
  1176. bool player0 = tio.player()==0;
  1177. RegBS s0, s1;
  1178. s1.set(tio.player()==1);
  1179. RegBS F_dh, F_sf, F_rs;
  1180. RegBS not_found = found;
  1181. if(player0) {
  1182. not_found^=1;
  1183. }
  1184. run_coroutines(tio, [&tio, &ret_struct, F_c2](yield_t &yield)
  1185. { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
  1186. [&tio, &F_dh, lf, not_found] (yield_t &yield)
  1187. { mpc_and(tio, yield, F_dh, lf, not_found);});
  1188. //[&tio, &ret_struct, F_dh, ptr] (yield_t &yield)
  1189. mpc_select(tio, yield, ret_struct.N_d, F_dh, ret_struct.N_d, ptr);
  1190. // F_sf = Successor found = F_c4 = Finding successor & no more left child
  1191. F_sf = F_c4;
  1192. if(player0)
  1193. F_2^=1;
  1194. // If we have to i) delete here, and it doesn't have two children
  1195. // we have to update child pointer in parent with the returned pointer
  1196. mpc_and(tio, yield, F_rs, F_dh, F_2);
  1197. // ii) if we found successor here
  1198. run_coroutines(tio, [&tio, &F_rs, F_sf](yield_t &yield)
  1199. { mpc_or(tio, yield, F_rs, F_rs, F_sf);},
  1200. [&tio, &ret_struct, F_sf, ptr] (yield_t &yield)
  1201. { mpc_select(tio, yield, ret_struct.N_s, F_sf, ret_struct.N_s, ptr);});
  1202. // F_rs and F_ri will never trigger together. So the line below
  1203. // set ret_ptr to the correct pointer to handle either case
  1204. // If neither F_rs nor F_ri, we set the ret_ptr to current ptr.
  1205. RegBS F_nr;
  1206. mpc_or(tio, yield, F_nr, F_rs, F_ri);
  1207. // F_nr = F_rs || F_ri
  1208. ret_struct.F_r = F_nr;
  1209. if(player0) {
  1210. F_nr^=1;
  1211. }
  1212. // F_nr = !(F_rs || F_ri)
  1213. run_coroutines(tio, [&tio, &ret_struct, F_nr, ptr](yield_t &yield)
  1214. { mpc_select(tio, yield, ret_struct.ret_ptr, F_nr, ret_struct.ret_ptr, ptr);},
  1215. [&tio, &bal_upd, F_rs, s1](yield_t &yield)
  1216. { // If F_rs, we skipped a node, so update bal_upd to 1
  1217. mpc_select(tio, yield, bal_upd, F_rs, bal_upd, s1);});
  1218. }
  1219. std::tuple<bool, RegBS> AVL::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  1220. Duoram<Node>::Flat &A, RegBS found, RegBS find_successor, int TTL,
  1221. avl_del_return &ret_struct) {
  1222. bool player0 = tio.player()==0;
  1223. if(TTL==0) {
  1224. //Reconstruct and return found
  1225. bool success = reconstruct_RegBS(tio, yield, found);
  1226. RegBS zero;
  1227. return {success, zero};
  1228. } else {
  1229. Node node;
  1230. RegXS oldptrs;
  1231. // This OblivIndex creation is not required if !OPT_ON, but for convenience we leave it in
  1232. // so that fixImbalance has an oidx to be supplied when in the !OPT_ON setting.
  1233. nbits_t width = ceil(log2(cur_max_index+1));
  1234. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, ptr, width);
  1235. #ifdef OPT_ON
  1236. node = A[oidx];
  1237. oldptrs = node.pointers;
  1238. #else
  1239. node = A[ptr];
  1240. #endif
  1241. // Compare key
  1242. CDPF cdpf = tio.cdpf(yield);
  1243. auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  1244. // c is the direction bit for next_ptr
  1245. // (c=0: go left or c=1: go right)
  1246. RegBS c = gt;
  1247. // lf = local found. We found the key to delete in this level.
  1248. RegBS lf = eq;
  1249. // Select the next ptr
  1250. RegXS left = getAVLLeftPtr(node.pointers);
  1251. RegXS right = getAVLRightPtr(node.pointers);
  1252. size_t &aes_ops = tio.aes_ops();
  1253. RegBS l0, r0;
  1254. // Check if left and right children are 0, and compute F_0, F_1, F_2
  1255. run_coroutines(tio, [&tio, &l0, left, &aes_ops, &cdpf](yield_t &yield)
  1256. { l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
  1257. [&tio, &r0, right, &aes_ops, &cdpf](yield_t &yield)
  1258. { r0 = cdpf.is_zero(tio, yield, right, aes_ops);});
  1259. RegBS F_0, F_1, F_2;
  1260. RegBS F_c1, F_c2, F_c3, F_c4;
  1261. RegXS next_ptr, cs_ptr;
  1262. RegBS c_prime;
  1263. // F_1 = l0 \xor r0
  1264. F_1 = l0 ^ r0;
  1265. // F_0 = l0 & r0
  1266. // Case 1: lf & F_1
  1267. run_coroutines(tio, [&tio, &F_0, l0, r0](yield_t &yield)
  1268. { mpc_and(tio, yield, F_0, l0, r0);},
  1269. [&tio, &F_c1, lf, F_1](yield_t &yield)
  1270. { mpc_and(tio, yield, F_c1, lf, F_1);});
  1271. // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
  1272. F_2 = F_0 ^ F_1;
  1273. if(player0)
  1274. F_2^=1;
  1275. // s1: shares of 1 bit, s0: shares of 0 bit
  1276. RegBS s1, s0;
  1277. s1.set(tio.player()==1);
  1278. // We set next ptr based on c, but we need to handle three
  1279. // edge cases where we do not pick next_ptr by just the comparison result
  1280. // Case 1: found the node here (lf), and node has only one child.
  1281. // Then we iterate down the only child.
  1282. // Set c_prime for Case 1
  1283. run_coroutines(tio, [&tio, &c_prime, F_c1, c, l0](yield_t &yield)
  1284. { mpc_select(tio, yield, c_prime, F_c1, c, l0);},
  1285. [&tio, &F_c2, lf, F_2](yield_t &yield)
  1286. { mpc_and(tio, yield, F_c2, lf, F_2);});
  1287. // Case 2: found the node here (lf) and node has both children (F_2)
  1288. // In find successor case, so we find inorder successor for node to be deleted
  1289. // (inorder successor = go right and then find leftmost child.)
  1290. // Case 3: finding successor (find_successor) and node has both children (F_2)
  1291. // Go left.
  1292. run_coroutines(tio, [&tio, &c_prime, F_c2, s1](yield_t &yield)
  1293. { mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);},
  1294. [&tio, &F_c3, find_successor, F_2](yield_t &yield)
  1295. { mpc_and(tio, yield, F_c3, find_successor, F_2);});
  1296. // Case 4: finding successor (find_successor) and node has no more left children (l0)
  1297. // This is the successor node then.
  1298. // Go right (since no more left)
  1299. run_coroutines(tio, [&tio, &c_prime, F_c3, s0](yield_t &yield)
  1300. { mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);},
  1301. [&tio, &F_c4, find_successor, l0](yield_t &yield)
  1302. { mpc_and(tio, yield, F_c4, find_successor, l0);});
  1303. RegBS found_prime, find_successor_prime;
  1304. mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
  1305. // Set next_ptr
  1306. mpc_select(tio, yield, next_ptr, c_prime, left, right, AVL_PTR_SIZE);
  1307. // cs_ptr: child's sibling pointer
  1308. run_coroutines(tio, [&tio, &cs_ptr, c_prime, right, left](yield_t &yield)
  1309. { mpc_select(tio, yield, cs_ptr, c_prime, right, left, AVL_PTR_SIZE);},
  1310. [&tio, &found_prime, found, lf](yield_t &yield)
  1311. { mpc_or(tio, yield, found_prime, found, lf);},
  1312. // If in Case 2, set find_successor. We are now finding successor
  1313. [&tio, &find_successor_prime, find_successor, F_c2](yield_t &yield)
  1314. { mpc_or(tio, yield, find_successor_prime, find_successor, F_c2);});
  1315. // If in Case 4. Successor found here already. Toggle find_successor off
  1316. find_successor_prime=find_successor_prime^F_c4;
  1317. TTL-=1;
  1318. auto [key_found, bal_upd] = del(tio, yield, next_ptr, del_key, A, found_prime, find_successor_prime, TTL, ret_struct);
  1319. // If we didn't find the key, we can end here.
  1320. if(!key_found) {
  1321. return {0, s0};
  1322. }
  1323. updateChildPointers(tio, yield, left, right, c_prime, ret_struct);
  1324. setAVLLeftPtr(node.pointers, left);
  1325. setAVLRightPtr(node.pointers, right);
  1326. // Delay storing pointers back until balance updates are done as well.
  1327. // Since we resolved the F_r flag returned with updateChildPointers(),
  1328. // we set it back to 0.
  1329. ret_struct.F_r = s0;
  1330. RegBS p_bal_l, p_bal_r;
  1331. p_bal_l = getLeftBal(node.pointers);
  1332. p_bal_r = getRightBal(node.pointers);
  1333. #ifdef DEBUG
  1334. size_t rec_key = reconstruct_RegAS(tio, yield, node.key);
  1335. bool rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
  1336. printf("current_key = %ld, bal_upd (before updateBalanceDel) = %d\n", rec_key, rec_bal_upd);
  1337. #endif
  1338. auto [new_p_bal_l, new_p_bal_r, new_bal_upd, imb] =
  1339. updateBalanceDel(tio, yield, p_bal_l, p_bal_r, bal_upd, c_prime);
  1340. bal_upd = new_bal_upd;
  1341. #ifdef DEBUG
  1342. bool rec_imb = reconstruct_RegBS(tio, yield, imb);
  1343. bool rec_new_bal_upd = reconstruct_RegBS(tio, yield, new_bal_upd);
  1344. printf("new_bal_upd (after updateBalanceDel) = %d, imb = %d\n", rec_new_bal_upd, rec_imb);
  1345. #endif
  1346. // F_ri: subflag for F_r. F_ri = returned flag set to 1 from imbalance fix.
  1347. RegBS F_ri;
  1348. fixImbalance(tio, yield, A, oidx, oldptrs, ptr, node.pointers, new_p_bal_l, new_p_bal_r, bal_upd,
  1349. c_prime, cs_ptr, imb, F_ri, ret_struct);
  1350. #ifdef DEBUG
  1351. rec_imb = reconstruct_RegBS(tio, yield, imb);
  1352. rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
  1353. printf("imb (after fixImbalance) = %d, bal_upd = %d\n", rec_imb, rec_bal_upd);
  1354. #endif
  1355. updateRetStruct(tio, yield, ptr, F_2, F_c2, F_c4, lf, F_ri, found, bal_upd, ret_struct);
  1356. #ifdef DEBUG
  1357. rec_bal_upd = reconstruct_RegBS(tio, yield, bal_upd);
  1358. printf("bal_upd (after updateRetStruct) = %d\n", rec_bal_upd);
  1359. #endif
  1360. return {key_found, bal_upd};
  1361. }
  1362. }
  1363. bool AVL::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
  1364. if(num_items==0)
  1365. return 0;
  1366. auto A = oram.flat(tio, yield, 0, cur_max_index+1);
  1367. if(num_items==1) {
  1368. //Delete root if root's key = del_key
  1369. Node zero;
  1370. nbits_t width = ceil(log2(cur_max_index+1));
  1371. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, root, width);
  1372. Node node = A[oidx];
  1373. // Compare key
  1374. CDPF cdpf = tio.cdpf(yield);
  1375. auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  1376. bool success = reconstruct_RegBS(tio, yield, eq);
  1377. if(success) {
  1378. empty_locations.emplace_back(root);
  1379. A[oidx] = zero;
  1380. num_items--;
  1381. return 1;
  1382. } else {
  1383. return 0;
  1384. }
  1385. } else {
  1386. int TTL = AVL_TTL(num_items);
  1387. // Flags for already found (found) item to delete and find successor (find_successor)
  1388. // if this deletion requires a successor swap
  1389. RegBS found, find_successor;
  1390. avl_del_return ret_struct;
  1391. auto [success, bal_upd] = del(tio, yield, root, del_key, A, found, find_successor, TTL, ret_struct);
  1392. printf ("Success = %d\n", success);
  1393. if(!success){
  1394. return 0;
  1395. }
  1396. else{
  1397. num_items--;
  1398. /*
  1399. printf("In delete's swap portion\n");
  1400. Node rec_del_node = A.reconstruct(A[ret_struct.N_d]);
  1401. Node rec_suc_node = A.reconstruct(A[ret_struct.N_s]);
  1402. printf("del_node key = %ld, suc_node key = %ld\n",
  1403. rec_del_node.key.ashare, rec_suc_node.key.ashare);
  1404. printf("flag_s = %d\n", ret_struct.F_ss.bshare);
  1405. */
  1406. Node del_node, suc_node;
  1407. nbits_t width = ceil(log2(cur_max_index+1));
  1408. typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_nd(tio, yield, ret_struct.N_d, width);
  1409. typename Duoram<Node>::template OblivIndex<RegXS,2> oidx_ns(tio, yield, ret_struct.N_s, width);
  1410. std::vector<coro_t> coroutines;
  1411. #ifdef OPT_ON
  1412. coroutines.emplace_back(
  1413. [&tio, &A, &oidx_nd, &del_node](yield_t &yield) {
  1414. auto acont = A.context(yield);
  1415. del_node = acont[oidx_nd];});
  1416. coroutines.emplace_back(
  1417. [&tio, &A, &oidx_ns, &suc_node](yield_t &yield) {
  1418. auto acont = A.context(yield);
  1419. suc_node = acont[oidx_ns];});
  1420. run_coroutines(tio, coroutines);
  1421. coroutines.clear();
  1422. #else
  1423. del_node = A[ret_struct.N_d];
  1424. suc_node = A[ret_struct.N_s];
  1425. #endif
  1426. RegAS zero_as; RegXS zero_xs;
  1427. // Update root if needed
  1428. mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
  1429. /*
  1430. bool rec_F_ss = reconstruct_RegBS(tio, yield, ret_struct.F_ss);
  1431. size_t rec_del_key = reconstruct_RegAS(tio, yield, del_node.key);
  1432. size_t rec_suc_key = reconstruct_RegAS(tio, yield, suc_node.key);
  1433. printf("rec_F_ss = %d, del_node.key = %lu, suc_nod.key = %lu\n",
  1434. rec_F_ss, rec_del_key, rec_suc_key);
  1435. */
  1436. RegXS old_del_value;
  1437. RegAS old_del_key;
  1438. #ifdef OPT_ON
  1439. old_del_value = del_node.value;
  1440. old_del_key = del_node.key;
  1441. #endif
  1442. RegXS empty_loc;
  1443. run_coroutines(tio, [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
  1444. { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
  1445. [&tio, &del_node, ret_struct, suc_node] (yield_t &yield)
  1446. { mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);},
  1447. [&tio, &empty_loc, ret_struct](yield_t &yield)
  1448. { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
  1449. #ifdef OPT_ON
  1450. coroutines.emplace_back(
  1451. [&tio, &A, &oidx_nd, &del_node, old_del_key] (yield_t &yield) {
  1452. auto acont = A.context(yield);
  1453. acont[oidx_nd].NODE_KEY+=(del_node.key - old_del_key);
  1454. });
  1455. coroutines.emplace_back(
  1456. [&tio, &A, &oidx_nd, &del_node, old_del_value] (yield_t &yield) {
  1457. auto acont = A.context(yield);
  1458. acont[oidx_nd].NODE_VALUE+=(del_node.value - old_del_value);
  1459. });
  1460. coroutines.emplace_back(
  1461. [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
  1462. auto acont = A.context(yield);
  1463. acont[oidx_ns].NODE_KEY+=(-suc_node.key);
  1464. });
  1465. coroutines.emplace_back(
  1466. [&tio, &A, &oidx_ns, &suc_node] (yield_t &yield) {
  1467. auto acont = A.context(yield);
  1468. acont[oidx_ns].NODE_VALUE+=(-suc_node.value);
  1469. });
  1470. run_coroutines(tio, coroutines);
  1471. coroutines.clear();
  1472. #else
  1473. A[ret_struct.N_d].NODE_KEY = del_node.key;
  1474. A[ret_struct.N_d].NODE_VALUE = del_node.value;
  1475. A[ret_struct.N_s].NODE_KEY = zero_as;
  1476. A[ret_struct.N_s].NODE_VALUE = zero_xs;
  1477. #endif
  1478. //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
  1479. empty_locations.emplace_back(empty_loc);
  1480. }
  1481. return 1;
  1482. }
  1483. }
  1484. void AVL::initialize(MPCTIO &tio, yield_t &yield, size_t depth) {
  1485. size_t init_size = (size_t(1)<<depth) - 1;
  1486. auto A = oram.flat(tio, yield);
  1487. std::vector<coro_t> coroutines;
  1488. for(size_t i=1; i<=depth; i++) {
  1489. size_t start = size_t(1)<<(i-1);
  1490. size_t gap = size_t(1)<<i;
  1491. size_t current = start;
  1492. for(size_t j=1; j<=(size_t(1)<<(depth-i)); j++) {
  1493. //printf("current = %ld ", current);
  1494. Node node;
  1495. node.key.set(current * tio.player());
  1496. if(i!=1) {
  1497. //Set left and right child pointers and balance bits
  1498. size_t ptr_gap = start/2;
  1499. RegXS lptr, rptr;
  1500. lptr.set(tio.player() * (current-(ptr_gap)));
  1501. rptr.set(tio.player() * (current+(ptr_gap)));
  1502. setAVLLeftPtr(node.pointers, lptr);
  1503. setAVLRightPtr(node.pointers, rptr);
  1504. }
  1505. coroutines.emplace_back(
  1506. [&tio, &A, current, node](yield_t &yield) {
  1507. auto acont = A.context(yield);
  1508. acont[current] = node;
  1509. });
  1510. current+=gap;
  1511. }
  1512. run_coroutines(tio, coroutines);
  1513. coroutines.clear();
  1514. }
  1515. // Set num_items to init_size after they have been initialized;
  1516. num_items = init_size;
  1517. cur_max_index = num_items;
  1518. // Set root correctly
  1519. root.set(tio.player() * size_t(1)<<(depth-1));
  1520. }
  1521. // Now we use the AVL class in various ways. This function is called by
  1522. // online.cpp.
  1523. void avl(MPCIO &mpcio,
  1524. const PRACOptions &opts, char **args)
  1525. {
  1526. size_t depth=4, n_inserts=0, n_deletes=0;
  1527. if (*args) {
  1528. depth = atoi(args[0]);
  1529. n_inserts = atoi(args[1]);
  1530. n_deletes = atoi(args[2]);
  1531. }
  1532. /* The ORAM will be initialized with 2^depth-1 items, but the 0 slot is reserved.
  1533. So we initialize (initial inserts) with 2^depth-2 items.
  1534. The ORAM size is set to 2^depth-1 + n_insert.
  1535. */
  1536. size_t init_size = (size_t(1)<<(depth));
  1537. size_t oram_size = init_size + n_inserts;
  1538. // +1 because init_size does not account for slot at 0.
  1539. MPCTIO tio(mpcio, 0, opts.num_threads);
  1540. run_coroutines(tio, [&tio, &mpcio, depth, oram_size, init_size, n_inserts, n_deletes] (yield_t &yield) {
  1541. //printf("ORAM init_size = %ld, oram_size = %ld\n", init_size, oram_size);
  1542. std::cout << "\n===== SETUP =====\n";
  1543. AVL tree(tio.player(), oram_size);
  1544. tree.initialize(tio, yield, depth);
  1545. //tree.pretty_print(tio, yield);
  1546. tio.sync_lamport();
  1547. Node node;
  1548. mpcio.dump_stats(std::cout);
  1549. std::cout << "\n===== INSERTS =====\n";
  1550. mpcio.reset_stats();
  1551. tio.reset_lamport();
  1552. for(size_t i = 1; i<=n_inserts; i++) {
  1553. newnode(node);
  1554. size_t ikey;
  1555. #ifdef RANDOMIZE
  1556. ikey = (1+(rand()%oram_size));
  1557. #else
  1558. ikey = (i+init_size);
  1559. #endif
  1560. printf("Insert key = %ld\n", ikey);
  1561. node.key.set(ikey * tio.player());
  1562. tree.insert(tio, yield, node);
  1563. #ifdef SANITY_TEST
  1564. tree.pretty_print(tio, yield);
  1565. tree.check_avl(tio, yield);
  1566. #endif
  1567. //tree.print_oram(tio, yield);
  1568. }
  1569. tio.sync_lamport();
  1570. mpcio.dump_stats(std::cout);
  1571. std::cout << "\n===== DELETES =====\n";
  1572. mpcio.reset_stats();
  1573. tio.reset_lamport();
  1574. for(size_t i = 1; i<=n_deletes; i++) {
  1575. RegAS del_key;
  1576. size_t dkey;
  1577. #ifdef RANDOMIZE
  1578. dkey = 1 + (rand()%init_size);
  1579. #else
  1580. dkey = i + init_size;
  1581. #endif
  1582. del_key.set(dkey * tio.player());
  1583. printf("Deletion key = %ld\n", dkey);
  1584. tree.del(tio, yield, del_key);
  1585. #ifdef SANITY_TEST
  1586. tree.pretty_print(tio, yield);
  1587. tree.check_avl(tio, yield);
  1588. #endif
  1589. }
  1590. });
  1591. }
  1592. void avl_tests(MPCIO &mpcio,
  1593. const PRACOptions &opts, char **args)
  1594. {
  1595. // Not taking arguments for tests
  1596. nbits_t depth=4;
  1597. size_t items = (size_t(1)<<depth)-1;
  1598. MPCTIO tio(mpcio, 0, opts.num_threads);
  1599. run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
  1600. size_t size = size_t(1)<<depth;
  1601. bool player0 = tio.player()==0;
  1602. AVL tree(tio.player(), size);
  1603. // (T1) : Test 1 : L rotation (root modified)
  1604. /*
  1605. Operation:
  1606. 5 7
  1607. \ / \
  1608. 7 ---> 5 9
  1609. \
  1610. 9
  1611. T1 checks:
  1612. - root is 7
  1613. - 5,7,9 in correct positions
  1614. - 5 and 9 have no children and 0 balances
  1615. */
  1616. {
  1617. bool success = 1;
  1618. int insert_array[] = {5, 7, 9};
  1619. size_t insert_array_size = 2;
  1620. Node node;
  1621. for(size_t i = 0; i<=insert_array_size; i++) {
  1622. newnode(node);
  1623. node.key.set(insert_array[i] * tio.player());
  1624. tree.insert(tio, yield, node);
  1625. tree.check_avl(tio, yield);
  1626. }
  1627. Duoram<Node>* oram = tree.get_oram();
  1628. RegXS root_xs = tree.get_root();
  1629. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1630. auto A = oram->flat(tio, yield);
  1631. auto R = A.reconstruct();
  1632. Node root_node, left_node, right_node;
  1633. size_t left_index, right_index;
  1634. root_node = R[root];
  1635. if((root_node.key).share()!=7) {
  1636. success = false;
  1637. }
  1638. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  1639. right_index = (getAVLRightPtr(root_node.pointers)).share();
  1640. left_node = R[left_index];
  1641. right_node = R[right_index];
  1642. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  1643. success = false;
  1644. }
  1645. //To check that left and right have no children and 0 balances
  1646. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  1647. if(sum!=0) {
  1648. success = false;
  1649. }
  1650. if(player0) {
  1651. if(success) {
  1652. print_green("T1 : SUCCESS\n");
  1653. } else {
  1654. print_red("T1 : FAIL\n");
  1655. }
  1656. }
  1657. /*
  1658. //MY_tests:
  1659. // OblivIndex read on the ORAM:
  1660. RegXS mptr;
  1661. mptr.set(tio.player() * 1);
  1662. nbits_t width_bits = 2;
  1663. typename Duoram<Node>::template OblivIndex<RegXS,1> oidx(tio, yield, mptr, width_bits);
  1664. size_t rec_key = reconstruct_RegAS(tio, yield, A[oidx].NODE_KEY);
  1665. printf("RI: Retrieved key (3) = %ld\n", rec_key);
  1666. */
  1667. A.init();
  1668. tree.init();
  1669. }
  1670. // (T2) : Test 2 : L rotation (root unmodified)
  1671. /*
  1672. Operation:
  1673. 5 5
  1674. / \ / \
  1675. 3 7 3 9
  1676. \ ---> / \
  1677. 9 7 7 12
  1678. \
  1679. 12
  1680. T2 checks:
  1681. - root is 5
  1682. - 3, 7, 9, 12 in expected positions
  1683. - Nodes 3, 7, 12 have 0 balance and no children
  1684. - 5's bal = 0 1
  1685. */
  1686. {
  1687. bool success = 1;
  1688. int insert_array[] = {5, 3, 7, 9, 12};
  1689. size_t insert_array_size = 4;
  1690. Node node;
  1691. for(size_t i = 0; i<=insert_array_size; i++) {
  1692. newnode(node);
  1693. node.key.set(insert_array[i] * tio.player());
  1694. tree.insert(tio, yield, node);
  1695. tree.check_avl(tio, yield);
  1696. }
  1697. Duoram<Node>* oram = tree.get_oram();
  1698. RegXS root_xs = tree.get_root();
  1699. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1700. auto A = oram->flat(tio, yield);
  1701. auto R = A.reconstruct();
  1702. Node root_node, n3, n7, n9, n12;
  1703. size_t n3_index, n7_index, n9_index, n12_index;
  1704. root_node = R[root];
  1705. if((root_node.key).share()!=5) {
  1706. success = false;
  1707. }
  1708. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  1709. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  1710. n3 = R[n3_index];
  1711. n9 = R[n9_index];
  1712. n7_index = getAVLLeftPtr(n9.pointers).share();
  1713. n12_index = getAVLRightPtr(n9.pointers).share();
  1714. n7 = R[n7_index];
  1715. n12 = R[n12_index];
  1716. // Node value checks
  1717. if(n3.key.share()!=3 || n9.key.share()!=9) {
  1718. success = false;
  1719. }
  1720. if(n7.key.share()!=7 || n12.key.share()!=12) {
  1721. success = false;
  1722. }
  1723. // Node children and balance checks
  1724. size_t zero = 0;
  1725. zero+=(n3.pointers.share());
  1726. zero+=(n7.pointers.share());
  1727. zero+=(n12.pointers.share());
  1728. zero+=(getLeftBal(root_node.pointers).share());
  1729. zero+=(getLeftBal(n9.pointers).share());
  1730. zero+=(getRightBal(n9.pointers).share());
  1731. if(zero!=0) {
  1732. success = false;
  1733. }
  1734. int one = (getRightBal(root_node.pointers).share());
  1735. if(one!=1) {
  1736. success = false;
  1737. }
  1738. if(player0) {
  1739. if(success) {
  1740. print_green("T2 : SUCCESS\n");
  1741. } else {
  1742. print_red("T2 : FAIL\n");
  1743. }
  1744. }
  1745. A.init();
  1746. tree.init();
  1747. }
  1748. // (T3) : Test 3 : R rotation (root modified)
  1749. /*
  1750. Operation:
  1751. 9 7
  1752. / / \
  1753. 7 ---> 5 9
  1754. /
  1755. 5
  1756. T3 checks:
  1757. - root is 7
  1758. - 5,7,9 in correct positions
  1759. - 5 and 9 have no children
  1760. */
  1761. {
  1762. bool success = 1;
  1763. int insert_array[] = {9, 7, 5};
  1764. size_t insert_array_size = 2;
  1765. Node node;
  1766. for(size_t i = 0; i<=insert_array_size; i++) {
  1767. newnode(node);
  1768. node.key.set(insert_array[i] * tio.player());
  1769. tree.insert(tio, yield, node);
  1770. tree.check_avl(tio, yield);
  1771. }
  1772. Duoram<Node>* oram = tree.get_oram();
  1773. RegXS root_xs = tree.get_root();
  1774. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1775. auto A = oram->flat(tio, yield);
  1776. auto R = A.reconstruct();
  1777. Node root_node, left_node, right_node;
  1778. size_t left_index, right_index;
  1779. root_node = R[root];
  1780. if((root_node.key).share()!=7) {
  1781. success = false;
  1782. }
  1783. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  1784. right_index = (getAVLRightPtr(root_node.pointers)).share();
  1785. left_node = R[left_index];
  1786. right_node = R[right_index];
  1787. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  1788. success = false;
  1789. }
  1790. //To check that left and right have no children and 0 balances
  1791. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  1792. if(sum!=0) {
  1793. success = false;
  1794. }
  1795. if(player0) {
  1796. if(success) {
  1797. print_green("T3 : SUCCESS\n");
  1798. } else{
  1799. print_red("T3 : FAIL\n");
  1800. }
  1801. }
  1802. A.init();
  1803. tree.init();
  1804. }
  1805. // (T4) : Test 4 : R rotation (root unmodified)
  1806. /*
  1807. Operation:
  1808. 9 9
  1809. / \ / \
  1810. 7 12 5 12
  1811. / ---> / \
  1812. 5 7 3 7
  1813. /
  1814. 3
  1815. T4 checks:
  1816. - root is 9
  1817. - 3,5,7,12 are in correct positions
  1818. - Nodes 3,7,12 have 0 balance
  1819. - Nodes 3,7,12 have no children
  1820. - 9's bal = 1 0
  1821. */
  1822. {
  1823. bool success = 1;
  1824. int insert_array[] = {9, 12, 7, 5, 3};
  1825. size_t insert_array_size = 4;
  1826. Node node;
  1827. for(size_t i = 0; i<=insert_array_size; i++) {
  1828. newnode(node);
  1829. node.key.set(insert_array[i] * tio.player());
  1830. tree.insert(tio, yield, node);
  1831. tree.check_avl(tio, yield);
  1832. }
  1833. Duoram<Node>* oram = tree.get_oram();
  1834. RegXS root_xs = tree.get_root();
  1835. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1836. auto A = oram->flat(tio, yield);
  1837. auto R = A.reconstruct();
  1838. Node root_node, n3, n7, n5, n12;
  1839. size_t n3_index, n7_index, n5_index, n12_index;
  1840. root_node = R[root];
  1841. if((root_node.key).share()!=9) {
  1842. success = false;
  1843. }
  1844. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  1845. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  1846. n5 = R[n5_index];
  1847. n12 = R[n12_index];
  1848. n3_index = getAVLLeftPtr(n5.pointers).share();
  1849. n7_index = getAVLRightPtr(n5.pointers).share();
  1850. n7 = R[n7_index];
  1851. n3 = R[n3_index];
  1852. // Node value checks
  1853. if(n12.key.share()!=12 || n5.key.share()!=5) {
  1854. success = false;
  1855. }
  1856. if(n3.key.share()!=3 || n7.key.share()!=7) {
  1857. success = false;
  1858. }
  1859. // Node balance checks
  1860. size_t zero = 0;
  1861. zero+=(n3.pointers.share());
  1862. zero+=(n7.pointers.share());
  1863. zero+=(n12.pointers.share());
  1864. zero+=(getRightBal(root_node.pointers).share());
  1865. zero+=(getLeftBal(n5.pointers).share());
  1866. zero+=(getRightBal(n5.pointers).share());
  1867. if(zero!=0) {
  1868. success = false;
  1869. }
  1870. int one = (getLeftBal(root_node.pointers).share());
  1871. if(one!=1) {
  1872. success = false;
  1873. }
  1874. if(player0) {
  1875. if(success) {
  1876. print_green("T4 : SUCCESS\n");
  1877. } else {
  1878. print_red("T4 : FAIL\n");
  1879. }
  1880. }
  1881. A.init();
  1882. tree.init();
  1883. }
  1884. // (T5) : Test 5 : LR rotation (root modified)
  1885. /*
  1886. Operation:
  1887. 9 9 7
  1888. / / / \
  1889. 5 --> 7 --> 5 9
  1890. \ /
  1891. 7 5
  1892. T5 checks:
  1893. - root is 7
  1894. - 9,5,7 are in correct positions
  1895. - Nodes 5,7,9 have 0 balance
  1896. - Nodes 5,9 have no children
  1897. */
  1898. {
  1899. bool success = 1;
  1900. int insert_array[] = {9, 5, 7};
  1901. size_t insert_array_size = 2;
  1902. Node node;
  1903. for(size_t i = 0; i<=insert_array_size; i++) {
  1904. newnode(node);
  1905. node.key.set(insert_array[i] * tio.player());
  1906. tree.insert(tio, yield, node);
  1907. tree.check_avl(tio, yield);
  1908. }
  1909. Duoram<Node>* oram = tree.get_oram();
  1910. RegXS root_xs = tree.get_root();
  1911. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1912. auto A = oram->flat(tio, yield);
  1913. auto R = A.reconstruct();
  1914. Node root_node, n9, n5;
  1915. size_t n9_index, n5_index;
  1916. root_node = R[root];
  1917. if((root_node.key).share()!=7) {
  1918. success = false;
  1919. }
  1920. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  1921. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  1922. n5 = R[n5_index];
  1923. n9 = R[n9_index];
  1924. // Node value checks
  1925. if(n9.key.share()!=9 || n5.key.share()!=5) {
  1926. success = false;
  1927. }
  1928. // Node balance checks
  1929. size_t zero = 0;
  1930. zero+=(n5.pointers.share());
  1931. zero+=(n9.pointers.share());
  1932. zero+=(getRightBal(root_node.pointers).share());
  1933. zero+=(getLeftBal(n5.pointers).share());
  1934. zero+=(getRightBal(n5.pointers).share());
  1935. zero+=(getLeftBal(n5.pointers).share());
  1936. zero+=(getRightBal(n9.pointers).share());
  1937. zero+=(getLeftBal(n9.pointers).share());
  1938. if(zero!=0) {
  1939. success = false;
  1940. }
  1941. if(player0) {
  1942. if(success) {
  1943. print_green("T5 : SUCCESS\n");
  1944. } else {
  1945. print_red("T5 : FAIL\n");
  1946. }
  1947. }
  1948. A.init();
  1949. tree.init();
  1950. }
  1951. // (T6) : Test 6 : LR rotation (root unmodified)
  1952. /*
  1953. Operation:
  1954. 9 9 9
  1955. / \ / \ / \
  1956. 7 12 7 12 5 12
  1957. / ---> / ---> / \
  1958. 3 5 3 7
  1959. \ /
  1960. 5 3
  1961. T6 checks:
  1962. - root is 9
  1963. - 3,5,7,12 are in correct positions
  1964. - Nodes 3,7,12 have 0 balance
  1965. - Nodes 3,7,12 have no children
  1966. - 9's bal = 1 0
  1967. */
  1968. {
  1969. bool success = 1;
  1970. int insert_array[] = {9, 12, 7, 3, 5};
  1971. size_t insert_array_size = 4;
  1972. Node node;
  1973. for(size_t i = 0; i<=insert_array_size; i++) {
  1974. newnode(node);
  1975. node.key.set(insert_array[i] * tio.player());
  1976. tree.insert(tio, yield, node);
  1977. tree.check_avl(tio, yield);
  1978. }
  1979. Duoram<Node>* oram = tree.get_oram();
  1980. RegXS root_xs = tree.get_root();
  1981. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  1982. auto A = oram->flat(tio, yield);
  1983. auto R = A.reconstruct();
  1984. Node root_node, n3, n7, n5, n12;
  1985. size_t n3_index, n7_index, n5_index, n12_index;
  1986. root_node = R[root];
  1987. if((root_node.key).share()!=9) {
  1988. success = false;
  1989. }
  1990. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  1991. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  1992. n5 = R[n5_index];
  1993. n12 = R[n12_index];
  1994. n3_index = getAVLLeftPtr(n5.pointers).share();
  1995. n7_index = getAVLRightPtr(n5.pointers).share();
  1996. n7 = R[n7_index];
  1997. n3 = R[n3_index];
  1998. // Node value checks
  1999. if(n5.key.share()!=5 || n12.key.share()!=12) {
  2000. success = false;
  2001. }
  2002. if(n3.key.share()!=3 || n7.key.share()!=7) {
  2003. success = false;
  2004. }
  2005. // Node balance checks
  2006. size_t zero = 0;
  2007. zero+=(n3.pointers.share());
  2008. zero+=(n7.pointers.share());
  2009. zero+=(n12.pointers.share());
  2010. zero+=(getRightBal(root_node.pointers).share());
  2011. zero+=(getLeftBal(n5.pointers).share());
  2012. zero+=(getRightBal(n5.pointers).share());
  2013. if(zero!=0) {
  2014. success = false;
  2015. }
  2016. int one = (getLeftBal(root_node.pointers).share());
  2017. if(one!=1) {
  2018. success = false;
  2019. }
  2020. if(player0) {
  2021. if(success) {
  2022. print_green("T6 : SUCCESS\n");
  2023. } else {
  2024. print_red("T6 : FAIL\n");
  2025. }
  2026. }
  2027. A.init();
  2028. tree.init();
  2029. }
  2030. // (T7) : Test 7 : RL rotation (root modified)
  2031. /*
  2032. Operation:
  2033. 5 5 7
  2034. \ \ / \
  2035. 9 --> 7 --> 5 9
  2036. / \
  2037. 7 9
  2038. T7 checks:
  2039. - root is 7
  2040. - 9,5,7 are in correct positions
  2041. - Nodes 5,7,9 have 0 balance
  2042. - Nodes 5,9 have no children
  2043. */
  2044. {
  2045. bool success = 1;
  2046. int insert_array[] = {5, 9, 7};
  2047. size_t insert_array_size = 2;
  2048. Node node;
  2049. for(size_t i = 0; i<=insert_array_size; i++) {
  2050. newnode(node);
  2051. node.key.set(insert_array[i] * tio.player());
  2052. tree.insert(tio, yield, node);
  2053. tree.check_avl(tio, yield);
  2054. }
  2055. Duoram<Node>* oram = tree.get_oram();
  2056. RegXS root_xs = tree.get_root();
  2057. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2058. auto A = oram->flat(tio, yield);
  2059. auto R = A.reconstruct();
  2060. Node root_node, n9, n5;
  2061. size_t n9_index, n5_index;
  2062. root_node = R[root];
  2063. if((root_node.key).share()!=7) {
  2064. success = false;
  2065. }
  2066. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2067. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2068. n5 = R[n5_index];
  2069. n9 = R[n9_index];
  2070. // Node value checks
  2071. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2072. success = false;
  2073. }
  2074. // Node balance checks
  2075. size_t zero = 0;
  2076. zero+=(n5.pointers.share());
  2077. zero+=(n9.pointers.share());
  2078. zero+=(getRightBal(root_node.pointers).share());
  2079. zero+=(getLeftBal(n5.pointers).share());
  2080. zero+=(getRightBal(n5.pointers).share());
  2081. zero+=(getLeftBal(n5.pointers).share());
  2082. zero+=(getRightBal(n9.pointers).share());
  2083. zero+=(getLeftBal(n9.pointers).share());
  2084. if(zero!=0) {
  2085. success = false;
  2086. }
  2087. if(player0) {
  2088. if(success) {
  2089. print_green("T7 : SUCCESS\n");
  2090. } else {
  2091. print_red("T7 : FAIL\n");
  2092. }
  2093. }
  2094. A.init();
  2095. tree.init();
  2096. }
  2097. // (T8) : Test 8 : RL rotation (root unmodified)
  2098. /*
  2099. Operation:
  2100. 5 5 5
  2101. / \ / \ / \
  2102. 3 12 3 12 3 9
  2103. / ---> / ---> / \
  2104. 7 9 7 12
  2105. \ /
  2106. 9 7
  2107. T8 checks:
  2108. - root is 5
  2109. - 3,9,7,12 are in correct positions
  2110. - Nodes 3,7,12 have 0 balance
  2111. - Nodes 3,7,12 have no children
  2112. - 5's bal = 0 1
  2113. */
  2114. {
  2115. bool success = 1;
  2116. int insert_array[] = {5, 3, 12, 7, 9};
  2117. size_t insert_array_size = 4;
  2118. Node node;
  2119. for(size_t i = 0; i<=insert_array_size; i++) {
  2120. newnode(node);
  2121. node.key.set(insert_array[i] * tio.player());
  2122. tree.insert(tio, yield, node);
  2123. tree.check_avl(tio, yield);
  2124. }
  2125. Duoram<Node>* oram = tree.get_oram();
  2126. RegXS root_xs = tree.get_root();
  2127. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2128. auto A = oram->flat(tio, yield);
  2129. auto R = A.reconstruct();
  2130. Node root_node, n3, n7, n9, n12;
  2131. size_t n3_index, n7_index, n9_index, n12_index;
  2132. root_node = R[root];
  2133. if((root_node.key).share()!=5) {
  2134. success = false;
  2135. }
  2136. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2137. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2138. n3 = R[n3_index];
  2139. n9 = R[n9_index];
  2140. n7_index = getAVLLeftPtr(n9.pointers).share();
  2141. n12_index = getAVLRightPtr(n9.pointers).share();
  2142. n7 = R[n7_index];
  2143. n12 = R[n12_index];
  2144. // Node value checks
  2145. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2146. success = false;
  2147. }
  2148. if(n7.key.share()!=7 || n12.key.share()!=12) {
  2149. success = false;
  2150. }
  2151. // Node balance checks
  2152. size_t zero = 0;
  2153. zero+=(n3.pointers.share());
  2154. zero+=(n7.pointers.share());
  2155. zero+=(n12.pointers.share());
  2156. zero+=(getLeftBal(root_node.pointers).share());
  2157. zero+=(getLeftBal(n9.pointers).share());
  2158. zero+=(getRightBal(n9.pointers).share());
  2159. if(zero!=0) {
  2160. success = false;
  2161. }
  2162. int one = (getRightBal(root_node.pointers).share());
  2163. if(one!=1) {
  2164. success = false;
  2165. }
  2166. if(player0) {
  2167. if(success) {
  2168. print_green("T8 : SUCCESS\n");
  2169. } else {
  2170. print_red("T8 : FAIL\n");
  2171. }
  2172. }
  2173. A.init();
  2174. tree.init();
  2175. }
  2176. // Deletion Tests:
  2177. // (T9) : Test 9 : L rotation (root modified)
  2178. /*
  2179. Operation:
  2180. 5 7
  2181. / \ Del 3 / \
  2182. 3 7 ------> 5 9
  2183. \
  2184. 9
  2185. T9 checks:
  2186. - root is 7
  2187. - 5,7,9 in correct positions
  2188. - 5 and 9 have no children and 0 balances
  2189. - 7 has 0 balances
  2190. */
  2191. {
  2192. bool success = 1;
  2193. int insert_array[] = {5, 3, 7, 9};
  2194. size_t insert_array_size = 3;
  2195. Node node;
  2196. for(size_t i = 0; i<=insert_array_size; i++) {
  2197. newnode(node);
  2198. node.key.set(insert_array[i] * tio.player());
  2199. tree.insert(tio, yield, node);
  2200. tree.check_avl(tio, yield);
  2201. }
  2202. RegAS del_key;
  2203. del_key.set(3 * tio.player());
  2204. tree.del(tio, yield, del_key);
  2205. tree.check_avl(tio, yield);
  2206. Duoram<Node>* oram = tree.get_oram();
  2207. RegXS root_xs = tree.get_root();
  2208. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2209. auto A = oram->flat(tio, yield);
  2210. auto R = A.reconstruct();
  2211. Node root_node, left_node, right_node;
  2212. size_t left_index, right_index;
  2213. root_node = R[root];
  2214. if((root_node.key).share()!=7) {
  2215. success = false;
  2216. }
  2217. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  2218. right_index = (getAVLRightPtr(root_node.pointers)).share();
  2219. left_node = R[left_index];
  2220. right_node = R[right_index];
  2221. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  2222. success = false;
  2223. }
  2224. //To check that left and right have no children and 0 balances
  2225. size_t sum = left_node.pointers.share() + right_node.pointers.share();
  2226. if(sum!=0) {
  2227. success = false;
  2228. }
  2229. if(player0) {
  2230. if(success) {
  2231. print_green("T9 : SUCCESS\n");
  2232. } else {
  2233. print_red("T9 : FAIL\n");
  2234. }
  2235. }
  2236. A.init();
  2237. tree.init();
  2238. }
  2239. // (T10) : Test 10 : L rotation (root unmodified)
  2240. /*
  2241. Operation:
  2242. 5 5
  2243. / \ / \
  2244. 3 7 Del 6 3 9
  2245. / / \ ------> / / \
  2246. 1 6 9 1 7 12
  2247. \
  2248. 12
  2249. T10 checks:
  2250. - root is 5
  2251. - 3, 7, 9, 12 in expected positions
  2252. - Nodes 3, 7, 12 have 0 balance and no children
  2253. - 5's bal = 0 1
  2254. */
  2255. {
  2256. bool success = 1;
  2257. int insert_array[] = {5, 3, 7, 9, 6, 1, 12};
  2258. size_t insert_array_size = 6;
  2259. Node node;
  2260. for(size_t i = 0; i<=insert_array_size; i++) {
  2261. newnode(node);
  2262. node.key.set(insert_array[i] * tio.player());
  2263. tree.insert(tio, yield, node);
  2264. tree.check_avl(tio, yield);
  2265. }
  2266. RegAS del_key;
  2267. del_key.set(6 * tio.player());
  2268. tree.del(tio, yield, del_key);
  2269. tree.check_avl(tio, yield);
  2270. Duoram<Node>* oram = tree.get_oram();
  2271. RegXS root_xs = tree.get_root();
  2272. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2273. auto A = oram->flat(tio, yield);
  2274. auto R = A.reconstruct();
  2275. Node root_node, n1, n3, n7, n9, n12;
  2276. size_t n1_index, n3_index, n7_index, n9_index, n12_index;
  2277. root_node = R[root];
  2278. if((root_node.key).share()!=5) {
  2279. success = false;
  2280. }
  2281. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2282. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2283. n3 = R[n3_index];
  2284. n9 = R[n9_index];
  2285. n7_index = getAVLLeftPtr(n9.pointers).share();
  2286. n12_index = getAVLRightPtr(n9.pointers).share();
  2287. n7 = R[n7_index];
  2288. n12 = R[n12_index];
  2289. n1_index = getAVLLeftPtr(n3.pointers).share();
  2290. n1 = R[n1_index];
  2291. // Node value checks
  2292. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2293. success = false;
  2294. }
  2295. if(n7.key.share()!=7 || n12.key.share()!=12 || n1.key.share()!=1) {
  2296. success = false;
  2297. }
  2298. // Node children and balance checks
  2299. size_t zero = 0;
  2300. zero+=(n1.pointers.share());
  2301. zero+=(n7.pointers.share());
  2302. zero+=(n12.pointers.share());
  2303. zero+=(getLeftBal(root_node.pointers).share());
  2304. zero+=(getRightBal(root_node.pointers).share());
  2305. zero+=(getLeftBal(n9.pointers).share());
  2306. zero+=(getRightBal(n9.pointers).share());
  2307. zero+=(getRightBal(n3.pointers).share());
  2308. if(zero!=0) {
  2309. success = false;
  2310. }
  2311. int one = (getLeftBal(n3.pointers).share());
  2312. if(one!=1) {
  2313. success = false;
  2314. }
  2315. if(player0) {
  2316. if(success) {
  2317. print_green("T10 : SUCCESS\n");
  2318. } else {
  2319. print_red("T10 : FAIL\n");
  2320. }
  2321. }
  2322. A.init();
  2323. tree.init();
  2324. }
  2325. // (T11) : Test 11 : R rotation (root modified)
  2326. /*
  2327. Operation:
  2328. 9 7
  2329. / \ Del 12 / \
  2330. 7 12 -------> 5 9
  2331. /
  2332. 5
  2333. T11 checks:
  2334. - root is 7
  2335. - 5,7,9 in correct positions and balances to 0
  2336. - 5 and 9 have no children
  2337. */
  2338. {
  2339. bool success = 1;
  2340. int insert_array[] = {9, 7, 12, 5};
  2341. size_t insert_array_size = 3;
  2342. Node node;
  2343. for(size_t i = 0; i<=insert_array_size; i++) {
  2344. newnode(node);
  2345. node.key.set(insert_array[i] * tio.player());
  2346. tree.insert(tio, yield, node);
  2347. tree.check_avl(tio, yield);
  2348. }
  2349. RegAS del_key;
  2350. del_key.set(12 * tio.player());
  2351. tree.del(tio, yield, del_key);
  2352. tree.check_avl(tio, yield);
  2353. Duoram<Node>* oram = tree.get_oram();
  2354. RegXS root_xs = tree.get_root();
  2355. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2356. auto A = oram->flat(tio, yield);
  2357. auto R = A.reconstruct();
  2358. Node root_node, left_node, right_node;
  2359. size_t left_index, right_index;
  2360. root_node = R[root];
  2361. if((root_node.key).share()!=7) {
  2362. success = false;
  2363. }
  2364. left_index = (getAVLLeftPtr(root_node.pointers)).share();
  2365. right_index = (getAVLRightPtr(root_node.pointers)).share();
  2366. left_node = R[left_index];
  2367. right_node = R[right_index];
  2368. if(left_node.key.share()!=5 || right_node.key.share()!=9) {
  2369. success = false;
  2370. }
  2371. //To check that left and right have no children and 0 balances
  2372. size_t zero = left_node.pointers.share() + right_node.pointers.share();
  2373. zero+=(getLeftBal(left_node.pointers).share());
  2374. zero+=(getRightBal(left_node.pointers).share());
  2375. zero+=(getLeftBal(right_node.pointers).share());
  2376. zero+=(getRightBal(right_node.pointers).share());
  2377. if(zero!=0) {
  2378. success = false;
  2379. }
  2380. if(player0) {
  2381. if(success) {
  2382. print_green("T11 : SUCCESS\n");
  2383. } else{
  2384. print_red("T11 : FAIL\n");
  2385. }
  2386. }
  2387. A.init();
  2388. tree.init();
  2389. }
  2390. // (T12) : Test 12 : R rotation (root unmodified)
  2391. /*
  2392. Operation:
  2393. 9 9
  2394. / \ / \
  2395. 7 12 Del 8 5 12
  2396. / \ \ ------> / \ \
  2397. 5 8 15 3 7 15
  2398. /
  2399. 3
  2400. T4 checks:
  2401. - root is 9
  2402. - 3,5,7,12,15 are in correct positions
  2403. - Nodes 3,7,15 have 0 balance
  2404. - Nodes 3,7,15 have no children
  2405. - 9,5 bal = 0 0
  2406. - 12 bal = 0 1
  2407. */
  2408. {
  2409. bool success = 1;
  2410. int insert_array[] = {9, 12, 7, 5, 8, 15, 3};
  2411. size_t insert_array_size = 6;
  2412. Node node;
  2413. for(size_t i = 0; i<=insert_array_size; i++) {
  2414. newnode(node);
  2415. node.key.set(insert_array[i] * tio.player());
  2416. tree.insert(tio, yield, node);
  2417. tree.check_avl(tio, yield);
  2418. }
  2419. RegAS del_key;
  2420. del_key.set(8 * tio.player());
  2421. tree.del(tio, yield, del_key);
  2422. tree.check_avl(tio, yield);
  2423. Duoram<Node>* oram = tree.get_oram();
  2424. RegXS root_xs = tree.get_root();
  2425. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2426. auto A = oram->flat(tio, yield);
  2427. auto R = A.reconstruct();
  2428. Node root_node, n3, n7, n5, n12, n15;
  2429. size_t n3_index, n7_index, n5_index, n12_index, n15_index;
  2430. root_node = R[root];
  2431. if((root_node.key).share()!=9) {
  2432. success = false;
  2433. }
  2434. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2435. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  2436. n5 = R[n5_index];
  2437. n12 = R[n12_index];
  2438. n3_index = getAVLLeftPtr(n5.pointers).share();
  2439. n7_index = getAVLRightPtr(n5.pointers).share();
  2440. n7 = R[n7_index];
  2441. n3 = R[n3_index];
  2442. n15_index = getAVLRightPtr(n12.pointers).share();
  2443. n15 = R[n15_index];
  2444. // Node value checks
  2445. if(n12.key.share()!=12 || n5.key.share()!=5) {
  2446. success = false;
  2447. }
  2448. if(n3.key.share()!=3 || n7.key.share()!=7 || n15.key.share()!=15) {
  2449. success = false;
  2450. }
  2451. // Node balance checks
  2452. size_t zero = 0;
  2453. zero+=(n3.pointers.share());
  2454. zero+=(n7.pointers.share());
  2455. zero+=(n15.pointers.share());
  2456. zero+=(getRightBal(root_node.pointers).share());
  2457. zero+=(getLeftBal(root_node.pointers).share());
  2458. zero+=(getLeftBal(n5.pointers).share());
  2459. zero+=(getRightBal(n5.pointers).share());
  2460. if(zero!=0) {
  2461. success = false;
  2462. }
  2463. int one = (getRightBal(n12.pointers).share());
  2464. if(one!=1) {
  2465. success = false;
  2466. }
  2467. if(player0) {
  2468. if(success) {
  2469. print_green("T12 : SUCCESS\n");
  2470. } else {
  2471. print_red("T12 : FAIL\n");
  2472. }
  2473. }
  2474. A.init();
  2475. tree.init();
  2476. }
  2477. // (T13) : Test 13 : LR rotation (root modified)
  2478. /*
  2479. Operation:
  2480. 9 9 7
  2481. / \ Del 12 / / \
  2482. 5 12 -------> 7 --> 5 9
  2483. \ /
  2484. 7 5
  2485. T5 checks:
  2486. - root is 7
  2487. - 9,5,7 are in correct positions
  2488. - Nodes 5,7,9 have 0 balance
  2489. - Nodes 5,9 have no children
  2490. */
  2491. {
  2492. bool success = 1;
  2493. int insert_array[] = {9, 5, 12, 7};
  2494. size_t insert_array_size = 3;
  2495. Node node;
  2496. for(size_t i = 0; i<=insert_array_size; i++) {
  2497. newnode(node);
  2498. node.key.set(insert_array[i] * tio.player());
  2499. tree.insert(tio, yield, node);
  2500. tree.check_avl(tio, yield);
  2501. }
  2502. RegAS del_key;
  2503. del_key.set(12 * tio.player());
  2504. tree.del(tio, yield, del_key);
  2505. tree.check_avl(tio, yield);
  2506. Duoram<Node>* oram = tree.get_oram();
  2507. RegXS root_xs = tree.get_root();
  2508. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2509. auto A = oram->flat(tio, yield);
  2510. auto R = A.reconstruct();
  2511. Node root_node, n9, n5;
  2512. size_t n9_index, n5_index;
  2513. root_node = R[root];
  2514. if((root_node.key).share()!=7) {
  2515. success = false;
  2516. }
  2517. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2518. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2519. n5 = R[n5_index];
  2520. n9 = R[n9_index];
  2521. // Node value checks
  2522. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2523. success = false;
  2524. }
  2525. // Node balance checks
  2526. size_t zero = 0;
  2527. zero+=(n5.pointers.share());
  2528. zero+=(n9.pointers.share());
  2529. zero+=(getRightBal(root_node.pointers).share());
  2530. zero+=(getLeftBal(n5.pointers).share());
  2531. zero+=(getRightBal(n5.pointers).share());
  2532. zero+=(getLeftBal(n5.pointers).share());
  2533. zero+=(getRightBal(n9.pointers).share());
  2534. zero+=(getLeftBal(n9.pointers).share());
  2535. if(zero!=0) {
  2536. success = false;
  2537. }
  2538. if(player0) {
  2539. if(success) {
  2540. print_green("T13 : SUCCESS\n");
  2541. } else {
  2542. print_red("T13 : FAIL\n");
  2543. }
  2544. }
  2545. A.init();
  2546. tree.init();
  2547. }
  2548. // (T14) : Test 14 : LR rotation (root unmodified)
  2549. /*
  2550. Operation:
  2551. 9 9 9
  2552. / \ / \ / \
  2553. 7 12 Del 8 7 12 5 12
  2554. / \ ------> / ---> / \
  2555. 3 8 5 3 7
  2556. \ /
  2557. 5 3
  2558. T6 checks:
  2559. - root is 9
  2560. - 3,5,7,12 are in correct positions
  2561. - Nodes 3,7,12 have 0 balance
  2562. - Nodes 3,7,12 have no children
  2563. - 9's bal = 1 0
  2564. */
  2565. {
  2566. bool success = 1;
  2567. int insert_array[] = {9, 12, 7, 3, 5};
  2568. size_t insert_array_size = 4;
  2569. Node node;
  2570. for(size_t i = 0; i<=insert_array_size; i++) {
  2571. newnode(node);
  2572. node.key.set(insert_array[i] * tio.player());
  2573. tree.insert(tio, yield, node);
  2574. tree.check_avl(tio, yield);
  2575. }
  2576. RegAS del_key;
  2577. del_key.set(8 * tio.player());
  2578. tree.del(tio, yield, del_key);
  2579. tree.check_avl(tio, yield);
  2580. Duoram<Node>* oram = tree.get_oram();
  2581. RegXS root_xs = tree.get_root();
  2582. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2583. auto A = oram->flat(tio, yield);
  2584. auto R = A.reconstruct();
  2585. Node root_node, n3, n7, n5, n12;
  2586. size_t n3_index, n7_index, n5_index, n12_index;
  2587. root_node = R[root];
  2588. if((root_node.key).share()!=9) {
  2589. success = false;
  2590. }
  2591. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2592. n12_index = (getAVLRightPtr(root_node.pointers)).share();
  2593. n5 = R[n5_index];
  2594. n12 = R[n12_index];
  2595. n3_index = getAVLLeftPtr(n5.pointers).share();
  2596. n7_index = getAVLRightPtr(n5.pointers).share();
  2597. n7 = R[n7_index];
  2598. n3 = R[n3_index];
  2599. // Node value checks
  2600. if(n5.key.share()!=5 || n12.key.share()!=12) {
  2601. success = false;
  2602. }
  2603. if(n3.key.share()!=3 || n7.key.share()!=7) {
  2604. success = false;
  2605. }
  2606. // Node balance checks
  2607. size_t zero = 0;
  2608. zero+=(n3.pointers.share());
  2609. zero+=(n7.pointers.share());
  2610. zero+=(n12.pointers.share());
  2611. zero+=(getRightBal(root_node.pointers).share());
  2612. zero+=(getLeftBal(n5.pointers).share());
  2613. zero+=(getRightBal(n5.pointers).share());
  2614. if(zero!=0) {
  2615. success = false;
  2616. }
  2617. int one = (getLeftBal(root_node.pointers).share());
  2618. if(one!=1) {
  2619. success = false;
  2620. }
  2621. if(player0) {
  2622. if(success) {
  2623. print_green("T14 : SUCCESS\n");
  2624. } else {
  2625. print_red("T14 : FAIL\n");
  2626. }
  2627. }
  2628. A.init();
  2629. tree.init();
  2630. }
  2631. // (T15) : Test 15 : RL rotation (root modified)
  2632. /*
  2633. Operation:
  2634. 5 5 7
  2635. / \ Del 3 \ / \
  2636. 3 9 -------> 7 --> 5 9
  2637. / \
  2638. 7 9
  2639. T15 checks:
  2640. - root is 7
  2641. - 9,5,7 are in correct positions
  2642. - Nodes 5,7,9 have 0 balance
  2643. - Nodes 5,9 have no children
  2644. */
  2645. {
  2646. bool success = 1;
  2647. int insert_array[] = {5, 9, 3, 7};
  2648. size_t insert_array_size = 3;
  2649. Node node;
  2650. for(size_t i = 0; i<=insert_array_size; i++) {
  2651. newnode(node);
  2652. node.key.set(insert_array[i] * tio.player());
  2653. tree.insert(tio, yield, node);
  2654. tree.check_avl(tio, yield);
  2655. }
  2656. RegAS del_key;
  2657. del_key.set(3 * tio.player());
  2658. tree.del(tio, yield, del_key);
  2659. tree.check_avl(tio, yield);
  2660. Duoram<Node>* oram = tree.get_oram();
  2661. RegXS root_xs = tree.get_root();
  2662. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2663. auto A = oram->flat(tio, yield);
  2664. auto R = A.reconstruct();
  2665. Node root_node, n9, n5;
  2666. size_t n9_index, n5_index;
  2667. root_node = R[root];
  2668. if((root_node.key).share()!=7) {
  2669. success = false;
  2670. }
  2671. n5_index = (getAVLLeftPtr(root_node.pointers)).share();
  2672. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2673. n5 = R[n5_index];
  2674. n9 = R[n9_index];
  2675. // Node value checks
  2676. if(n9.key.share()!=9 || n5.key.share()!=5) {
  2677. success = false;
  2678. }
  2679. // Node balance checks
  2680. size_t zero = 0;
  2681. zero+=(n5.pointers.share());
  2682. zero+=(n9.pointers.share());
  2683. zero+=(getRightBal(root_node.pointers).share());
  2684. zero+=(getLeftBal(n5.pointers).share());
  2685. zero+=(getRightBal(n5.pointers).share());
  2686. zero+=(getLeftBal(n5.pointers).share());
  2687. zero+=(getRightBal(n9.pointers).share());
  2688. zero+=(getLeftBal(n9.pointers).share());
  2689. if(zero!=0) {
  2690. success = false;
  2691. }
  2692. if(player0) {
  2693. if(success) {
  2694. print_green("T15 : SUCCESS\n");
  2695. } else {
  2696. print_red("T15 : FAIL\n");
  2697. }
  2698. }
  2699. A.init();
  2700. tree.init();
  2701. }
  2702. // (T16) : Test 16 : RL rotation (root unmodified)
  2703. /*
  2704. Operation:
  2705. 5 5 5
  2706. / \ / \ / \
  2707. 3 8 Del 7 3 8 3 9
  2708. / / \ ------> / \ ---> / / \
  2709. 1 7 12 1 9 1 8 12
  2710. / \
  2711. 9 12
  2712. T8 checks:
  2713. - root is 5
  2714. - 3,9,8,12 are in correct positions
  2715. - Nodes 1,5,8,9,12 have 0 balance
  2716. - Nodes 1,5,8,9,12 have no children
  2717. - Node 3 has 1 0 balance
  2718. */
  2719. {
  2720. bool success = 1;
  2721. int insert_array[] = {5, 3, 8, 7, 1, 12, 9};
  2722. size_t insert_array_size = 6;
  2723. Node node;
  2724. for(size_t i = 0; i<=insert_array_size; i++) {
  2725. newnode(node);
  2726. node.key.set(insert_array[i] * tio.player());
  2727. tree.insert(tio, yield, node);
  2728. tree.check_avl(tio, yield);
  2729. }
  2730. RegAS del_key;
  2731. del_key.set(7 * tio.player());
  2732. tree.del(tio, yield, del_key);
  2733. tree.check_avl(tio, yield);
  2734. Duoram<Node>* oram = tree.get_oram();
  2735. RegXS root_xs = tree.get_root();
  2736. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2737. auto A = oram->flat(tio, yield);
  2738. auto R = A.reconstruct();
  2739. Node root_node, n1, n3, n8, n9, n12;
  2740. size_t n1_index, n3_index, n8_index, n9_index, n12_index;
  2741. root_node = R[root];
  2742. if((root_node.key).share()!=5) {
  2743. success = false;
  2744. }
  2745. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2746. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2747. n3 = R[n3_index];
  2748. n9 = R[n9_index];
  2749. n1_index = getAVLLeftPtr(n3.pointers).share();
  2750. n8_index = getAVLLeftPtr(n9.pointers).share();
  2751. n12_index = getAVLRightPtr(n9.pointers).share();
  2752. n1 = R[n1_index];
  2753. n8 = R[n8_index];
  2754. n12 = R[n12_index];
  2755. // Node value checks
  2756. if(n1.key.share()!=1) {
  2757. success = false;
  2758. }
  2759. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2760. success = false;
  2761. }
  2762. if(n8.key.share()!=8 || n12.key.share()!=12) {
  2763. success = false;
  2764. }
  2765. // Node balance checks
  2766. size_t zero = 0;
  2767. zero+=(n1.pointers.share());
  2768. zero+=(getRightBal(n3.pointers).share());
  2769. zero+=(n8.pointers.share());
  2770. zero+=(n12.pointers.share());
  2771. zero+=(getLeftBal(root_node.pointers).share());
  2772. zero+=(getRightBal(root_node.pointers).share());
  2773. zero+=(getLeftBal(n9.pointers).share());
  2774. zero+=(getRightBal(n9.pointers).share());
  2775. if(zero!=0) {
  2776. success = false;
  2777. }
  2778. if(player0) {
  2779. if(success) {
  2780. print_green("T16 : SUCCESS\n");
  2781. } else {
  2782. print_red("T16 : FAIL\n");
  2783. }
  2784. }
  2785. A.init();
  2786. tree.init();
  2787. }
  2788. // (T17) : Test 17 : Double imbalance (root modified)
  2789. /*
  2790. Operation:
  2791. 9 9
  2792. / \ / \
  2793. 5 12 Del 10 5 15
  2794. / \ / \ --------> / \ / \
  2795. 3 7 10 15 3 7 12 20
  2796. / \ / \ \ / \ / \
  2797. 2 4 6 8 20 2 4 6 8
  2798. / /
  2799. 1 1
  2800. 5
  2801. / \
  2802. 3 9
  2803. -----> / \ / \
  2804. 2 4 7 15
  2805. / / \ / \
  2806. 1 6 8 10 20
  2807. T17 checks:
  2808. - root is 5
  2809. - all other nodes are in correct positions
  2810. - balances and children are correct
  2811. */
  2812. {
  2813. bool success = 1;
  2814. int insert_array[] = {9, 5, 12, 7, 3, 10, 15, 2, 4, 6, 8, 20, 1};
  2815. size_t insert_array_size = 12;
  2816. Node node;
  2817. for(size_t i = 0; i<=insert_array_size; i++) {
  2818. newnode(node);
  2819. node.key.set(insert_array[i] * tio.player());
  2820. tree.insert(tio, yield, node);
  2821. tree.check_avl(tio, yield);
  2822. }
  2823. RegAS del_key;
  2824. del_key.set(10 * tio.player());
  2825. tree.del(tio, yield, del_key);
  2826. tree.check_avl(tio, yield);
  2827. Duoram<Node>* oram = tree.get_oram();
  2828. RegXS root_xs = tree.get_root();
  2829. size_t root = reconstruct_RegXS(tio, yield, root_xs);
  2830. auto A = oram->flat(tio, yield);
  2831. auto R = A.reconstruct();
  2832. Node root_node, n3, n7, n9;
  2833. Node n1, n2, n4, n6, n8, n12, n15, n20;
  2834. size_t n3_index, n7_index, n9_index;
  2835. size_t n1_index, n2_index, n4_index, n6_index;
  2836. size_t n8_index, n12_index, n15_index, n20_index;
  2837. root_node = R[root];
  2838. if((root_node.key).share()!=5) {
  2839. success = false;
  2840. }
  2841. n3_index = (getAVLLeftPtr(root_node.pointers)).share();
  2842. n9_index = (getAVLRightPtr(root_node.pointers)).share();
  2843. n3 = R[n3_index];
  2844. n9 = R[n9_index];
  2845. n2_index = getAVLLeftPtr(n3.pointers).share();
  2846. n4_index = getAVLRightPtr(n3.pointers).share();
  2847. n7_index = getAVLLeftPtr(n9.pointers).share();
  2848. n15_index = getAVLRightPtr(n9.pointers).share();
  2849. n2 = R[n2_index];
  2850. n4 = R[n4_index];
  2851. n7 = R[n7_index];
  2852. n15 = R[n15_index];
  2853. n1_index = getAVLLeftPtr(n2.pointers).share();
  2854. n6_index = getAVLLeftPtr(n7.pointers).share();
  2855. n8_index = getAVLRightPtr(n7.pointers).share();
  2856. n12_index = getAVLLeftPtr(n15.pointers).share();
  2857. n20_index = getAVLRightPtr(n15.pointers).share();
  2858. n1 = R[n1_index];
  2859. n6 = R[n6_index];
  2860. n8 = R[n8_index];
  2861. n12 = R[n12_index];
  2862. n20 = R[n20_index];
  2863. // Node value checks
  2864. if(n3.key.share()!=3 || n9.key.share()!=9) {
  2865. success = false;
  2866. }
  2867. if(n2.key.share()!=2 || n4.key.share()!=4) {
  2868. success = false;
  2869. }
  2870. if(n7.key.share()!=7 || n15.key.share()!=15) {
  2871. success = false;
  2872. }
  2873. if(n1.key.share()!=1 || n6.key.share()!=6 || n8.key.share()!=8) {
  2874. success = false;
  2875. }
  2876. if(n12.key.share()!=12 || n20.key.share()!=20) {
  2877. success = false;
  2878. }
  2879. // Node balance checks
  2880. size_t zero = 0;
  2881. zero+=(n1.pointers.share());
  2882. zero+=(n4.pointers.share());
  2883. zero+=(n6.pointers.share());
  2884. zero+=(n8.pointers.share());
  2885. zero+=(n12.pointers.share());
  2886. zero+=(n20.pointers.share());
  2887. zero+=(getLeftBal(n7.pointers).share());
  2888. zero+=(getRightBal(n7.pointers).share());
  2889. zero+=(getLeftBal(n9.pointers).share());
  2890. zero+=(getRightBal(n9.pointers).share());
  2891. zero+=(getLeftBal(n15.pointers).share());
  2892. zero+=(getRightBal(n15.pointers).share());
  2893. zero+=(getRightBal(n3.pointers).share());
  2894. zero+=(getLeftBal(root_node.pointers).share());
  2895. zero+=(getRightBal(root_node.pointers).share());
  2896. if(zero!=0) {
  2897. success = false;
  2898. }
  2899. int one = (getLeftBal(n3.pointers).share());
  2900. if(one!=1) {
  2901. success = false;
  2902. }
  2903. if(player0) {
  2904. if(success) {
  2905. print_green("T17 : SUCCESS\n");
  2906. } else {
  2907. print_red("T17 : FAIL\n");
  2908. }
  2909. }
  2910. A.init();
  2911. tree.init();
  2912. }
  2913. });
  2914. }