Browse Source

ORAM reads and explicit reads and writes for wide data types

Ian Goldberg 1 year ago
parent
commit
f8e2a126cf
9 changed files with 355 additions and 66 deletions
  1. 8 2
      Makefile
  2. 148 0
      baltree.cpp
  3. 10 0
      baltree.hpp
  4. 55 0
      duoram.cpp
  5. 8 0
      duoram.hpp
  6. 40 64
      duoram.tcc
  7. 4 0
      online.cpp
  8. 21 0
      rdpf.tcc
  9. 61 0
      types.hpp

+ 8 - 2
Makefile

@@ -9,7 +9,7 @@ LDLIBS=-lbsd -lboost_system -lboost_context -lboost_chrono -lboost_thread -lpthr
 
 BIN=prac
 SRCS=prac.cpp mpcio.cpp preproc.cpp online.cpp mpcops.cpp rdpf.cpp \
-    cdpf.cpp
+    cdpf.cpp duoram.cpp baltree.cpp
 OBJS=$(SRCS:.cpp=.o)
 ASMS=$(SRCS:.cpp=.s)
 
@@ -39,10 +39,16 @@ preproc.o: options.hpp rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc
 preproc.o: cdpf.hpp cdpf.tcc
 online.o: online.hpp mpcio.hpp types.hpp corotypes.hpp options.hpp mpcops.hpp
 online.o: coroutine.hpp rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp
-online.o: rdpf.tcc duoram.hpp duoram.tcc cdpf.hpp cdpf.tcc
+online.o: rdpf.tcc duoram.hpp duoram.tcc cdpf.hpp cdpf.tcc baltree.hpp
 mpcops.o: mpcops.hpp types.hpp mpcio.hpp corotypes.hpp coroutine.hpp
 mpcops.o: bitutils.hpp
 rdpf.o: rdpf.hpp mpcio.hpp types.hpp corotypes.hpp coroutine.hpp bitutils.hpp
 rdpf.o: dpf.hpp prg.hpp aes.hpp rdpf.tcc mpcops.hpp
 cdpf.o: bitutils.hpp cdpf.hpp mpcio.hpp types.hpp corotypes.hpp coroutine.hpp
 cdpf.o: dpf.hpp prg.hpp aes.hpp cdpf.tcc
+duoram.o: duoram.hpp types.hpp mpcio.hpp corotypes.hpp coroutine.hpp
+duoram.o: duoram.tcc mpcops.hpp cdpf.hpp dpf.hpp prg.hpp bitutils.hpp aes.hpp
+duoram.o: cdpf.tcc rdpf.hpp rdpf.tcc
+baltree.o: types.hpp duoram.hpp mpcio.hpp corotypes.hpp coroutine.hpp
+baltree.o: duoram.tcc mpcops.hpp cdpf.hpp dpf.hpp prg.hpp bitutils.hpp
+baltree.o: aes.hpp cdpf.tcc rdpf.hpp rdpf.tcc baltree.hpp options.hpp

+ 148 - 0
baltree.cpp

