|
@@ -0,0 +1,330 @@
|
|
|
+#include <functional>
|
|
|
+
|
|
|
+#include "types.hpp"
|
|
|
+#include "duoram.hpp"
|
|
|
+#include "node.hpp"
|
|
|
+
|
|
|
+// This file demonstrates how to implement custom ORAM wide cell types.
|
|
|
+// Such types can be structures of arbitrary numbers of RegAS and RegXS
|
|
|
+// fields. The example here imagines a cell of a binary search tree,
|
|
|
+// where you would want the key to be additively shared (so that you can
|
|
|
+// easily do comparisons), the pointers field to be XOR shared (so that
|
|
|
+// you can easily do bit operations to pack two pointers and maybe some
|
|
|
+// tree balancing information into one field) and the value doesn't
|
|
|
+// really matter, but XOR shared is usually slightly more efficient.
|
|
|
+
|
|
|
+struct Node {
|
|
|
+ RegAS key;
|
|
|
+ RegXS pointers;
|
|
|
+ RegXS value;
|
|
|
+
|
|
|
+// Field-access macros so we can write A[i].NODE_KEY instead of
|
|
|
+// A[i].field(&Node::key)
|
|
|
+
|
|
|
+#define NODE_KEY field(&Node::key)
|
|
|
+#define NODE_POINTERS field(&Node::pointers)
|
|
|
+#define NODE_VALUE field(&Node::value)
|
|
|
+
|
|
|
+ // For debugging and checking answers
|
|
|
+ void dump() const {
|
|
|
+ printf("[%016lx %016lx %016lx]", key.share(), pointers.share(),
|
|
|
+ value.share());
|
|
|
+ }
|
|
|
+
|
|
|
+ // You'll need to be able to create a random element, and do the
|
|
|
+ // operations +=, +, -=, - (binary and unary). Note that for
|
|
|
+ // XOR-shared fields, + and - are both really XOR.
|
|
|
+
|
|
|
+
|
|
|
+ inline void zeronode() {
|
|
|
+ key.set(0);
|
|
|
+ pointers.set(0);
|
|
|
+ value.set(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void newnode() {
|
|
|
+ key.randomize(8);
|
|
|
+ pointers.set(0);
|
|
|
+ value.randomize();
|
|
|
+ }
|
|
|
+
|
|
|
+ inline void randomize() {
|
|
|
+ key.randomize();
|
|
|
+ pointers.randomize();
|
|
|
+ value.randomize();
|
|
|
+ }
|
|
|
+
|
|
|
+ inline Node &operator+=(const Node &rhs) {
|
|
|
+ this->key += rhs.key;
|
|
|
+ this->pointers += rhs.pointers;
|
|
|
+ this->value += rhs.value;
|
|
|
+ return *this;
|
|
|
+ }
|
|
|
+
|
|
|
+ inline Node operator+(const Node &rhs) const {
|
|
|
+ Node res = *this;
|
|
|
+ res += rhs;
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ inline Node &operator-=(const Node &rhs) {
|
|
|
+ this->key -= rhs.key;
|
|
|
+ this->pointers -= rhs.pointers;
|
|
|
+ this->value -= rhs.value;
|
|
|
+ return *this;
|
|
|
+ }
|
|
|
+
|
|
|
+ inline Node operator-(const Node &rhs) const {
|
|
|
+ Node res = *this;
|
|
|
+ res -= rhs;
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ inline Node operator-() const {
|
|
|
+ Node res;
|
|
|
+ res.key = -this->key;
|
|
|
+ res.pointers = -this->pointers;
|
|
|
+ res.value = -this->value;
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Multiply each field by the local share of the corresponding field
|
|
|
+ // in the argument
|
|
|
+ inline Node mulshare(const Node &rhs) const {
|
|
|
+ Node res = *this;
|
|
|
+ res.key.mulshareeq(rhs.key);
|
|
|
+ res.pointers.mulshareeq(rhs.pointers);
|
|
|
+ res.value.mulshareeq(rhs.value);
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+
|
|
|
+ // You need a method to turn a leaf node of a DPF into a share of a
|
|
|
+ // unit element of your type. Typically set each RegAS to
|
|
|
+ // dpf.unit_as(leaf) and each RegXS or RegBS to dpf.unit_bs(leaf).
|
|
|
+ // Note that RegXS will extend a RegBS of 1 to the all-1s word, not
|
|
|
+ // the word with value 1. This is used for ORAM reads, where the
|
|
|
+ // same DPF is used for all the fields.
|
|
|
+ inline void unit(const RDPF &dpf, DPFnode leaf) {
|
|
|
+ key = dpf.unit_as(leaf);
|
|
|
+ pointers = dpf.unit_bs(leaf);
|
|
|
+ value = dpf.unit_bs(leaf);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Perform an update on each of the fields, using field-specific
|
|
|
+ // MemRefs constructed from the Shape shape and the index idx
|
|
|
+ template <typename Sh, typename U>
|
|
|
+ inline static void update(Sh &shape, yield_t ­ield, U idx,
|
|
|
+ const Node &M) {
|
|
|
+ run_coroutines(shyield,
|
|
|
+ [&shape, &idx, &M] (yield_t &yield) {
|
|
|
+ Sh Sh_coro = shape.context(yield);
|
|
|
+ Sh_coro[idx].NODE_KEY += M.key;
|
|
|
+ },
|
|
|
+ [&shape, &idx, &M] (yield_t &yield) {
|
|
|
+ Sh Sh_coro = shape.context(yield);
|
|
|
+ Sh_coro[idx].NODE_POINTERS += M.pointers;
|
|
|
+ },
|
|
|
+ [&shape, &idx, &M] (yield_t &yield) {
|
|
|
+ Sh Sh_coro = shape.context(yield);
|
|
|
+ Sh_coro[idx].NODE_VALUE += M.value;
|
|
|
+ });
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+// I/O operations (for sending over the network)
|
|
|
+
|
|
|
+template <typename T>
|
|
|
+T& operator>>(T& is, Node &x)
|
|
|
+{
|
|
|
+ is >> x.key >> x.pointers >> x.value;
|
|
|
+ return is;
|
|
|
+}
|
|
|
+
|
|
|
+template <typename T>
|
|
|
+T& operator<<(T& os, const Node &x)
|
|
|
+{
|
|
|
+ os << x.key << x.pointers << x.value;
|
|
|
+ return os;
|
|
|
+}
|
|
|
+
|
|
|
+// This macro will define I/O on tuples of two or three of the cell type
|
|
|
+
|
|
|
+DEFAULT_TUPLE_IO(Node)
|
|
|
+
|
|
|
+int num_items = 0;
|
|
|
+RegAS root;
|
|
|
+
|
|
|
+std::tuple<RegBS, RegBS> compare_keys(Node n1, Node n2, MPCTIO tio, yield_t &yield) {
|
|
|
+ CDPF cdpf = tio.cdpf(yield);
|
|
|
+ auto [lt, eq, gt] = cdpf.compare(tio, yield, n2.key - n1.key, tio.aes_ops());
|
|
|
+ RegBS lteq = lt^eq;
|
|
|
+ return {lteq, gt};
|
|
|
+}
|
|
|
+
|
|
|
+RegBS check_key_zero(Node n1, MPCTIO tio, yield_t &yield) {
|
|
|
+ CDPF cdpf = tio.cdpf(yield);
|
|
|
+ RegAS zero;
|
|
|
+ auto [lt, eq, gt] = cdpf.compare(tio, yield, n1.key - zero, tio.aes_ops());
|
|
|
+ return eq;
|
|
|
+}
|
|
|
+
|
|
|
+RegBS check_ptr_zero(RegXS ptr, MPCTIO tio, yield_t &yield) {
|
|
|
+ CDPF cdpf = tio.cdpf(yield);
|
|
|
+ RegAS ptr_as;
|
|
|
+ mpc_xs_to_as(tio, yield, ptr_as, ptr);
|
|
|
+ RegAS zero;
|
|
|
+ auto [lt, eq, gt] = cdpf.compare(tio, yield, ptr_as - zero, tio.aes_ops());
|
|
|
+ return eq;
|
|
|
+}
|
|
|
+
|
|
|
+// Assuming pointer of 64 bits is split as:
|
|
|
+// - 32 bits Left ptr
|
|
|
+// - 32 bits Right ptr
|
|
|
+// < Left, Right>
|
|
|
+
|
|
|
+inline RegXS extractLeftPtr(RegXS pointer){
|
|
|
+ return ((pointer&(0xFFFFFFFF00000000))>>32);
|
|
|
+}
|
|
|
+
|
|
|
+inline RegXS extractRightPtr(RegXS pointer){
|
|
|
+ return (pointer&(0x00000000FFFFFFFF));
|
|
|
+}
|
|
|
+
|
|
|
+inline void setLeftPtr(RegXS &pointer, RegXS new_ptr){
|
|
|
+ pointer&=(0x00000000FFFFFFFF);
|
|
|
+ pointer+=(new_ptr<<32);
|
|
|
+}
|
|
|
+
|
|
|
+inline void setRightPtr(RegXS &pointer, RegXS new_ptr){
|
|
|
+ pointer&=(0xFFFFFFFF00000000);
|
|
|
+ pointer+=(new_ptr);
|
|
|
+}
|
|
|
+
|
|
|
+std::tuple<RegAS, RegBS> insert(RegAS &ptr, Node &new_node, auto A, int TTL, RegBS isDummy, MPCTIO &tio, yield_t &yield) {
|
|
|
+ if(TTL==0) {
|
|
|
+ RegBS zero;
|
|
|
+ return {ptr, zero};
|
|
|
+ }
|
|
|
+
|
|
|
+ Node cnode = A[ptr];
|
|
|
+ //Compare key
|
|
|
+ auto [lteq, gt] = compare_keys(cnode, new_node, tio, yield);
|
|
|
+
|
|
|
+ // Depending on [lteq, gt] select the next ptr/index as
|
|
|
+ // upper 32 bits of cnode.pointers if lteq
|
|
|
+ // lower 32 bits of cnode.pointers if gt
|
|
|
+ RegXS left = extractLeftPtr(cnode.pointers);
|
|
|
+ RegXS right = extractRightPtr(cnode.pointers);
|
|
|
+
|
|
|
+ RegXS next_ptr;
|
|
|
+ mpc_select(tio, yield, next_ptr, gt, left, right, 32);
|
|
|
+
|
|
|
+ RegBS F_z = check_ptr_zero(next_ptr, tio, yield);
|
|
|
+ RegBS F_i;
|
|
|
+
|
|
|
+ if(tio.player()==0) {
|
|
|
+ isDummy^=1;
|
|
|
+ }
|
|
|
+ mpc_and(tio, yield, F_i, (isDummy), F_z);
|
|
|
+ if(tio.player()==0) {
|
|
|
+ isDummy^=1;
|
|
|
+ }
|
|
|
+
|
|
|
+ RegAS next_ptr_as;
|
|
|
+ mpc_xs_to_as(tio, yield, next_ptr_as, next_ptr, 32);
|
|
|
+ isDummy^=F_i;
|
|
|
+ auto [wptr, direction] = insert(next_ptr_as, new_node, A, TTL-1, isDummy, tio, yield);
|
|
|
+
|
|
|
+ RegAS ret_ptr;
|
|
|
+ RegBS ret_direction;
|
|
|
+ mpc_select(tio, yield, ret_ptr, F_i, wptr, ptr);
|
|
|
+ //ret_direction = direction + F_p(direction - gt)
|
|
|
+ mpc_and(tio, yield, ret_direction, F_i, direction^gt);
|
|
|
+ ret_direction^=direction;
|
|
|
+
|
|
|
+ return {ret_ptr, ret_direction};
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+// Insert(root, ptr, key, TTL, isDummy) -> (new_ptr, wptr, wnode, f_p)
|
|
|
+void insert(RegAS &root, Node &node, auto A, MPCTIO &tio, yield_t &yield) {
|
|
|
+ if(num_items==0) {
|
|
|
+ Node zero;
|
|
|
+ zero.zeronode();
|
|
|
+ A[0] = zero;
|
|
|
+ A[1] = node;
|
|
|
+ (root).set(1*tio.player());
|
|
|
+ num_items++;
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ // Insert node into next free slot in the ORAM
|
|
|
+ int new_id = 1 + num_items;
|
|
|
+ int TTL = num_items++;
|
|
|
+ A[new_id] = node;
|
|
|
+ RegXS new_addr;
|
|
|
+ new_addr.set(new_id * tio.player());
|
|
|
+ RegBS isDummy;
|
|
|
+ isDummy.set(0);
|
|
|
+
|
|
|
+ //Do a recursive insert
|
|
|
+ auto [wptr, direction] = insert(root, node, A, TTL, isDummy, tio, yield);
|
|
|
+
|
|
|
+ //Complete the insertion by reading wptr and updating its pointers
|
|
|
+ RegXS pointers = A[wptr].NODE_POINTERS;
|
|
|
+ RegXS left_ptr = extractLeftPtr(pointers);
|
|
|
+ RegXS right_ptr = extractRightPtr(pointers);
|
|
|
+ RegXS new_right_ptr, new_left_ptr;
|
|
|
+ mpc_select(tio, yield, new_right_ptr, direction, right_ptr, new_addr);
|
|
|
+ if(tio.player()==0) {
|
|
|
+ direction^=1;
|
|
|
+ }
|
|
|
+ mpc_select(tio, yield, new_left_ptr, direction, left_ptr, new_addr);
|
|
|
+ setLeftPtr(pointers, new_left_ptr);
|
|
|
+ setRightPtr(pointers, new_right_ptr);
|
|
|
+ A[wptr].NODE_POINTERS = pointers;
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+// Now we use the cell in various ways. This function is called by
|
|
|
+// online.cpp.
|
|
|
+
|
|
|
+void bst(MPCIO &mpcio,
|
|
|
+ const PRACOptions &opts, char **args)
|
|
|
+{
|
|
|
+ nbits_t depth=5;
|
|
|
+
|
|
|
+ if (*args) {
|
|
|
+ depth = atoi(*args);
|
|
|
+ ++args;
|
|
|
+ }
|
|
|
+
|
|
|
+ MPCTIO tio(mpcio, 0, opts.num_threads);
|
|
|
+ run_coroutines(tio, [&tio, depth] (yield_t &yield) {
|
|
|
+ size_t size = size_t(1)<<depth;
|
|
|
+ Duoram<Node> oram(tio.player(), size);
|
|
|
+ auto A = oram.flat(tio, yield);
|
|
|
+
|
|
|
+ Node c;
|
|
|
+ for(int i = 0; i<30; i++) {
|
|
|
+ c.newnode();
|
|
|
+ insert(root, c, A, tio, yield);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ if (depth < 10) {
|
|
|
+ oram.dump();
|
|
|
+ auto R = A.reconstruct();
|
|
|
+ if (tio.player() == 0) {
|
|
|
+ for(size_t i=0;i<R.size();++i) {
|
|
|
+ printf("\n%04lx ", i);
|
|
|
+ R[i].dump();
|
|
|
+ }
|
|
|
+ printf("\n");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ });
|
|
|
+}
|