Browse Source

Add ecall_precompute_sort to precompute a single WaksmanNetwork

You can call this multiple times from multiple background threads.
Ian Goldberg 1 year ago
parent
commit
919ef799cc
7 changed files with 83 additions and 7 deletions
  1. 6 0
      App/start.cpp
  2. 2 0
      Enclave/Enclave.edl
  3. 51 1
      Enclave/route.cpp
  4. 9 1
      Enclave/sort.cpp
  5. 6 5
      Enclave/sort.hpp
  6. 7 0
      Untrusted/Untrusted.cpp
  7. 2 0
      Untrusted/Untrusted.hpp

+ 6 - 0
App/start.cpp

@@ -68,6 +68,12 @@ static void route_test(NetIO &netio, char **args)
         printf("\n");
     }
     */
+
+    // Precompute some WaksmanNetworks
+    size_t num_sizes = ecall_precompute_sort(-1);
+    for (int i=0;i<int(num_sizes);++i) {
+        ecall_precompute_sort(i);
+    }
     if (!ecall_ingest_raw(msgs, tot_tokens)) {
         printf("Ingestion failed\n");
         return;

+ 2 - 0
Enclave/Enclave.edl

@@ -35,6 +35,8 @@ enclave {
             [user_check] const uint8_t *chunkdata,
             uint32_t chunklen);
 
+        public size_t ecall_precompute_sort(int sizeidx);
+
         public bool ecall_ingest_raw(
             [user_check] uint8_t *msgs,
             uint32_t num_msgs);

+ 51 - 1
Enclave/route.cpp

@@ -2,8 +2,11 @@
 #include "Enclave_t.h"
 #include "config.hpp"
 #include "utils.hpp"
+#include "sort.hpp"
 #include "route.hpp"
 
+#define PROFILE_ROUTING
+
 struct MsgBuffer {
     pthread_mutex_t mutex;
     uint8_t *buf;
@@ -13,7 +16,9 @@ struct MsgBuffer {
     // The number of messages definitely in the buffer
     uint32_t inserted;
 
-    MsgBuffer() : buf(NULL), reserved(0), inserted(0) {}
+    MsgBuffer() : buf(NULL), reserved(0), inserted(0) {
+        pthread_mutex_init(&mutex, NULL);
+    }
 
     ~MsgBuffer() {
         delete[] buf;
@@ -132,9 +137,54 @@ bool route_init()
     route_state.max_msg_to_each_str = max_msg_to_each_str;
     route_state.max_round2_msgs = max_round2_msgs;
 
+    threadid_t nthreads = g_teems_config.nthreads;
+#ifdef PROFILE_ROUTING
+    unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
+#endif
+    sort_precompute_evalplan(tot_msg_per_ing, nthreads);
+    sort_precompute_evalplan(max_round2_msgs, nthreads);
+#ifdef PROFILE_ROUTING
+    printf_with_rtclock_diff(start, "end precompute evalplans\n");
+#endif
     return true;
 }
 
+// Precompute the WaksmanNetworks needed for the sorts.  If you pass -1,
+// it will return the number of different sizes it needs.  If you pass
+// [0,sizes-1], it will compute one WaksmanNetwork with that size index
+// and return the number of available WaksmanNetworks of that size.
+
+size_t ecall_precompute_sort(int sizeidx)
+{
+    size_t ret = 0;
+
+    switch(sizeidx) {
+    case 0:
+#ifdef PROFILE_ROUTING
+    {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.tot_msg_per_ing);
+#endif
+        ret = sort_precompute(route_state.tot_msg_per_ing);
+#ifdef PROFILE_ROUTING
+    printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.tot_msg_per_ing);}
+#endif
+        break;
+    case 1:
+#ifdef PROFILE_ROUTING
+    {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.max_round2_msgs);
+#endif
+        ret = sort_precompute(route_state.max_round2_msgs);
+#ifdef PROFILE_ROUTING
+    printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_round2_msgs);}
+#endif
+        break;
+    default:
+        ret = 2;
+        break;
+    }
+
+    return ret;
+}
+
 // Directly ingest a buffer of num_msgs messages into the round1 buffer.
 // Return true on success, false on failure.
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)

+ 9 - 1
Enclave/sort.cpp

@@ -7,12 +7,16 @@
 struct SizedWNs {
     pthread_mutex_t mutex;
     std::deque<WaksmanNetwork> wns;
+
+    SizedWNs() { pthread_mutex_init(&mutex, NULL); }
 };
 
 // A (mutexed) map mapping sizes to SizedWNs
 struct PrecompWNs {
     pthread_mutex_t mutex;
     std::map<uint32_t,SizedWNs> sized_wns;
+
+    PrecompWNs() { pthread_mutex_init(&mutex, NULL); }
 };
 
 static PrecompWNs precomp_wns;
@@ -21,11 +25,13 @@ static PrecompWNs precomp_wns;
 struct EvalPlans {
     pthread_mutex_t mutex;
     std::map<std::pair<uint32_t,threadid_t>,WNEvalPlan> eval_plans;
+
+    EvalPlans() { pthread_mutex_init(&mutex, NULL); }
 };
 
 static EvalPlans precomp_eps;
 
-void sort_precompute(uint32_t N)
+size_t sort_precompute(uint32_t N)
 {
     uint32_t *random_permutation = NULL;
     try {
@@ -48,7 +54,9 @@ void sort_precompute(uint32_t N)
     pthread_mutex_unlock(&precomp_wns.mutex);
     pthread_mutex_lock(&szwn.mutex);
     szwn.wns.push_back(std::move(wnet));
+    size_t ret = szwn.wns.size();
     pthread_mutex_unlock(&szwn.mutex);
+    return ret;
 }
 
 void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)

+ 6 - 5
Enclave/sort.hpp

@@ -23,14 +23,15 @@
 // Precompute a WaksmanNetwork of size N for a random permutation.  This
 // call does not itself use threads, but may be called from a background
 // thread.  These are consumed as they are used, so you need to keep
-// making more.
-void sort_precompute(uint32_t N);
+// making more.  Returns the number of WaksmanNetworks available at that
+// size.
+size_t sort_precompute(uint32_t N);
 
 // Precompute a WNEvalPlan for a given size and number of threads.
 // These are not consumed as they are used, so you only need to call
-// this once for each size/nthreads you need.  The precomputation itself
-// only uses a single thread, but also may be called from a background
-// thread.
+// this once for each (size,nthreads) pair you need.  The precomputation
+// itself only uses a single thread, but also may be called from a
+// background thread.
 void sort_precompute_evalplan(uint32_t N, threadid_t nthreads);
 
 // Perform the sort using up to nthreads threads.  The items to sort are

+ 7 - 0
Untrusted/Untrusted.cpp

@@ -264,6 +264,13 @@ bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
     return ret;
 }
 
+size_t ecall_precompute_sort(int sizeidx)
+{
+    size_t ret;
+    ecall_precompute_sort(global_eid, &ret, sizeidx);
+    return ret;
+}
+
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
 {
     bool ret;

+ 2 - 0
Untrusted/Untrusted.hpp

@@ -33,6 +33,8 @@ bool ecall_message(nodenum_t node_num, uint32_t message_len);
 bool ecall_chunk(nodenum_t node_num, const uint8_t *chunkdata,
     uint32_t chunklen);
 
+size_t ecall_precompute_sort(int size);
+
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs);
 
 #endif