@@ -0,0 +1,148 @@
+#include <functional>
+
+#include "types.hpp"
+#include "duoram.hpp"
+#include "baltree.hpp"
+
+struct Cell {
+    RegAS key;
+    RegXS pointers;
+    RegXS value;
+
+    // The width (the number of RegAS and RegXS entries) of this type
+    static const size_t WIDTH = 3;
+
+    void dump() const {
+        printf("[%016lx %016lx %016lx]", key.share(), pointers.share(),
+            value.share());
+    }
+
+    // You'll need to be able to create a random element, and do the
+    // operations +=, +, -=, - (binary and unary).  Note that for
+    // XOR-shared fields, + and - are both really XOR.
+
+    inline void randomize() {
+        key.randomize();
+        pointers.randomize();
+        value.randomize();
+    }
+
+    inline Cell &operator+=(const Cell &rhs) {
+        this->key += rhs.key;
+        this->pointers += rhs.pointers;
+        this->value += rhs.value;
+        return *this;
+    }
+
+    inline Cell operator+(const Cell &rhs) const {
+        Cell res = *this;
+        res += rhs;
+        return res;
+    }
+
+    inline Cell &operator-=(const Cell &rhs) {
+        this->key -= rhs.key;
+        this->pointers -= rhs.pointers;
+        this->value -= rhs.value;
+        return *this;
+    }
+
+    inline Cell operator-(const Cell &rhs) const {
+        Cell res = *this;
+        res -= rhs;
+        return res;
+    }
+
+    inline Cell operator-() const {
+        Cell res;
+        res.key = -this->key;
+        res.pointers = -this->pointers;
+        res.value = -this->value;
+        return res;
+    }
+
+    // Multiply each field by the local share of the corresponding field
+    // in the argument
+    inline Cell mulshare(const Cell &rhs) const {
+        Cell res = *this;
+        res.key.mulshareeq(rhs.key);
+        res.pointers.mulshareeq(rhs.pointers);
+        res.value.mulshareeq(rhs.value);
+        return res;
+    }
+
+    // You need a method to turn a leaf node of a DPF into a share of a
+    // unit element of your type.  Typically set each RegAS to
+    // dpf.unit_as(leaf) and each RegXS or RegBS to dpf.unit_bs(leaf).
+    // Note that RegXS will extend a RegBS of 1 to the all-1s word, not
+    // the word with value 1.  This is used for ORAM reads, where the
+    // same DPF is used for all the fields.
+    inline void unit(const RDPF &dpf, DPFnode leaf) {
+        key = dpf.unit_as(leaf);
+        pointers = dpf.unit_bs(leaf);
+        value = dpf.unit_bs(leaf);
+    }
+
+    // You need a method to turn WIDTH DPFs into an element of your
+    // type, using each DPF's scaled_sum or scaled_xor value as
+    // appropriate for each field.  We need WIDTH of them because
+    // reusing a scaled value from a DPF leaks information.  The dpfs
+    // argument is a function that given 0 <= f < WIDTH, returns a
+    // reference to DPF number f.
+    inline void scaled_value(std::function<const RDPF &(size_t)> dpfs) {
+        key = dpfs(0).scaled_sum;
+        pointers = dpfs(1).scaled_xor;
+        value = dpfs(2).scaled_xor;
+    }
+};
+
+template <typename T>
+T& operator>>(T& is, Cell &x)
+{
+    is >> x.key >> x.pointers >> x.value;
+    return is;
+}
+
+template <typename T>
+T& operator<<(T& os, const Cell &x)
+{
+    os << x.key << x.pointers << x.value;
+    return os;
+}
+
+DEFAULT_TUPLE_IO(Cell)
+
+void baltree (MPCIO &mpcio,
+    const PRACOptions &opts, char **args)
+{
+    nbits_t depth=4;
+
+    if (*args) {
+        depth = atoi(*args);
+        ++args;
+    }
+
+    MPCTIO tio(mpcio, 0, opts.num_threads);
+    run_coroutines(tio, [&tio, depth] (yield_t &yield) {
+        size_t size = size_t(1)<<depth;
+        Duoram<Cell> oram(tio.player(), size);
+        auto A = oram.flat(tio, yield);
+        Cell c;
+        c.key.set(0x0102030405060708);
+        c.pointers.set(0x1112131415161718);
+        c.value.set(0x2122232425262728);
+        A[0] = c;
+        RegAS idx;
+        Cell expl_read_c = A[0];
+        printf("expl_read_c = ");
+        expl_read_c.dump();
+        printf("\n");
+        Cell oram_read_c = A[idx];
+        printf ("oram_read_c = ");
+        oram_read_c.dump();
+        printf("\n");
+
+        printf("\n");
+        oram.dump();
+    });
+}

+ 10 - 0
baltree.hpp

@@ -0,0 +1,10 @@
+#ifndef __BALTREE_HPP__
+#define __BALTREE_HPP__
+
+#include "mpcio.hpp"
+#include "options.hpp"
+
+void baltree (MPCIO &mpcio,
+    const PRACOptions &opts, char **args);
+
+#endif

+ 55 - 0
duoram.cpp

