Browse Source

Simplify the run_coroutines API when there are a small constant number of coroutines to run

Ian Goldberg 1 year ago
parent
commit
704ade665a
4 changed files with 91 additions and 50 deletions
  1. 52 2
      coroutine.hpp
  2. 36 40
      duoram.tcc
  3. 1 3
      online.cpp
  4. 2 5
      rdpf.cpp

+ 52 - 2
coroutine.hpp

@@ -2,12 +2,14 @@
 #define __COROUTINE_HPP__
 
 #include <vector>
+#include <functional>
 #include <boost/coroutine2/coroutine.hpp>
 
 #include "mpcio.hpp"
 
-typedef boost::coroutines2::coroutine<void>::pull_type  coro_t;
-typedef boost::coroutines2::coroutine<void>::push_type  yield_t;
+using coro_t = boost::coroutines2::coroutine<void>::pull_type;
+using yield_t = boost::coroutines2::coroutine<void>::push_type;
+using coro_lambda_t = std::function<void(yield_t&)>;
 
 // The top-level coroutine runner will call run_coroutines with
 // a MPCTIO, and we should call its send() method.  Subcoroutines that
@@ -17,6 +19,8 @@ typedef boost::coroutines2::coroutine<void>::push_type  yield_t;
 static inline void send_or_yield(MPCTIO &tio) { tio.send(); }
 static inline void send_or_yield(yield_t &yield) { yield(); }
 
+// Use this version if you have a variable number of coroutines (or a
+// larger constant number than is supported below).
 template <typename T>
 inline void run_coroutines(T &mpctio_or_yield, std::vector<coro_t> &coroutines) {
     // Loop until all the coroutines are finished
@@ -40,4 +44,50 @@ inline void run_coroutines(T &mpctio_or_yield, std::vector<coro_t> &coroutines)
     }
 }
 
+// Use one of these versions if you have a small fixed number of
+// coroutines.  You can of course also use the above, but the API for
+// this version is simpler.
+
+template <typename T>
+inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1)
+{
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(l1);
+    run_coroutines(mpctio_or_yield, coroutines);
+}
+
+template <typename T>
+inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
+    const coro_lambda_t &l2)
+{
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(l1);
+    coroutines.emplace_back(l2);
+    run_coroutines(mpctio_or_yield, coroutines);
+}
+
+template <typename T>
+inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
+    const coro_lambda_t &l2, const coro_lambda_t &l3)
+{
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(l1);
+    coroutines.emplace_back(l2);
+    coroutines.emplace_back(l3);
+    run_coroutines(mpctio_or_yield, coroutines);
+}
+
+template <typename T>
+inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
+    const coro_lambda_t &l2, const coro_lambda_t &l3,
+    const coro_lambda_t &l4)
+{
+    std::vector<coro_t> coroutines;
+    coroutines.emplace_back(l1);
+    coroutines.emplace_back(l2);
+    coroutines.emplace_back(l3);
+    coroutines.emplace_back(l4);
+    run_coroutines(mpctio_or_yield, coroutines);
+}
+
 #endif

+ 36 - 40
duoram.tcc

@@ -131,16 +131,15 @@ void Duoram<T>::Flat::bitonic_sort(address_t start, nbits_t depth, bool dir)
     }
     // Recurse on the first half (increasing order) and the second half
     // (decreasing order) in parallel
-    std::vector<coro_t> coroutines;
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro.bitonic_sort(start, depth-1, 0);
-    });
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro.bitonic_sort(start+(1<<(depth-1)), depth-1, 1);
-    });
-    run_coroutines(this->yield, coroutines);
+    run_coroutines(this->yield,
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro.bitonic_sort(start, depth-1, 0);
+        },
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro.bitonic_sort(start+(1<<(depth-1)), depth-1, 1);
+        });
     // Merge the two into the desired order
     butterfly(start, depth, dir);
 }
@@ -165,16 +164,15 @@ void Duoram<T>::Flat::butterfly(address_t start, nbits_t depth, bool dir)
     }
     run_coroutines(this->yield, coroutines);
     // Recurse on each half in parallel
-    coroutines.clear();
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro.butterfly(start, depth-1, dir);
-    });
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro.butterfly(start+halfwidth, depth-1, dir);
-    });
-    run_coroutines(this->yield, coroutines);
+    run_coroutines(this->yield,
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro.butterfly(start, depth-1, dir);
+        },
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro.butterfly(start+halfwidth, depth-1, dir);
+        });
 }
 
 // Assuming the memory is already sorted, do an oblivious binary
@@ -421,17 +419,16 @@ template <> template <typename U,typename V>
 void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
 {
     // Load the values in parallel
-    std::vector<coro_t> coroutines;
     RegAS val1, val2;
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        val1 = Acoro[idx1];
-    });
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        val2 = Acoro[idx2];
-    });
-    run_coroutines(yield, coroutines);
+    run_coroutines(yield,
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            val1 = Acoro[idx1];
+        },
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            val2 = Acoro[idx2];
+        });
     // Get a CDPF
     CDPF cdpf = tio.cdpf();
     // Use it to compare the values
@@ -442,16 +439,15 @@ void Duoram<RegAS>::Flat::osort(const U &idx1, const V &idx2, bool dir)
     RegAS cmp_diff;
     mpc_flagmult(tio, yield, cmp_diff, cmp, diff);
     // Update the two locations in parallel
-    coroutines.clear();
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro[idx1] -= cmp_diff;
-    });
-    coroutines.emplace_back([&](yield_t &yield) {
-        Flat Acoro = context(yield);
-        Acoro[idx2] += cmp_diff;
-    });
-    run_coroutines(yield, coroutines);
+    run_coroutines(yield,
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro[idx1] -= cmp_diff;
+        },
+        [&](yield_t &yield) {
+            Flat Acoro = context(yield);
+            Acoro[idx2] += cmp_diff;
+        });
 }
 
 // The MemRefXS routines are almost identical to the MemRefAS routines,

+ 1 - 3
online.cpp

@@ -810,8 +810,7 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
     // Run everything inside a coroutine so that simple tests don't have
     // to start one themselves
     MPCTIO tio(mpcio, 0);
-    std::vector<coro_t> coroutines;
-    coroutines.emplace_back(
+    run_coroutines(tio,
         [&](yield_t &yield) {
             if (!*args) {
                 std::cerr << "Mode is required as the first argument when not preprocessing.\n";
@@ -857,5 +856,4 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
                 std::cerr << "Unknown mode " << *args << "\n";
             }
         });
-    run_coroutines(tio, coroutines);
 }

+ 2 - 5
rdpf.cpp

@@ -132,21 +132,18 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
         bool peer_parity_bit;
         // Exchange the parities and do mpc_reconstruct_choice at the
         // same time (bundled into the same rounds)
-        std::vector<coro_t> coroutines;
-        coroutines.emplace_back(
+        run_coroutines(yield,
             [&](yield_t &yield) {
                 tio.queue_peer(&our_parity_bit, 1);
                 yield();
                 uint8_t peer_parity_byte;
                 tio.recv_peer(&peer_parity_byte, 1);
                 peer_parity_bit = peer_parity_byte & 1;
-            });
-        coroutines.emplace_back(
+            },
             [&](yield_t &yield) {
                 mpc_reconstruct_choice(tio, yield, CW, bs_choice,
                     (R ^ our_parity), L);
             });
-        run_coroutines(yield, coroutines);
         bool parity_bit = our_parity_bit ^ peer_parity_bit;
         cfbits |= (value_t(parity_bit)<<level);
         DPFnode CWR = CW ^ lsb128_mask[parity_bit];