瀏覽代碼

Binary search

Ian Goldberg 2 年之前
父節點
當前提交
bcb7ae1263
共有 3 個文件被更改,包括 148 次插入2 次删除
  1. 7 0
      duoram.hpp
  2. 53 1
      duoram.tcc
  3. 88 1
      online.cpp

+ 7 - 0
duoram.hpp

@@ -235,6 +235,13 @@ public:
     // the elements must be at most 63 bits long each for the notion of
     // the elements must be at most 63 bits long each for the notion of
     // ">" to make consistent sense.
     // ">" to make consistent sense.
     void bitonic_sort(address_t start, nbits_t depth, bool dir=0);
     void bitonic_sort(address_t start, nbits_t depth, bool dir=0);
+
+    // Assuming the memory is already sorted, do an oblivious binary
+    // search for the largest index containing the value at most the
+    // given one.  (The answer will be 0 if all of the memory elements
+    // are greate than the target.) This Flat must be a power of 2 size.
+    // Only available for additive shared databases for now.
+    RegAS obliv_binary_search(RegAS &target);
 };
 };
 
 
 // The parent class of shared memory references
 // The parent class of shared memory references

+ 53 - 1
duoram.tcc

@@ -177,6 +177,59 @@ void Duoram<T>::Flat::butterfly(address_t start, nbits_t depth, bool dir)
     run_coroutines(this->yield, coroutines);
     run_coroutines(this->yield, coroutines);
 }
 }
 
 
+// Assuming the memory is already sorted, do an oblivious binary
+// search for the largest index containing the value at most the
+// given one.  (The answer will be 0 if all of the memory elements
+// are greate than the target.) This Flat must be a power of 2 size.
+// Only available for additive shared databases for now.
+template <>
+RegAS Duoram<RegAS>::Flat::obliv_binary_search(RegAS &target)
+{
+    nbits_t depth = this->addr_size;
+    // Start in the middle
+    RegAS index;
+    index.set(this->tio.player() ? 0 : 1<<(depth-1));
+    // Invariant: index points to the first element of the right half of
+    // the remaining possible range
+    while (depth > 0) {
+        // Obliviously read the value there
+        RegAS val = operator[](index);
+        // Compare it to the target
+        CDPF cdpf = tio.cdpf();
+        auto [lt, eq, gt] = cdpf.compare(this->tio, this->yield,
+            val-target, tio.aes_ops());
+        if (depth > 1) {
+            // If val > target, the answer is strictly to the left
+            // and we should subtract 2^{depth-2} from index
+            // If val <= target, the answer is here or to the right
+            // and we should add 2^{depth-2} to index
+            // So we unconditionally subtract 2^{depth-2} from index, and
+            // add (lt+eq)*2^{depth-1}.
+            RegAS uncond;
+            uncond.set(tio.player() ? 0 : address_t(1)<<(depth-2));
+            RegAS cond;
+            cond.set(tio.player() ? 0 : address_t(1)<<(depth-1));
+            RegAS condprod;
+            RegBS le = lt ^ eq;
+            mpc_flagmult(this->tio, this->yield, condprod, le, cond);
+            index -= uncond;
+            index += condprod;
+        } else {
+            // If val > target, the answer is strictly to the left
+            // If val <= target, the answer is here or to the right
+            // so subtract gt from index
+            RegAS cond;
+            cond.set(tio.player() ? 0 : 1);
+            RegAS condprod;
+            mpc_flagmult(this->tio, this->yield, condprod, gt, cond);
+            index -= condprod;
+        }
+        --depth;
+    }
+
+    return index;
+}
+
 // Oblivious read from an additively shared index of Duoram memory
 // Oblivious read from an additively shared index of Duoram memory
 template <typename T>
 template <typename T>
 Duoram<T>::Shape::MemRefAS::operator T()
 Duoram<T>::Shape::MemRefAS::operator T()