@@ -0,0 +1,55 @@
+#include "duoram.hpp"
+
+// Assuming the memory is already sorted, do an oblivious binary
+// search for the largest index containing the value at most the
+// given one.  (The answer will be 0 if all of the memory elements
+// are greate than the target.) This Flat must be a power of 2 size.
+// Only available for additive shared databases for now.
+template <>
+RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
+{
+    nbits_t depth = this->addr_size;
+    // Start in the middle
+    RegAS index;
+    index.set(this->tio.player() ? 0 : 1<<(depth-1));
+    // Invariant: index points to the first element of the right half of
+    // the remaining possible range
+    while (depth > 0) {
+        // Obliviously read the value there
+        RegAS val = operator[](index);
+        // Compare it to the target
+        CDPF cdpf = tio.cdpf(this->yield);
+        auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
+            val-target, tio.aes_ops());
+        if (depth > 1) {
+            // If val > target, the answer is strictly to the left
+            // and we should subtract 2^{depth-2} from index
+            // If val <= target, the answer is here or to the right
+            // and we should add 2^{depth-2} to index
+            // So we unconditionally subtract 2^{depth-2} from index, and
+            // add (lt+eq)*2^{depth-1}.
+            RegAS uncond;
+            uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
+            RegAS cond;
+            cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
+            RegAS condprod;
+            RegBS le = lt ^ eq;
+            mpc_flagmult(this->tio, this->yield, condprod, le, cond);
+            index -= uncond;
+            index += condprod;
+        } else {
+            // If val > target, the answer is strictly to the left
+            // If val <= target, the answer is here or to the right
+            // so subtract gt from index
+            RegAS cond;
+            cond.set(tio.player() ? 0 : 1);
+            RegAS condprod;
+            mpc_flagmult(this->tio, this->yield, condprod, gt, cond);
+            index -= condprod;
+        }
+        --depth;
+    }
+
+    return index;
+}
+

+ 8 - 0
duoram.hpp

@@ -2,6 +2,8 @@
 #define __DUORAM_HPP__
 
 #include "types.hpp"
+#include "mpcio.hpp"
+#include "coroutine.hpp"
 
 // Implementation of the 3-party protocols described in:
 // Adithya Vadapalli, Ryan Henry, Ian Goldberg, "Duoram: A
@@ -334,6 +336,12 @@ template <typename T> template <typename U>
 class Duoram<T>::Shape::MemRefS : public Duoram<T>::Shape::MemRef {
     U idx;
 
+private:
+    // Oblivious update to a shared index of Duoram memory, only for
+    // T = RegAS or RegXS
+    MemRefS<U> &oram_update(const T& M, const prac_template_true&);
+    MemRefS<U> &oram_update(const T& M, const prac_template_false&);
+
 public:
     MemRefS<U>(Shape &shape, const U &idx) :
         MemRef(shape), idx(idx) {}

+ 40 - 64
duoram.tcc

@@ -2,7 +2,9 @@
 
 #include <stdio.h>
 
+#include "mpcops.hpp"
 #include "cdpf.hpp"
+#include "rdpf.hpp"
 
 // Pass the player number and desired size
 template <typename T>
@@ -24,12 +26,19 @@ void Duoram<T>::dump() const
 {
     for (size_t i=0; i<oram_size; ++i) {
         if (player < 2) {
-            printf("%04lx %016lx %016lx %016lx\n",
-                i, database[i].share(), blind[i].share(),
-                peer_blinded_db[i].share());
+            printf("%04lx ", i);
+            database[i].dump();
+            printf(" ");
+            blind[i].dump();
+            printf(" ");
+            peer_blinded_db[i].dump();
+            printf("\n");
         } else {
-            printf("%04lx %016lx %016lx\n",
-                i, p0_blind[i].share(), p1_blind[i].share());
+            printf("%04lx ", i);
+            p0_blind[i].dump();
+            printf(" ");
+            p1_blind[i].dump();
+            printf("\n");
         }
     }
     printf("\n");
@@ -227,59 +236,6 @@ void Duoram<T>::Flat::butterfly(address_t start, nbits_t depth, bool dir)
         });
 }
 
-// Assuming the memory is already sorted, do an oblivious binary
-// search for the largest index containing the value at most the
-// given one.  (The answer will be 0 if all of the memory elements
-// are greate than the target.) This Flat must be a power of 2 size.
-// Only available for additive shared databases for now.
-template <>
-RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
-{
-    nbits_t depth = this->addr_size;
-    // Start in the middle
-    RegAS index;
-    index.set(this->tio.player() ? 0 : 1<<(depth-1));
-    // Invariant: index points to the first element of the right half of
-    // the remaining possible range
-    while (depth > 0) {
-        // Obliviously read the value there
-        RegAS val = operator[](index);
-        // Compare it to the target
-        CDPF cdpf = tio.cdpf(this->yield);
-        auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
-            val-target, tio.aes_ops());
-        if (depth > 1) {
-            // If val > target, the answer is strictly to the left
-            // and we should subtract 2^{depth-2} from index
-            // If val <= target, the answer is here or to the right
-            // and we should add 2^{depth-2} to index
-            // So we unconditionally subtract 2^{depth-2} from index, and
-            // add (lt+eq)*2^{depth-1}.
-            RegAS uncond;
-            uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
-            RegAS cond;
-            cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
-            RegAS condprod;
-            RegBS le = lt ^ eq;
-            mpc_flagmult(this->tio, this->yield, condprod, le, cond);
-            index -= uncond;
-            index += condprod;
-        } else {
-            // If val > target, the answer is strictly to the left
-            // If val <= target, the answer is here or to the right
-            // so subtract gt from index
-            RegAS cond;
-            cond.set(tio.player() ? 0 : 1);
-            RegAS condprod;
-            mpc_flagmult(this->tio, this->yield, condprod, gt, cond);
-            index -= condprod;
-        }
-        --depth;
-    }
-
-    return index;
-}
-
 // Helper functions to specialize the read and update operations for
 // RegAS and RegXS shared indices
 template <typename U>
