Browse Source

Binary search for Flat of any size

No longer restricted to a power of 2 size

Binary search now also slightly changes what it returns: it's now the
index where you would insert the target.
Ian Goldberg 1 year ago
parent
commit
81bf0f7ddb
6 changed files with 166 additions and 138 deletions
  1. 1 1
      Makefile
  2. 35 19
      duoram.cpp
  3. 4 4
      duoram.hpp
  4. 119 111
      online.cpp
  5. 5 2
      shapes.hpp
  6. 2 1
      shapes.tcc

+ 1 - 1
Makefile

@@ -51,7 +51,7 @@ cdpf.o: bitutils.hpp cdpf.hpp mpcio.hpp types.hpp corotypes.hpp mpcio.tcc
 cdpf.o: coroutine.hpp dpf.hpp prg.hpp aes.hpp cdpf.tcc
 duoram.o: duoram.hpp types.hpp bitutils.hpp mpcio.hpp corotypes.hpp mpcio.tcc
 duoram.o: coroutine.hpp duoram.tcc mpcops.hpp mpcops.tcc cdpf.hpp dpf.hpp
-duoram.o: prg.hpp aes.hpp cdpf.tcc rdpf.hpp rdpf.tcc
+duoram.o: prg.hpp aes.hpp cdpf.tcc rdpf.hpp rdpf.tcc shapes.hpp shapes.tcc
 cell.o: types.hpp bitutils.hpp duoram.hpp mpcio.hpp corotypes.hpp mpcio.tcc
 cell.o: coroutine.hpp duoram.tcc mpcops.hpp mpcops.tcc cdpf.hpp dpf.hpp
 cell.o: prg.hpp aes.hpp cdpf.tcc rdpf.hpp rdpf.tcc cell.hpp options.hpp

+ 35 - 19
duoram.cpp

@@ -1,51 +1,67 @@
 #include "duoram.hpp"
+#include "shapes.hpp"
 
 // Assuming the memory is already sorted, do an oblivious binary
-// search for the largest index containing the value at most the
-// given one.  (The answer will be 0 if all of the memory elements
-// are greate than the target.) This Flat must be a power of 2 size.
-// Only available for additive shared databases for now.
+// search for the smallest index containing the value at least the
+// given one.  (The answer will be the length of the Flat if all
+// elements are smaller than the target.) Only available for additive
+// shared databases for now.
 template <>
 RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
 {
-    nbits_t depth = this->addr_size;
+    if (this->shape_size == 0) {
+        RegAS zero;
+        return zero;
+    }
+    // Create a Pad of the smallest power of 2 size strictly greater
+    // than the Flat size
+    address_t padsize = 1;
+    nbits_t depth = 0;
+    while (padsize <= this->shape_size) {
+        padsize *= 2;
+        ++depth;
+    }
+    Duoram<RegAS>::Pad P(*this, tio, yield, padsize);
+
     // Start in the middle
     RegAS index;
-    index.set(this->tio.player() ? 0 : 1<<(depth-1));
-    // Invariant: index points to the first element of the right half of
-    // the remaining possible range
+    index.set(this->tio.player() ? 0 : (1<<(depth-1))-1);
+    // Invariant: index points to the last element of the left half of
+    // the remaining possible range, which is of width (1<<depth).
     while (depth > 0) {
         // Obliviously read the value there
-        RegAS val = operator[](index);
+        RegAS val = P[index];
         // Compare it to the target
         CDPF cdpf = tio.cdpf(this->yield);
         auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
             val-target, tio.aes_ops());
         if (depth > 1) {
-            // If val > target, the answer is strictly to the left
+            // If val >= target, the answer is here or to the left
             // and we should subtract 2^{depth-2} from index
-            // If val <= target, the answer is here or to the right
+            // If val < target, the answer is strictly to the right
             // and we should add 2^{depth-2} to index
             // So we unconditionally subtract 2^{depth-2} from index, and
-            // add (lt+eq)*2^{depth-1}.
+            // add (lt)*2^{depth-1}.
             RegAS uncond;
             uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
             RegAS cond;
             cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
             RegAS condprod;
-            RegBS le = lt ^ eq;
-            mpc_flagmult(this->tio, this->yield, condprod, le, cond);
+            mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
             index -= uncond;
             index += condprod;
         } else {
-            // If val > target, the answer is strictly to the left
-            // If val <= target, the answer is here or to the right
-            // so subtract gt from index
+            // The possible range is of width 2, and we're pointing to
+            // the first element of it.
+            // If val >= target, the answer is here or to the left, so
+            // it's here.
+            // If val < target, the answer is strictly to the right
+            // so add lt to index
             RegAS cond;
             cond.set(tio.player() ? 0 : 1);
             RegAS condprod;
-            mpc_flagmult(this->tio, this->yield, condprod, gt, cond);
-            index -= condprod;
+            mpc_flagmult(this->tio, this->yield, condprod, lt, cond);
+            index += condprod;
         }
         --depth;
     }