@@ -367,7 +420,6 @@ typename Duoram<T>::Shape::MemRefAS
 template <> template <typename U,typename V>
 template <> template <typename U,typename V>
 void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
 void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
 {
 {
-    printf("osort %u %u %d\n", idx1, idx2, dir);
     // Load the values in parallel
     // Load the values in parallel
     std::vector<coro_t> coroutines;
     std::vector<coro_t> coroutines;
     RegAS val1, val2;
     RegAS val1, val2;

+ 88 - 1
online.cpp

@@ -704,7 +704,6 @@ static void sort_test(MPCIO &mpcio, yield_t &yield,
                     });
                     });
             }
             }
             run_coroutines(yield, coroutines);
             run_coroutines(yield, coroutines);
-            //A.osort(0,1);
             A.bitonic_sort(0, depth);
             A.bitonic_sort(0, depth);
             if (depth <= 10) {
             if (depth <= 10) {
                 oram.dump();
                 oram.dump();
@@ -721,6 +720,91 @@ static void sort_test(MPCIO &mpcio, yield_t &yield,
     pool.join();
     pool.join();
 }
 }
 
 
+static void bsearch_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
+{
+    value_t target;
+    arc4random_buf(&target, sizeof(target));
+    target >>= 1;
+    nbits_t depth=6;
+
+    if (*args) {
+        depth = atoi(*args);
+        ++args;
+    }
+    if (*args) {
+        target = strtoull(*args, NULL, 16);
+        ++args;
+    }
+
+    int num_threads = opts.num_threads;
+    boost::asio::thread_pool pool(num_threads);
+    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, depth, target] {
+            address_t size = address_t(1)<<depth;
+            MPCTIO tio(mpcio, thread_num);
+            RegAS tshare;
+            if (tio.player() == 2) {
+                // Send shares of the target to the computational
+                // players
+                RegAS tshare0, tshare1;
+                tshare0.randomize();
+                tshare1.set(target-tshare0.share());
+                tio.iostream_p0() << tshare0;
+                tio.iostream_p1() << tshare1;
+                printf("Using target = %016lx\n", target);
+                yield();
+            } else {
+                // Get the share of the target
+                tio.iostream_server() >> tshare;
+            }
+
+            // Create a random database and sort it
+            // size_t &aes_ops = tio.aes_ops();
+            Duoram<RegAS> oram(mpcio.player, size);
+            auto A = oram.flat(tio, yield);
+            // Initialize the memory to random values in parallel
+            std::vector<coro_t> coroutines;
+            for (address_t i=0; i<size; ++i) {
+                coroutines.emplace_back(
+                    [&A, i](yield_t &yield) {
+                        auto Acoro = A.context(yield);
+                        RegAS v;
+                        v.randomize(62);
+                        Acoro[i] += v;
+                    });
+            }
+            run_coroutines(yield, coroutines);
+            A.bitonic_sort(0, depth);
+
+            // Binary search for the target
+            RegAS tindex = A.obliv_binary_search(tshare);
+
+            // Check the answer
+            if (tio.player() == 1) {
+                tio.iostream_peer() << tindex;
+            } else if (tio.player() == 0) {
+                RegAS peer_tindex;
+                tio.iostream_peer() >> peer_tindex;
+                tindex += peer_tindex;
+            }
+            if (depth <= 10) {
+                auto check = A.reconstruct();
+                if (tio.player() == 0) {
+                    for (address_t i=0;i<size;++i) {
+                        printf("%04x %016lx\n", i, check[i].share());
+                    }
+                }
+            }
+            if (tio.player() == 0) {
+                printf("Found index = %lu\n", tindex.share());
+            }
+            tio.send();
+        });
+    }
+    pool.join();
+}
+
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
 {
 {
     // Run everything inside a coroutine so that simple tests don't have
     // Run everything inside a coroutine so that simple tests don't have
@@ -766,6 +850,9 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
             } else if (!strcmp(*args, "sorttest")) {
             } else if (!strcmp(*args, "sorttest")) {
                 ++args;
                 ++args;
                 sort_test(mpcio, yield, opts, args);
                 sort_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "bsearch")) {
+                ++args;
+                bsearch_test(mpcio, yield, opts, args);
             } else {
             } else {
                 std::cerr << "Unknown mode " << *args << "\n";
                 std::cerr << "Unknown mode " << *args << "\n";
             }
             }