|
@@ -2,6 +2,9 @@
|
|
|
#include "config.hpp"
|
|
|
#include "utils.hpp"
|
|
|
#include "sort.hpp"
|
|
|
+#include "comms.hpp"
|
|
|
+#include "obliv.hpp"
|
|
|
+#include "storage.hpp"
|
|
|
#include "route.hpp"
|
|
|
|
|
|
#define PROFILE_ROUTING
|
|
@@ -12,18 +15,22 @@ enum RouteStep {
|
|
|
ROUTE_ROUND_2
|
|
|
};
|
|
|
|
|
|
-// The round1 MsgBuffer stores messages we ingest while waiting for
|
|
|
-// round 1 to start, which will be sorted and sent out in round 1. The
|
|
|
-// round2 MsgBuffer stores messages we receive in round 1, which will be
|
|
|
-// padded, sorted, and sent out in round 2.
|
|
|
+// The ingbuf MsgBuffer stores messages an ingestion node ingests while
|
|
|
+// waiting for round 1 to start, which will be sorted and sent out in
|
|
|
+// round 1. The round1 MsgBuffer stores messages a routing node
|
|
|
+// receives in round 1, which will be padded, sorted, and sent out in
|
|
|
+// round 2. The round2 MsgBuffer stores messages a storage node
|
|
|
+// receives in round 2.
|
|
|
|
|
|
static struct RouteState {
|
|
|
+ MsgBuffer ingbuf;
|
|
|
MsgBuffer round1;
|
|
|
MsgBuffer round2;
|
|
|
RouteStep step;
|
|
|
uint32_t tot_msg_per_ing;
|
|
|
- uint32_t max_msg_to_each_str;
|
|
|
+ uint32_t max_msg_to_each_stg;
|
|
|
uint32_t max_round2_msgs;
|
|
|
+ void *cbpointer;
|
|
|
} route_state;
|
|
|
|
|
|
// Computes ceil(x/y) where x and y are integers, x>=0, y>0.
|
|
@@ -58,25 +65,25 @@ bool route_init()
|
|
|
// Compute the maximum number of messages we could send in round 2
|
|
|
|
|
|
// Each storage node has at most this many users
|
|
|
- uint32_t users_per_str = CEILDIV(g_teems_config.user_count,
|
|
|
+ uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
|
|
|
g_teems_config.num_storage_nodes);
|
|
|
|
|
|
// And so can receive at most this many messages
|
|
|
- uint32_t tot_msg_per_str = users_per_str *
|
|
|
+ uint32_t tot_msg_per_stg = users_per_stg *
|
|
|
g_teems_config.m_priv_in;
|
|
|
|
|
|
// Which will be at most this many from us
|
|
|
- uint32_t max_msg_to_each_str = CEILDIV(tot_msg_per_str,
|
|
|
+ uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
|
|
|
g_teems_config.tot_weight) * g_teems_config.my_weight;
|
|
|
|
|
|
// But we can't send more messages to each storage server than we
|
|
|
// could receive in total
|
|
|
- if (max_msg_to_each_str > max_round1_msgs) {
|
|
|
- max_msg_to_each_str = max_round1_msgs;
|
|
|
+ if (max_msg_to_each_stg > max_round1_msgs) {
|
|
|
+ max_msg_to_each_stg = max_round1_msgs;
|
|
|
}
|
|
|
|
|
|
// And the max total number of outgoing messages in round 2 is then
|
|
|
- uint32_t max_round2_msgs = max_msg_to_each_str *
|
|
|
+ uint32_t max_round2_msgs = max_msg_to_each_stg *
|
|
|
g_teems_config.num_storage_nodes;
|
|
|
|
|
|
// In case we have a weird configuration where users can send more
|
|
@@ -92,17 +99,27 @@ bool route_init()
|
|
|
*/
|
|
|
|
|
|
// Create the route state
|
|
|
+ uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
try {
|
|
|
- route_state.round1.alloc(tot_msg_per_ing);
|
|
|
- route_state.round2.alloc(max_round2_msgs);
|
|
|
+ if (my_roles & ROLE_INGESTION) {
|
|
|
+ route_state.ingbuf.alloc(tot_msg_per_ing);
|
|
|
+ }
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.round1.alloc(max_round2_msgs);
|
|
|
+ }
|
|
|
+ if (my_roles & ROLE_STORAGE) {
|
|
|
+ route_state.round2.alloc(tot_msg_per_stg +
|
|
|
+ g_teems_config.tot_weight);
|
|
|
+ }
|
|
|
} catch (std::bad_alloc&) {
|
|
|
printf("Memory allocation failed in route_init\n");
|
|
|
return false;
|
|
|
}
|
|
|
route_state.step = ROUTE_NOT_STARTED;
|
|
|
route_state.tot_msg_per_ing = tot_msg_per_ing;
|
|
|
- route_state.max_msg_to_each_str = max_msg_to_each_str;
|
|
|
+ route_state.max_msg_to_each_stg = max_msg_to_each_stg;
|
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
|
|
|
threadid_t nthreads = g_teems_config.nthreads;
|
|
|
#ifdef PROFILE_ROUTING
|
|
@@ -152,80 +169,523 @@ size_t ecall_precompute_sort(int sizeidx)
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
-// Directly ingest a buffer of num_msgs messages into the round1 buffer.
|
|
|
+static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
|
|
|
+ NodeCommState &, uint32_t tot_enc_chunk_size)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+
|
|
|
+ // Chunks will be encrypted and have a MAC tag attached which will
|
|
|
+ // not correspond to plaintext bytes, so we can trim them.
|
|
|
+
|
|
|
+ // The minimum number of chunks needed to transmit this message
|
|
|
+ uint32_t min_num_chunks =
|
|
|
+ (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
|
|
|
+ // The number of plaintext bytes this message could contain
|
|
|
+ uint32_t plaintext_bytes = tot_enc_chunk_size -
|
|
|
+ SGX_AESGCM_MAC_SIZE * min_num_chunks;
|
|
|
+
|
|
|
+ assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
|
|
|
+
|
|
|
+ uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
|
|
|
+
|
|
|
+ pthread_mutex_lock(&msgbuf.mutex);
|
|
|
+ uint32_t start = msgbuf.reserved;
|
|
|
+ if (start + num_msgs > msgbuf.bufsize) {
|
|
|
+ pthread_mutex_unlock(&msgbuf.mutex);
|
|
|
+ printf("Max %u messages exceeded\n", msgbuf.bufsize);
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ msgbuf.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&msgbuf.mutex);
|
|
|
+
|
|
|
+ return msgbuf.buf + start * msg_size;
|
|
|
+}
|
|
|
+
|
|
|
+static void round2_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
+
|
|
|
+// A round 1 message was received by a routing node from an ingestion
|
|
|
+// node; we put it into the round 2 buffer for processing in round 2
|
|
|
+static void round1_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
|
+
|
|
|
+ pthread_mutex_lock(&route_state.round1.mutex);
|
|
|
+ route_state.round1.inserted += num_msgs;
|
|
|
+ route_state.round1.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round1.nodes_received;
|
|
|
+ bool completed_prev_round = route_state.round1.completed_prev_round;
|
|
|
+ pthread_mutex_unlock(&route_state.round1.mutex);
|
|
|
+
|
|
|
+ // What is the next message we expect from this node?
|
|
|
+ if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round2, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round2_received;
|
|
|
+ }
|
|
|
+ // Otherwise, it's just the next round 1 message, so don't change
|
|
|
+ // the handlers.
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_ingestion_nodes &&
|
|
|
+ completed_prev_round) {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// A round 2 message was received by a storage node from a routing node
|
|
|
+static void round2_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
|
+
|
|
|
+ pthread_mutex_lock(&route_state.round2.mutex);
|
|
|
+ route_state.round2.inserted += num_msgs;
|
|
|
+ route_state.round2.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round2.nodes_received;
|
|
|
+ bool completed_prev_round = route_state.round2.completed_prev_round;
|
|
|
+ pthread_mutex_unlock(&route_state.round2.mutex);
|
|
|
+
|
|
|
+ // What is the next message we expect from this node?
|
|
|
+ if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1_received;
|
|
|
+ }
|
|
|
+ // Otherwise, it's just the next round 2 message, so don't change
|
|
|
+ // the handlers.
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes &&
|
|
|
+ completed_prev_round) {
|
|
|
+ route_state.step = ROUTE_ROUND_2;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 2);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// For a given other node, set the received message handler to the first
|
|
|
+// message we would expect from them, given their roles and our roles.
|
|
|
+void route_init_msg_handler(nodenum_t node_num)
|
|
|
+{
|
|
|
+ // Our roles and their roles
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[node_num];
|
|
|
+
|
|
|
+ // The node communication state
|
|
|
+ NodeCommState &nodest = g_commstates[node_num];
|
|
|
+
|
|
|
+ // If we are a routing node (possibly among other roles) and they
|
|
|
+ // are an ingestion node (possibly among other roles), a round 1
|
|
|
+ // routing message is the first thing we expect from them. We put
|
|
|
+ // these messages into the round1 buffer for processing.
|
|
|
+ if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1_received;
|
|
|
+ }
|
|
|
+ // Otherwise, if we are a storage node (possibly among other roles)
|
|
|
+ // and they are a routing node (possibly among other roles), a round
|
|
|
+ // 2 routing message is the first thing we expect from them
|
|
|
+ else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round2, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round2_received;
|
|
|
+ }
|
|
|
+ // Otherwise, we don't expect a message from this node. Set the
|
|
|
+ // unknown message handler.
|
|
|
+ else {
|
|
|
+ nodest.in_msg_get_buf = default_in_msg_get_buf;
|
|
|
+ nodest.in_msg_received = unknown_in_msg_received;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
|
|
|
// Return true on success, false on failure.
|
|
|
bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
|
|
|
{
|
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
|
- MsgBuffer &round1 = route_state.round1;
|
|
|
+ MsgBuffer &ingbuf = route_state.ingbuf;
|
|
|
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
- uint32_t start = round1.reserved;
|
|
|
+ pthread_mutex_lock(&ingbuf.mutex);
|
|
|
+ uint32_t start = ingbuf.reserved;
|
|
|
if (start + num_msgs > route_state.tot_msg_per_ing) {
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
+ pthread_mutex_unlock(&ingbuf.mutex);
|
|
|
printf("Max %u messages exceeded\n",
|
|
|
route_state.tot_msg_per_ing);
|
|
|
return false;
|
|
|
}
|
|
|
- round1.reserved += num_msgs;
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
+ ingbuf.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&ingbuf.mutex);
|
|
|
|
|
|
- memmove(round1.buf + start * msg_size,
|
|
|
+ memmove(ingbuf.buf + start * msg_size,
|
|
|
msgs, num_msgs * msg_size);
|
|
|
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
- round1.inserted += num_msgs;
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
+ pthread_mutex_lock(&ingbuf.mutex);
|
|
|
+ ingbuf.inserted += num_msgs;
|
|
|
+ pthread_mutex_unlock(&ingbuf.mutex);
|
|
|
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
-// Send the round 1 messages
|
|
|
+// Send the round 1 messages. Note that N here is not private.
|
|
|
static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
|
|
|
uint32_t N)
|
|
|
{
|
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
|
uint16_t tot_weight = g_teems_config.tot_weight;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
|
|
|
- /*
|
|
|
- for (uint32_t i=0;i<N;++i) {
|
|
|
- const uint8_t *msg = msgs + indices[i]*msg_size;
|
|
|
- for (uint16_t j=0;j<msg_size/4;++j) {
|
|
|
- printf("%08x ", ((const uint32_t*)msg)[j]);
|
|
|
+ uint32_t full_rows = N / uint32_t(tot_weight);
|
|
|
+ uint32_t last_row = N % uint32_t(tot_weight);
|
|
|
+
|
|
|
+ for (auto &routing_node: g_teems_config.routing_nodes) {
|
|
|
+ uint8_t weight =
|
|
|
+ g_teems_config.weights[routing_node].weight;
|
|
|
+ if (weight == 0) {
|
|
|
+ // This shouldn't happen, but just in case
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ uint16_t start_weight =
|
|
|
+ g_teems_config.weights[routing_node].startweight;
|
|
|
+
|
|
|
+ // The number of messages headed for this routing node from the
|
|
|
+ // full rows
|
|
|
+ uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
|
|
|
+ // The number of messages headed for this routing node from the
|
|
|
+ // incomplete last row is:
|
|
|
+ // 0 if last_row < start_weight
|
|
|
+ // last_row-start_weight if start_weight <= last_row < start_weight + weight
|
|
|
+ // weight if start_weight + weight <= last_row
|
|
|
+ uint32_t num_msgs_last_row = 0;
|
|
|
+ if (start_weight <= last_row && last_row < start_weight + weight) {
|
|
|
+ num_msgs_last_row = last_row-start_weight;
|
|
|
+ } else if (start_weight + weight <= last_row) {
|
|
|
+ num_msgs_last_row = weight;
|
|
|
+ }
|
|
|
+ // The total number of messages headed for this routing node
|
|
|
+ uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
|
|
|
+
|
|
|
+ if (routing_node == my_node_num) {
|
|
|
+ // Special case: we're sending to ourselves; just put the
|
|
|
+ // messages in our own round1 buffer
|
|
|
+ MsgBuffer &round1 = route_state.round1;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ uint32_t start = round1.reserved;
|
|
|
+ if (start + num_msgs > round1.bufsize) {
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+ printf("Max %u messages exceeded\n", round1.bufsize);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ round1.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+ uint8_t *buf = round1.buf + start * msg_size;
|
|
|
+
|
|
|
+ for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
+ const uint64_t *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<weight; ++j) {
|
|
|
+ memmove(buf, msgs + idxp[j]*msg_size, msg_size);
|
|
|
+ buf += msg_size;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
+ memmove(buf, msgs + idxp[j]*msg_size, msg_size);
|
|
|
+ buf += msg_size;
|
|
|
+ }
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ round1.inserted += num_msgs;
|
|
|
+ round1.nodes_received += 1;
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+
|
|
|
+ } else {
|
|
|
+ NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
+ nodecom.message_start(num_msgs * msg_size);
|
|
|
+ for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
+ const uint64_t *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<weight; ++j) {
|
|
|
+ nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
+ nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
|
|
|
+ }
|
|
|
}
|
|
|
- printf("\n");
|
|
|
}
|
|
|
- */
|
|
|
+}
|
|
|
+
|
|
|
+// Send the round 2 messages from the round 1 buffer, which are already
|
|
|
+// padded and shuffled, so this can be done non-obliviously. tot_msgs
|
|
|
+// is the total number of messages in the input buffer, which may
|
|
|
+// include padding messages added by the shuffle. Those messages are
|
|
|
+// not sent anywhere. There are num_msgs_per_stg messages for each
|
|
|
+// storage node labelled for that node.
|
|
|
+static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ MsgBuffer &round1 = route_state.round1;
|
|
|
+ const uint8_t* buf = round1.buf;
|
|
|
+ nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ uint8_t *myself_buf = NULL;
|
|
|
+
|
|
|
+ for (nodenum_t i=0; i<num_storage_nodes; ++i) {
|
|
|
+ nodenum_t node = g_teems_config.storage_nodes[i];
|
|
|
+ if (node != my_node_num) {
|
|
|
+ g_commstates[node].message_start(msg_size * num_msgs_per_stg);
|
|
|
+ } else {
|
|
|
+ MsgBuffer &round2 = route_state.round2;
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ uint32_t start = round2.reserved;
|
|
|
+ if (start + num_msgs_per_stg > round2.bufsize) {
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+ printf("Max %u messages exceeded\n", round2.bufsize);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ round2.reserved += num_msgs_per_stg;
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+ myself_buf = round2.buf + start * msg_size;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ while (tot_msgs) {
|
|
|
+ nodenum_t storage_node_id =
|
|
|
+ nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
|
|
|
+ if (storage_node_id < num_storage_nodes) {
|
|
|
+ nodenum_t node = g_teems_config.storage_map[storage_node_id];
|
|
|
+ if (node == my_node_num) {
|
|
|
+ memmove(myself_buf, buf, msg_size);
|
|
|
+ myself_buf += msg_size;
|
|
|
+ } else {
|
|
|
+ g_commstates[node].message_data(buf, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ buf += msg_size;
|
|
|
+ --tot_msgs;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (myself_buf) {
|
|
|
+ MsgBuffer &round2 = route_state.round2;
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ round2.inserted += num_msgs_per_stg;
|
|
|
+ round2.nodes_received += 1;
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Perform the next round of routing. The callback pointer will be
|
|
|
// passed to ocall_routing_round_complete when the round is complete.
|
|
|
void ecall_routing_proceed(void *cbpointer)
|
|
|
{
|
|
|
- if (route_state.step == ROUTE_NOT_STARTED) {
|
|
|
+ uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
|
|
|
- MsgBuffer &round1 = route_state.round1;
|
|
|
+ if (route_state.step == ROUTE_NOT_STARTED) {
|
|
|
+ if (my_roles & ROLE_INGESTION) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ MsgBuffer &ingbuf = route_state.ingbuf;
|
|
|
+ MsgBuffer &round1 = route_state.round1;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&ingbuf.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (ingbuf.reserved != ingbuf.inserted) {
|
|
|
+ pthread_mutex_unlock(&ingbuf.mutex);
|
|
|
+ pthread_mutex_lock(&ingbuf.mutex);
|
|
|
+ }
|
|
|
+ // Sort the messages we've received
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = ingbuf.inserted;
|
|
|
+ unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
+#endif
|
|
|
+ sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
|
|
|
+ g_teems_config.msg_size, ingbuf.inserted,
|
|
|
+ route_state.tot_msg_per_ing, send_round1_msgs);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
+ printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+ ingbuf.reset();
|
|
|
+ pthread_mutex_unlock(&ingbuf.mutex);
|
|
|
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
- // Ensure there are no pending messages currently being inserted
|
|
|
- // into the buffer
|
|
|
- while (round1.reserved != round1.inserted) {
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ round1.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round1.nodes_received;
|
|
|
pthread_mutex_unlock(&round1.mutex);
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_ingestion_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
+ } else if (route_state.step == ROUTE_ROUND_1) {
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ MsgBuffer &round1 = route_state.round1;
|
|
|
+ MsgBuffer &round2 = route_state.round2;
|
|
|
+
|
|
|
pthread_mutex_lock(&round1.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (round1.reserved != round1.inserted) {
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+ // If the _total_ number of messages we received in round 1
|
|
|
+ // is less than the max number of messages we could send to
|
|
|
+ // _each_ storage node, then cap the number of messages we
|
|
|
+ // will send to each storage node to that number.
|
|
|
+ uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
|
|
|
+ if (round1.inserted < msgs_per_stg) {
|
|
|
+ msgs_per_stg = round1.inserted;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Note: at this point, it is required that each message in
|
|
|
+ // the round1 buffer have a _valid_ storage node id field.
|
|
|
+
|
|
|
+ // Obliviously tally the number of messages we received in
|
|
|
+ // round1 destined for each storage node
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = round1.inserted;
|
|
|
+ unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
|
|
|
+ unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
|
|
|
+ std::vector<uint32_t> tally = obliv_tally_stg(
|
|
|
+ round1.buf, msg_size, round1.inserted, num_storage_nodes);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+
|
|
|
+ // Note: tally contains private values! It's OK to
|
|
|
+ // non-obliviously check for an error condition, though.
|
|
|
+ // While we're at it, obliviously change the tally of
|
|
|
+ // messages received to a tally of padding messages
|
|
|
+ // required.
|
|
|
+ uint32_t tot_padding = 0;
|
|
|
+ for (nodenum_t i=0; i<num_storage_nodes; ++i) {
|
|
|
+ if (tally[i] > msgs_per_stg) {
|
|
|
+ printf("Received too many messages for storage node %u\n", i);
|
|
|
+ assert(tally[i] <= msgs_per_stg);
|
|
|
+ }
|
|
|
+ tally[i] = msgs_per_stg - tally[i];
|
|
|
+ tot_padding += tally[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ round1.reserved += tot_padding;
|
|
|
+ assert(round1.reserved <= round1.bufsize);
|
|
|
+
|
|
|
+ // Obliviously add padding for each storage node according
|
|
|
+ // to the (private) padding tally.
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
|
|
|
+#endif
|
|
|
+ obliv_pad_stg(round1.buf + round1.inserted * msg_size,
|
|
|
+ msg_size, tally, tot_padding);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
|
|
|
+#endif
|
|
|
+
|
|
|
+ round1.inserted += tot_padding;
|
|
|
+
|
|
|
+ // Obliviously shuffle the messages
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
|
|
|
+#endif
|
|
|
+ uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
|
|
|
+ round1.buf, msg_size, round1.inserted, round1.bufsize);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
|
|
|
+ printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
|
|
|
+#endif
|
|
|
+
|
|
|
+ // Now we can handle the messages non-obliviously, since we
|
|
|
+ // know there will be exactly msgs_per_stg messages to each
|
|
|
+ // storage node, and the oblivious shuffle broke the
|
|
|
+ // connection between where each message came from and where
|
|
|
+ // it's going.
|
|
|
+ send_round2_msgs(num_shuffled, msgs_per_stg);
|
|
|
+
|
|
|
+ round1.reset();
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ round2.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round2.nodes_received;
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_2;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 2);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_2;
|
|
|
+ ocall_routing_round_complete(cbpointer, 2);
|
|
|
}
|
|
|
- // Sort the messages we've received
|
|
|
+ } else if (route_state.step == ROUTE_ROUND_2) {
|
|
|
+ if (my_roles & ROLE_STORAGE) {
|
|
|
+ MsgBuffer &round2 = route_state.round2;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (round2.reserved != round2.inserted) {
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
#ifdef PROFILE_ROUTING
|
|
|
- uint32_t inserted = round1.inserted;
|
|
|
- unsigned long start = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
+ unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
|
|
|
#endif
|
|
|
- sort_mtobliv(g_teems_config.nthreads, round1.buf,
|
|
|
- g_teems_config.msg_size, round1.inserted,
|
|
|
- route_state.tot_msg_per_ing, send_round1_msgs);
|
|
|
- round1.reset();
|
|
|
+ storage_received(round2.buf, round2.inserted);
|
|
|
#ifdef PROFILE_ROUTING
|
|
|
- printf_with_rtclock_diff(start, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
+ printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
|
|
|
#endif
|
|
|
- route_state.step = ROUTE_ROUND_1;
|
|
|
- ocall_routing_round_complete(cbpointer, 1);
|
|
|
+
|
|
|
+ round2.reset();
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+
|
|
|
+ // We're done
|
|
|
+ route_state.step = ROUTE_NOT_STARTED;
|
|
|
+ ocall_routing_round_complete(cbpointer, 0);
|
|
|
+ } else {
|
|
|
+ // We're done
|
|
|
+ route_state.step = ROUTE_NOT_STARTED;
|
|
|
+ ocall_routing_round_complete(cbpointer, 0);
|
|
|
+ }
|
|
|
}
|
|
|
}
|