Explorar o código

Simplify ParallelEval API

Ian Goldberg hai 1 ano
pai
achega
33a85d89ca
Modificáronse 4 ficheiros con 39 adicións e 74 borrados
  1. 8 23
      duoram.tcc
  2. 8 22
      online.cpp
  3. 10 12
      rdpf.hpp
  4. 13 17
      rdpf.tcc

+ 8 - 23
duoram.tcc

@@ -330,17 +330,14 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
             shape.shape_size, shape.tio.cpu_nthreads(),
             shape.tio.aes_ops());
         T init;
-        res = pe.reduce(init, [&dp, &shape] (const ParallelEval<RDPFPair> &pe,
-            int thread_num, address_t i, const RDPFPair::node &leaf) {
+        res = pe.reduce(init, [&dp, &shape] (int thread_num, address_t i,
+                const RDPFPair::node &leaf) {
             // The values from the two DPFs
             auto [V0, V1] = dp.unit<T>(leaf);
             // References to the appropriate cells in our database, our
             // blind, and our copy of the peer's blinded database
             auto [DB, BL, PBD] = shape.get_comp(i);
             return (DB + PBD) * V0.share() - BL * (V1-V0).share();
-        },
-        [] (const ParallelEval<RDPFPair> &pe, T &accum, const T &value) {
-            accum += value;
         });
 
         shape.yield();
@@ -368,8 +365,8 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
         ParallelEval pe(dp, IfRegAS<U>(indshift), IfRegXS<U>(indshift),
             shape.shape_size, shape.tio.cpu_nthreads(),
             shape.tio.aes_ops());
-        gamma = pe.reduce(init, [&dp, &shape] (const ParallelEval<RDPFPair> &pe,
-            int thread_num, address_t i, const RDPFPair::node &leaf) {
+        gamma = pe.reduce(init, [&dp, &shape] (int thread_num, address_t i,
+                const RDPFPair::node &leaf) {
             // The values from the two DPFs
             auto [V0, V1] = dp.unit<T>(leaf);
 
@@ -377,10 +374,6 @@ Duoram<T>::Shape::MemRefS<U>::operator T()
             // appropriate cells in the two blinded databases
             auto [BL0, BL1] = shape.get_server(i);
             return std::make_tuple(-BL0 * V1.share(), -BL1 * V0.share());
-        },
-        [] (const ParallelEval<RDPFPair> &pe, std::tuple<T,T> &accum,
-            const std::tuple<T,T> &value) {
-            accum += value;
         });
 
         // Choose a random blinding factor
@@ -443,9 +436,8 @@ typename Duoram<T>::Shape::MemRefS<U>
             shape.shape_size, shape.tio.cpu_nthreads(),
             shape.tio.aes_ops());
         int init = 0;
-        pe.reduce(init, [&dt, &shape, &Mshift, player]
-            (const ParallelEval<RDPFTriple> &pe, int thread_num,
-            address_t i, const RDPFTriple::node &leaf) {
+        pe.reduce(init, [&dt, &shape, &Mshift, player] (int thread_num,
+                address_t i, const RDPFTriple::node &leaf) {
             // The values from the three DPFs
             auto [V0, V1, V2] = dt.scaled<T>(leaf) + dt.unit<T>(leaf) * Mshift;
             // References to the appropriate cells in our database, our
@@ -460,9 +452,6 @@ typename Duoram<T>::Shape::MemRefS<U>
                 PBD += V1-V0;
             }
             return 0;
-        },
-        // We don't need to return a value
-        [] (const ParallelEval<RDPFTriple> &pe, int &accum, int value) {
         });
     } else {
         // The server does this
@@ -485,9 +474,8 @@ typename Duoram<T>::Shape::MemRefS<U>
             shape.shape_size, shape.tio.cpu_nthreads(),
             shape.tio.aes_ops());
         int init = 0;
-        pe.reduce(init, [&dp, &shape, &Mshift]
-            (const ParallelEval<RDPFPair> &pe, int thread_num,
-            address_t i, const RDPFPair::node &leaf) {
+        pe.reduce(init, [&dp, &shape, &Mshift] (int thread_num,
+                address_t i, const RDPFPair::node &leaf) {
             // The values from the two DPFs
             auto V = dp.scaled<T>(leaf) + dp.unit<T>(leaf) * Mshift;
             // shape.get_server(i) returns a pair of references to the
@@ -495,9 +483,6 @@ typename Duoram<T>::Shape::MemRefS<U>
             // subtract the pair directly.
             shape.get_server(i) -= V;
             return 0;
-        },
-        // We don't need to return a value
-        [] (const ParallelEval<RDPFPair> &pe, int &accum, int value) {
         });
     }
     return *this;

+ 8 - 22
online.cpp

@@ -404,13 +404,9 @@ static void par_rdpfeval_timing(MPCIO &mpcio,
                 auto pe = ParallelEval(dpf, start, 0,
                     address_t(1)<<depth, num_threads, tio.aes_ops());
                 RegXS result, init;
-                result = pe.reduce(init, [&dpf] (const ParallelEval<RDPF> &pe,
-                    int thread_num, address_t i, const RDPF::node &leaf) {
+                result = pe.reduce(init, [&dpf] (int thread_num,
+                        address_t i, const RDPF::node &leaf) {
                     return dpf.scaled_xs(leaf);
-                },
-                [] (const ParallelEval<RDPF> &pe, RegXS &accum,
-                    const RegXS &value) {
-                    accum ^= value;
                 });
                 printf("%016lx\n%016lx\n", result.xshare,
                     dpf.scaled_xor.xshare);
@@ -424,13 +420,9 @@ static void par_rdpfeval_timing(MPCIO &mpcio,
                 auto pe = ParallelEval(dpf, start, 0,
                     address_t(1)<<depth, num_threads, tio.aes_ops());
                 RegXS result, init;
-                result = pe.reduce(init, [&dpf] (const ParallelEval<RDPF> &pe,
-                    int thread_num, address_t i, const RDPF::node &leaf) {
+                result = pe.reduce(init, [&dpf] (int thread_num,
+                        address_t i, const RDPF::node &leaf) {
                     return dpf.scaled_xs(leaf);
-                },
-                [] (const ParallelEval<RDPF> &pe, RegXS &accum,
-                    const RegXS &value) {
-                    accum ^= value;
                 });
                 printf("%016lx\n%016lx\n", result.xshare,
                     dpf.scaled_xor.xshare);
@@ -527,12 +519,9 @@ static void par_tupleeval_timing(MPCIO &mpcio,
                 num_threads, aes_ops);
             using V = std::tuple<RegXS,RegXS>;
             V result, init;
-            result = pe.reduce(init, [&dp] (const ParallelEval<RDPFPair> &pe,
-                int thread_num, address_t i, const RDPFPair::node &leaf) {
+            result = pe.reduce(init, [&dp] (int thread_num, address_t i,
+                    const RDPFPair::node &leaf) {
                 return dp.scaled<RegXS>(leaf);
-            },
-            [] (const ParallelEval<RDPFPair> &pe, V &accum, const V &value) {
-                accum += value;
             });
             printf("%016lx\n%016lx\n", std::get<0>(result).xshare,
                 dp.dpf[0].scaled_xor.xshare);
@@ -546,12 +535,9 @@ static void par_tupleeval_timing(MPCIO &mpcio,
                 num_threads, aes_ops);
             using V = std::tuple<RegXS,RegXS,RegXS>;
             V result, init;
-            result = pe.reduce(init, [&dt] (const ParallelEval<RDPFTriple> &pe,
-                int thread_num, address_t i, const RDPFTriple::node &leaf) {
+            result = pe.reduce(init, [&dt] (int thread_num, address_t i,
+                    const RDPFTriple::node &leaf) {
                 return dt.scaled<RegXS>(leaf);
-            },
-            [] (const ParallelEval<RDPFTriple> &pe, V &accum, const V &value) {
-                accum += value;
             });
             printf("%016lx\n%016lx\n", std::get<0>(result).xshare,
                 dt.dpf[0].scaled_xor.xshare);

+ 10 - 12
rdpf.hpp

@@ -284,26 +284,24 @@ struct ParallelEval {
     // Run the parallel evaluator.  The type V is the type of the
     // accumulator; init should be the "zero" value of the accumulator.
     // The type W (process) is a lambda type with the signature
-    // (const ParallelEval &, int, address_t, const T::node &) -> V
+    // (int, address_t, const T::node &) -> V
     // which will be called like this for each i from 0 to num_evals-1,
     // across num_thread threads:
-    // value_i = process(*this, t, i, DPF((start+i) XOR xor_offset))
+    // value_i = process(t, i, DPF((start+i) XOR xor_offset))
     // t is the thread number (0 <= t < num_threads).
-    // The type X (accumulate) is a lambda type with the signature
-    // (const ParallelEval &, V &, const V &)
-    // which will be called to combine the num_evals values of accum,
-    // first accumulating the values within each thread (starting with
-    // the init value), and then accumulating the totals from each
-    // thread together (again starting with the init value):
+    // The resulting num_evals values will be combined using V's +=
+    // operator, first accumulating the values within each thread
+    // (starting with the init value), and then accumulating the totals
+    // from each thread together (again starting with the init value):
     //
     // total = init
     // for each thread t:
     //     accum_t = init
     //     for each accum_i generated by thread t:
-    //         accumulate(*this, acccum_t, value_i)
-    //     accumulate(*this, total, accum_t)
-    template <typename V, typename W, typename X>
-    inline V reduce(V init, W process, X accumulate);
+    //         accum_t += value_i
+    //     total += accum_t
+    template <typename V, typename W>
+    inline V reduce(V init, W process);
 };
 
 #include "rdpf.tcc"

+ 13 - 17
rdpf.tcc

@@ -105,26 +105,24 @@ typename T::node StreamEval<T>::next()
 // Run the parallel evaluator.  The type V is the type of the
 // accumulator; init should be the "zero" value of the accumulator.
 // The type W (process) is a lambda type with the signature
-// (const ParallelEval &, int, address_t, T::node) -> V
+// (int, address_t, const T::node &) -> V
 // which will be called like this for each i from 0 to num_evals-1,
 // across num_thread threads:
-// value_i = process(*this, t, i, DPF((start+i) XOR xor_offset))
+// value_i = process(t, i, DPF((start+i) XOR xor_offset))
 // t is the thread number (0 <= t < num_threads).
-// The type X (accumulate) is a lambda type with the signature
-// (const ParallelEval &, V &, const V &)
-// which will be called to combine the num_evals values of accum,
-// first accumulating the values within each thread (starting with
-// the init value), and then accumulating the totals from each
-// thread together (again starting with the init value):
+// The resulting num_evals values will be combined using V's +=
+// operator, first accumulating the values within each thread
+// (starting with the init value), and then accumulating the totals
+// from each thread together (again starting with the init value):
 //
 // total = init
 // for each thread t:
 //     accum_t = init
 //     for each accum_i generated by thread t:
-//         accumulate(*this, acccum_t, value_i)
-//     accumulate(*this, total, accum_t)
-template <typename T> template <typename V, typename W, typename X>
-inline V ParallelEval<T>::reduce(V init, W process, X accumulate)
+//         accum_t += value_i
+//     total += accum_t
+template <typename T> template <typename V, typename W>
+inline V ParallelEval<T>::reduce(V init, W process)
 {
     size_t thread_aes_ops[num_threads];
     V accums[num_threads];
@@ -139,17 +137,15 @@ inline V ParallelEval<T>::reduce(V init, W process, X accumulate)
         address_t threadsize = threadchunk + (address_t(thread_num) < threadextra);
         boost::asio::post(pool,
             [this, &init, &thread_aes_ops, &accums, &process,
-                    &accumulate, thread_num, threadstart, threadsize,
-                    indexmask] {
+                    thread_num, threadstart, threadsize, indexmask] {
                 size_t local_aes_ops = 0;
                 auto ev = StreamEval(rdpf, (start+threadstart)&indexmask,
                     xor_offset, local_aes_ops);
                 V accum = init;
                 for (address_t x=0;x<threadsize;++x) {
                     typename T::node leaf = ev.next();
-                    V value = process(*this, thread_num,
+                    accum += process(thread_num,
                         (threadstart+x)&indexmask, leaf);
-                    accumulate(*this, accum, value);
                 }
                 accums[thread_num] = accum;
                 thread_aes_ops[thread_num] = local_aes_ops;
@@ -159,7 +155,7 @@ inline V ParallelEval<T>::reduce(V init, W process, X accumulate)
     pool.join();
     V total = init;
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        accumulate(*this, total, accums[thread_num]);
+        total += accums[thread_num];
         aes_ops += thread_aes_ops[thread_num];
     }
     return total;