Browse Source

Add a callback that fires when all the network handshakes are complete

Ian Goldberg 1 year ago
parent
commit
f10fdca30f
6 changed files with 105 additions and 38 deletions
  1. 0 15
      App/start.cpp
  2. 32 1
      App/teems.cpp
  3. 5 1
      Enclave/Enclave.edl
  4. 53 17
      Enclave/comms.cpp
  5. 13 2
      Untrusted/Untrusted.cpp
  6. 2 2
      Untrusted/Untrusted.hpp

+ 0 - 15
App/start.cpp

@@ -105,19 +105,4 @@ void start(NetIO &netio, char **args)
         route_test(netio, args);
         return;
     }
-    printf("Reading\n");
-    for (nodenum_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
-        if (node_num == netio.me) continue;
-        NodeIO &node = netio.node(node_num);
-        node.recv_commands(
-            // error_cb
-            [](boost::system::error_code) {
-                printf("Error\n");
-            },
-            // epoch_cb
-            [](uint32_t epoch) {
-                printf("Epoch %u\n", epoch);
-            });
-    }
-
 }

+ 32 - 1
App/teems.cpp

@@ -1,6 +1,8 @@
 #include <iostream>
 #include <cstdio>
 #include <cstring>
+#include <condition_variable>
+#include <mutex>
 
 #include <boost/asio.hpp>
 #include <boost/thread.hpp>
@@ -243,8 +245,37 @@ int main(int argc, char **argv)
 
     // Queue up the actual work
     boost::asio::post(io_context, [&]{
+        // Create a condition variable and mutex to block on until
+        // communication with all other nodes is established
+        bool comms_ready = false;
+        std::mutex m;
+        std::condition_variable cv;
+
         // Start enclave-to-enclave communications
-        ecall_comms_start();
+        ecall_comms_start([&]{
+            {
+                std::lock_guard lk(m);
+                comms_ready = true;
+            }
+            cv.notify_one();
+        });
+        printf("Reading\n");
+        for (nodenum_t node_num = 0; node_num < netio.num_nodes; ++node_num) {
+            if (node_num == netio.me) continue;
+            NodeIO &node = netio.node(node_num);
+            node.recv_commands(
+                // error_cb
+                [](boost::system::error_code) {
+                    printf("Error\n");
+                },
+                // epoch_cb
+                [](uint32_t epoch) {
+                    printf("Epoch %u\n", epoch);
+                });
+        }
+        std::unique_lock lk(m);
+        cv.wait(lk, [&]{ return comms_ready; });
+        printf("Starting\n");
         start(netio, argv);
     });
 

+ 5 - 1
Enclave/Enclave.edl

