Browse Source

Enable coroutines to be used in the preprocessing phase

Ian Goldberg 1 year ago
parent
commit
83282f88ab
3 changed files with 37 additions and 6 deletions
  1. 1 1
      Makefile
  2. 2 0
      coroutine.hpp
  3. 34 5
      preproc.cpp

+ 1 - 1
Makefile

@@ -25,6 +25,6 @@ depend:
 
 prac.o: mpcio.hpp types.hpp preproc.hpp online.hpp
 mpcio.o: mpcio.hpp types.hpp
-preproc.o: types.hpp preproc.hpp mpcio.hpp
+preproc.o: types.hpp coroutine.hpp mpcio.hpp preproc.hpp
 online.o: online.hpp mpcio.hpp types.hpp mpcops.hpp coroutine.hpp
 mpcops.o: mpcops.hpp types.hpp mpcio.hpp coroutine.hpp

+ 2 - 0
coroutine.hpp

@@ -4,6 +4,8 @@
 #include <vector>
 #include <boost/coroutine2/coroutine.hpp>
 
+#include "mpcio.hpp"
+
 typedef boost::coroutines2::coroutine<void>::pull_type  coro_t;
 typedef boost::coroutines2::coroutine<void>::push_type  yield_t;
 

+ 34 - 5
preproc.cpp

@@ -1,6 +1,7 @@
 #include <vector>
 
 #include "types.hpp"
+#include "coroutine.hpp"
 #include "preproc.hpp"
 
 // Open a file for writing with name the given prefix, and ".pX.tY"
@@ -28,7 +29,8 @@ static std::ofstream openfile(const char *prefix, unsigned player,
 // One byte: type
 //   0x80: Multiplication triple
 //   0x81: Multiplication half-triple
-//   0x01 to 0x40: DPF of that depth
+//   0x01 to 0x30: RAM DPF of that depth
+//   0x40: Comparison DPF
 //   0x00: End of preprocessing
 //
 // Four bytes: number of objects of that type (not sent for type == 0x00)
@@ -43,6 +45,7 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
         boost::asio::post(pool, [&mpcio, thread_num] {
             MPCTIO tio(mpcio, thread_num);
+            std::vector<coro_t> coroutines;
             while(1) {
                 unsigned char type = 0;
                 unsigned int num = 0;
@@ -72,12 +75,20 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
                         halffile.write((const char *)&H, sizeof(H));
                     }
                     halffile.close();
+                } else if (type >= 0x01 && type <= 0x30) {
+                    // RAM DPFs
+                    for (unsigned int i=0; i<num; ++i) {
+                        coroutines.emplace_back(
+                            [&](yield_t &yield) {
+                                //rdpf_gen(stio, yield, depth);
+                            });
+                    }
                 }
             }
+            run_coroutines(tio, coroutines);
         });
     }
     pool.join();
-    std::cout << "Lamport clock = " << mpcio.lamport << "\n";
 }
 
 void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
@@ -87,6 +98,7 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
         boost::asio::post(pool, [&mpcsrvio, thread_num, args] {
             char **threadargs = args;
             MPCTIO stio(mpcsrvio, thread_num);
+            std::vector<coro_t> coroutines;
             if (*threadargs && threadargs[0][0] == 'T') {
                 // Per-thread initialization.  The args look like:
                 // T0 t:50 h:10 T1 t:20 h:30 T2 h:20
@@ -136,7 +148,25 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
                     for (unsigned int i=0; i<num; ++i) {
                         stio.halftriple();
                     }
-                }
+                } else if (type[0] == 'r') {
+                    int depth = atoi(type+1);
+                    if (depth < 1 || depth > 48) {
+                        std::cerr << "Invalid DPF depth\n";
+                    } else {
+                        unsigned char typetag = depth;
+                        stio.queue_p0(&typetag, 1);
+                        stio.queue_p0(&num, 4);
+                        stio.queue_p1(&typetag, 1);
+                        stio.queue_p1(&num, 4);
+
+                        for (unsigned int i=0; i<num; ++i) {
+                            coroutines.emplace_back(
+                                [&](yield_t &yield) {
+                                    //rdpf_gen(stio, yield, depth);
+                                });
+                        }
+                    }
+		}
                 free(arg);
                 ++threadargs;
             }
@@ -144,9 +174,8 @@ void preprocessing_server(MPCServerIO &mpcsrvio, int num_threads, char **args)
             unsigned char typetag = 0x00;
             stio.queue_p0(&typetag, 1);
             stio.queue_p1(&typetag, 1);
-            stio.send();
+            run_coroutines(stio, coroutines);
         });
     }
     pool.join();
-    std::cout << "Lamport clock = " << mpcsrvio.lamport << "\n";
 }