Ian Goldberg 2 éve
szülő
commit
022aae16c3
5 módosított fájl, 247 hozzáadás és 24 törlés
  1. 71 24
      duoram.hpp
  2. 113 0
      duoram.tcc
  3. 1 0
      mpcio.hpp
  4. 51 0
      online.cpp
  5. 11 0
      types.hpp

+ 71 - 24
duoram.hpp

@@ -87,6 +87,8 @@ class Duoram<T>::Shape {
     // additive-shared index (x) into an XOR-shared database (T), for
     // example.
 
+    // The parent class of the MemRef* classes
+    class MemRef;
     // When x is unshared explicit value
     class MemRefExpl;
     // When x is additively shared
@@ -165,15 +167,6 @@ public:
     // Get the size
     inline size_t size() { return shape_size; }
 
-    // Update the context (MPCTIO and yield if you've started a new
-    // thread, or just yield if you've started a new coroutine in the
-    // same thread)
-    void context(MPCTIO &new_tio, yield_t &new_yield) {
-        tio = new_tio;
-        yield = new_yield;
-    }
-    void context(yield_t &new_yield) { yield = new_yield; }
-
     // Index into this Shape in various ways
     MemRefAS operator[](const RegAS &idx) { return MemRefAS(*this, idx); }
     MemRefXS operator[](const RegXS &idx) { return MemRefXS(*this, idx); }
@@ -197,6 +190,7 @@ template <typename T>
 class Duoram<T>::Flat : public Duoram<T>::Shape {
     // If this is a subrange, start may be non-0, but it's usually 0
     size_t start;
+    size_t len;
 
     inline size_t indexmap(size_t idx) const {
         size_t paridx = idx + start;
@@ -207,11 +201,61 @@ class Duoram<T>::Flat : public Duoram<T>::Shape {
         }
     }
 
+    // Internal function to aid bitonic_sort
+    void butterfly(address_t start, nbits_t depth, bool dir);
+
 public:
     // Constructor.  len=0 means the maximum size (the parent's size
     // minus start).
     Flat(Duoram &duoram, MPCTIO &tio, yield_t &yield, size_t start = 0,
         size_t len = 0);
+
+    // Update the context (MPCTIO and yield if you've started a new
+    // thread, or just yield if you've started a new coroutine in the
+    // same thread).  Returns a new Shape with an updated context.
+    Flat context(MPCTIO &new_tio, yield_t &new_yield) const {
+        return Flat(this->duoram, new_tio, new_yield, start, len);
+    }
+    Flat context(yield_t &new_yield) const {
+        return Flat(this->duoram, this->tio, new_yield, start, len);
+    }
+
+    // Oblivious sort the elements indexed by the two given indices.
+    // Without reconstructing the values, if dir=0, this[idx1] will
+    // become a share of the smaller of the reconstructed values, and
+    // this[idx2] will become a share of the larger.  If dir=1, it's the
+    // other way around.
+    //
+    // Note: this only works for additively shared databases
+    template<typename U,typename V>
+    void osort(const U &idx1, const V &idx2, bool dir=0);
+
+    // Bitonic sort the elements from start to start+(1<<depth)-1, in
+    // increasing order if dir=0 or decreasing order if dir=1. Note that
+    // the elements must be at most 63 bits long each for the notion of
+    // ">" to make consistent sense.
+    void bitonic_sort(address_t start, nbits_t depth, bool dir=0);
+};
+
+// The parent class of shared memory references
+template <typename T>
+class Duoram<T>::Shape::MemRef {
+protected:
+    const Shape &shape;
+
+    MemRef(const Shape &shape): shape(shape) {}
+
+public:
+
+    // Oblivious read from an additively shared index of Duoram memory
+    virtual operator T() = 0;
+
+    // Oblivious update to an additively shared index of Duoram memory
+    virtual MemRef &operator+=(const T& M) = 0;
+
+    // Convenience function
+    MemRef &operator-=(const T& M) { *this += (-M); return *this; }
+
 };
 
 // An additively shared memory reference.  You get one of these from a
@@ -219,19 +263,18 @@ public:
 // perform operations on this object, which do the Duoram operations.
 
 template <typename T>
-class Duoram<T>::Shape::MemRefAS {
-    const Shape &shape;
+class Duoram<T>::Shape::MemRefAS : public Duoram<T>::Shape::MemRef {
     RegAS idx;
 
 public:
     MemRefAS(const Shape &shape, const RegAS &idx) :
-        shape(shape), idx(idx) {}
+        MemRef(shape), idx(idx) {}
 
     // Oblivious read from an additively shared index of Duoram memory
-    operator T();
+    operator T() override;
 
     // Oblivious update to an additively shared index of Duoram memory
-    MemRefAS &operator+=(const T& M);
+    MemRefAS &operator+=(const T& M) override;
 };
 
 // An XOR shared memory reference.  You get one of these from a Shape A
@@ -239,19 +282,21 @@ public:
 // operations on this object, which do the Duoram operations.
 
 template <typename T>
-class Duoram<T>::Shape::MemRefXS {
-    const Shape &shape;
+class Duoram<T>::Shape::MemRefXS : public Duoram<T>::Shape::MemRef {
     RegXS idx;
 
 public:
     MemRefXS(const Shape &shape, const RegXS &idx) :
-        shape(shape), idx(idx) {}
+        MemRef(shape), idx(idx) {}
 
     // Oblivious read from an XOR shared index of Duoram memory
-    operator T();
+    operator T() override;
 
     // Oblivious update to an XOR shared index of Duoram memory
-    MemRefXS &operator+=(const T& M);
+    MemRefXS &operator+=(const T& M) override;
+
+    // Convenience function
+    MemRefXS &operator-=(const T& M) { *this += (-M); return *this; }
 };
 
 // An explicit memory reference.  You get one of these from a Shape A
@@ -260,19 +305,21 @@ public:
 // operations.
 
 template <typename T>
-class Duoram<T>::Shape::MemRefExpl {
-    const Shape &shape;
+class Duoram<T>::Shape::MemRefExpl : public Duoram<T>::Shape::MemRef {
     address_t idx;
 
 public:
     MemRefExpl(const Shape &shape, address_t idx) :
-        shape(shape), idx(idx) {}
+        MemRef(shape), idx(idx) {}
 
     // Explicit read from a given index of Duoram memory
-    operator T();
+    operator T() override;
 
     // Explicit update to a given index of Duoram memory
-    MemRefExpl &operator+=(const T& M);
+    MemRefExpl &operator+=(const T& M) override;
+
+    // Convenience function
+    MemRefExpl &operator-=(const T& M) { *this += (-M); return *this; }
 };
 
 #include "duoram.tcc"

+ 113 - 0
duoram.tcc

@@ -2,6 +2,8 @@
 
 #include <stdio.h>
 
+#include "cdpf.hpp"
+
 // Pass the player number and desired size
 template <typename T>
 Duoram<T>::Duoram(int player, size_t size) : player(player),
@@ -111,14 +113,76 @@ Duoram<T>::Flat::Flat(Duoram &duoram, MPCTIO &tio, yield_t &yield,
     if (len > maxshapesize || len == 0) {
         len = maxshapesize;
     }
+    this->len = len;
     this->set_shape_size(len);
 }
 
+// Bitonic sort the elements from start to start+(1<<depth)-1, in
+// increasing order if dir=0 or decreasing order if dir=1. Note that
+// the elements must be at most 63 bits long each for the notion of
+// ">" to make consistent sense.
+template <typename T>
+void Duoram<T>::Flat::bitonic_sort(address_t start, nbits_t depth, bool dir)
+{
+    if (depth == 0) return;
+    if (depth == 1) {
+        osort(start, start+1, dir);
+        return;
+    }
+    // Recurse on the first half (increasing order) and the second half
+    // (decreasing order) in parallel
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro.bitonic_sort(start, depth-1, 0);
+    });
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro.bitonic_sort(start+(1<<(depth-1)), depth-1, 1);
+    });
+    run_coroutines(this->yield, coroutines);
+    // Merge the two into the desired order
+    butterfly(start, depth, dir);
+}
+
+// Internal function to aid bitonic_sort
+template <typename T>
+void Duoram<T>::Flat::butterfly(address_t start, nbits_t depth, bool dir)
+{
+    if (depth == 0) return;
+    if (depth == 1) {
+        osort(start, start+1, dir);
+        return;
+    }
+    // Sort pairs of elements half the width apart in parallel
+    address_t halfwidth = address_t(1)<<(depth-1);
+    std::vector<coro_t> coroutines;
+    for (address_t i=0; i<halfwidth;++i) {
+        coroutines.emplace_back([&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro.osort(start+i, start+i+halfwidth, dir);
+        });
+    }
+    run_coroutines(this->yield, coroutines);
+    // Recurse on each half in parallel
+    coroutines.clear();
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro.butterfly(start, depth-1, dir);
+    });
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro.butterfly(start+halfwidth, depth-1, dir);
+    });
+    run_coroutines(this->yield, coroutines);
+}
+
 // Oblivious read from an additively shared index of Duoram memory
 template <typename T>
 Duoram<T>::Shape::MemRefAS::operator T()
 {
     T res;
+    const Shape &shape = this->shape;
     int player = shape.tio.player();
     if (player < 2) {
         // Computational players do this
@@ -212,6 +276,7 @@ template <typename T>
 typename Duoram<T>::Shape::MemRefAS
     &Duoram<T>::Shape::MemRefAS::operator+=(const T& M)
 {
+    const Shape &shape = this->shape;
     int player = shape.tio.player();
     if (player < 2) {
         // Computational players do this
@@ -293,6 +358,50 @@ typename Duoram<T>::Shape::MemRefAS
     return *this;
 }
 
+// Oblivious sort with the provided other element.  Without
+// reconstructing the values, *this will become a share of the
+// smaller of the reconstructed values, and other will become a
+// share of the larger.
+//
+// Note: this only works for additively shared databases
+template <> template <typename U,typename V>
+void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
+{
+    printf("osort %u %u %d\n", idx1, idx2, dir);
+    // Load the values in parallel
+    std::vector<coro_t> coroutines;
+    RegAS val1, val2;
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        val1 = Acoro[idx1];
+    });
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        val2 = Acoro[idx2];
+    });
+    run_coroutines(yield, coroutines);
+    // Get a CDPF
+    CDPF cdpf = tio.cdpf();
+    // Use it to compare the values
+    RegAS diff = val1-val2;
+    auto [lt, eq, gt] = cdpf.compare(tio, yield, diff, tio.aes_ops());
+    RegBS cmp = dir ? lt : gt;
+    // Get additive shares of cmp*diff
+    RegAS cmp_diff;
+    mpc_flagmult(tio, yield, cmp_diff, cmp, diff);
+    // Update the two locations in parallel
+    coroutines.clear();
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro[idx1] -= cmp_diff;
+    });
+    coroutines.emplace_back([&](yield_t &yield) {
+        Flat Acoro = context(yield);
+        Acoro[idx2] += cmp_diff;
+    });
+    run_coroutines(yield, coroutines);
+}
+
 // The MemRefXS routines are almost identical to the MemRefAS routines,
 // but I couldn't figure out how to get them to be two instances of a
 // template.  Sorry for the code duplication.