@@ -25,7 +25,8 @@ enclave {
 
         public void ecall_close();
 
-        public bool ecall_comms_start();
+        public bool ecall_comms_start(
+            [user_check]void *cbpointer);
 
         public bool ecall_message(
             nodenum_t node_num, uint32_t message_len);
@@ -61,6 +62,9 @@ enclave {
             [user_check] uint8_t *chunkdata,
             uint32_t chunklen);
 
+        void ocall_comms_ready(
+            [user_check] void *cbpointer);
+
         void ocall_routing_round_complete(
             [user_check] void *cbpointer,
             uint32_t round_num);

+ 53 - 17
Enclave/comms.cpp

@@ -1,6 +1,7 @@
 #include <vector>
 #include <functional>
 #include <cstring>
+#include <pthread.h>
 
 #include "sgx_tcrypto.h"
 #include "sgx_tseal.h"
@@ -109,6 +110,48 @@ struct NodeCommState {
     void handshake_start();
 };
 
+// The communication states for all the nodes.  There's an entry for
+// ourselves in here, but it is unused.
+static std::vector<NodeCommState> commstates;
+static nodenum_t tot_nodes, my_node_num;
+static class CompletedHandshakeCounter {
+    // Mutex around completed_handshakes
+    pthread_mutex_t mutex;
+    // The number of completed handshakes
+    nodenum_t completed_handshakes;
+    // The callback pointer to use when all handshakes complete
+    void *complete_handshake_cbpointer;
+
+public:
+    CompletedHandshakeCounter() {
+        pthread_mutex_init(&mutex, NULL);
+        completed_handshakes = 0;
+        complete_handshake_cbpointer = NULL;
+    }
+
+    void reset(void *cbpointer) {
+        pthread_mutex_lock(&mutex);
+        completed_handshakes = 0;
+        complete_handshake_cbpointer = cbpointer;
+        pthread_mutex_unlock(&mutex);
+    }
+
+    void inc() {
+        pthread_mutex_lock(&mutex);
+        ++completed_handshakes;
+        nodenum_t num_completed = completed_handshakes;
+        pthread_mutex_unlock(&mutex);
+        if (num_completed == tot_nodes - 1) {
+            pthread_mutex_lock(&mutex);
+            void *cbpointer = complete_handshake_cbpointer;
+            complete_handshake_cbpointer = NULL;
+            completed_handshakes = 0;
+            pthread_mutex_unlock(&mutex);
+            ocall_comms_ready(cbpointer);
+        }
+    }
+} completed_handshake_counter;
+
 // A typical default in_msg_get_buf handler.  It computes the maximum
 // possible size of the decrypted data, allocates that much memory, and
 // returns a pointer to it.
@@ -389,11 +432,8 @@ static void handshake_2_msg_received(NodeCommState &nodest,
     nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
         false);
 
-    // Send a test message
-    nodest.message_start(12);
-    unsigned char buf[13];
-    memmove(buf, "Hello, world", 13);
-    nodest.message_data(buf, 12);
+    // Mark the handshake as complete
+    completed_handshake_counter.inc();
 }
 
 static void handshake_3_msg_received(NodeCommState &nodest,
@@ -438,11 +478,8 @@ static void handshake_3_msg_received(NodeCommState &nodest,
     nodest.in_msg_get_buf = default_in_msg_get_buf;
     nodest.in_msg_received = default_in_msg_received;
 
-    // Send a test message
-    nodest.message_start(12);
-    unsigned char buf[13];
-    memmove(buf, "Hello, world", 13);
-    nodest.message_data(buf, 12);
+    // Mark the handshake as complete
+    completed_handshake_counter.inc();
 }
 
 // Start a new outgoing message.  Pass the number of _plaintext_ bytes
@@ -543,11 +580,6 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
     }
 }
 
-// The communication states for all the nodes.  There's an entry for
-// ourselves in here, but it is unused.
-static std::vector<NodeCommState> commstates;
-static nodenum_t tot_nodes, my_node_num;
-
 // Generate a new identity signature key.  Output the public key and the
 // sealed private key.  outsealedpriv must point to SEALEDPRIVKEY_SIZE =
 // sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
@@ -780,9 +812,13 @@ void NodeCommState::handshake_start()
         sizeof(handshake_dh_pubkey), false);
 }
 
-// Start all handshakes for which we are the client
-bool ecall_comms_start()
+// Start all handshakes for which we are the client.  Call
+// ocall_comms_ready(cbpointer) when the handshakes with all other nodes
+// (for which we are client or server) are complete.
+bool ecall_comms_start(void *cbpointer)
 {
+    completed_handshake_counter.reset(cbpointer);
+
     for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
         commstates[t].handshake_start();
     }

+ 13 - 2
Untrusted/Untrusted.cpp

@@ -242,10 +242,12 @@ void ecall_close()
     ecall_close(global_eid);
 }
 
-bool ecall_comms_start()
+bool ecall_comms_start(std::function<void(void)> cb)
 {
+    std::function<void(void)> *p = new std::function<void(void)>;
+    *p = cb;
     bool ret;
-    ecall_comms_start(global_eid, &ret);
+    ecall_comms_start(global_eid, &ret, p);
     return ret;
 }
 
@@ -285,6 +287,15 @@ void ecall_routing_proceed(std::function<void(uint32_t)> cb)
     ecall_routing_proceed(global_eid, p);
 }
 
+void ocall_comms_ready(void *cbpointer)
+{
+    std::function<void(void)> *p =
+        (std::function<void(void)> *)cbpointer;
+    std::function<void(void)> f = *p;
+    delete p;
+    f();
+}
+
 void ocall_routing_round_complete(void *cbpointer, uint32_t round_num)
 {
     std::function<void(uint32_t)> *p =

+ 2 - 2
Untrusted/Untrusted.hpp

@@ -27,7 +27,7 @@ bool ecall_config_load(threadid_t nthreads,
 
 void ecall_close();
 
-bool ecall_comms_start();
+bool ecall_comms_start(std::function<void(void)> cb);
 
 bool ecall_message(nodenum_t node_num, uint32_t message_len);
 
@@ -38,6 +38,6 @@ size_t ecall_precompute_sort(int size);
 
 bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs);
 
-void ecall_routing_proceed(std::function<void(uint32_t)>);
+void ecall_routing_proceed(std::function<void(uint32_t)> cb);
 
 #endif