|
@@ -1,6 +1,7 @@
|
|
|
#include <vector>
|
|
|
#include <functional>
|
|
|
#include <cstring>
|
|
|
+#include <pthread.h>
|
|
|
|
|
|
#include "sgx_tcrypto.h"
|
|
|
#include "sgx_tseal.h"
|
|
@@ -109,6 +110,48 @@ struct NodeCommState {
|
|
|
void handshake_start();
|
|
|
};
|
|
|
|
|
|
+// The communication states for all the nodes. There's an entry for
|
|
|
+// ourselves in here, but it is unused.
|
|
|
+static std::vector<NodeCommState> commstates;
|
|
|
+static nodenum_t tot_nodes, my_node_num;
|
|
|
+static class CompletedHandshakeCounter {
|
|
|
+ // Mutex around completed_handshakes
|
|
|
+ pthread_mutex_t mutex;
|
|
|
+ // The number of completed handshakes
|
|
|
+ nodenum_t completed_handshakes;
|
|
|
+ // The callback pointer to use when all handshakes complete
|
|
|
+ void *complete_handshake_cbpointer;
|
|
|
+
|
|
|
+public:
|
|
|
+ CompletedHandshakeCounter() {
|
|
|
+ pthread_mutex_init(&mutex, NULL);
|
|
|
+ completed_handshakes = 0;
|
|
|
+ complete_handshake_cbpointer = NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ void reset(void *cbpointer) {
|
|
|
+ pthread_mutex_lock(&mutex);
|
|
|
+ completed_handshakes = 0;
|
|
|
+ complete_handshake_cbpointer = cbpointer;
|
|
|
+ pthread_mutex_unlock(&mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+ void inc() {
|
|
|
+ pthread_mutex_lock(&mutex);
|
|
|
+ ++completed_handshakes;
|
|
|
+ nodenum_t num_completed = completed_handshakes;
|
|
|
+ pthread_mutex_unlock(&mutex);
|
|
|
+ if (num_completed == tot_nodes - 1) {
|
|
|
+ pthread_mutex_lock(&mutex);
|
|
|
+ void *cbpointer = complete_handshake_cbpointer;
|
|
|
+ complete_handshake_cbpointer = NULL;
|
|
|
+ completed_handshakes = 0;
|
|
|
+ pthread_mutex_unlock(&mutex);
|
|
|
+ ocall_comms_ready(cbpointer);
|
|
|
+ }
|
|
|
+ }
|
|
|
+} completed_handshake_counter;
|
|
|
+
|
|
|
// A typical default in_msg_get_buf handler. It computes the maximum
|
|
|
// possible size of the decrypted data, allocates that much memory, and
|
|
|
// returns a pointer to it.
|
|
@@ -389,11 +432,8 @@ static void handshake_2_msg_received(NodeCommState &nodest,
|
|
|
nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
|
|
|
false);
|
|
|
|
|
|
- // Send a test message
|
|
|
- nodest.message_start(12);
|
|
|
- unsigned char buf[13];
|
|
|
- memmove(buf, "Hello, world", 13);
|
|
|
- nodest.message_data(buf, 12);
|
|
|
+ // Mark the handshake as complete
|
|
|
+ completed_handshake_counter.inc();
|
|
|
}
|
|
|
|
|
|
static void handshake_3_msg_received(NodeCommState &nodest,
|
|
@@ -438,11 +478,8 @@ static void handshake_3_msg_received(NodeCommState &nodest,
|
|
|
nodest.in_msg_get_buf = default_in_msg_get_buf;
|
|
|
nodest.in_msg_received = default_in_msg_received;
|
|
|
|
|
|
- // Send a test message
|
|
|
- nodest.message_start(12);
|
|
|
- unsigned char buf[13];
|
|
|
- memmove(buf, "Hello, world", 13);
|
|
|
- nodest.message_data(buf, 12);
|
|
|
+ // Mark the handshake as complete
|
|
|
+ completed_handshake_counter.inc();
|
|
|
}
|
|
|
|
|
|
// Start a new outgoing message. Pass the number of _plaintext_ bytes
|
|
@@ -543,11 +580,6 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// The communication states for all the nodes. There's an entry for
|
|
|
-// ourselves in here, but it is unused.
|
|
|
-static std::vector<NodeCommState> commstates;
|
|
|
-static nodenum_t tot_nodes, my_node_num;
|
|
|
-
|
|
|
// Generate a new identity signature key. Output the public key and the
|
|
|
// sealed private key. outsealedpriv must point to SEALEDPRIVKEY_SIZE =
|
|
|
// sizeof(sgx_sealed_data_t) + sizeof(sgx_ec256_private_t) + 18 bytes of
|
|
@@ -780,9 +812,13 @@ void NodeCommState::handshake_start()
|
|
|
sizeof(handshake_dh_pubkey), false);
|
|
|
}
|
|
|
|
|
|
-// Start all handshakes for which we are the client
|
|
|
-bool ecall_comms_start()
|
|
|
+// Start all handshakes for which we are the client. Call
|
|
|
+// ocall_comms_ready(cbpointer) when the handshakes with all other nodes
|
|
|
+// (for which we are client or server) are complete.
|
|
|
+bool ecall_comms_start(void *cbpointer)
|
|
|
{
|
|
|
+ completed_handshake_counter.reset(cbpointer);
|
|
|
+
|
|
|
for (nodenum_t t = my_node_num+1; t<tot_nodes; ++t) {
|
|
|
commstates[t].handshake_start();
|
|
|
}
|