Browse Source

Keep separate track of the number of threads we can use for computation and for communication

First keep track of the number of threads a coroutine using this MPCTIO
can use for local computation (no communication and no yielding).
Multiple coroutines with the same MPCTIO can have this value larger than
1, since they will not be able to use multiple threads at the same time.

Also keep track of the number of threads a coroutine using this MPCTIO
can launch into separate MPCTIOs with their own communication.  It is
important that at most one coroutine using this MPCTIO can have this
value set larger than 1, since all MPCTIOs with the same thread_num (and
so using the same sockets) have to be controlled by the same
run_coroutines(tio, ...) call.
Ian Goldberg 1 year ago
parent
commit
e09f4e3f3b
3 changed files with 53 additions and 4 deletions
  1. 25 0
      coroutine.hpp
  2. 4 3
      mpcio.cpp
  3. 24 1
      mpcio.hpp

+ 25 - 0
coroutine.hpp

@@ -15,10 +15,30 @@
 static inline void send_or_yield(MPCTIO &tio) { tio.send(); }
 static inline void send_or_yield(yield_t &yield) { yield(); }
 
+// Get and set communication_nthreads for an MPCTIO; for a yield_t, this
+// is a no-op.
+static inline int getset_communication_nthreads(MPCTIO &tio, int nthreads = 0) {
+    return tio.comm_nthreads(nthreads);
+}
+static inline int getset_communication_nthreads(yield_t &yield, int nthreads = 0) {
+    return 0;
+}
+
 // Use this version if you have a variable number of coroutines (or a
 // larger constant number than is supported below).
 template <typename T>
 inline void run_coroutines(T &mpctio_or_yield, std::vector<coro_t> &coroutines) {
+    // If there's more than one coroutine, at most one of them can have
+    // communication_nthreads larger than 1 (see mpcio.hpp for details).
+    // For now, we set them _all_ to 1 (if there's more than one of
+    // them), and restore communication_nthreads when they're all done.
+
+    int saved_communication_nthreads = 0;
+    if (coroutines.size() > 1) {
+        saved_communication_nthreads =
+            getset_communication_nthreads(mpctio_or_yield, 1);
+    }
+
     // Loop until all the coroutines are finished
     bool finished = false;
     while(!finished) {
@@ -38,6 +58,11 @@ inline void run_coroutines(T &mpctio_or_yield, std::vector<coro_t> &coroutines)
             }
         }
     }
+
+    if (saved_communication_nthreads > 0) {
+        getset_communication_nthreads(mpctio_or_yield,
+            saved_communication_nthreads);
+    }
 }
 
 // Use one of these versions if you have a small fixed number of

+ 4 - 3
mpcio.cpp

@@ -402,9 +402,10 @@ void MPCServerIO::dump_stats(std::ostream &os)
     dump_precomp_stats(os);
 }
 
-MPCTIO::MPCTIO(MPCIO &mpcio, int thread_num) :
-        thread_num(thread_num), thread_lamport(mpcio.lamport),
-        mpcio(mpcio)
+MPCTIO::MPCTIO(MPCIO &mpcio, int thread_num, int num_threads) :
+        thread_num(thread_num), local_cpu_nthreads(num_threads),
+        communication_nthreads(num_threads),
+        thread_lamport(mpcio.lamport), mpcio(mpcio)
 #ifdef VERBOSE_COMMS
         , round_num(0)
 #endif

+ 24 - 1
mpcio.hpp

@@ -275,6 +275,22 @@ public:
 
 class MPCTIO {
     int thread_num;
+
+    // The number of threads a coroutine using this MPCTIO can use for
+    // local computation (no communication and no yielding).  Multiple
+    // coroutines with the same MPCTIO can have this value larger than
+    // 1, since they will not be able to use multiple threads at the
+    // same time.
+    int local_cpu_nthreads;
+
+    // The number of threads a coroutine using this MPCTIO can launch
+    // into separate MPCTIOs with their own communication.  It is
+    // important that at most one coroutine using this MPCTIO can have
+    // this value set larger than 1, since all MPCTIOs with the same
+    // thread_num (and so using the same sockets) have to be controlled
+    // by the same run_coroutines(tio, ...) call.
+    int communication_nthreads;
+
     lamport_t thread_lamport;
     MPCIO &mpcio;
     std::optional<MPCSingleIOStream> peer_iostream;
@@ -286,7 +302,7 @@ class MPCTIO {
 #endif
 
 public:
-    MPCTIO(MPCIO &mpcio, int thread_num);
+    MPCTIO(MPCIO &mpcio, int thread_num, int num_threads = 1);
 
     // Sync our per-thread lamport clock with the master one in the
     // mpcio.  You only need to call this explicitly if your MPCTIO
@@ -365,6 +381,13 @@ public:
     inline bool is_server() { return mpcio.player == 2; }
     inline size_t& aes_ops() { return mpcio.aes_ops[thread_num]; }
     inline size_t msgs_sent() { return mpcio.msgs_sent[thread_num]; }
+    inline int comm_nthreads(int nthreads=0) {
+        int res = communication_nthreads;
+        if (nthreads > 0) {
+            communication_nthreads = nthreads;
+        }
+        return res;
+    }
 };
 
 // Set up the socket connections between the two computational parties