@@ -340,12 +296,12 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
         T init;
         res = pe.reduce(init, [&dp, &shape] (int thread_num, address_t i,
                 const RDPFPair::node &leaf) {
-            // The values from the two DPFs
+            // The values from the two DPFs, which will each be of type T
             auto [V0, V1] = dp.unit<T>(leaf);
             // References to the appropriate cells in our database, our
             // blind, and our copy of the peer's blinded database
             auto [DB, BL, PBD] = shape.get_comp(i);
-            return (DB + PBD) * V0.share() - BL * (V1-V0).share();
+            return (DB + PBD).mulshare(V0) - BL.mulshare(V1-V0);
         });
 
         shape.yield();
@@ -375,13 +331,13 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
             shape.tio.aes_ops());
         gamma = pe.reduce(init, [&dp, &shape] (int thread_num, address_t i,
                 const RDPFPair::node &leaf) {
-            // The values from the two DPFs
+            // The values from the two DPFs, each of type T
             auto [V0, V1] = dp.unit<T>(leaf);
 
             // shape.get_server(i) returns a pair of references to the
             // appropriate cells in the two blinded databases
             auto [BL0, BL1] = shape.get_server(i);
-            return std::make_tuple(-BL0 * V1.share(), -BL1 * V0.share());
+            return std::make_tuple(-BL0.mulshare(V1), -BL1.mulshare(V0));
         });
 
         // Choose a random blinding factor
@@ -400,10 +356,12 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
     return res;  // The server will always get 0
 }
 
-// Oblivious update to an additively or XOR shared index of Duoram memory
+// Oblivious update to a shared index of Duoram memory, only for
+// T = RegAS or RegXS
 template <typename T> template <typename U>
 typename Duoram<T>::Shape::template MemRefS<U>
-    &Duoram<T>::Shape::MemRefS<U>::operator+=(const T& M)
+    &Duoram<T>::Shape::MemRefS<U>::oram_update(const T& M,
+        const prac_template_true &)
 {
     Shape &shape = this->shape;
     shape.explicitonly(false);
@@ -498,6 +456,24 @@ typename Duoram<T>::Shape::template MemRefS<U>
     return *this;
 }
 
+// Oblivious update to a shared index of Duoram memory, only for
+// T not equal to RegAS or RegXS
+template <typename T> template <typename U>
+typename Duoram<T>::Shape::template MemRefS<U>
+    &Duoram<T>::Shape::MemRefS<U>::oram_update(const T& M,
+        const prac_template_false &)
+{
+    return *this;
+}
+
+// Oblivious update to an additively or XOR shared index of Duoram memory
+template <typename T> template <typename U>
+typename Duoram<T>::Shape::template MemRefS<U>
+    &Duoram<T>::Shape::MemRefS<U>::operator+=(const T& M)
+{
+    return oram_update(M, prac_basic_Reg_S<T>());
+}
+
 // Oblivious write to an additively or XOR shared index of Duoram memory
 template <typename T> template <typename U>
 typename Duoram<T>::Shape::template MemRefS<U>

+ 4 - 0
online.cpp

@@ -5,6 +5,7 @@
 #include "rdpf.hpp"
 #include "duoram.hpp"
 #include "cdpf.hpp"
+#include "baltree.hpp"
 
 
 static void online_test(MPCIO &mpcio,
@@ -1164,6 +1165,9 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
         } else {
             duoram<RegAS>(mpcio, opts, args);
         }
+    } else if (!strcmp(*args, "baltree")) {
+        ++args;
+        baltree(mpcio, opts, args);
     } else {
         std::cerr << "Unknown mode " << *args << "\n";
     }

+ 21 - 0
rdpf.tcc

@@ -207,6 +207,17 @@ inline std::tuple<RegAS,RegAS,RegAS> RDPFTriple::unit<RegAS>(node leaf) const {
         dpf[2].unit_as(std::get<2>(leaf)));
 }
 
