bst.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. #include <functional>
  2. #include "bst.hpp"
  3. #ifdef BST_DEBUG
  4. void BST::print_oram(MPCTIO &tio, yield_t &yield) {
  5. auto A = oram.flat(tio, yield);
  6. auto R = A.reconstruct();
  7. for(size_t i=0;i<R.size();++i) {
  8. printf("\n%04lx ", i);
  9. R[i].dump();
  10. }
  11. printf("\n");
  12. }
  13. #endif
  14. // Helper function to reconstruct shared RegBS
  15. bool reconstruct_RegBS(MPCTIO &tio, yield_t &yield, RegBS flag) {
  16. RegBS reconstructed_flag;
  17. if (tio.player() < 2) {
  18. RegBS peer_flag;
  19. tio.queue_peer(&flag, 1);
  20. tio.queue_server(&flag, 1);
  21. yield();
  22. tio.recv_peer(&peer_flag, 1);
  23. reconstructed_flag = flag;
  24. reconstructed_flag ^= peer_flag;
  25. } else {
  26. RegBS p0_flag, p1_flag;
  27. yield();
  28. tio.recv_p0(&p0_flag, 1);
  29. tio.recv_p1(&p1_flag, 1);
  30. reconstructed_flag = p0_flag;
  31. reconstructed_flag ^= p1_flag;
  32. }
  33. return reconstructed_flag.bshare;
  34. }
  35. /*
  36. A function to assign a new random 8-bit key to a node, and resets its
  37. pointers to zeroes. The node is assigned a new random 64-bit value.
  38. */
  39. static void randomize_node(Node &a) {
  40. a.key.randomize(8);
  41. a.pointers.set(0);
  42. a.value.randomize();
  43. }
  44. /*
  45. A function to perform key comparsions for BST traversal.
  46. Inputs: k1 = key of node in the tree, k2 = insertion/deletion/lookup key.
  47. Evaluates (k2-k1), and combines the lt and eq flag into one (flag to go
  48. left), and keeps the gt flag as is (flag to go right) during traversal.
  49. */
  50. std::tuple<RegBS, RegBS> compare_keys(MPCTIO tio, yield_t &yield, RegAS k1,
  51. RegAS k2) {
  52. CDPF cdpf = tio.cdpf(yield);
  53. auto [lt, eq, gt] = cdpf.compare(tio, yield, k2 - k1, tio.aes_ops());
  54. RegBS lteq = lt^eq;
  55. return {lteq, gt};
  56. }
  57. // Assuming pointer of 64 bits is split as:
  58. // - 32 bits Left ptr (L)
  59. // - 32 bits Right ptr (R)
  60. // The pointers are stored as: (L << 32) | R
  61. inline RegXS extractLeftPtr(RegXS pointer){
  62. return ((pointer&(0xFFFFFFFF00000000))>>32);
  63. }
  64. inline RegXS extractRightPtr(RegXS pointer){
  65. return (pointer&(0x00000000FFFFFFFF));
  66. }
  67. inline void setLeftPtr(RegXS &pointer, RegXS new_ptr){
  68. pointer&=(0x00000000FFFFFFFF);
  69. pointer+=(new_ptr<<32);
  70. }
  71. inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
  72. pointer&=(0xFFFFFFFF00000000);
  73. pointer+=(new_ptr);
  74. }
  75. // Pretty-print a reconstructed BST, rooted at node. is_left_child and
  76. // is_right_child indicate whether node is a left or right child of its
  77. // parent. They cannot both be true, but the root of the tree has both
  78. // of them false.
  79. void BST::pretty_print(const std::vector<Node> &R, value_t node,
  80. const std::string &prefix = "", bool is_left_child = false,
  81. bool is_right_child = false)
  82. {
  83. if (node == 0) {
  84. // NULL pointer
  85. if (is_left_child) {
  86. printf("%s\xE2\x95\xA7\n", prefix.c_str()); // ╧
  87. } else if (is_right_child) {
  88. printf("%s\xE2\x95\xA4\n", prefix.c_str()); // ╤
  89. } else {
  90. printf("%s\xE2\x95\xA2\n", prefix.c_str()); // ╢
  91. }
  92. return;
  93. }
  94. const Node &n = R[node];
  95. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  96. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  97. std::string rightprefix(prefix), leftprefix(prefix),
  98. nodeprefix(prefix);
  99. if (is_left_child) {
  100. rightprefix.append("\xE2\x94\x82"); // │
  101. leftprefix.append(" ");
  102. nodeprefix.append("\xE2\x94\x94"); // └
  103. } else if (is_right_child) {
  104. rightprefix.append(" ");
  105. leftprefix.append("\xE2\x94\x82"); // │
  106. nodeprefix.append("\xE2\x94\x8C"); // ┌
  107. } else {
  108. rightprefix.append(" ");
  109. leftprefix.append(" ");
  110. nodeprefix.append("\xE2\x94\x80"); // ─
  111. }
  112. pretty_print(R, right_ptr, rightprefix, false, true);
  113. printf("%s\xE2\x94\xA4", nodeprefix.c_str()); // ┤
  114. n.dump();
  115. printf("\n");
  116. pretty_print(R, left_ptr, leftprefix, true, false);
  117. }
  118. void BST::pretty_print(MPCTIO &tio, yield_t &yield) {
  119. RegXS peer_root;
  120. RegXS reconstructed_root = root;
  121. if (tio.player() == 1) {
  122. tio.queue_peer(&root, sizeof(root));
  123. yield();
  124. } else {
  125. RegXS peer_root;
  126. yield();
  127. if(tio.player()==0) {
  128. tio.recv_peer(&peer_root, sizeof(peer_root));
  129. }
  130. reconstructed_root += peer_root;
  131. }
  132. auto A = oram.flat(tio, yield);
  133. auto R = A.reconstruct();
  134. if(tio.player()==0) {
  135. pretty_print(R, reconstructed_root.xshare);
  136. }
  137. }
  138. // Check the BST invariant of the tree (that all keys to the left are
  139. // less than or equal to this key, all keys to the right are strictly
  140. // greater, and this is true recursively). Returns a
  141. // tuple<bool,address_t>, where the bool says whether the BST invariant
  142. // holds, and the address_t is the height of the tree (which will be
  143. // useful later when we check AVL trees).
  144. std::tuple<bool, address_t> BST::check_bst(const std::vector<Node> &R,
  145. value_t node, value_t min_key = 0, value_t max_key = ~0)
  146. {
  147. //printf("node = %ld\n", node);
  148. if (node == 0) {
  149. return { true, 0 };
  150. }
  151. const Node &n = R[node];
  152. value_t key = n.key.ashare;
  153. value_t left_ptr = extractLeftPtr(n.pointers).xshare;
  154. value_t right_ptr = extractRightPtr(n.pointers).xshare;
  155. auto [leftok, leftheight ] = check_bst(R, left_ptr, min_key, key);
  156. auto [rightok, rightheight ] = check_bst(R, right_ptr, key+1, max_key);
  157. address_t height = leftheight;
  158. if (rightheight > height) {
  159. height = rightheight;
  160. }
  161. height += 1;
  162. //printf("node = %ld, leftok = %d, rightok = %d\n", node, leftok, rightok);
  163. return { leftok && rightok && key >= min_key && key <= max_key,
  164. height };
  165. }
  166. void BST::check_bst(MPCTIO &tio, yield_t &yield) {
  167. auto A = oram.flat(tio, yield);
  168. auto R = A.reconstruct();
  169. RegXS rec_root = this->root;
  170. if (tio.player() == 1) {
  171. tio.queue_peer(&(this->root), sizeof(this->root));
  172. yield();
  173. } else {
  174. RegXS peer_root;
  175. yield();
  176. if(tio.player()==0) {
  177. tio.recv_peer(&peer_root, sizeof(peer_root));
  178. }
  179. rec_root+= peer_root;
  180. }
  181. if (tio.player() == 0) {
  182. auto [ ok, height ] = check_bst(R, rec_root.xshare);
  183. printf("BST structure %s\nBST height = %u\n",
  184. ok ? "ok" : "NOT OK", height);
  185. }
  186. }
  187. /*
  188. The recursive insert() call, invoked by the wrapper insert() function.
  189. Takes as input the pointer to the current node in tree traversal (ptr),
  190. the new node to be inserted (new_node), the underlying Duoram as a
  191. flat (A), and the Time-To_live TTL, and a shared flag (isDummy) which
  192. tracks if the operation is dummy/real.
  193. */
  194. std::tuple<RegXS, RegBS> BST::insert(MPCTIO &tio, yield_t &yield, RegXS ptr,
  195. const Node &new_node, Duoram<Node>::Flat &A, int TTL, RegBS isDummy) {
  196. if(TTL==0) {
  197. RegBS zero;
  198. return {ptr, zero};
  199. }
  200. RegBS isNotDummy = isDummy ^ (!tio.player());
  201. Node cnode = A[ptr];
  202. // Compare key
  203. auto [lteq, gt] = compare_keys(tio, yield, cnode.key, new_node.key);
  204. // Depending on [lteq, gt] select the next ptr/index as
  205. // upper 32 bits of cnode.pointers if lteq
  206. // lower 32 bits of cnode.pointers if gt
  207. RegXS left = extractLeftPtr(cnode.pointers);
  208. RegXS right = extractRightPtr(cnode.pointers);
  209. RegXS next_ptr;
  210. mpc_select(tio, yield, next_ptr, gt, left, right, 32);
  211. CDPF dpf = tio.cdpf(yield);
  212. size_t &aes_ops = tio.aes_ops();
  213. // F_z: Check if this is last node on path
  214. RegBS F_z = dpf.is_zero(tio, yield, next_ptr, aes_ops);
  215. RegBS F_i;
  216. // F_i: If this was last node on path (F_z) && isNotDummy:
  217. // insert new_node here.
  218. mpc_and(tio, yield, F_i, (isNotDummy), F_z);
  219. isDummy^=F_i;
  220. auto [wptr, direction] = insert(tio, yield, next_ptr, new_node, A, TTL-1, isDummy);
  221. RegXS ret_ptr;
  222. RegBS ret_direction;
  223. // If we insert here (F_i), return the ptr to this node as wptr
  224. // and update direction to the direction taken by compare_keys
  225. run_coroutines(tio, [&tio, &ret_ptr, F_i, wptr, ptr](yield_t &yield)
  226. { mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);},
  227. [&tio, &ret_direction, F_i, direction, gt](yield_t &yield)
  228. //ret_direction = direction + F_i (direction - gt)
  229. { mpc_and(tio, yield, ret_direction, F_i, direction^gt);});
  230. ret_direction^=direction;
  231. return {ret_ptr, ret_direction};
  232. }
  233. /*
  234. The wrapper insert() operation invoked by the main insert call
  235. BST::insert(tio, yield, Node& new_node);
  236. Takes as input the new node (node), the underlying Duoram as a flat (A).
  237. */
  238. void BST::insert(MPCTIO &tio, yield_t &yield, const Node &node, Duoram<Node>::Flat &A) {
  239. bool player0 = tio.player()==0;
  240. // If there are no items in tree. Make this new item the root.
  241. if (num_items==0) {
  242. A[1] = node;
  243. // Set root to a secret sharing of the constant value 1
  244. root.set(1*tio.player());
  245. num_items++;
  246. //printf("num_items == %ld!\n", num_items);
  247. return;
  248. } else {
  249. // Insert node into next free slot in the ORAM
  250. int new_id;
  251. RegXS insert_address;
  252. int TTL = num_items++;
  253. bool insertAtEmptyLocation = (empty_locations.size() > 0);
  254. if(insertAtEmptyLocation) {
  255. insert_address = empty_locations.back();
  256. empty_locations.pop_back();
  257. A[insert_address] = node;
  258. } else {
  259. new_id = 1 + num_items;
  260. A[new_id] = node;
  261. insert_address.set(new_id * tio.player());
  262. }
  263. RegBS isDummy;
  264. //Do a recursive insert
  265. auto [wptr, direction] = insert(tio, yield, root, node, A, TTL, isDummy);
  266. //Complete the insertion by reading wptr and updating its pointers
  267. RegXS pointers = A[wptr].NODE_POINTERS;
  268. RegXS left_ptr = extractLeftPtr(pointers);
  269. RegXS right_ptr = extractRightPtr(pointers);
  270. RegXS new_right_ptr, new_left_ptr;
  271. RegBS not_direction = direction;
  272. if (player0) {
  273. not_direction^=1;
  274. }
  275. run_coroutines(tio,
  276. [&tio, &new_right_ptr, direction, right_ptr, insert_address](yield_t &yield)
  277. { mpc_select(tio, yield, new_right_ptr, direction, right_ptr, insert_address);},
  278. [&tio, &new_left_ptr, not_direction, left_ptr, insert_address](yield_t &yield)
  279. { mpc_select(tio, yield, new_left_ptr, not_direction, left_ptr, insert_address);});
  280. setLeftPtr(pointers, new_left_ptr);
  281. setRightPtr(pointers, new_right_ptr);
  282. A[wptr].NODE_POINTERS = pointers;
  283. }
  284. }
  285. /*
  286. Insert a new node into the BST.
  287. Takes as input the new node (node).
  288. */
  289. void BST::insert(MPCTIO &tio, yield_t &yield, Node &node) {
  290. auto A = oram.flat(tio, yield);
  291. insert(tio, yield, node, A);
  292. /*
  293. // To visualize database and tree after each insert:
  294. auto R = A.reconstruct();
  295. if (tio.player() == 0) {
  296. for(size_t i=0;i<R.size();++i) {
  297. printf("\n%04lx ", i);
  298. R[i].dump();
  299. }
  300. printf("\n");
  301. }
  302. pretty_print(R, 1);
  303. */
  304. }
  305. RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS key, Duoram<Node>::Flat &A,
  306. int TTL, RegBS isDummy, Node *ret_node) {
  307. if(TTL==0) {
  308. // If we found the key, then isDummy will be true
  309. return isDummy;
  310. }
  311. Node cnode = A[ptr];
  312. // Compare key
  313. CDPF cdpf = tio.cdpf(yield);
  314. auto [lt, eq, gt] = cdpf.compare(tio, yield, key - cnode.key, tio.aes_ops());
  315. // Depending on [lteq, gt] select the next ptr/index as
  316. // upper 32 bits of cnode.pointers if lteq
  317. // lower 32 bits of cnode.pointers if gt
  318. RegXS left = extractLeftPtr(cnode.pointers);
  319. RegXS right = extractRightPtr(cnode.pointers);
  320. RegXS next_ptr;
  321. RegBS F_found;
  322. // If we haven't found the key yet, and the lookup matches the current node key,
  323. // then we found the node to return
  324. RegBS isNotDummy = isDummy ^ (!tio.player());
  325. // Note: This logic returns the last matched key and value.
  326. // Returning the first one incurs an additional round.
  327. std::vector<coro_t> coroutines;
  328. coroutines.emplace_back(
  329. [&tio, &next_ptr, gt, left, right](yield_t &yield)
  330. { mpc_select(tio, yield, next_ptr, gt, left, right, 32);});
  331. coroutines.emplace_back(
  332. [&tio, &F_found, isNotDummy, eq](yield_t &yield)
  333. { mpc_and(tio, yield, F_found, isNotDummy, eq);});
  334. coroutines.emplace_back(
  335. [&tio, &ret_node, eq, cnode](yield_t &yield)
  336. { mpc_select(tio, yield, ret_node->key, eq, ret_node->key, cnode.key);});
  337. coroutines.emplace_back(
  338. [&tio, &ret_node, eq, cnode](yield_t &yield)
  339. { mpc_select(tio, yield, ret_node->value, eq, ret_node->value, cnode.value);});
  340. coroutines.emplace_back(
  341. [&tio, &isDummy, eq](yield_t &yield)
  342. { mpc_or(tio, yield, isDummy, isDummy, eq);});
  343. run_coroutines(tio, coroutines);
  344. #ifdef BST_DEBUG
  345. size_t ckey = mpc_reconstruct(tio, yield, cnode.key, 64);
  346. size_t lkey = mpc_reconstruct(tio, yield, key, 64);
  347. bool rec_lt = mpc_reconstruct(tio, yield, lt);
  348. bool rec_eq = mpc_reconstruct(tio, yield, eq);
  349. bool rec_gt = mpc_reconstruct(tio, yield, gt);
  350. bool rec_found = mpc_reconstruct(tio, yield, isDummy);
  351. bool rec_f_found = mpc_reconstruct(tio, yield, F_found);
  352. printf("rec_lt = %d, rec_eq = %d, rec_gt = %d\n", rec_lt, rec_eq, rec_gt);
  353. printf("rec_isDummy/found = %d ,rec_f_found = %d, cnode.key = %ld, lookup key = %ld\n", rec_found, rec_f_found, ckey, lkey);
  354. #endif
  355. RegBS found = lookup(tio, yield, next_ptr, key, A, TTL-1, isDummy, ret_node);
  356. return found;
  357. }
  358. RegBS BST::lookup(MPCTIO &tio, yield_t &yield, RegAS key, Node *ret_node) {
  359. auto A = oram.flat(tio, yield);
  360. RegBS isDummy;
  361. RegBS found = lookup(tio, yield, root, key, A, num_items, isDummy, ret_node);
  362. /*
  363. // To visualize database and tree after each lookup:
  364. auto R = A.reconstruct();
  365. if (tio.player() == 0) {
  366. for(size_t i=0;i<R.size();++i) {
  367. printf("\n%04lx ", i);
  368. R[i].dump();
  369. }
  370. printf("\n");
  371. }
  372. pretty_print(R, 1);
  373. */
  374. return found;
  375. }
  376. /*
  377. The recursive del() call, invoked by the wrapper del() function.
  378. Takes as input the pointer to the current node in tree traversal (ptr),
  379. the key to be deleted (del_key), the underlying Duoram as a
  380. flat (A), Flags af (already found) and fs (find successor), the
  381. Time-To_live TTL. Finally, a return structure ret_struct that tracks
  382. the location of the successor node and the node to delete to perform
  383. the actual deletion after the recursive traversal; which is required in
  384. the case of a deletion that requires a successor swap (,i.e., when node
  385. to delete has both children).
  386. Returns success/fail bit.
  387. */
  388. bool BST::del(MPCTIO &tio, yield_t &yield, RegXS ptr, RegAS del_key,
  389. Duoram<Node>::Flat &A, RegBS af, RegBS fs, int TTL,
  390. del_return &ret_struct) {
  391. bool player0 = tio.player()==0;
  392. //printf("TTL = %d\n", TTL);
  393. if(TTL==0) {
  394. //Reconstruct and return af
  395. bool success = reconstruct_RegBS(tio, yield, af);
  396. //printf("Reconstructed flag = %d\n", success);
  397. if(player0) {
  398. ret_struct.F_r^=1;
  399. }
  400. return success;
  401. } else {
  402. // s1: shares of 1 bit, s0: shares of 0 bit
  403. RegBS s1, s0;
  404. s1.set(tio.player()==1);
  405. Node node = A[ptr];
  406. RegXS left = extractLeftPtr(node.pointers);
  407. RegXS right = extractRightPtr(node.pointers);
  408. CDPF cdpf = tio.cdpf(yield);
  409. size_t &aes_ops = tio.aes_ops();
  410. RegBS l0, r0, lt, eq, gt;
  411. // l0: Is left child 0
  412. // r0: Is right child 0
  413. run_coroutines(tio,
  414. [&tio, &l0, left, &aes_ops, &cdpf](yield_t &yield)
  415. { l0 = cdpf.is_zero(tio, yield, left, aes_ops);},
  416. [&tio, &r0, right, &aes_ops, &cdpf](yield_t &yield)
  417. { r0 = cdpf.is_zero(tio, yield, right, aes_ops);},
  418. [&tio, &lt, &eq, &gt, del_key, node, &cdpf](yield_t &yield)
  419. // Compare Key
  420. { auto [a, b, c] = cdpf.compare(tio, yield, del_key - node.key, tio.aes_ops());
  421. lt = a; eq = b; gt = c;});
  422. /*
  423. // Reconstruct and Debug Block 0
  424. bool lt_rec, eq_rec, gt_rec;
  425. lt_rec = mpc_reconstruct(tio, yield, lt);
  426. eq_rec = mpc_reconstruct(tio, yield, eq);
  427. gt_rec = mpc_reconstruct(tio, yield, gt);
  428. size_t del_key_rec, node_key_rec;
  429. del_key_rec = mpc_reconstruct(tio, yield, del_key, 64);
  430. node_key_rec = mpc_reconstruct(tio, yield, node.key, 64);
  431. printf("node.key = %ld, del_key= %ld\n", node_key_rec, del_key_rec);
  432. printf("cdpf.compare results: lt = %d, eq = %d, gt = %d\n", lt_rec, eq_rec, gt_rec);
  433. */
  434. // c is the direction bit for next_ptr
  435. // (c=0: go left or c=1: go right)
  436. RegBS c = gt;
  437. // lf = local found. We found the key to delete in this level.
  438. RegBS lf = eq;
  439. // F_{X}: Flags that indicate the number of children this node has
  440. // F_0: no children, F_1: one child, F_2: both children
  441. RegBS F_0, F_1, F_2;
  442. // F_1 = l0 \xor r0
  443. F_1 = l0 ^ r0;
  444. // We set next ptr based on c, but we need to handle three
  445. // edge cases where we do not go by just the comparison result
  446. RegXS next_ptr;
  447. RegBS c_prime;
  448. // Case 1: found the node here (lf): we traverse down the lone child path.
  449. // or we are finding successor (fs) and there is no left child.
  450. RegBS F_c1, F_c2, F_c3, F_c4;
  451. // Case 1: lf & F_1
  452. run_coroutines(tio,
  453. [&tio, &F_c1, lf, F_1](yield_t &yield)
  454. { mpc_and(tio, yield, F_c1, lf, F_1);},
  455. [&tio, &F_0, l0, r0] (yield_t &yield)
  456. // F_0 = l0 & r0
  457. { mpc_and(tio, yield, F_0, l0, r0);});
  458. // F_2 = !(F_0 + F_1) (Only 1 of F_0, F_1, and F_2 can be true)
  459. F_2 = F_0 ^ F_1;
  460. if(player0)
  461. F_2^=1;
  462. /*
  463. // Reconstruct and Debug Block 1
  464. bool F_0_rec, F_1_rec, F_2_rec, c_prime_rec;
  465. F_0_rec = mpc_reconstruct(tio, yield, F_0);
  466. F_1_rec = mpc_reconstruct(tio, yield, F_1);
  467. F_2_rec = mpc_reconstruct(tio, yield, F_2);
  468. c_prime_rec = mpc_reconstruct(tio, yield, c_prime);
  469. printf("F_0 = %d, F_1 = %d, F_2 = %d, c_prime = %d\n", F_0_rec, F_1_rec, F_2_rec, c_prime_rec);
  470. */
  471. run_coroutines(tio,
  472. [&tio, &c_prime, F_c1, c, l0](yield_t &yield)
  473. // Set c_prime for Case 1
  474. { mpc_select(tio, yield, c_prime, F_c1, c, l0);},
  475. [&tio, &F_c2, lf, F_2](yield_t &yield)
  476. // Case 2: found the node here (lf) and node has both children (F_2)
  477. // In find successor case, so find inorder successor
  478. // (Go right and then find leftmost child.)
  479. { mpc_and(tio, yield, F_c2, lf, F_2);});
  480. /*
  481. // Reconstruct and Debug Block 2
  482. bool F_c2_rec, s1_rec;
  483. F_c2_rec = mpc_reconstruct(tio, yield, F_c2);
  484. s1_rec = mpc_reconstruct(tio, yield, s1);
  485. c_prime_rec = mpc_reconstruct(tio, yield, c_prime);
  486. printf("c_prime = %d, F_c2 = %d, s1 = %d\n", c_prime_rec, F_c2_rec, s1_rec);
  487. */
  488. run_coroutines(tio,
  489. [&tio, &c_prime, F_c2, s1](yield_t &yield)
  490. { mpc_select(tio, yield, c_prime, F_c2, c_prime, s1);},
  491. [&tio, &F_c3, fs, F_2](yield_t &yield)
  492. // Case 3: finding successor (fs) and node has both children (F_2)
  493. // Go left.
  494. { mpc_and(tio, yield, F_c3, fs, F_2);});
  495. run_coroutines(tio,
  496. [&tio, &c_prime, F_c3, s0](yield_t &yield)
  497. { mpc_select(tio, yield, c_prime, F_c3, c_prime, s0);},
  498. // Case 4: finding successor (fs) and node has no more left children (l0)
  499. // This is the successor node then.
  500. // Go right (since no more left)
  501. [&tio, &F_c4, fs, l0] (yield_t &yield)
  502. { mpc_and(tio, yield, F_c4, fs, l0);});
  503. mpc_select(tio, yield, c_prime, F_c4, c_prime, l0);
  504. RegBS af_prime, fs_prime;
  505. run_coroutines(tio,
  506. [&tio, &next_ptr, c_prime, left, right](yield_t &yield)
  507. // Set next_ptr
  508. { mpc_select(tio, yield, next_ptr, c_prime, left, right, 32);},
  509. [&tio, &af_prime, af, lf](yield_t &yield)
  510. { mpc_or(tio, yield, af_prime, af, lf);},
  511. [&tio, &fs_prime, fs, F_c2](yield_t &yield)
  512. // If in Case 2, set fs. We are now finding successor
  513. { mpc_or(tio, yield, fs_prime, fs, F_c2);});
  514. // If in Case 4. Successor found here already. Toggle fs off
  515. fs_prime=fs_prime^F_c4;
  516. bool key_found = del(tio, yield, next_ptr, del_key, A, af_prime, fs_prime, TTL-1, ret_struct);
  517. // If we didn't find the key, we can end here.
  518. if(!key_found) {
  519. return 0;
  520. }
  521. //printf("TTL = %d\n", TTL);
  522. RegBS F_rs_right, F_rs_left, not_c_prime=c_prime;
  523. if(player0) {
  524. not_c_prime^=1;
  525. }
  526. // Flag here should be direction (c_prime) and F_r i.e. we need to swap return ptr in,
  527. // F_r needs to be returned in ret_struct
  528. run_coroutines(tio,
  529. [&tio, &F_rs_right, c_prime, ret_struct](yield_t &yield)
  530. { mpc_and(tio, yield, F_rs_right, c_prime, ret_struct.F_r);},
  531. [&tio, &F_rs_left, not_c_prime, left, ret_struct](yield_t &yield)
  532. { mpc_and(tio, yield, F_rs_left, not_c_prime, ret_struct.F_r);});
  533. run_coroutines(tio,
  534. [&tio, &right, F_rs_right, ret_struct](yield_t &yield)
  535. { mpc_select(tio, yield, right, F_rs_right, right, ret_struct.ret_ptr);},
  536. [&tio, &left, F_rs_left, ret_struct](yield_t &yield)
  537. { mpc_select(tio, yield, left, F_rs_left, left, ret_struct.ret_ptr);});
  538. /*
  539. // Reconstruct and Debug Block 3
  540. bool F_rs_rec, F_ls_rec;
  541. size_t ret_ptr_rec;
  542. F_rs_rec = mpc_reconstruct(tio, yield, F_rs);
  543. F_ls_rec = mpc_reconstruct(tio, yield, F_rs);
  544. ret_ptr_rec = mpc_reconstruct(tio, yield, ret_struct.ret_ptr, 64);
  545. printf("F_rs_rec = %d, F_ls_rec = %d, ret_ptr_rec = %ld\n", F_rs_rec, F_ls_rec, ret_ptr_rec);
  546. */
  547. RegXS new_ptr;
  548. setLeftPtr(new_ptr, left);
  549. setRightPtr(new_ptr, right);
  550. A[ptr].NODE_POINTERS = new_ptr;
  551. // Update the return structure
  552. RegBS F_nd, F_ns, F_r, not_af = af, not_F_2 = F_2;
  553. if(player0) {
  554. not_af^=1;
  555. not_F_2^=1;
  556. }
  557. // F_ns = fs & l0
  558. // Finding successor flag & no more left child = F_c4
  559. F_ns = F_c4;
  560. run_coroutines(tio,
  561. [&tio, &ret_struct, F_c2](yield_t &yield)
  562. { mpc_or(tio, yield, ret_struct.F_ss, ret_struct.F_ss, F_c2);},
  563. [&tio, &F_nd, lf, not_af](yield_t &yield)
  564. { mpc_and(tio, yield, F_nd, lf, not_af);});
  565. // F_r = F_d.(!F_2)
  566. // If we have to delete here, and it doesn't have two children we have to
  567. // update child pointer in parent with the returned pointer
  568. mpc_and(tio, yield, F_r, F_nd, not_F_2);
  569. mpc_or(tio, yield, F_r, F_r, F_ns);
  570. ret_struct.F_r = F_r;
  571. run_coroutines(tio,
  572. [&tio, &ret_struct, F_nd, ptr](yield_t &yield)
  573. { mpc_select(tio, yield, ret_struct.N_d, F_nd, ret_struct.N_d, ptr);},
  574. [&tio, &ret_struct, F_ns, ptr](yield_t &yield)
  575. { mpc_select(tio, yield, ret_struct.N_s, F_ns, ret_struct.N_s, ptr);},
  576. [&tio, &ret_struct, F_r, ptr](yield_t &yield)
  577. { mpc_select(tio, yield, ret_struct.ret_ptr, F_r, ptr, ret_struct.ret_ptr);});
  578. //We don't empty the key and value of the node with del_key in the ORAM
  579. return 1;
  580. }
  581. }
  582. /*
  583. The main del() function.
  584. Trying to delete an item that does not exist in the tree will result in
  585. an explicit (non-oblivious) failure.
  586. Takes as input the key to delete (del_key).
  587. Returns success/fail bit.
  588. */
  589. bool BST::del(MPCTIO &tio, yield_t &yield, RegAS del_key) {
  590. if(num_items==0)
  591. return 0;
  592. if(num_items==1) {
  593. //Delete root
  594. auto A = oram.flat(tio, yield);
  595. Node zero;
  596. empty_locations.emplace_back(root);
  597. A[root] = zero;
  598. num_items--;
  599. return 1;
  600. } else {
  601. int TTL = num_items;
  602. // Flags for already found (af) item to delete and find successor (fs)
  603. // if this deletion requires a successor swap
  604. RegBS af;
  605. RegBS fs;
  606. del_return ret_struct;
  607. auto A = oram.flat(tio, yield);
  608. int success = del(tio, yield, root, del_key, A, af, fs, TTL, ret_struct);
  609. printf ("Success = %d\n", success);
  610. if(!success){
  611. return 0;
  612. }
  613. else{
  614. num_items--;
  615. /*
  616. printf("In delete's swap portion\n");
  617. Node del_node = A.reconstruct(A[ret_struct.N_d]);
  618. Node suc_node = A.reconstruct(A[ret_struct.N_s]);
  619. printf("del_node key = %ld, suc_node key = %ld\n",
  620. del_node.key.ashare, suc_node.key.ashare);
  621. printf("flag_s = %d\n", ret_struct.F_ss.bshare);
  622. */
  623. Node del_node = A[ret_struct.N_d];
  624. Node suc_node = A[ret_struct.N_s];
  625. RegAS zero_as; RegXS zero_xs;
  626. RegXS empty_loc, temp_root = root;
  627. run_coroutines(tio,
  628. [&tio, &temp_root, ret_struct](yield_t &yield)
  629. { mpc_select(tio, yield, temp_root, ret_struct.F_r, temp_root, ret_struct.ret_ptr);},
  630. [&tio, &del_node, ret_struct, suc_node](yield_t &yield)
  631. { mpc_select(tio, yield, del_node.key, ret_struct.F_ss, del_node.key, suc_node.key);},
  632. [&tio, &del_node, ret_struct, suc_node](yield_t & yield)
  633. { mpc_select(tio, yield, del_node.value, ret_struct.F_ss, del_node.value, suc_node.value);},
  634. [&tio, &empty_loc, ret_struct](yield_t &yield)
  635. { mpc_select(tio, yield, empty_loc, ret_struct.F_ss, ret_struct.N_d, ret_struct.N_s);});
  636. root = temp_root;
  637. run_coroutines(tio,
  638. [&tio, &A, ret_struct, del_node](yield_t &yield)
  639. { auto acont = A.context(yield);
  640. acont[ret_struct.N_d].NODE_KEY = del_node.key;},
  641. [&tio, &A, ret_struct, del_node](yield_t &yield)
  642. { auto acont = A.context(yield);
  643. acont[ret_struct.N_d].NODE_VALUE = del_node.value;},
  644. [&tio, &A, ret_struct, zero_as](yield_t &yield)
  645. { auto acont = A.context(yield);
  646. acont[ret_struct.N_s].NODE_KEY = zero_as;},
  647. [&tio, &A, ret_struct, zero_xs](yield_t &yield)
  648. { auto acont = A.context(yield);
  649. acont[ret_struct.N_s].NODE_VALUE = zero_xs;});
  650. //Add deleted (empty) location into the empty_locations vector for reuse in next insert()
  651. empty_locations.emplace_back(empty_loc);
  652. }
  653. return 1;
  654. }
  655. }
  656. // Now we use the node in various ways. This function is called by
  657. // online.cpp.
  658. void bst(MPCIO &mpcio,
  659. const PRACOptions &opts, char **args)
  660. {
  661. nbits_t depth=4;
  662. if (*args) {
  663. depth = atoi(*args);
  664. ++args;
  665. }
  666. size_t items = (size_t(1)<<depth)-1;
  667. if (*args) {
  668. items = atoi(*args);
  669. ++args;
  670. }
  671. MPCTIO tio(mpcio, 0, opts.num_threads);
  672. run_coroutines(tio, [&tio, depth, items] (yield_t &yield) {
  673. size_t size = size_t(1)<<depth;
  674. BST tree(tio.player(), size);
  675. int insert_array[] = {10, 10, 13, 11, 14, 8, 15, 20, 17, 19, 7, 12};
  676. //int insert_array[] = {1, 2, 3, 4, 5, 6};
  677. size_t insert_array_size = sizeof(insert_array)/sizeof(int);
  678. Node node;
  679. for(size_t i = 0; i<insert_array_size; i++) {
  680. randomize_node(node);
  681. node.key.set(insert_array[i] * tio.player());
  682. tree.insert(tio, yield, node);
  683. }
  684. tree.pretty_print(tio, yield);
  685. RegAS del_key;
  686. printf("\n\nDelete %x\n", 20);
  687. del_key.set(20 * tio.player());
  688. tree.del(tio, yield, del_key);
  689. tree.pretty_print(tio, yield);
  690. tree.check_bst(tio, yield);
  691. printf("\n\nDelete %x\n", 10);
  692. del_key.set(10 * tio.player());
  693. tree.del(tio, yield, del_key);
  694. tree.pretty_print(tio, yield);
  695. tree.check_bst(tio, yield);
  696. /*
  697. printf("\n\nDelete %x\n", 8);
  698. del_key.set(8 * tio.player());
  699. tree.del(tio, yield, del_key);
  700. tree.pretty_print(tio, yield);
  701. tree.check_bst(tio, yield);
  702. */
  703. printf("\n\nDelete %x\n", 7);
  704. del_key.set(7 * tio.player());
  705. tree.del(tio, yield, del_key);
  706. tree.pretty_print(tio, yield);
  707. tree.check_bst(tio, yield);
  708. printf("\n\nDelete %x\n", 17);
  709. del_key.set(17 * tio.player());
  710. tree.del(tio, yield, del_key);
  711. tree.pretty_print(tio, yield);
  712. tree.check_bst(tio, yield);
  713. printf("\n\nDelete %x\n", 15);
  714. del_key.set(15 * tio.player());
  715. tree.del(tio, yield, del_key);
  716. tree.pretty_print(tio, yield);
  717. tree.check_bst(tio, yield);
  718. printf("\n\nDelete %x\n", 5);
  719. del_key.set(5 * tio.player());
  720. tree.del(tio, yield, del_key);
  721. tree.pretty_print(tio, yield);
  722. tree.check_bst(tio, yield);
  723. printf("\n\nInsert %x\n", 14);
  724. randomize_node(node);
  725. node.key.set(14 * tio.player());
  726. tree.insert(tio, yield, node);
  727. tree.pretty_print(tio, yield);
  728. tree.check_bst(tio, yield);
  729. printf("\n\nLookup %x\n", 8);
  730. randomize_node(node);
  731. RegAS lookup_key;
  732. RegBS found;
  733. bool rec_found;
  734. lookup_key.set(8 * tio.player());
  735. found = tree.lookup(tio, yield, lookup_key, &node);
  736. rec_found = mpc_reconstruct(tio, yield, found);
  737. tree.pretty_print(tio, yield);
  738. if(tio.player()!=2) {
  739. if(rec_found) {
  740. printf("Lookup Success\n");
  741. size_t value = mpc_reconstruct(tio, yield, node.value, 64);
  742. printf("value = %lx\n", value);
  743. } else {
  744. printf("Lookup Failed\n");
  745. }
  746. }
  747. printf("\n\nLookup %x\n", 63);
  748. randomize_node(node);
  749. lookup_key.set(63 * tio.player());
  750. found = tree.lookup(tio, yield, lookup_key, &node);
  751. rec_found = mpc_reconstruct(tio, yield, found);
  752. //rec_found = reconstruct_RegBS(tio, yield, found);
  753. tree.pretty_print(tio, yield);
  754. if(tio.player()!=2) {
  755. if(rec_found) {
  756. printf("Lookup Success\n");
  757. size_t value = mpc_reconstruct(tio, yield, node.value, 64);
  758. printf("value = %lx\n", value);
  759. } else {
  760. printf("Lookup Failed\n");
  761. }
  762. }
  763. });
  764. }