Pārlūkot izejas kodu

Improved binary search

But the ORAM operations aren't yet reusing DPFs
Ian Goldberg 1 gadu atpakaļ
vecāks
revīzija
c6841ae846
3 mainītis faili ar 96 papildinājumiem un 7 dzēšanām
  1. 64 1
      duoram.cpp
  2. 18 1
      duoram.hpp
  3. 14 5
      online.cpp

+ 64 - 1
duoram.cpp

@@ -6,8 +6,11 @@
 // given one.  (The answer will be the length of the Flat if all
 // elements are smaller than the target.) Only available for additive
 // shared databases for now.
+
+// The basic version uses log(N) ORAM reads of size N, where N is the
+// smallest power of 2 strictly larger than the Flat size
 template <>
-RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
+RegAS Duoram<RegAS>::Flat::basic_binary_search(RegAS &target)
 {
     if (this->shape_size == 0) {
         RegAS zero;
@@ -69,3 +72,63 @@ RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
     return index;
 }
 
+// This version does 1 ORAM read of size 2, 1 of size 4, 1 of size
+// 8, ..., 1 of size N/2, where N is the smallest power of 2 strictly
+// larger than the Flat size
+template <>
+RegXS Duoram<RegAS>::Flat::binary_search(RegAS &target)
+{
+    if (this->shape_size == 0) {
+        RegXS zero;
+        return zero;
+    }
+    // Create a Pad of the smallest power of 2 size strictly greater
+    // than the Flat size
+    address_t padsize = 1;
+    nbits_t depth = 0;
+    while (padsize <= this->shape_size) {
+        padsize *= 2;
+        ++depth;
+    }
+    Duoram<RegAS>::Pad P(*this, tio, yield, padsize);
+    // Explicitly read the middle item
+    address_t mid = (1<<(depth-1))-1;
+    RegAS val = P[mid];
+    // Compare it to the target
+    CDPF cdpf = tio.cdpf(this->yield);
+    auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
+        val-target, tio.aes_ops());
+    if (depth == 1) {
+        // There was only one item in the Flat, and mid will equal 0, so
+        // val is (a share of) that item, P[0].  If val >= target, the
+        // answer is here or to the left, so it must be 0.  If val <
+        // target, the answer is strictly to the right, so it must be 1.
+        // So just return lt.
+        return RegXS(lt);
+    }
+    auto oidx = P.oblivindex(depth-1);
+    oidx.incr(lt);
+    --depth;
+    while(depth > 0) {
+        // Create the Stride shape; the ORAM will operate only over
+        // elements of the Stride, which will consist of exactly those
+        // elements of the Pad we could possibly be accessing at this
+        // depth.  Those will be elements start=(1<<(depth-1)-1,
+        // start+(1<<depth), start+(2<<depth), start+(3<<depth), and so
+        // on.  The invariant is that the range of remaining possible
+        // answers is of width (1<<depth), and we will look at the
+        // rightmost element of the left half.  If that value (val) has
+        // val >= target, then the answer is at that position or to the
+        // left, so we append a 0 to the index.  If val < targer, then
+        // the answer is strictly to the right, so we append a 1 to the
+        // index.  That is, always append lt to the index.
+        Duoram<RegAS>::Stride S(P, tio, yield, (1<<(depth-1))-1, 1<<depth);
+        RegAS val = S[oidx];
+        CDPF cdpf = tio.cdpf(this->yield);
+        auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
+            val-target, tio.aes_ops());
+        oidx.incr(lt);
+        --depth;
+    }
+    return oidx.index();
+}

+ 18 - 1
duoram.hpp

@@ -407,7 +407,14 @@ public:
     // given one.  (The answer will be the length of the Flat if all
     // elements are smaller than the target.) Only available for additive
     // shared databases for now.
-    RegAS obliv_binary_search(RegAS &target);
+
+    // The basic version uses log(N) ORAM reads of size N, where N is
+    // the smallest power of 2 strictly larger than the Flat size
+    RegAS basic_binary_search(RegAS &target);
+    // This version does 1 ORAM read of size 2, 1 of size 4, 1 of size
+    // 8, ..., 1 of size N/2, where N is the smallest power of 2
+    // strictly larger than the Flat size
+    RegXS binary_search(RegAS &target);
 };
 
 // Oblivious indices for use in related-index ORAM accesses.
@@ -451,6 +458,16 @@ public:
         }
     }
 
+    // Incrementally append a (shared) bit to the oblivious index
+    void incr(RegBS bit)
+    {
+        assert(incremental);
+        idx.xshare = (idx.xshare << 1) | value_t(bit.bshare);
+        ++curdepth;
+    }
+
+    // Get a copy of the index
+    U index() { return idx; }
 };
 
 // An additive or XOR shared memory reference.  You get one of these

+ 14 - 5
online.cpp

@@ -1206,7 +1206,7 @@ static void pad_test(MPCIO &mpcio,
 
 
 static void bsearch_test(MPCIO &mpcio,
-    const PRACOptions &opts, char **args)
+    const PRACOptions &opts, char **args, bool basic)
 {
     value_t target;
     arc4random_buf(&target, sizeof(target));
@@ -1228,7 +1228,7 @@ static void bsearch_test(MPCIO &mpcio,
     }
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
-    run_coroutines(tio, [&tio, depth, len, target] (yield_t &yield) {
+    run_coroutines(tio, [&tio, depth, len, target, basic] (yield_t &yield) {
         RegAS tshare;
         if (tio.player() == 2) {
             // Send shares of the target to the computational
@@ -1266,11 +1266,17 @@ static void bsearch_test(MPCIO &mpcio,
         A.explicitonly(false);
 
         // Binary search for the target
-        RegAS tindex = A.obliv_binary_search(tshare);
+        value_t checkindex;
+        if (basic) {
+            RegAS tindex = A.basic_binary_search(tshare);
+            checkindex = mpc_reconstruct(tio, yield, tindex);
+        } else {
+            RegXS tindex = A.binary_search(tshare);
+            checkindex = mpc_reconstruct(tio, yield, tindex);
+        }
 
         // Check the answer
         size_t size = size_t(1) << depth;
-        value_t checkindex = mpc_reconstruct(tio, yield, tindex);
         value_t checktarget = mpc_reconstruct(tio, yield, tshare);
         auto check = A.reconstruct();
         bool fail = false;
@@ -1384,9 +1390,12 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
     } else if (!strcmp(*args, "padtest")) {
         ++args;
         pad_test(mpcio, opts, args);
+    } else if (!strcmp(*args, "bbsearch")) {
+        ++args;
+        bsearch_test(mpcio, opts, args, true);
     } else if (!strcmp(*args, "bsearch")) {
         ++args;
-        bsearch_test(mpcio, opts, args);
+        bsearch_test(mpcio, opts, args, false);
     } else if (!strcmp(*args, "duoram")) {
         ++args;
         if (opts.use_xor_db) {