bst.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. #include <functional>
  2. #include "bst.hpp"
  3. #ifdef BST_DEBUG
  4. void BST::print_oram(MPCTIO &tio, yield_t &yield) {
  5. auto A = oram.flat(tio, yield);
  6. auto R = A.reconstruct();
  7. for(size_t i=0;i<R.size();++i) {
  8. printf("\n%04lx ", i);
  9. R[i].dump();
  10. }
  11. printf("\n");
  12. }
  13. #endif
  14. // Helper functions to reconstruct shared RegBS, RegAS or RegXS
  15. bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
  16. RegBS reconstructed_flag;
  17. if (tio.player() < 2) {
  18. RegBS peer_flag;
  19. tio.queue_peer(&flag, 1);
  20. tio.queue_server(&flag, 1);
  21. yield();
  22. tio.recv_peer(&peer_flag, 1);
  23. reconstructed_flag = flag;
  24. reconstructed_flag ^= peer_flag;
  25. } else {
  26. RegBS p0_flag, p1_flag;
  27. yield();
  28. tio.recv_p0(&p0_flag, 1);
  29. tio.recv_p1(&p1_flag, 1);
  30. reconstructed_flag = p0_flag;
  31. reconstructed_flag ^= p1_flag;
  32. }
  33. return reconstructed_flag.bshare;
  34. }
  35. /* A function to assign a new random 8-bit key to a node, and resets its
  36. pointers to zeroes. The node is assigned a new random 64-bit value.
  37. */
  38. static void randomize_node(Node &a) {
  39. a.key.randomize(8);
  40. a.pointers.set(0);
  41. a.value.randomize();
  42. }
  43. /*
  44. A function to perform key comparsions for BST traversal.
  45. Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
  46. Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go
  47. left), and keeps the gt flag as is (flag to go right) during traversal.
  48. */
  49. std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1,
  50. RegAS k2) {
  51. CDPF cdpf = tio.cdpf(yield);
  52. auto [lt, eq, gt] = cdpf.compare(tio, yield, k2 - k1, tio.aes_ops());
  53. RegBS lteq = lt^eq;
  54. return {lteq, gt};
  55. }
  56. // Assuming pointer of 64 bits is split as:
  57. // - 32 bits Left ptr (L)
  58. // - 32 bits Right ptr (R)
  59. // The pointers are stored as: L | R
  60. inline RegXS extractLeftPtr(RegXS pointer){
  61. return ((pointer&(0xFFFFFFFF00000000))>>32);
  62. }
  63. inline RegXS extractRightPtr(RegXS pointer){
  64. return (pointer&(0x00000000FFFFFFFF));
  65. }
  66. inline void setLeftPtr(RegXS &pointer, RegXS new_ptr){
  67. pointer&=(0x00000000FFFFFFFF);
  68. pointer+=(new_ptr<<32);
  69. }
  70. inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
  71. pointer&=(0xFFFFFFFF00000000);
  72. pointer+=(new_ptr);
  73. }
  74. // Pretty-print a reconstructed BST, rooted at node. is_left_child and
  75. // is_right_child indicate whether node is a left or right child of its
  76. // parent. They cannot both be true, but the root of the tree has both
  77. // of them false.
  78. void BST::pretty_print(const std::vector<Node> &R, value_t node,
  79. const std::string &prefix = "", bool is_left_child = false,
  80. bool is_right_child = false)
  81. {
  82. if (node == 0) {
  83. // NULL pointer
  84. if (is_left_child) {
  85. printf("%s\xE2\x95\xA7\n", prefix.c_str()); // ╧
  86. } else if (is_right_child) {
  87. printf("%s\xE2\x95\xA4\n", prefix.c_str()); // ╤
  88. } else {
  89. printf("%s\xE2\x95\xA2\n", prefix.c_str()); // ╢
  90. }
  91. return;
  92. }
  93. const Node &n = R[node];
  94. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  95. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  96. std::string rightprefix(prefix), leftprefix(prefix),
  97. nodeprefix(prefix);
  98. if (is_left_child) {
  99. rightprefix.append("\xE2\x94\x82"); // │
  100. leftprefix.append(" ");
  101. nodeprefix.append("\xE2\x94\x94"); // └
  102. } else if (is_right_child) {
  103. rightprefix.append(" ");
  104. leftprefix.append("\xE2\x94\x82"); // │
  105. nodeprefix.append("\xE2\x94\x8C"); // ┌
  106. } else {
  107. rightprefix.append(" ");
  108. leftprefix.append(" ");
  109. nodeprefix.append("\xE2\x94\x80"); // ─
  110. }
  111. pretty_print(R, right_ptr, rightprefix, false, true);
  112. printf("%s\xE2\x94\xA4", nodeprefix.c_str()); // ┤
  113. n.dump();
  114. printf("\n");
  115. pretty_print(R, left_ptr, leftprefix, true, false);
  116. }
  117. void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
  118. RegXS peer_root;
  119. RegXS reconstructed_root = root;
  120. if (tio.player() == 1) {
  121. tio.queue_peer(&root, sizeof(root));
  122. yield();
  123. } else {
  124. RegXS peer_root;
  125. yield();
  126. tio.recv_peer(&peer_root, sizeof(peer_root));
  127. reconstructed_root += peer_root;
  128. }
  129. auto A = oram.flat(tio, yield);
  130. auto R = A.reconstruct();
  131. if(tio.player()==0) {
  132. pretty_print(R, reconstructed_root.xshare);
  133. }
  134. }
  135. // Check the BST invariant of the tree (that all keys to the left are
  136. // less than or equal to this key, all keys to the right are strictly
  137. // greater, and this is true recursively). Returns a
  138. // tuple<bool,address_t>, where the bool says whether the BST invariant
  139. // holds, and the address_t is the height of the tree (which will be
  140. // useful later when we check AVL trees).
  141. std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
  142. value_t node, value_t min_key = 0, value_t max_key = ~0)
  143. {
  144. //printf("node = %ld\n", node);
  145. if (node == 0) {
  146. return { true, 0 };
  147. }
  148. const Node &n = R[node];
  149. value_t key = n.key.ashare;
  150. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  151. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  152. auto [leftok, leftheight ] = check_bst(R, left_ptr, min_key, key);
  153. auto [rightok, rightheight ] = check_bst(R, right_ptr, key+1, max_key);
  154. address_t height = leftheight;
  155. if (rightheight > height) {
  156. height = rightheight;
  157. }
  158. height += 1;
  159. //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
  160. return { leftok && rightok && key >= min_key && key <= max_key,
  161. height };
  162. }
  163. void BST::check_bst(MPCTIO &tio, yield_t &yield) {
  164. auto A = oram.flat(tio, yield);
  165. auto R = A.reconstruct();
  166. RegXS rec_root = this->root;
  167. if (tio.player() == 1) {
  168. tio.queue_peer(&(this->root), sizeof(this->root));
  169. } else {
  170. RegXS peer_root;
  171. tio.recv_peer(&peer_root, sizeof(peer_root));
  172. rec_root+= peer_root;
  173. }
  174. if (tio.player() == 0) {
  175. auto [ ok, height ] = check_bst(R, rec_root.xshare);
  176. printf("BST structure %s\nBST height = %u\n",
  177. ok ? "ok" : "NOT OK", height);
  178. }
  179. }
  180. /*
  181. The recursive insert() call, invoked by the wrapper insert() function.
  182. Takes as input the pointer to the current node in tree traversal (ptr),
  183. the new node to be inserted (new_node), the underlying Duoram as a
  184. flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
  185. tracks if the operation is dummy/real.
  186. */
  187. std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  188. const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
  189. if(TTL==0) {
  190. RegBS zero;
  191. return {ptr, zero};
  192. }
  193. RegBS isNotDummy = isDummy ^ (!tio.player());
  194. Node cnode = A[ptr];
  195. // Compare key
  196. auto [lteq, gt] = compare_keys(tio, yield, cnode.key, new_node.key);
  197. // Depending on [lteq, gt] select the next ptr/index as
  198. // upper 32 bits of cnode.pointers if lteq
  199. // lower 32 bits of cnode.pointers if gt
  200. RegXS left = extractLeftPtr(cnode.pointers);
  201. RegXS right = extractRightPtr(cnode.pointers);
  202. RegXS next_ptr;
  203. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  204. CDPF dpf = tio.cdpf(yield);
  205. size_t &aes_ops = tio.aes_ops();
  206. // F_z: Check if this is last node on path
  207. RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
  208. RegBS F_i;
  209. // F_i: If this was last node on path (F_z) ^ isNotDummy:
  210. // insert new_node here.
  211. mpc_and(tio, yield, F_i, (isNotDummy), F_z);
  212. isDummy^=F_i;
  213. auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
  214. RegXS ret_ptr;
  215. RegBS ret_direction;
  216. // If we insert here (F_i), return the ptr to this node as wptr
  217. // and update direction to the direction taken by compare_keys
  218. mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
  219. //ret_direction = direction + F_i (direction - gt)
  220. mpc_and(tio, yield, ret_direction, F_i, direction^gt);
  221. ret_direction^=direction;
  222. return {ret_ptr, ret_direction};
  223. }
  224. /*
  225. The wrapper insert() operation invoked by the main insert call
  226. BST::insert(tio, yield, Node& new_node);
  227. Takes as input the new node (node), the underlying Duoram as a flat (A).
  228. */
  229. void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
  230. bool player0 = tio.player()==0;
  231. // If there are no items in tree. Make this new item the root.
  232. if (num_items==0) {
  233. Node zero;
  234. A[1] = node;
  235. // Set root to a secret sharing of the constant value 1
  236. (root).set(1*tio.player());
  237. num_items++;
  238. //printf("num_items == %ld!\n", num_items);
  239. return;
  240. } else {
  241. // Insert node into next free slot in the ORAM
  242. int new_id;
  243. RegXS insert_address;
  244. int TTL = num_items++;
  245. bool insertAtEmptyLocation = (empty_locations.size() > 0);
  246. if(insertAtEmptyLocation) {
  247. insert_address = empty_locations.back();
  248. empty_locations.pop_back();
  249. A[insert_address] = node;
  250. } else {
  251. new_id = 1 + num_items;
  252. A[new_id] = node;
  253. insert_address.set(new_id * tio.player());
  254. }
  255. RegBS isDummy;
  256. //Do a recursive insert
  257. auto [wptr, direction] = insert(tio, yield, root, node, A, TTL, isDummy);
  258. //Complete the insertion by reading wptr and updating its pointers
  259. RegXS pointers = A[wptr].NODE_POINTERS;
  260. RegXS left_ptr = extractLeftPtr(pointers);
  261. RegXS right_ptr = extractRightPtr(pointers);
  262. RegXS new_right_ptr, new_left_ptr;
  263. mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);
  264. if (player0) {
  265. direction^=1;
  266. }
  267. mpc_select(tio, yield, new_left_ptr, direction, left_ptr, insert_address);
  268. setLeftPtr(pointers, new_left_ptr);
  269. setRightPtr(pointers, new_right_ptr);
  270. A[wptr].NODE_POINTERS = pointers;
  271. //printf("num_items == %ld!\n", num_items);
  272. }
  273. }
  274. /*
  275. Insert a new node into the BST.
  276. Takes as input the new node (node).
  277. */
  278. void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
  279. auto A = oram.flat(tio, yield);
  280. insert(tio, yield, node, A);
  281. /*
  282. // To visualize database and tree after each insert:
  283. auto R = A.reconstruct();
  284. if (tio.player() == 0) {
  285. for(size_t i=0;i<R.size();++i) {
  286. printf("\n%04lx ", i);
  287. R[i].dump();
  288. }
  289. printf("\n");
  290. }
  291. pretty_print(R, 1);
  292. */
  293. }
  294. bool BST::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>::Flat &A,
  295. int TTL, RegBS isDummy, Node *ret_node) {
  296. if(TTL==0) {
  297. // Reconstruct and return isDummy
  298. // If we found the key, then isDummy will be true
  299. bool found = mpc_reconstruct(tio, yield, isDummy, 1);
  300. return found;
  301. }
  302. RegBS isNotDummy = isDummy ^ (tio.player());
  303. Node cnode = A[ptr];
  304. // Compare key
  305. CDPF cdpf = tio.cdpf(yield);
  306. auto [lt, eq, gt] = cdpf.compare(tio, yield, key - cnode.key, tio.aes_ops());
  307. // Depending on [lteq, gt] select the next ptr/index as
  308. // upper 32 bits of cnode.pointers if lteq
  309. // lower 32 bits of cnode.pointers if gt
  310. RegXS left = extractLeftPtr(cnode.pointers);
  311. RegXS right = extractRightPtr(cnode.pointers);
  312. RegXS next_ptr;
  313. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  314. RegBS F_found;
  315. // If we haven't found the key yet, and the lookup matches the current node key,
  316. // then we found the node to return
  317. mpc_and(tio, yield, F_found, isNotDummy, eq);
  318. mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key);
  319. mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value);
  320. isDummy^=F_found;
  321. bool found = lookup(tio, yield, next_ptr, key, A, TTL-1, isDummy, ret_node);
  322. return found;
  323. }
  324. bool BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
  325. auto A = oram.flat(tio, yield);
  326. RegBS isDummy;
  327. bool found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
  328. /*
  329. // To visualize database and tree after each lookup:
  330. auto R = A.reconstruct();
  331. if (tio.player() == 0) {
  332. for(size_t i=0;i<R.size();++i) {
  333. printf("\n%04lx ", i);
  334. R[i].dump();
  335. }
  336. printf("\n");
  337. }
  338. pretty_print(R, 1);
  339. */
  340. return found;
  341. }
  342. bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  343. Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL,
  344. del_return &ret_struct) {
  345. bool player0 = tio.player()==0;
  346. //printf("TTL = %d\n", TTL);
  347. if(TTL==0) {
  348. //Reconstruct and return af
  349. bool success = reconstruct_RegBS(tio, yield, af);
  350. //printf("Reconstructed flag = %d\n", success);
  351. if(player0)
  352. ret_struct.F_r^=1;
  353. return success;
  354. } else {
  355. Node node = A[ptr];
  356. // Compare key
  357. CDPF cdpf = tio.cdpf(yield);
  358. auto [lt, eq, gt] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  359. /*
  360. // Reconstruct and Debug Block 0
  361. bool lt_rec, eq_rec, gt_rec;
  362. lt_rec = mpc_reconstruct(tio, yield, lt, 1);
  363. eq_rec = mpc_reconstruct(tio, yield, eq, 1);
  364. gt_rec = mpc_reconstruct(tio, yield, gt, 1);
  365. size_t del_key_rec, node_key_rec;
  366. del_key_rec = mpc_reconstruct(tio, yield, del_key, 64);
  367. node_key_rec = mpc_reconstruct(tio, yield, node.key, 64);
  368. printf("node.key = %ld, del_key= %ld\n", node_key_rec, del_key_rec);
  369. printf("cdpf.compare results: lt = %d, eq = %d, gt = %d\n", lt_rec, eq_rec, gt_rec);
  370. */
  371. // c is the direction bit for next_ptr
  372. // (c=0: go left or c=1: go right)
  373. RegBS c = gt;
  374. // lf = local found. We found the key to delete in this level.
  375. RegBS lf = eq;
  376. // Depending on [lteq, gt] select the next ptr/index as
  377. // upper 32 bits of cnode.pointers if lteq
  378. // lower 32 bits of cnode.pointers if gt
  379. RegXS left = extractLeftPtr(node.pointers);
  380. RegXS right = extractRightPtr(node.pointers);
  381. CDPF dpf = tio.cdpf(yield);
  382. size_t &aes_ops = tio.aes_ops();
  383. // Check if left and right children are 0, and compute F_0, F_1, F_2
  384. RegBS l0 = dpf.is_zero(tio, yield, left, aes_ops);
  385. RegBS r0 = dpf.is_zero(tio, yield, right, aes_ops);
  386. RegBS F_0, F_1, F_2;
  387. // F_0 = l0 & r0
  388. mpc_and(tio, yield, F_0, l0, r0);
  389. // F_1 = l0 \xor r0
  390. F_1 = l0 ^ r0;
  391. // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
  392. F_2 = F_0 ^ F_1;
  393. if(player0)
  394. F_2^=1;
  395. // We set next ptr based on c, but we need to handle three
  396. // edge cases where we do not go by just the comparison result
  397. RegXS next_ptr;
  398. RegBS c_prime;
  399. // Case 1: found the node here (lf): we traverse down the lone child path.
  400. // or we are finding successor (fs) and there is no left child.
  401. RegBS F_c1, F_c2, F_c3, F_c4;
  402. // Case 1: lf & F_1
  403. mpc_and(tio, yield, F_c1, lf, F_1);
  404. // Set c_prime for Case 1
  405. mpc_select(tio, yield, c_prime, F_c1, c, l0);
  406. /*
  407. // Reconstruct and Debug Block 1
  408. bool F_0_rec, F_1_rec, F_2_rec, c_prime_rec;
  409. F_0_rec = mpc_reconstruct(tio, yield, F_0, 1);
  410. F_1_rec = mpc_reconstruct(tio, yield, F_1, 1);
  411. F_2_rec = mpc_reconstruct(tio, yield, F_2, 1);
  412. c_prime_rec = mpc_reconstruct(tio, yield, c_prime, 1);
  413. printf("F_0 = %d, F_1 = %d, F_2 = %d, c_prime = %d\n", F_0_rec, F_1_rec, F_2_rec, c_prime_rec);
  414. */
  415. // s1: shares of 1 bit, s0: shares of 0 bit
  416. RegBS s1, s0;
  417. s1.set(tio.player()==1);
  418. // Case 2: found the node here (lf) and node has both children (F_2)
  419. // In find successor case, so find inorder successor
  420. // (Go right and then find leftmost child.)
  421. mpc_and(tio, yield, F_c2, lf, F_2);
  422. mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);
  423. /*
  424. // Reconstruct and Debug Block 2
  425. bool F_c2_rec, s1_rec;
  426. F_c2_rec = mpc_reconstruct(tio, yield, F_c2, 1);
  427. s1_rec = mpc_reconstruct(tio, yield, s1, 1);
  428. c_prime_rec = mpc_reconstruct(tio, yield, c_prime, 1);
  429. printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
  430. */
  431. // Case 3: finding successor (fs) and node has both children (F_2)
  432. // Go left.
  433. mpc_and(tio, yield, F_c3, fs, F_2);
  434. mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);
  435. // Case 4: finding successor (fs) and node has no more left children (l0)
  436. // This is the successor node then.
  437. // Go right (since no more left)
  438. mpc_and(tio, yield, F_c4, fs, l0);
  439. mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
  440. // Set next_ptr
  441. mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);
  442. RegBS af_prime, fs_prime;
  443. mpc_or(tio, yield, af_prime, af, lf);
  444. // If in Case 2, set fs. We are now finding successor
  445. mpc_or(tio, yield, fs_prime, fs, F_c2);
  446. // If in Case 3. Successor found here already. Toggle fs off
  447. fs_prime=fs_prime^F_c4;
  448. bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
  449. // If we didn't find the key, we can end here.
  450. if(!key_found)
  451. return 0;
  452. //printf("TTL = %d\n", TTL);
  453. RegBS F_rs;
  454. // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
  455. // F_r needs to be returned in ret_struct
  456. mpc_and(tio, yield, F_rs, c_prime, ret_struct.F_r);
  457. mpc_select(tio, yield, right, F_rs, right, ret_struct.ret_ptr);
  458. if(player0)
  459. c_prime^=1;
  460. mpc_and(tio, yield, F_rs, c_prime, ret_struct.F_r);
  461. mpc_select(tio, yield, left, F_rs, left, ret_struct.ret_ptr);
  462. /*
  463. // Reconstruct and Debug Block 3
  464. bool F_rs_rec, F_ls_rec;
  465. size_t ret_ptr_rec;
  466. F_rs_rec = mpc_reconstruct(tio, yield, F_rs, 1);
  467. F_ls_rec = mpc_reconstruct(tio, yield, F_rs, 1);
  468. ret_ptr_rec = mpc_reconstruct(tio, yield, ret_struct.ret_ptr, 64);
  469. printf("F_rs_rec = %d, F_ls_rec = %d, ret_ptr_rec = %ld\n", F_rs_rec, F_ls_rec, ret_ptr_rec);
  470. */
  471. RegXS new_ptr;
  472. setLeftPtr(new_ptr, left);
  473. setRightPtr(new_ptr, right);
  474. A[ptr].NODE_POINTERS = new_ptr;
  475. // Update the return structure
  476. RegBS F_nd, F_ns, F_r;
  477. mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);
  478. if(player0)
  479. af^=1;
  480. mpc_and(tio, yield, F_nd, lf, af);
  481. // F_ns = fs & l0
  482. // Finding successor flag & no more left child
  483. F_ns = F_c4;
  484. // F_r = F_d.(!F_2)
  485. if(player0)
  486. F_2^=1;
  487. // If we have to delete here, and it doesn't have two children we have to
  488. // update child pointer in parent with the returned pointer
  489. mpc_and(tio, yield, F_r, F_nd, F_2);
  490. mpc_or(tio, yield, F_r, F_r, F_ns);
  491. ret_struct.F_r = F_r;
  492. mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);
  493. mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);
  494. mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);
  495. //We don't empty the key and value of the node with del_key in the ORAM
  496. return 1;
  497. }
  498. }
  499. bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
  500. if(num_items==0)
  501. return 0;
  502. if(num_items==1) {
  503. //Delete root
  504. auto A = oram.flat(tio, yield);
  505. Node zero;
  506. empty_locations.emplace_back(root);
  507. A[root] = zero;
  508. num_items--;
  509. return 1;
  510. } else {
  511. int TTL = num_items;
  512. // Flags for already found (af) item to delete and find successor (fs)
  513. // if this deletion requires a successor swap
  514. RegBS af;
  515. RegBS fs;
  516. del_return ret_struct;
  517. auto A = oram.flat(tio, yield);
  518. int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
  519. printf ("Success = %d\n", success);
  520. if(!success){
  521. return 0;
  522. }
  523. else{
  524. num_items--;
  525. /*
  526. printf("In delete's swap portion\n");
  527. Node del_node = A.reconstruct(A[ret_struct.N_d]);
  528. Node suc_node = A.reconstruct(A[ret_struct.N_s]);
  529. printf("del_node key = %ld, suc_node key = %ld\n",
  530. del_node.key.ashare, suc_node.key.ashare);
  531. printf("flag_s = %d\n", ret_struct.F_ss.bshare);
  532. */
  533. Node del_node = A[ret_struct.N_d];
  534. Node suc_node = A[ret_struct.N_s];
  535. RegAS zero_as; RegXS zero_xs;
  536. mpc_select(tio, yield, root, ret_struct.F_r, root, ret_struct.ret_ptr);
  537. mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);
  538. mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);
  539. A[ret_struct.N_d].NODE_KEY = del_node.key;
  540. A[ret_struct.N_d].NODE_VALUE = del_node.value;
  541. A[ret_struct.N_s].NODE_KEY = zero_as;
  542. A[ret_struct.N_s].NODE_VALUE = zero_xs;
  543. RegXS empty_loc;
  544. mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);
  545. //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
  546. empty_locations.emplace_back(empty_loc);
  547. }
  548. return 1;
  549. }
  550. }
  551. // Now we use the node in various ways. This function is called by
  552. // online.cpp.
  553. void bst(MPCIO &mpcio,
  554. const PRACOptions &opts, char **args)
  555. {
  556. nbits_t depth=4;
  557. if (*args) {
  558. depth = atoi(*args);
  559. ++args;
  560. }
  561. size_t items = (size_t(1)<<depth)-1;
  562. if (*args) {
  563. items = atoi(*args);
  564. ++args;
  565. }
  566. MPCTIO tio(mpcio, 0, opts.num_threads);
  567. run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
  568. size_t size = size_t(1)<<depth;
  569. BST tree(tio.player(), size);
  570. int insert_array[] = {10, 10, 13, 11, 14, 8, 15, 20, 17, 19, 7, 12};
  571. //int insert_array[] = {1, 2, 3, 4, 5, 6};
  572. size_t insert_array_size = 11;
  573. Node node;
  574. for(size_t i = 0; i<=insert_array_size; i++) {
  575. randomize_node(node);
  576. node.key.set(insert_array[i] * tio.player());
  577. tree.insert(tio, yield, node);
  578. }
  579. tree.pretty_print(tio, yield);
  580. RegAS del_key;
  581. /*
  582. printf("\n\nDelete %x\n", 20);
  583. del_key.set(20 * tio.player());
  584. tree.del(tio, yield, del_key);
  585. tree.pretty_print(tio, yield);
  586. tree.check_bst(tio, yield);
  587. printf("\n\nDelete %x\n", 10);
  588. del_key.set(10 * tio.player());
  589. tree.del(tio, yield, del_key);
  590. tree.pretty_print(tio, yield);
  591. tree.check_bst(tio, yield);
  592. printf("\n\nDelete %x\n", 8);
  593. del_key.set(8 * tio.player());
  594. tree.del(tio, yield, del_key);
  595. tree.pretty_print(tio, yield);
  596. tree.check_bst(tio, yield);
  597. printf("\n\nDelete %x\n", 7);
  598. del_key.set(7 * tio.player());
  599. tree.del(tio, yield, del_key);
  600. tree.pretty_print(tio, yield);
  601. tree.check_bst(tio, yield);
  602. printf("\n\nDelete %x\n", 17);
  603. del_key.set(17 * tio.player());
  604. tree.del(tio, yield, del_key);
  605. tree.pretty_print(tio, yield);
  606. tree.check_bst(tio, yield);
  607. printf("\n\nDelete %x\n", 15);
  608. del_key.set(15 * tio.player());
  609. tree.del(tio, yield, del_key);
  610. tree.pretty_print(tio, yield);
  611. tree.check_bst(tio, yield);
  612. printf("\n\nDelete %x\n", 5);
  613. del_key.set(5 * tio.player());
  614. tree.del(tio, yield, del_key);
  615. tree.pretty_print(tio, yield);
  616. tree.check_bst(tio, yield);
  617. */
  618. printf("\n\nInsert %x\n", 14);
  619. randomize_node(node);
  620. node.key.set(14 * tio.player());
  621. tree.insert(tio, yield, node);
  622. tree.pretty_print(tio, yield);
  623. tree.check_bst(tio, yield);
  624. printf("\n\nLookup %x\n", 8);
  625. randomize_node(node);
  626. RegAS lookup_key;
  627. bool found;
  628. lookup_key.set(8 * tio.player());
  629. found = tree.lookup(tio, yield, lookup_key, &node);
  630. tree.pretty_print(tio, yield);
  631. if(found) {
  632. printf("Lookup Success\n");
  633. size_t value = mpc_reconstruct(tio, yield, node.value, 64);
  634. printf("value = %lx\n", value);
  635. } else {
  636. printf("Lookup Failed\n");
  637. }
  638. printf("\n\nLookup %x\n", 99);
  639. randomize_node(node);
  640. lookup_key.set(99 * tio.player());
  641. found = tree.lookup(tio, yield, lookup_key, &node);
  642. tree.pretty_print(tio, yield);
  643. if(found) {
  644. printf("Lookup Success\n");
  645. size_t value = mpc_reconstruct(tio, yield, node.value, 64);
  646. printf("value = %lx\n", value);
  647. } else {
  648. printf("Lookup Failed\n");
  649. }
  650. });
  651. }