|
@@ -18,8 +18,11 @@ struct MsgBuffer {
|
|
|
uint32_t inserted;
|
|
|
// The number of messages that can fit in buf
|
|
|
uint32_t bufsize;
|
|
|
+ // The number of nodes we've heard from
|
|
|
+ nodenum_t nodes_received;
|
|
|
|
|
|
- MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0) {
|
|
|
+ MsgBuffer() : buf(NULL), reserved(0), inserted(0), bufsize(0),
|
|
|
+ nodes_received(0) {
|
|
|
pthread_mutex_init(&mutex, NULL);
|
|
|
}
|
|
|
|
|
@@ -36,6 +39,7 @@ struct MsgBuffer {
|
|
|
// This may throw bad_alloc, but we'll catch it higher up
|
|
|
buf = new uint8_t[size_t(msgs) * g_teems_config.msg_size];
|
|
|
bufsize = msgs;
|
|
|
+ nodes_received = 0;
|
|
|
}
|
|
|
|
|
|
// Reset the contents of the buffer
|
|
@@ -43,6 +47,7 @@ struct MsgBuffer {
|
|
|
memset(buf, 0, bufsize * g_teems_config.msg_size);
|
|
|
reserved = 0;
|
|
|
inserted = 0;
|
|
|
+ nodes_received = 0;
|
|
|
}
|
|
|
|
|
|
// You can't copy a MsgBuffer
|
|
@@ -68,6 +73,7 @@ static struct RouteState {
|
|
|
uint32_t tot_msg_per_ing;
|
|
|
uint32_t max_msg_to_each_str;
|
|
|
uint32_t max_round2_msgs;
|
|
|
+ void *cbpointer;
|
|
|
} route_state;
|
|
|
|
|
|
// Computes ceil(x/y) where x and y are integers, x>=0, y>0.
|
|
@@ -147,6 +153,7 @@ bool route_init()
|
|
|
route_state.tot_msg_per_ing = tot_msg_per_ing;
|
|
|
route_state.max_msg_to_each_str = max_msg_to_each_str;
|
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
|
|
|
threadid_t nthreads = g_teems_config.nthreads;
|
|
|
#ifdef PROFILE_ROUTING
|
|
@@ -196,6 +203,107 @@ size_t ecall_precompute_sort(int sizeidx)
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
+static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
|
|
|
+ NodeCommState &, uint32_t tot_enc_chunk_size)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+
|
|
|
+ // Chunks will be encrypted and have a MAC tag attached which will
|
|
|
+ // not correspond to plaintext bytes, so we can trim them.
|
|
|
+
|
|
|
+ // The minimum number of chunks needed to transmit this message
|
|
|
+ uint32_t min_num_chunks =
|
|
|
+ (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
|
|
|
+ // The number of plaintext bytes this message could contain
|
|
|
+ uint32_t plaintext_bytes = tot_enc_chunk_size -
|
|
|
+ SGX_AESGCM_MAC_SIZE * min_num_chunks;
|
|
|
+
|
|
|
+ assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
|
|
|
+
|
|
|
+ uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
|
|
|
+
|
|
|
+ pthread_mutex_lock(&msgbuf.mutex);
|
|
|
+ uint32_t start = msgbuf.reserved;
|
|
|
+ if (start + num_msgs > msgbuf.bufsize) {
|
|
|
+ pthread_mutex_unlock(&msgbuf.mutex);
|
|
|
+ printf("Max %u messages exceeded\n", msgbuf.bufsize);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ msgbuf.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&msgbuf.mutex);
|
|
|
+
|
|
|
+ return msgbuf.buf + start * msg_size;
|
|
|
+}
|
|
|
+
|
|
|
+// A round 1 message was received by a routing node from an ingestion
|
|
|
+// node; we put it into the round 2 buffer for processing in round 2
|
|
|
+static void round1_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+
|
|
|
+ pthread_mutex_lock(&route_state.round2.mutex);
|
|
|
+ route_state.round2.inserted += num_msgs;
|
|
|
+ route_state.round2.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round2.nodes_received;
|
|
|
+ pthread_mutex_unlock(&route_state.round2.mutex);
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_ingestion_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// A round 2 message was received by a storage node from a routing node
|
|
|
+static void round2_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+}
|
|
|
+
|
|
|
+// For a given other node, set the received message handler to the first
|
|
|
+// message we would expect from them, given their roles and our roles.
|
|
|
+void route_init_msg_handler(nodenum_t node_num)
|
|
|
+{
|
|
|
+ // Our roles and their roles
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[node_num];
|
|
|
+
|
|
|
+ // The node communication state
|
|
|
+ NodeCommState &nodest = g_commstates[node_num];
|
|
|
+
|
|
|
+ // If we are a routing node (possibly among other roles) and they
|
|
|
+ // are an ingestion node (possibly among other roles), a round 1
|
|
|
+ // routing message is the first thing we expect from them. We put
|
|
|
+ // these messages into the round2 buffer for processing.
|
|
|
+ if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round2, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1_received;
|
|
|
+ }
|
|
|
+ // Otherwise, if we are a storage node (possibly among other roles)
|
|
|
+ // and they are a routing node (possibly among other roles), a round
|
|
|
+ // 2 routing message is the first thing we expect form them
|
|
|
+ else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
|
|
|
+ nodest.in_msg_get_buf = default_in_msg_get_buf;
|
|
|
+ nodest.in_msg_received = round2_received;
|
|
|
+ }
|
|
|
+ // Otherwise, we don't expect a message from this node. Set the
|
|
|
+ // unknown message handler.
|
|
|
+ else {
|
|
|
+ nodest.in_msg_get_buf = default_in_msg_get_buf;
|
|
|
+ nodest.in_msg_received = unknown_in_msg_received;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// Directly ingest a buffer of num_msgs messages into the round1 buffer.
|
|
|
// Return true on success, false on failure.
|
|
|
bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
|
|
@@ -293,7 +401,16 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
|
|
|
|
|
|
pthread_mutex_lock(&round2.mutex);
|
|
|
round2.inserted += num_msgs;
|
|
|
+ round2.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = round2.nodes_received;
|
|
|
pthread_mutex_unlock(&round2.mutex);
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_ingestion_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
} else {
|
|
|
NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
nodecom.message_start(num_msgs * msg_size);
|
|
@@ -309,16 +426,6 @@ static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- /*
|
|
|
- for (uint32_t i=0;i<N;++i) {
|
|
|
- const uint8_t *msg = msgs + indices[i]*msg_size;
|
|
|
- for (uint16_t j=0;j<msg_size/4;++j) {
|
|
|
- printf("%08x ", ((const uint32_t*)msg)[j]);
|
|
|
- }
|
|
|
- printf("\n");
|
|
|
- }
|
|
|
- */
|
|
|
}
|
|
|
|
|
|
// Perform the next round of routing. The callback pointer will be
|
|
@@ -327,6 +434,7 @@ void ecall_routing_proceed(void *cbpointer)
|
|
|
{
|
|
|
if (route_state.step == ROUTE_NOT_STARTED) {
|
|
|
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
MsgBuffer &round1 = route_state.round1;
|
|
|
|
|
|
pthread_mutex_lock(&round1.mutex);
|
|
@@ -348,7 +456,5 @@ void ecall_routing_proceed(void *cbpointer)
|
|
|
#ifdef PROFILE_ROUTING
|
|
|
printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
#endif
|
|
|
- route_state.step = ROUTE_ROUND_1;
|
|
|
- ocall_routing_round_complete(cbpointer, 1);
|
|
|
}
|
|
|
}
|