+ 4 - 4
duoram.hpp

@@ -349,10 +349,10 @@ public:
     void bitonic_sort(address_t start, address_t len, bool dir=0);
 
     // Assuming the memory is already sorted, do an oblivious binary
-    // search for the largest index containing the value at most the
-    // given one.  (The answer will be 0 if all of the memory elements
-    // are greate than the target.) This Flat must be a power of 2 size.
-    // Only available for additive shared databases for now.
+    // search for the smallest index containing the value at least the
+    // given one.  (The answer will be the length of the Flat if all
+    // elements are smaller than the target.) Only available for additive
+    // shared databases for now.
     RegAS obliv_binary_search(RegAS &target);
 };
 

+ 119 - 111
online.cpp

@@ -1077,55 +1077,48 @@ static void sort_test(MPCIO &mpcio,
         ++args;
     }
 
-    int num_threads = opts.num_threads;
-    boost::asio::thread_pool pool(num_threads);
-    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth, len] {
-            MPCTIO tio(mpcio, thread_num);
-            run_coroutines(tio, [&tio, depth, len] (yield_t &yield) {
-                address_t size = address_t(1)<<depth;
-                // size_t &aes_ops = tio.aes_ops();
-                Duoram<RegAS> oram(tio.player(), size);
-                auto A = oram.flat(tio, yield);
-                A.explicitonly(true);
-                // Initialize the memory to random values in parallel
-                std::vector<coro_t> coroutines;
-                for (address_t i=0; i<size; ++i) {
-                    coroutines.emplace_back(
-                        [&A, i](yield_t &yield) {
-                            auto Acoro = A.context(yield);
-                            RegAS v;
-                            v.randomize(62);
-                            Acoro[i] += v;
-                        });
-                }
-                run_coroutines(yield, coroutines);
-                A.bitonic_sort(0, len);
+    MPCTIO tio(mpcio, 0, opts.num_threads);
+    run_coroutines(tio, [&tio, depth, len] (yield_t &yield) {
+        address_t size = address_t(1)<<depth;
+        // size_t &aes_ops = tio.aes_ops();
+        Duoram<RegAS> oram(tio.player(), size);
+        auto A = oram.flat(tio, yield);
+        A.explicitonly(true);
+        // Initialize the memory to random values in parallel
+        std::vector<coro_t> coroutines;
+        for (address_t i=0; i<size; ++i) {
+            coroutines.emplace_back(
+                [&A, i](yield_t &yield) {
+                    auto Acoro = A.context(yield);
+                    RegAS v;
+                    v.randomize(62);
+                    Acoro[i] += v;
+                });
+        }
+        run_coroutines(yield, coroutines);
+        A.bitonic_sort(0, len);
+        if (depth <= 10) {
+            oram.dump();
+        }
+        auto check = A.reconstruct();
+        bool fail = false;
+        if (tio.player() == 0) {
+            for (address_t i=0;i<size;++i) {
                 if (depth <= 10) {
-                    oram.dump();
+                    printf("%04x %016lx\n", i, check[i].share());
                 }
-                auto check = A.reconstruct();
-                bool fail = false;
-                if (tio.player() == 0) {
-                    for (address_t i=0;i<size;++i) {
-                        if (depth <= 10) {
-                            printf("%04x %016lx\n", i, check[i].share());
-                        }
-                        if (i>0 && i<len &&
-                            check[i].share() < check[i-1].share()) {
-                            fail = true;
-                        }
-                    }
-                    if (fail) {
-                        printf("FAIL\n");
-                    } else {
-                        printf("PASS\n");
-                    }
+                if (i>0 && i<len &&
+                    check[i].share() < check[i-1].share()) {
+                    fail = true;
                 }
-            });
-        });
-    }
-    pool.join();
+            }
+            if (fail) {
+                printf("FAIL\n");
+            } else {
+                printf("PASS\n");
+            }
+        }
+    });
 }
 
 static void pad_test(MPCIO &mpcio,
@@ -1156,7 +1149,7 @@ static void pad_test(MPCIO &mpcio,
             A[i] = v;
         }
         A.explicitonly(false);
-        // Add 0 to A[0], which reblinds the whole database
+        // Obliviously add 0 to A[0], which reblinds the whole database
         RegAS z;
         A[z] += z;
         auto check = A.reconstruct();
@@ -1179,8 +1172,12 @@ static void pad_test(MPCIO &mpcio,
         }
         printf("\n");
         for (address_t i=0; i<maxsize; ++i) {
+            value_t offset = 0xdeadbeef;
+            if (player) {
+                offset = -offset;
+            }
             RegAS ind;
-            ind.set(player*i);
+            ind.set(player*i+offset);
             RegAS v = P[ind];
             if (depth <= 10) {
                 value_t vval = mpc_reconstruct(tio, yield, v);
@@ -1214,74 +1211,85 @@ static void bsearch_test(MPCIO &mpcio,
         ++args;
     }
 
-    int num_threads = opts.num_threads;
-    boost::asio::thread_pool pool(num_threads);
-    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth, len, target] {
-            MPCTIO tio(mpcio, thread_num);
-            run_coroutines(tio, [&tio, depth, len, target] (yield_t &yield) {
-                address_t size = address_t(1)<<depth;
-                RegAS tshare;
-                if (tio.player() == 2) {
-                    // Send shares of the target to the computational
-                    // players
-                    RegAS tshare0, tshare1;
-                    tshare0.randomize();
-                    tshare1.set(target-tshare0.share());
-                    tio.iostream_p0() << tshare0;
-                    tio.iostream_p1() << tshare1;
-                    printf("Using target = %016lx\n", target);
-                    yield();
-                } else {
-                    // Get the share of the target
-                    tio.iostream_server() >> tshare;
-                }
+    MPCTIO tio(mpcio, 0, opts.num_threads);
+    run_coroutines(tio, [&tio, depth, len, target] (yield_t &yield) {
+        RegAS tshare;
+        if (tio.player() == 2) {
+            // Send shares of the target to the computational
+            // players
+            RegAS tshare0, tshare1;
+            tshare0.randomize();
+            tshare1.set(target-tshare0.share());
+            tio.iostream_p0() << tshare0;
+            tio.iostream_p1() << tshare1;
+            printf("Using target = %016lx\n", target);
+            yield();
+        } else {
+            // Get the share of the target
+            tio.iostream_server() >> tshare;
+        }
+
+        // Create a random database and sort it
+        // size_t &aes_ops = tio.aes_ops();
+        Duoram<RegAS> oram(tio.player(), len);
+        auto A = oram.flat(tio, yield);
+        A.explicitonly(true);
+        // Initialize the memory to random values in parallel
+        std::vector<coro_t> coroutines;
+        for (address_t i=0; i<len; ++i) {
+            coroutines.emplace_back(
+                [&A, i](yield_t &yield) {
+                    auto Acoro = A.context(yield);
+                    RegAS v;
+                    v.randomize(62);
+                    Acoro[i] += v;
+                });
+        }
+        run_coroutines(yield, coroutines);
+        A.bitonic_sort(0, len);
+        A.explicitonly(false);
 
-                // Create a random database and sort it
-                // size_t &aes_ops = tio.aes_ops();
-                Duoram<RegAS> oram(tio.player(), size);
-                auto A = oram.flat(tio, yield);
-                A.explicitonly(true);
-                // Initialize the memory to random values in parallel
-                std::vector<coro_t> coroutines;
-                for (address_t i=0; i<size; ++i) {
-                    coroutines.emplace_back(
-                        [&A, i](yield_t &yield) {
-                            auto Acoro = A.context(yield);
-                            RegAS v;
-                            v.randomize(62);
-                            Acoro[i] += v;
-                        });
+        // Binary search for the target
+        RegAS tindex = A.obliv_binary_search(tshare);
+
+        // Check the answer
+        size_t size = size_t(1) << depth;
+        value_t checkindex = mpc_reconstruct(tio, yield, tindex);
+        value_t checktarget = mpc_reconstruct(tio, yield, tshare);
+        auto check = A.reconstruct();
+        bool fail = false;
+        if (tio.player() == 0) {
+            for (address_t i=0;i<len;++i) {
+                if (depth <= 10) {
+                    printf("%c%04x %016lx\n",
+                        (i == checkindex ? '*' : ' '),
+                        i, check[i].share());
                 }
-                run_coroutines(yield, coroutines);
-                A.bitonic_sort(0, len);
-
-                // Binary search for the target
-                RegAS tindex = A.obliv_binary_search(tshare);
-
-                // Check the answer
-                if (tio.player() == 1) {
-                    tio.iostream_peer() << tindex;
-                } else if (tio.player() == 0) {
-                    RegAS peer_tindex;
-                    tio.iostream_peer() >> peer_tindex;
-                    tindex += peer_tindex;
+                if (i>0 && i<len &&
+                    check[i].share() < check[i-1].share()) {
+                    fail = true;
                 }
-                if (depth <= 10) {
-                    auto check = A.reconstruct();
-                    if (tio.player() == 0) {
-                        for (address_t i=0;i<size;++i) {
-                            printf("%04x %016lx\n", i, check[i].share());
-                        }
+                if (i == checkindex) {
+                    // check[i] should be >= target, and check[i-1]
+                    // should be < target
+                    if ((i < len && check[i].share() < checktarget) ||
+                        (i > 0 && check[i-1].share() >= checktarget)) {
+                        fail = true;
                     }
                 }
-                if (tio.player() == 0) {
-                    printf("Found index = %lx\n", tindex.share());
-                }
-            });
-        });
-    }
-    pool.join();
+            }
+            printf("Target = %016lx\n", checktarget);
+            printf("Found index = %02lx\n", checkindex);
+            if (checkindex > size) {
+                fail = true;
+            }
+            if (fail) {
+                printf("FAIL\n");
+            } else {
+                printf("PASS\n");
+            }
+        }
+    });
 }
 
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)

+ 5 - 2
shapes.hpp

@@ -12,6 +12,8 @@
 
 template <typename T>
 class Duoram<T>::Pad : public Duoram<T>::Shape {
+    // These are pointers because we need to be able to return a
+    // (non-const) T& even from a const Pad.
     T *padvalp;
     T *peerpadvalp;
     T *zerop;
@@ -24,7 +26,8 @@ class Duoram<T>::Pad : public Duoram<T>::Shape {
     Pad &operator=(const Pad &) = delete;
 
 public:
-    // Constructor
+    // Constructor for the Pad shape. The parent must _not_ be in
+    // explicit-only mode.
     Pad(Shape &parent, MPCTIO &tio, yield_t &yield,
         address_t padded_size, value_t padval = 0x7fffffffffffffff);
 
@@ -65,7 +68,7 @@ public:
     inline std::tuple<T&,T&,T&> get_comp(size_t idx,
         std::nullopt_t null = std::nullopt) const override {
         if (idx < this->parent.shape_size) {
-        size_t physaddr = indexmap(idx);
+            size_t physaddr = indexmap(idx);
             return std::tie(
                 this->duoram.database[physaddr],
                 this->duoram.blind[physaddr],

+ 2 - 1
shapes.tcc

@@ -1,7 +1,8 @@
 #ifndef __SHAPES_TCC__
 #define __SHAPES_TCC__
 
-// Constructor for the Pad shape.
+// Constructor for the Pad shape. The parent must _not_ be in
+// explicit-only mode.
 template <typename T>
 Duoram<T>::Pad::Pad(Shape &parent, MPCTIO &tio, yield_t &yield,
     address_t padded_size, size_t padval) :