bst.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. #include <functional>
  2. #include "bst.hpp"
  3. #ifdef BST_DEBUG
  4. void BST::print_oram(MPCTIO &tio, yield_t &yield) {
  5. auto A = oram.flat(tio, yield);
  6. auto R = A.reconstruct();
  7. for(size_t i=0;i<R.size();++i) {
  8. printf("\n%04lx ", i);
  9. R[i].dump();
  10. }
  11. printf("\n");
  12. }
  13. #endif
  14. // Helper function to reconstruct shared RegBS
  15. bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
  16. RegBS reconstructed_flag;
  17. if (tio.player() < 2) {
  18. RegBS peer_flag;
  19. tio.queue_peer(&flag, 1);
  20. tio.queue_server(&flag, 1);
  21. yield();
  22. tio.recv_peer(&peer_flag, 1);
  23. reconstructed_flag = flag;
  24. reconstructed_flag ^= peer_flag;
  25. } else {
  26. RegBS p0_flag, p1_flag;
  27. yield();
  28. tio.recv_p0(&p0_flag, 1);
  29. tio.recv_p1(&p1_flag, 1);
  30. reconstructed_flag = p0_flag;
  31. reconstructed_flag ^= p1_flag;
  32. }
  33. return reconstructed_flag.bshare;
  34. }
  35. /*
  36. A function to assign a new random 8-bit key to a node, and resets its
  37. pointers to zeroes. The node is assigned a new random 64-bit value.
  38. */
  39. static void randomize_node(Node &a) {
  40. a.key.randomize(8);
  41. a.pointers.set(0);
  42. a.value.randomize();
  43. }
  44. /*
  45. A function to perform key comparsions for BST traversal.
  46. Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
  47. Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go
  48. left), and keeps the gt flag as is (flag to go right) during traversal.
  49. */
  50. std::tuple<RegBS, RegBS> compare_keys(MPCTIO &tio, yield_t &yield, RegAS k1,
  51. RegAS k2) {
  52. CDPF cdpf = tio.cdpf(yield);
  53. auto [lt, eq, gt] = cdpf.compare(tio, yield, k2 - k1, tio.aes_ops());
  54. RegBS lteq = lt^eq;
  55. return {lteq, gt};
  56. }
  57. // Assuming pointer of 64 bits is split as:
  58. // - 32 bits Left ptr (L)
  59. // - 32 bits Right ptr (R)
  60. // The pointers are stored as: (L << 32) | R
  61. inline RegXS extractLeftPtr(RegXS pointer){
  62. return ((pointer&(0xFFFFFFFF00000000))>>32);
  63. }
  64. inline RegXS extractRightPtr(RegXS pointer){
  65. return (pointer&(0x00000000FFFFFFFF));
  66. }
  67. inline void setLeftPtr(RegXS &pointer, RegXS new_ptr){
  68. pointer&=(0x00000000FFFFFFFF);
  69. pointer+=(new_ptr<<32);
  70. }
  71. inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
  72. pointer&=(0xFFFFFFFF00000000);
  73. pointer+=(new_ptr);
  74. }
  75. // Pretty-print a reconstructed BST, rooted at node. is_left_child and
  76. // is_right_child indicate whether node is a left or right child of its
  77. // parent. They cannot both be true, but the root of the tree has both
  78. // of them false.
  79. void BST::pretty_print(const std::vector<Node> &R, value_t node,
  80. const std::string &prefix = "", bool is_left_child = false,
  81. bool is_right_child = false)
  82. {
  83. if (node == 0) {
  84. // NULL pointer
  85. if (is_left_child) {
  86. printf("%s\xE2\x95\xA7\n", prefix.c_str()); // ╧
  87. } else if (is_right_child) {
  88. printf("%s\xE2\x95\xA4\n", prefix.c_str()); // ╤
  89. } else {
  90. printf("%s\xE2\x95\xA2\n", prefix.c_str()); // ╢
  91. }
  92. return;
  93. }
  94. const Node &n = R[node];
  95. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  96. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  97. std::string rightprefix(prefix), leftprefix(prefix),
  98. nodeprefix(prefix);
  99. if (is_left_child) {
  100. rightprefix.append("\xE2\x94\x82"); // │
  101. leftprefix.append(" ");
  102. nodeprefix.append("\xE2\x94\x94"); // └
  103. } else if (is_right_child) {
  104. rightprefix.append(" ");
  105. leftprefix.append("\xE2\x94\x82"); // │
  106. nodeprefix.append("\xE2\x94\x8C"); // ┌
  107. } else {
  108. rightprefix.append(" ");
  109. leftprefix.append(" ");
  110. nodeprefix.append("\xE2\x94\x80"); // ─
  111. }
  112. pretty_print(R, right_ptr, rightprefix, false, true);
  113. printf("%s\xE2\x94\xA4", nodeprefix.c_str()); // ┤
  114. n.dump();
  115. printf("\n");
  116. pretty_print(R, left_ptr, leftprefix, true, false);
  117. }
  118. void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
  119. RegXS peer_root;
  120. RegXS reconstructed_root = root;
  121. if (tio.player() == 1) {
  122. tio.queue_peer(&root, sizeof(root));
  123. yield();
  124. } else {
  125. RegXS peer_root;
  126. yield();
  127. if(tio.player()==0) {
  128. tio.recv_peer(&peer_root, sizeof(peer_root));
  129. }
  130. reconstructed_root += peer_root;
  131. }
  132. auto A = oram.flat(tio, yield);
  133. auto R = A.reconstruct();
  134. if(tio.player()==0) {
  135. pretty_print(R, reconstructed_root.xshare);
  136. }
  137. }
  138. // Check the BST invariant of the tree (that all keys to the left are
  139. // less than or equal to this key, all keys to the right are strictly
  140. // greater, and this is true recursively). Returns a
  141. // tuple<bool,address_t>, where the bool says whether the BST invariant
  142. // holds, and the address_t is the height of the tree (which will be
  143. // useful later when we check AVL trees).
  144. std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
  145. value_t node, value_t min_key = 0, value_t max_key = ~0)
  146. {
  147. //printf("node = %ld\n", node);
  148. if (node == 0) {
  149. return { true, 0 };
  150. }
  151. const Node &n = R[node];
  152. value_t key = n.key.ashare;
  153. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  154. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  155. auto [leftok, leftheight ] = check_bst(R, left_ptr, min_key, key);
  156. auto [rightok, rightheight ] = check_bst(R, right_ptr, key+1, max_key);
  157. address_t height = leftheight;
  158. if (rightheight > height) {
  159. height = rightheight;
  160. }
  161. height += 1;
  162. //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
  163. return { leftok && rightok && key >= min_key && key <= max_key,
  164. height };
  165. }
  166. void BST::check_bst(MPCTIO &tio, yield_t &yield) {
  167. auto A = oram.flat(tio, yield);
  168. auto R = A.reconstruct();
  169. RegXS rec_root = this->root;
  170. if (tio.player() == 1) {
  171. tio.queue_peer(&(this->root), sizeof(this->root));
  172. yield();
  173. } else {
  174. RegXS peer_root;
  175. yield();
  176. if(tio.player()==0) {
  177. tio.recv_peer(&peer_root, sizeof(peer_root));
  178. }
  179. rec_root+= peer_root;
  180. }
  181. if (tio.player() == 0) {
  182. auto [ ok, height ] = check_bst(R, rec_root.xshare);
  183. printf("BST structure %s\nBST height = %u\n",
  184. ok ? "ok" : "NOT OK", height);
  185. }
  186. }
  187. /*
  188. The recursive insert() call, invoked by the wrapper insert() function.
  189. Takes as input the pointer to the current node in tree traversal (ptr),
  190. the key to be inserted (insertion_key), the underlying Duoram as a
  191. flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
  192. tracks if the operation is dummy/real.
  193. Returns a tuple <ptr, dir> where
  194. ptr: the pointer to the node where the insertion should happen
  195. dir: the bit indicating whether the new node should be inserted as the
  196. left/right child.
  197. */
  198. std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  199. RegAS insertion_key, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
  200. if(TTL==0) {
  201. RegBS zero;
  202. return {ptr, zero};
  203. }
  204. RegBS isNotDummy = isDummy ^ (!tio.player());
  205. Node cnode = A[ptr];
  206. // Compare key
  207. auto [lteq, gt] = compare_keys(tio, yield, cnode.key, insertion_key);
  208. // Depending on [lteq, gt] select the next ptr/index as
  209. // upper 32 bits of cnode.pointers if lteq
  210. // lower 32 bits of cnode.pointers if gt
  211. RegXS left = extractLeftPtr(cnode.pointers);
  212. RegXS right = extractRightPtr(cnode.pointers);
  213. RegXS next_ptr;
  214. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  215. CDPF dpf = tio.cdpf(yield);
  216. size_t &aes_ops = tio.aes_ops();
  217. // F_z: Check if this is last node on path
  218. RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
  219. RegBS F_i;
  220. // F_i: If this was last node on path (F_z) && isNotDummy:
  221. // insert new_node here.
  222. mpc_and(tio, yield, F_i, (isNotDummy), F_z);
  223. isDummy^=F_i;
  224. auto [wptr, direction] = insert(tio, yield, next_ptr, insertion_key, A, TTL-1, isDummy);
  225. RegXS ret_ptr;
  226. RegBS ret_direction;
  227. // If we insert here (F_i), return the ptr to this node as wptr
  228. // and update direction to the direction taken by compare_keys
  229. run_coroutines(tio, [&tio, &ret_ptr, F_i, wptr, ptr](yield_t &yield)
  230. { mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);},
  231. [&tio, &ret_direction, F_i, direction, gt](yield_t &yield)
  232. //ret_direction = direction + F_i (direction - gt)
  233. { mpc_and(tio, yield, ret_direction, F_i, direction^gt);});
  234. ret_direction^=direction;
  235. return {ret_ptr, ret_direction};
  236. }
  237. /*
  238. The wrapper insert() operation invoked by the main insert call
  239. BST::insert(tio, yield, Node& new_node);
  240. Takes as input the new node (node), the underlying Duoram as a flat (A).
  241. */
  242. void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
  243. bool player0 = tio.player()==0;
  244. // If there are no items in tree. Make this new item the root.
  245. if (num_items==0) {
  246. A[1] = node;
  247. // Set root to a secret sharing of the constant value 1
  248. root.set(1*tio.player());
  249. num_items++;
  250. //printf("num_items == %ld!\n", num_items);
  251. return;
  252. } else {
  253. // Insert node into next free slot in the ORAM
  254. int new_id;
  255. RegXS insert_address;
  256. int TTL = num_items++;
  257. bool insertAtEmptyLocation = (empty_locations.size() > 0);
  258. if(insertAtEmptyLocation) {
  259. insert_address = empty_locations.back();
  260. empty_locations.pop_back();
  261. A[insert_address] = node;
  262. } else {
  263. new_id = 1 + num_items;
  264. A[new_id] = node;
  265. insert_address.set(new_id * tio.player());
  266. }
  267. RegBS isDummy;
  268. //Do a recursive insert
  269. auto [wptr, direction] = insert(tio, yield, root, node.key, A, TTL, isDummy);
  270. //Complete the insertion by reading wptr and updating its pointers
  271. RegXS pointers = A[wptr].NODE_POINTERS;
  272. RegXS left_ptr = extractLeftPtr(pointers);
  273. RegXS right_ptr = extractRightPtr(pointers);
  274. RegXS new_right_ptr, new_left_ptr;
  275. RegBS not_direction = direction;
  276. if (player0) {
  277. not_direction^=1;
  278. }
  279. run_coroutines(tio,
  280. [&tio, &new_right_ptr, direction, right_ptr, insert_address](yield_t &yield)
  281. { mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);},
  282. [&tio, &new_left_ptr, not_direction, left_ptr, insert_address](yield_t &yield)
  283. { mpc_select(tio, yield, new_left_ptr, not_direction, left_ptr, insert_address);});
  284. setLeftPtr(pointers, new_left_ptr);
  285. setRightPtr(pointers, new_right_ptr);
  286. A[wptr].NODE_POINTERS = pointers;
  287. }
  288. }
  289. /*
  290. Insert a new node into the BST.
  291. Takes as input the new node (node).
  292. */
  293. void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
  294. auto A = oram.flat(tio, yield);
  295. insert(tio, yield, node, A);
  296. /*
  297. // To visualize database and tree after each insert:
  298. auto R = A.reconstruct();
  299. if (tio.player() == 0) {
  300. for(size_t i=0;i<R.size();++i) {
  301. printf("\n%04lx ", i);
  302. R[i].dump();
  303. }
  304. printf("\n");
  305. }
  306. pretty_print(R, 1);
  307. */
  308. }
  309. RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>::Flat &A,
  310. int TTL, RegBS isDummy, Node *ret_node) {
  311. if(TTL==0) {
  312. // If we found the key, then isDummy will be true
  313. return isDummy;
  314. }
  315. Node cnode = A[ptr];
  316. // Compare key
  317. CDPF cdpf = tio.cdpf(yield);
  318. auto [lt, eq, gt] = cdpf.compare(tio, yield, key - cnode.key, tio.aes_ops());
  319. // Depending on [lteq, gt] select the next ptr/index as
  320. // upper 32 bits of cnode.pointers if lteq
  321. // lower 32 bits of cnode.pointers if gt
  322. RegXS left = extractLeftPtr(cnode.pointers);
  323. RegXS right = extractRightPtr(cnode.pointers);
  324. RegXS next_ptr;
  325. RegBS F_found;
  326. // If we haven't found the key yet, and the lookup matches the current node key,
  327. // then we found the node to return
  328. RegBS isNotDummy = isDummy ^ (!tio.player());
  329. // Note: This logic returns the last matched key and value.
  330. // Returning the first one incurs an additional round.
  331. std::vector<coro_t> coroutines;
  332. coroutines.emplace_back(
  333. [&tio, &next_ptr, gt, left, right](yield_t &yield)
  334. { mpc_select(tio, yield, next_ptr, gt, left, right, 32);});
  335. coroutines.emplace_back(
  336. [&tio, &F_found, isNotDummy, eq](yield_t &yield)
  337. { mpc_and(tio, yield, F_found, isNotDummy, eq);});
  338. coroutines.emplace_back(
  339. [&tio, &ret_node, eq, cnode](yield_t &yield)
  340. { mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key);});
  341. coroutines.emplace_back(
  342. [&tio, &ret_node, eq, cnode](yield_t &yield)
  343. { mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value);});
  344. coroutines.emplace_back(
  345. [&tio, &isDummy, eq](yield_t &yield)
  346. { mpc_or(tio, yield, isDummy, isDummy, eq);});
  347. run_coroutines(tio, coroutines);
  348. #ifdef BST_DEBUG
  349. size_t ckey = mpc_reconstruct(tio, yield, cnode.key);
  350. size_t lkey = mpc_reconstruct(tio, yield, key);
  351. bool rec_lt = mpc_reconstruct(tio, yield, lt);
  352. bool rec_eq = mpc_reconstruct(tio, yield, eq);
  353. bool rec_gt = mpc_reconstruct(tio, yield, gt);
  354. bool rec_found = mpc_reconstruct(tio, yield, isDummy);
  355. bool rec_f_found = mpc_reconstruct(tio, yield, F_found);
  356. printf("rec_lt = %d, rec_eq = %d, rec_gt = %d\n", rec_lt, rec_eq, rec_gt);
  357. printf("rec_isDummy/found = %d ,rec_f_found = %d, cnode.key = %ld, lookup key = %ld\n", rec_found, rec_f_found, ckey, lkey);
  358. #endif
  359. RegBS found = lookup(tio, yield, next_ptr, key, A, TTL-1, isDummy, ret_node);
  360. return found;
  361. }
  362. RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
  363. auto A = oram.flat(tio, yield);
  364. RegBS isDummy;
  365. RegBS found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
  366. /*
  367. // To visualize database and tree after each lookup:
  368. auto R = A.reconstruct();
  369. if (tio.player() == 0) {
  370. for(size_t i=0;i<R.size();++i) {
  371. printf("\n%04lx ", i);
  372. R[i].dump();
  373. }
  374. printf("\n");
  375. }
  376. pretty_print(R, 1);
  377. */
  378. return found;
  379. }
  380. /*
  381. The recursive del() call, invoked by the wrapper del() function.
  382. Takes as input the pointer to the current node in tree traversal (ptr),
  383. the key to be deleted (del_key), the underlying Duoram as a
  384. flat (A), Flags af (already found) and fs (find successor), and the
  385. Time-To_live TTL. Finally, a return structure ret_struct that tracks
  386. the location of the successor node and the node to delete, in order
  387. to perform the actual deletion after the recursive traversal. This
  388. is required in the case of a deletion that requires a successor swap
  389. (i.e., when the node to delete has both children).
  390. Returns success/fail bit.
  391. */
  392. bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  393. Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL,
  394. del_return &ret_struct) {
  395. bool player0 = tio.player()==0;
  396. //printf("TTL = %d\n", TTL);
  397. if(TTL==0) {
  398. //Reconstruct and return af
  399. bool success = reconstruct_RegBS(tio, yield, af);
  400. //printf("Reconstructed flag = %d\n", success);
  401. if(player0) {
  402. ret_struct.F_r^=1;
  403. }
  404. return success;
  405. } else {
  406. // s1: shares of 1 bit, s0: shares of 0 bit
  407. RegBS s1, s0;
  408. s1.set(tio.player()==1);
  409. Node node = A[ptr];
  410. RegXS left = extractLeftPtr(node.pointers);
  411. RegXS right = extractRightPtr(node.pointers);
  412. CDPF cdpf = tio.cdpf(yield);
  413. size_t &aes_ops = tio.aes_ops();
  414. RegBS l0, r0, lt, eq, gt;
  415. // l0: Is left child 0
  416. // r0: Is right child 0
  417. run_coroutines(tio,
  418. [&tio, &l0, left, &aes_ops, &cdpf](yield_t &yield)
  419. { l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
  420. [&tio, &r0, right, &aes_ops, &cdpf](yield_t &yield)
  421. { r0 = cdpf.is_zero(tio, yield, right, aes_ops);},
  422. [&tio, &lt, &eq, &gt, del_key, node, &cdpf](yield_t &yield)
  423. // Compare Key
  424. { auto [a, b, c] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  425. lt = a; eq = b; gt = c;});
  426. /*
  427. // Reconstruct and Debug Block 0
  428. bool lt_rec, eq_rec, gt_rec;
  429. lt_rec = mpc_reconstruct(tio, yield, lt);
  430. eq_rec = mpc_reconstruct(tio, yield, eq);
  431. gt_rec = mpc_reconstruct(tio, yield, gt);
  432. size_t del_key_rec, node_key_rec;
  433. del_key_rec = mpc_reconstruct(tio, yield, del_key);
  434. node_key_rec = mpc_reconstruct(tio, yield, node.key);
  435. printf("node.key = %ld, del_key= %ld\n", node_key_rec, del_key_rec);
  436. printf("cdpf.compare results: lt = %d, eq = %d, gt = %d\n", lt_rec, eq_rec, gt_rec);
  437. */
  438. // c is the direction bit for next_ptr
  439. // (c=0: go left or c=1: go right)
  440. RegBS c = gt;
  441. // lf = local found. We found the key to delete in this level.
  442. RegBS lf = eq;
  443. // F_{X}: Flags that indicate the number of children this node has
  444. // F_0: no children, F_1: one child, F_2: both children
  445. RegBS F_0, F_1, F_2;
  446. // F_1 = l0 \xor r0
  447. F_1 = l0 ^ r0;
  448. // We set next ptr based on c, but we need to handle three
  449. // edge cases where we do not go by just the comparison result
  450. RegXS next_ptr;
  451. RegBS c_prime;
  452. // Case 1: found the node here (lf): we traverse down the lone child path.
  453. // or we are finding successor (fs) and there is no left child.
  454. RegBS F_c1, F_c2, F_c3, F_c4;
  455. // Case 1: lf & F_1
  456. run_coroutines(tio,
  457. [&tio, &F_c1, lf, F_1](yield_t &yield)
  458. { mpc_and(tio, yield, F_c1, lf, F_1);},
  459. [&tio, &F_0, l0, r0] (yield_t &yield)
  460. // F_0 = l0 & r0
  461. { mpc_and(tio, yield, F_0, l0, r0);});
  462. // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
  463. F_2 = F_0 ^ F_1;
  464. if(player0)
  465. F_2^=1;
  466. /*
  467. // Reconstruct and Debug Block 1
  468. bool F_0_rec, F_1_rec, F_2_rec, c_prime_rec;
  469. F_0_rec = mpc_reconstruct(tio, yield, F_0);
  470. F_1_rec = mpc_reconstruct(tio, yield, F_1);
  471. F_2_rec = mpc_reconstruct(tio, yield, F_2);
  472. c_prime_rec = mpc_reconstruct(tio, yield, c_prime);
  473. printf("F_0 = %d, F_1 = %d, F_2 = %d, c_prime = %d\n", F_0_rec, F_1_rec, F_2_rec, c_prime_rec);
  474. */
  475. run_coroutines(tio,
  476. [&tio, &c_prime, F_c1, c, l0](yield_t &yield)
  477. // Set c_prime for Case 1
  478. { mpc_select(tio, yield, c_prime, F_c1, c, l0);},
  479. [&tio, &F_c2, lf, F_2](yield_t &yield)
  480. // Case 2: found the node here (lf) and node has both children (F_2)
  481. // In find successor case, so find inorder successor
  482. // (Go right and then find leftmost child.)
  483. { mpc_and(tio, yield, F_c2, lf, F_2);});
  484. /*
  485. // Reconstruct and Debug Block 2
  486. bool F_c2_rec, s1_rec;
  487. F_c2_rec = mpc_reconstruct(tio, yield, F_c2);
  488. s1_rec = mpc_reconstruct(tio, yield, s1);
  489. c_prime_rec = mpc_reconstruct(tio, yield, c_prime);
  490. printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
  491. */
  492. run_coroutines(tio,
  493. [&tio, &c_prime, F_c2, s1](yield_t &yield)
  494. { mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);},
  495. [&tio, &F_c3, fs, F_2](yield_t &yield)
  496. // Case 3: finding successor (fs) and node has both children (F_2)
  497. // Go left.
  498. { mpc_and(tio, yield, F_c3, fs, F_2);});
  499. run_coroutines(tio,
  500. [&tio, &c_prime, F_c3, s0](yield_t &yield)
  501. { mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);},
  502. // Case 4: finding successor (fs) and node has no more left children (l0)
  503. // This is the successor node then.
  504. // Go right (since no more left)
  505. [&tio, &F_c4, fs, l0] (yield_t &yield)
  506. { mpc_and(tio, yield, F_c4, fs, l0);});
  507. mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
  508. RegBS af_prime, fs_prime;
  509. run_coroutines(tio,
  510. [&tio, &next_ptr, c_prime, left, right](yield_t &yield)
  511. // Set next_ptr
  512. { mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);},
  513. [&tio, &af_prime, af, lf](yield_t &yield)
  514. { mpc_or(tio, yield, af_prime, af, lf);},
  515. [&tio, &fs_prime, fs, F_c2](yield_t &yield)
  516. // If in Case 2, set fs. We are now finding successor
  517. { mpc_or(tio, yield, fs_prime, fs, F_c2);});
  518. // If in Case 4. Successor found here already. Toggle fs off
  519. fs_prime=fs_prime^F_c4;
  520. bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
  521. // If we didn't find the key, we can end here.
  522. if(!key_found) {
  523. return 0;
  524. }
  525. //printf("TTL = %d\n", TTL);
  526. RegBS F_rs_right, F_rs_left, not_c_prime=c_prime;
  527. if(player0) {
  528. not_c_prime^=1;
  529. }
  530. // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
  531. // F_r needs to be returned in ret_struct
  532. run_coroutines(tio,
  533. [&tio, &F_rs_right, c_prime, ret_struct](yield_t &yield)
  534. { mpc_and(tio, yield, F_rs_right, c_prime, ret_struct.F_r);},
  535. [&tio, &F_rs_left, not_c_prime, left, ret_struct](yield_t &yield)
  536. { mpc_and(tio, yield, F_rs_left, not_c_prime, ret_struct.F_r);});
  537. run_coroutines(tio,
  538. [&tio, &right, F_rs_right, ret_struct](yield_t &yield)
  539. { mpc_select(tio, yield, right, F_rs_right, right, ret_struct.ret_ptr);},
  540. [&tio, &left, F_rs_left, ret_struct](yield_t &yield)
  541. { mpc_select(tio, yield, left, F_rs_left, left, ret_struct.ret_ptr);});
  542. /*
  543. // Reconstruct and Debug Block 3
  544. bool F_rs_rec, F_ls_rec;
  545. size_t ret_ptr_rec;
  546. F_rs_rec = mpc_reconstruct(tio, yield, F_rs);
  547. F_ls_rec = mpc_reconstruct(tio, yield, F_rs);
  548. ret_ptr_rec = mpc_reconstruct(tio, yield, ret_struct.ret_ptr);
  549. printf("F_rs_rec = %d, F_ls_rec = %d, ret_ptr_rec = %ld\n", F_rs_rec, F_ls_rec, ret_ptr_rec);
  550. */
  551. RegXS new_ptr;
  552. setLeftPtr(new_ptr, left);
  553. setRightPtr(new_ptr, right);
  554. A[ptr].NODE_POINTERS = new_ptr;
  555. // Update the return structure
  556. RegBS F_nd, F_ns, F_r, not_af = af, not_F_2 = F_2;
  557. if(player0) {
  558. not_af^=1;
  559. not_F_2^=1;
  560. }
  561. // F_ns = fs & l0
  562. // Finding successor flag & no more left child = F_c4
  563. F_ns = F_c4;
  564. run_coroutines(tio,
  565. [&tio, &ret_struct, F_c2](yield_t &yield)
  566. { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
  567. [&tio, &F_nd, lf, not_af](yield_t &yield)
  568. { mpc_and(tio, yield, F_nd, lf, not_af);});
  569. // F_r = F_d.(!F_2)
  570. // If we have to delete here, and it doesn't have two children we have to
  571. // update child pointer in parent with the returned pointer
  572. mpc_and(tio, yield, F_r, F_nd, not_F_2);
  573. mpc_or(tio, yield, F_r, F_r, F_ns);
  574. ret_struct.F_r = F_r;
  575. run_coroutines(tio,
  576. [&tio, &ret_struct, F_nd, ptr](yield_t &yield)
  577. { mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);},
  578. [&tio, &ret_struct, F_ns, ptr](yield_t &yield)
  579. { mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);},
  580. [&tio, &ret_struct, F_r, ptr](yield_t &yield)
  581. { mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);});
  582. //We don't empty the key and value of the node with del_key in the ORAM
  583. return 1;
  584. }
  585. }
  586. /*
  587. The main del() function.
  588. Trying to delete an item that does not exist in the tree will result in
  589. an explicit (non-oblivious) failure.
  590. Takes as input the key to delete (del_key).
  591. Returns success/fail bit.
  592. */
  593. bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
  594. if(num_items==0)
  595. return 0;
  596. if(num_items==1) {
  597. //Delete root
  598. auto A = oram.flat(tio, yield);
  599. Node zero;
  600. empty_locations.emplace_back(root);
  601. A[root] = zero;
  602. num_items--;
  603. return 1;
  604. } else {
  605. int TTL = num_items;
  606. // Flags for already found (af) item to delete and find successor (fs)
  607. // if this deletion requires a successor swap
  608. RegBS af;
  609. RegBS fs;
  610. del_return ret_struct;
  611. auto A = oram.flat(tio, yield);
  612. int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
  613. if(!success){
  614. return 0;
  615. }
  616. else{
  617. num_items--;
  618. /*
  619. printf("In delete's swap portion\n");
  620. Node del_node = A.reconstruct(A[ret_struct.N_d]);
  621. Node suc_node = A.reconstruct(A[ret_struct.N_s]);
  622. printf("del_node key = %ld, suc_node key = %ld\n",
  623. del_node.key.ashare, suc_node.key.ashare);
  624. printf("flag_s = %d\n", ret_struct.F_ss.bshare);
  625. */
  626. Node del_node = A[ret_struct.N_d];
  627. Node suc_node = A[ret_struct.N_s];
  628. RegAS zero_as; RegXS zero_xs;
  629. RegXS empty_loc, temp_root = root;
  630. run_coroutines(tio,
  631. [&tio, &temp_root, ret_struct](yield_t &yield)
  632. { mpc_select(tio, yield, temp_root, ret_struct.F_r, temp_root, ret_struct.ret_ptr);},
  633. [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
  634. { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
  635. [&tio, &del_node, ret_struct, suc_node](yield_t & yield)
  636. { mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);},
  637. [&tio, &empty_loc, ret_struct](yield_t &yield)
  638. { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
  639. root = temp_root;
  640. run_coroutines(tio,
  641. [&tio, &A, ret_struct, del_node](yield_t &yield)
  642. { auto acont = A.context(yield);
  643. acont[ret_struct.N_d].NODE_KEY = del_node.key;},
  644. [&tio, &A, ret_struct, del_node](yield_t &yield)
  645. { auto acont = A.context(yield);
  646. acont[ret_struct.N_d].NODE_VALUE = del_node.value;},
  647. [&tio, &A, ret_struct, zero_as](yield_t &yield)
  648. { auto acont = A.context(yield);
  649. acont[ret_struct.N_s].NODE_KEY = zero_as;},
  650. [&tio, &A, ret_struct, zero_xs](yield_t &yield)
  651. { auto acont = A.context(yield);
  652. acont[ret_struct.N_s].NODE_VALUE = zero_xs;});
  653. //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
  654. empty_locations.emplace_back(empty_loc);
  655. }
  656. return 1;
  657. }
  658. }
  659. // Now we use the node in various ways. This function is called by
  660. // online.cpp.
  661. void bst(MPCIO &mpcio,
  662. const PRACOptions &opts, char **args)
  663. {
  664. nbits_t depth=4;
  665. if (*args) {
  666. depth = atoi(*args);
  667. ++args;
  668. }
  669. MPCTIO tio(mpcio, 0, opts.num_cpu_threads);
  670. run_coroutines(tio, [&tio, depth] (yield_t &yield) {
  671. size_t size = size_t(1)<<depth;
  672. BST tree(tio.player(), size);
  673. int insert_array[] = {10, 10, 13, 11, 14, 8, 15, 20, 17, 19, 7, 12};
  674. //int insert_array[] = {1, 2, 3, 4, 5, 6};
  675. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  676. Node node;
  677. for(size_t i = 0; i<insert_array_size; i++) {
  678. randomize_node(node);
  679. node.key.set(insert_array[i] * tio.player());
  680. tree.insert(tio, yield, node);
  681. }
  682. tree.pretty_print(tio, yield);
  683. RegAS del_key;
  684. printf("\n\nDelete %x\n", 20);
  685. del_key.set(20 * tio.player());
  686. tree.del(tio, yield, del_key);
  687. tree.pretty_print(tio, yield);
  688. tree.check_bst(tio, yield);
  689. printf("\n\nDelete %x\n", 10);
  690. del_key.set(10 * tio.player());
  691. tree.del(tio, yield, del_key);
  692. tree.pretty_print(tio, yield);
  693. tree.check_bst(tio, yield);
  694. printf("\n\nDelete %x\n", 7);
  695. del_key.set(7 * tio.player());
  696. tree.del(tio, yield, del_key);
  697. tree.pretty_print(tio, yield);
  698. tree.check_bst(tio, yield);
  699. printf("\n\nDelete %x\n", 17);
  700. del_key.set(17 * tio.player());
  701. tree.del(tio, yield, del_key);
  702. tree.pretty_print(tio, yield);
  703. tree.check_bst(tio, yield);
  704. printf("\n\nDelete %x\n", 15);
  705. del_key.set(15 * tio.player());
  706. tree.del(tio, yield, del_key);
  707. tree.pretty_print(tio, yield);
  708. tree.check_bst(tio, yield);
  709. printf("\n\nDelete %x\n", 5);
  710. del_key.set(5 * tio.player());
  711. tree.del(tio, yield, del_key);
  712. tree.pretty_print(tio, yield);
  713. tree.check_bst(tio, yield);
  714. printf("\n\nInsert %x\n", 14);
  715. randomize_node(node);
  716. node.key.set(14 * tio.player());
  717. tree.insert(tio, yield, node);
  718. tree.pretty_print(tio, yield);
  719. tree.check_bst(tio, yield);
  720. printf("\n\nLookup %x\n", 8);
  721. randomize_node(node);
  722. RegAS lookup_key;
  723. RegBS found;
  724. bool rec_found;
  725. lookup_key.set(8 * tio.player());
  726. found = tree.lookup(tio, yield, lookup_key, &node);
  727. rec_found = mpc_reconstruct(tio, yield, found);
  728. tree.pretty_print(tio, yield);
  729. if(tio.player()!=2) {
  730. if(rec_found) {
  731. printf("Lookup Success\n");
  732. size_t value = mpc_reconstruct(tio, yield, node.value);
  733. printf("value = %lx\n", value);
  734. } else {
  735. printf("Lookup Failed\n");
  736. }
  737. }
  738. });
  739. }