+// For any more complex entry type, that type will handle the conversion
+// for each DPF
+template <typename T>
+inline std::tuple<T,T,T> RDPFTriple::unit(node leaf) const {
+    T v0, v1, v2;
+    v0.unit(dpf[0], std::get<0>(leaf));
+    v1.unit(dpf[1], std::get<1>(leaf));
+    v2.unit(dpf[2], std::get<2>(leaf));
+    return std::make_tuple(v0,v1,v2);
+}
+
 // Get the XOR-shared scaled vector entry from the leaf ndoe
 template <>
 inline std::tuple<RegXS,RegXS,RegXS> RDPFTriple::scaled<RegXS>(node leaf) const {
@@ -255,6 +266,16 @@ inline std::tuple<RegAS,RegAS> RDPFPair::unit<RegAS>(node leaf) const {
         dpf[1].unit_as(std::get<1>(leaf)));
 }
 
+// For any more complex entry type, that type will handle the conversion
+// for each DPF
+template <typename T>
+inline std::tuple<T,T> RDPFPair::unit(node leaf) const {
+    T v0, v1;
+    v0.unit(dpf[0], std::get<0>(leaf));
+    v1.unit(dpf[1], std::get<1>(leaf));
+    return std::make_tuple(v0,v1);
+}
+
 // Get the XOR-shared scaled vector entry from the leaf ndoe
 template <>
 inline std::tuple<RegXS,RegXS> RDPFPair::scaled<RegXS>(node leaf) const {

+ 61 - 0
types.hpp

@@ -48,6 +48,9 @@ using nbits_t = uint8_t;
 struct RegAS {
     value_t ashare;
 
+    // The basic types just have one value
+    static const size_t WIDTH = 1;
+
     RegAS() : ashare(0) {}
 
     inline value_t share() const { return ashare; }
@@ -109,6 +112,22 @@ struct RegAS {
         res &= mask;
         return res;
     }
+
+    // Multiply by the local share of the argument, not multiplcation of
+    // two shared values (two versions)
+    inline RegAS &mulshareeq(const RegAS &rhs) {
+        *this *= rhs.ashare;
+        return *this;
+    }
+    inline RegAS mulshare(const RegAS &rhs) const {
+        RegAS res = *this;
+        res *= rhs.ashare;
+        return res;
+    }
+
+    inline void dump() const {
+        printf("%016lx", ashare);
+    }
 };
 
 inline value_t combine(const RegAS &A, const RegAS &B,
@@ -163,6 +182,9 @@ struct RegBS {
 struct RegXS {
     value_t xshare;
 
+    // The basic types just have one value
+    static const size_t WIDTH = 1;
+
     RegXS() : xshare(0) {}
     RegXS(const RegBS &b) { xshare = b.bshare ? ~0 : 0; }
 
@@ -241,6 +263,22 @@ struct RegXS {
         return res;
     }
 
+    // Multiply by the local share of the argument, not multiplcation of
+    // two shared values (two versions)
+    inline RegXS &mulshareeq(const RegXS &rhs) {
+        *this *= rhs.xshare;
+        return *this;
+    }
+    inline RegXS mulshare(const RegXS &rhs) const {
+        RegXS res = *this;
+        res *= rhs.xshare;
+        return res;
+    }
+
+    inline void dump() const {
+        printf("%016lx", xshare);
+    }
+
     // Extract a bit share of bit bitnum of the XOR-shared register
     inline RegBS bit(nbits_t bitnum) const {
         RegBS bs;
@@ -258,6 +296,29 @@ inline value_t combine(const RegXS &A, const RegXS &B,
     return (A.xshare ^ B.xshare) & mask;
 }
 
+// Enable templates to specialize on just the basic types RegAS and
+// RegXS.  Technique from
+// https://stackoverflow.com/questions/2430039/one-template-specialization-for-multiple-classes
+
+template <bool B> struct prac_template_bool_type {};
+using prac_template_true = prac_template_bool_type<true>;
+using prac_template_false = prac_template_bool_type<false>;
+template <typename T>
+struct prac_basic_Reg_S : prac_template_false
+{
+    static const bool value = false;
+};
+template<>
+struct prac_basic_Reg_S<RegAS>: prac_template_true
+{
+    static const bool value = true;
+};
+template<>
+struct prac_basic_Reg_S<RegXS>: prac_template_true
+{
+    static const bool value = true;
+};
+
 // Some useful operations on tuples, vectors, and arrays of the above
 // types