Prechádzať zdrojové kódy

Improvements to the binary search unit test

- Don't reconstruct the result in the BINARY SEARCH stanza; wait for the
  CHECK ANSWER stanza
- If using a pre-sorted database, initialize it locally with no
  communication
Ian Goldberg 6 mesiacov pred
rodič
commit
5f7df96f86
1 zmenil súbory, kde vykonal 23 pridanie a 21 odobranie
  1. 23 21
      online.cpp

+ 23 - 21
online.cpp

@@ -1245,9 +1245,10 @@ static void pad_test(MPCIO &mpcio,
     });
 }
 
-
+// T is RegAS for basic bsearch, or RegXS for optimized bsearch
+template<typename T,bool basic>
 static void bsearch_test(MPCIO &mpcio,
-    const PRACOptions &opts, char **args, bool basic)
+    const PRACOptions &opts, char **args)
 {
     value_t target;
     arc4random_buf(&target, sizeof(target));
@@ -1281,7 +1282,7 @@ static void bsearch_test(MPCIO &mpcio,
     }
 
     MPCTIO tio(mpcio, 0, opts.num_threads);
-    run_coroutines(tio, [&tio, &mpcio, depth, len, iters, target, basic, is_presorted] (yield_t &yield) {
+    run_coroutines(tio, [&tio, &mpcio, depth, len, iters, target, is_presorted] (yield_t &yield) {
         RegAS tshare;
         std::cout << "\n===== SETUP =====\n";
 
@@ -1297,6 +1298,7 @@ static void bsearch_test(MPCIO &mpcio,
             yield();
         } else {
             // Get the share of the target
+            yield();
             tio.iostream_server() >> tshare;
         }
 
@@ -1312,20 +1314,21 @@ static void bsearch_test(MPCIO &mpcio,
         // random values and explicitly sort it.
         Duoram<RegAS> oram(tio.player(), len);
         auto A = oram.flat(tio, yield);
-        A.explicitonly(true);
+
         // Initialize the memory to sorted or random values, depending
         // on the is_presorted flag
-        for (address_t i=0; i<len; ++i) {
-            RegAS v;
-            if (!is_presorted) {
+        if (is_presorted) {
+            A.init([](size_t i) {
+                return value_t(i) << 16;
+            });
+        } else {
+            A.explicitonly(true);
+            for (address_t i=0; i<len; ++i) {
+                RegAS v;
                 v.randomize(62);
-            } else {
-                v.ashare = (tio.player() * value_t(i)) << 16;
+                A[i] = v;
             }
-            A[i] = v;
-        }
-        A.explicitonly(false);
-        if (!is_presorted) {
+            A.explicitonly(false);
             A.bitonic_sort(0, len);
         }
 
@@ -1336,14 +1339,12 @@ static void bsearch_test(MPCIO &mpcio,
         mpcio.reset_stats();
         tio.reset_lamport();
         // Binary search for the target
-        value_t checkindex = 0;
+        T tindex;
         for (int i=0; i<iters; ++i) {
-            if (basic) {
-                RegAS tindex = A.basic_binary_search(tshare);
-                checkindex = mpc_reconstruct(tio, yield, tindex);
+            if constexpr (basic) {
+                tindex = A.basic_binary_search(tshare);
             } else {
-                RegXS tindex = A.binary_search(tshare);
-                checkindex = mpc_reconstruct(tio, yield, tindex);
+                tindex = A.binary_search(tshare);
             }
         }
 
@@ -1355,6 +1356,7 @@ static void bsearch_test(MPCIO &mpcio,
         tio.reset_lamport();
         // Check the answer
         size_t size = size_t(1) << depth;
+        value_t checkindex = mpc_reconstruct(tio, yield, tindex);
         value_t checktarget = mpc_reconstruct(tio, yield, tshare);
         auto check = A.reconstruct();
         bool fail = false;
@@ -1653,10 +1655,10 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
         pad_test(mpcio, opts, args);
     } else if (!strcmp(*args, "bbsearch")) {
         ++args;
-        bsearch_test(mpcio, opts, args, true);
+        bsearch_test<RegAS,true>(mpcio, opts, args);
     } else if (!strcmp(*args, "bsearch")) {
         ++args;
-        bsearch_test(mpcio, opts, args, false);
+        bsearch_test<RegXS,false>(mpcio, opts, args);
     } else if (!strcmp(*args, "duoram")) {
         ++args;
         if (opts.use_xor_db) {