@@ -301,6 +410,7 @@ typename Duoram<T>::Shape::MemRefAS
 template <typename T>
 Duoram<T>::Shape::MemRefXS::operator T()
 {
+    const Shape &shape = this->shape;
     T res;
     int player = shape.tio.player();
     if (player < 2) {
@@ -393,6 +503,7 @@ template <typename T>
 typename Duoram<T>::Shape::MemRefXS
     &Duoram<T>::Shape::MemRefXS::operator+=(const T& M)
 {
+    const Shape &shape = this->shape;
     int player = shape.tio.player();
     if (player < 2) {
         // Computational players do this
@@ -478,6 +589,7 @@ typename Duoram<T>::Shape::MemRefXS
 template <typename T>
 Duoram<T>::Shape::MemRefExpl::operator T()
 {
+    const Shape &shape = this->shape;
     T res;
     int player = shape.tio.player();
     if (player < 2) {
@@ -491,6 +603,7 @@ template <typename T>
 typename Duoram<T>::Shape::MemRefExpl
     &Duoram<T>::Shape::MemRefExpl::operator+=(const T& M)
 {
+    const Shape &shape = this->shape;
     int player = shape.tio.player();
     if (player < 2) {
         // Computational players do this

+ 1 - 0
mpcio.hpp

@@ -351,6 +351,7 @@ public:
     inline bool preprocessing() { return mpcio.preprocessing; }
     inline bool is_server() { return mpcio.player == 2; }
     inline size_t& aes_ops() { return mpcio.aes_ops[thread_num]; }
+    inline size_t msgs_sent() { return mpcio.msgs_sent[thread_num]; }
 };
 
 // Set up the socket connections between the two computational parties

+ 51 - 0
online.cpp

@@ -673,6 +673,54 @@ static void compare_test(MPCIO &mpcio, yield_t &yield,
     pool.join();
 }
 
+static void sort_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
+{
+    nbits_t depth=6;
+
+    if (*args) {
+        depth = atoi(*args);
+        ++args;
+    }
+
+    int num_threads = opts.num_threads;
+    boost::asio::thread_pool pool(num_threads);
+    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth] {
+            address_t size = address_t(1)<<depth;
+            MPCTIO tio(mpcio, thread_num);
+            // size_t &aes_ops = tio.aes_ops();
+            Duoram<RegAS> oram(mpcio.player, size);
+            auto A = oram.flat(tio, yield);
+            // Initialize the memory to random values in parallel
+            std::vector<coro_t> coroutines;
+            for (address_t i=0; i<size; ++i) {
+                coroutines.emplace_back(
+                    [&A, i](yield_t &yield) {
+                        auto Acoro = A.context(yield);
+                        RegAS v;
+                        v.randomize(62);
+                        Acoro[i] += v;
+                    });
+            }
+            run_coroutines(yield, coroutines);
+            //A.osort(0,1);
+            A.bitonic_sort(0, depth);
+            if (depth <= 10) {
+                oram.dump();
+                auto check = A.reconstruct();
+                if (tio.player() == 0) {
+                    for (address_t i=0;i<size;++i) {
+                        printf("%04x %016lx\n", i, check[i].share());
+                    }
+                }
+            }
+            tio.send();
+        });
+    }
+    pool.join();
+}
+
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
 {
     // Run everything inside a coroutine so that simple tests don't have
@@ -715,6 +763,9 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
             } else if (!strcmp(*args, "cmptest")) {
                 ++args;
                 compare_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "sorttest")) {
+                ++args;
+                sort_test(mpcio, yield, opts, args);
             } else {
                 std::cerr << "Unknown mode " << *args << "\n";
             }

+ 11 - 0
types.hpp

@@ -80,6 +80,12 @@ struct RegAS {
         return res;
     }
 
+    inline RegAS operator-() const {
+        RegAS res = *this;
+        res.ashare = -res.ashare;
+        return res;
+    }
+
     inline RegAS &operator*=(value_t rhs) {
         this->ashare *= rhs;
         return *this;
@@ -195,6 +201,11 @@ struct RegXS {
         return res;
     }
 
+    inline RegXS operator-() const {
+        RegXS res = *this;
+        return res;
+    }
+
     inline RegXS &operator*=(value_t rhs) {
         this->xshare &= rhs;
         return *this;