|
@@ -35,10 +35,11 @@ bool route_init()
|
|
|
}
|
|
|
|
|
|
// Compute the maximum number of messages we could receive in round 1
|
|
|
- // Each ingestion node will send us an our_weight/tot_weight
|
|
|
- // fraction of the messages they hold
|
|
|
- uint32_t max_msg_from_each_ing = CEILDIV(tot_msg_per_ing,
|
|
|
- g_teems_config.tot_weight) * g_teems_config.my_weight;
|
|
|
+ // In private routing, each ingestion node will send us an
|
|
|
+ // our_weight/tot_weight fraction of the messages they hold
|
|
|
+ uint32_t max_msg_from_each_ing;
|
|
|
+ max_msg_from_each_ing = CEILDIV(tot_msg_per_ing, g_teems_config.tot_weight) *
|
|
|
+ g_teems_config.my_weight;
|
|
|
|
|
|
// And the maximum number we can receive in total is that times the
|
|
|
// number of ingestion nodes
|
|
@@ -60,8 +61,9 @@ bool route_init()
|
|
|
}
|
|
|
|
|
|
// Which will be at most this many from us
|
|
|
- uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
|
|
|
- g_teems_config.tot_weight) * g_teems_config.my_weight;
|
|
|
+ uint32_t max_msg_to_each_stg;
|
|
|
+ max_msg_to_each_stg = CEILDIV(tot_msg_per_stg, g_teems_config.tot_weight) *
|
|
|
+ g_teems_config.my_weight;
|
|
|
|
|
|
// But we can't send more messages to each storage server than we
|
|
|
// could receive in total
|
|
@@ -81,7 +83,16 @@ bool route_init()
|
|
|
}
|
|
|
|
|
|
// The max number of messages that can arrive at a storage server
|
|
|
- uint32_t max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
|
|
|
+ uint32_t max_stg_msgs;
|
|
|
+ max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
|
|
|
+
|
|
|
+ // Calculating public-routing buffer sizes
|
|
|
+ // Weights are not used in public routing
|
|
|
+ uint32_t max_round1a_msgs = max_round1_msgs;
|
|
|
+ uint32_t max_round1b_msgs_to_adj_rtr =
|
|
|
+ (g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1);
|
|
|
+ uint32_t max_round1b_msgs = 2*max_round1b_msgs_to_adj_rtr;
|
|
|
+ uint32_t max_round1c_msgs = max_round1a_msgs;
|
|
|
|
|
|
/*
|
|
|
printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
|
|
@@ -95,6 +106,9 @@ bool route_init()
|
|
|
}
|
|
|
if (my_roles & ROLE_ROUTING) {
|
|
|
route_state.round1.alloc(max_round2_msgs);
|
|
|
+ route_state.round1a.alloc(max_round1a_msgs);
|
|
|
+ route_state.round1b.alloc(2*max_round1b_msgs); // double space for sorting with 1a msgs
|
|
|
+ route_state.round1c.alloc(max_round1c_msgs);
|
|
|
}
|
|
|
if (my_roles & ROLE_STORAGE) {
|
|
|
route_state.round2.alloc(max_stg_msgs);
|
|
@@ -108,6 +122,10 @@ bool route_init()
|
|
|
}
|
|
|
route_state.step = ROUTE_NOT_STARTED;
|
|
|
route_state.tot_msg_per_ing = tot_msg_per_ing;
|
|
|
+ route_state.max_round1_msgs = max_round1_msgs;
|
|
|
+ route_state.max_round1a_msgs = max_round1a_msgs;
|
|
|
+ route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
|
|
|
+ route_state.max_round1c_msgs = max_round1c_msgs;
|
|
|
route_state.max_msg_to_each_stg = max_msg_to_each_stg;
|
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
|
route_state.max_stg_msgs = max_stg_msgs;
|
|
@@ -123,14 +141,12 @@ bool route_init()
|
|
|
if (my_roles & ROLE_ROUTING) {
|
|
|
sort_precompute_evalplan(max_round2_msgs, nthreads);
|
|
|
if(!g_teems_config.private_routing) {
|
|
|
- sort_precompute_evalplan(max_round2_msgs, nthreads);
|
|
|
+ sort_precompute_evalplan(max_round1a_msgs, nthreads);
|
|
|
+ sort_precompute_evalplan(max_round1b_msgs, nthreads);
|
|
|
}
|
|
|
}
|
|
|
if (my_roles & ROLE_STORAGE) {
|
|
|
sort_precompute_evalplan(max_stg_msgs, nthreads);
|
|
|
- if(!g_teems_config.private_routing) {
|
|
|
- sort_precompute_evalplan(max_stg_msgs, nthreads);
|
|
|
- }
|
|
|
}
|
|
|
#ifdef PROFILE_ROUTING
|
|
|
printf_with_rtclock_diff(start, "end precompute evalplans\n");
|
|
@@ -181,7 +197,10 @@ size_t ecall_precompute_sort(int sizeidx)
|
|
|
if (my_roles & ROLE_ROUTING) {
|
|
|
used_sizes.push_back(route_state.max_round2_msgs);
|
|
|
if(!g_teems_config.private_routing) {
|
|
|
- used_sizes.push_back(route_state.max_round2_msgs);
|
|
|
+ used_sizes.push_back(route_state.max_round1a_msgs);
|
|
|
+ used_sizes.push_back(route_state.max_round1b_msgs);
|
|
|
+ used_sizes.push_back(route_state.max_round1b_msgs);
|
|
|
+ used_sizes.push_back(route_state.max_round1c_msgs);
|
|
|
}
|
|
|
}
|
|
|
if (my_roles & ROLE_STORAGE) {
|
|
@@ -227,6 +246,15 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
|
|
|
return msgbuf.buf + start * msg_size;
|
|
|
}
|
|
|
|
|
|
+static void round1a_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
+
|
|
|
+static void round1b_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
+
|
|
|
+static void round1c_received(NodeCommState &nodest, uint8_t *data,
|
|
|
+ uint32_t plaintext_len, uint32_t);
|
|
|
+
|
|
|
static void round2_received(NodeCommState &nodest,
|
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
|
|
@@ -249,23 +277,169 @@ static void round1_received(NodeCommState &nodest,
|
|
|
pthread_mutex_unlock(&route_state.round1.mutex);
|
|
|
|
|
|
// What is the next message we expect from this node?
|
|
|
- if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
|
|
|
+ if (g_teems_config.private_routing) {
|
|
|
+ if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round2, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round2_received;
|
|
|
+ }
|
|
|
+ // Otherwise, it's just the next round 1 message, so don't change
|
|
|
+ // the handlers.
|
|
|
+ } else {
|
|
|
+ if (their_roles & ROLE_ROUTING) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1a, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1a_received;
|
|
|
+ }
|
|
|
+ // Otherwise, it's just the next round 1 message, so don't change
|
|
|
+ // the handlers.
|
|
|
+ }
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_ingestion_nodes &&
|
|
|
+ completed_prev_round) {
|
|
|
+ route_state.step = ROUTE_ROUND_1;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// A round 1a message was received by a routing node
|
|
|
+static void round1a_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+ pthread_mutex_lock(&route_state.round1a.mutex);
|
|
|
+ route_state.round1a.inserted += num_msgs;
|
|
|
+ route_state.round1a.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round1a.nodes_received;
|
|
|
+ bool completed_prev_round = route_state.round1a.completed_prev_round;
|
|
|
+ pthread_mutex_unlock(&route_state.round1a.mutex);
|
|
|
+
|
|
|
+ // Both are routing nodes
|
|
|
+ // We only expect a message from the previous and next nodes (if they exist)
|
|
|
+ //FIX: replace handlers for previous and next nodes, to put messages in correct locations
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
+ if ((prev_nodes > 0) &&
|
|
|
+ (nodest.node_num == g_teems_config.routing_nodes[num_routing_nodes-1])) {
|
|
|
+ // Node is previous routing node
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1b, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1b_received;
|
|
|
+ } else if ((prev_nodes < num_routing_nodes-1) &&
|
|
|
+ (nodest.node_num == g_teems_config.routing_nodes[1])) {
|
|
|
+ // Node is next routing node
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1b, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1b_received;
|
|
|
+ } else {
|
|
|
+ // other routing nodes will not send to this node until round 1c
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1c, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1c_received;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes &&
|
|
|
+ completed_prev_round) {
|
|
|
+ route_state.step = ROUTE_ROUND_1A;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1A);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// A round 1b message was received by a routing node
|
|
|
+static void round1b_received(NodeCommState &nodest,
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
|
+ pthread_mutex_lock(&route_state.round1b.mutex);
|
|
|
+ route_state.round1b.inserted += num_msgs;
|
|
|
+ route_state.round1b.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round1b.nodes_received;
|
|
|
+ bool completed_prev_round = route_state.round1b.completed_prev_round;
|
|
|
+ pthread_mutex_unlock(&route_state.round1b.mutex);
|
|
|
+ // Set handler back to standard encrypted message handler
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1c, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1c_received;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+ nodenum_t adjacent_nodes =
|
|
|
+ (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);
|
|
|
+ if (nodes_received == adjacent_nodes && completed_prev_round) {
|
|
|
+ route_state.step = ROUTE_ROUND_1B;
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Message received in round 1c
|
|
|
+static void round1c_received(NodeCommState &nodest, uint8_t *data,
|
|
|
+ uint32_t plaintext_len, uint32_t)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
|
+ pthread_mutex_lock(&route_state.round1c.mutex);
|
|
|
+ route_state.round1c.inserted += num_msgs;
|
|
|
+ route_state.round1c.nodes_received += 1;
|
|
|
+ nodenum_t nodes_received = route_state.round1c.nodes_received;
|
|
|
+ bool completed_prev_round = route_state.round1c.completed_prev_round;
|
|
|
+ pthread_mutex_unlock(&route_state.round1c.mutex);
|
|
|
+
|
|
|
+ // What is the next message we expect from this node?
|
|
|
+ if (our_roles & ROLE_STORAGE) {
|
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
uint32_t tot_enc_chunk_size) {
|
|
|
return msgbuffer_get_buf(route_state.round2, commst,
|
|
|
tot_enc_chunk_size);
|
|
|
};
|
|
|
nodest.in_msg_received = round2_received;
|
|
|
+ } else if (their_roles & ROLE_INGESTION) {
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
+ return msgbuffer_get_buf(route_state.round1, commst,
|
|
|
+ tot_enc_chunk_size);
|
|
|
+ };
|
|
|
+ nodest.in_msg_received = round1_received;
|
|
|
}
|
|
|
- // Otherwise, it's just the next round 1 message, so don't change
|
|
|
- // the handlers.
|
|
|
-
|
|
|
- if (nodes_received == g_teems_config.num_ingestion_nodes &&
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes &&
|
|
|
completed_prev_round) {
|
|
|
- route_state.step = ROUTE_ROUND_1;
|
|
|
+ route_state.step = ROUTE_ROUND_1C;
|
|
|
void *cbpointer = route_state.cbpointer;
|
|
|
route_state.cbpointer = NULL;
|
|
|
- ocall_routing_round_complete(cbpointer, 1);
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1C);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -378,19 +552,23 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
|
|
|
}
|
|
|
|
|
|
// Send the round 1 messages. Note that N here is not private.
|
|
|
-static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
|
|
|
+template<typename T>
|
|
|
+static void send_round1_msgs(const uint8_t *msgs, const T *indices,
|
|
|
uint32_t N)
|
|
|
{
|
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
|
- uint16_t tot_weight = g_teems_config.tot_weight;
|
|
|
+ uint16_t tot_weight;
|
|
|
+ tot_weight = g_teems_config.tot_weight;
|
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
|
|
|
- uint32_t full_rows = N / uint32_t(tot_weight);
|
|
|
- uint32_t last_row = N % uint32_t(tot_weight);
|
|
|
+ uint32_t full_rows;
|
|
|
+ uint32_t last_row;
|
|
|
+ full_rows = N / uint32_t(tot_weight);
|
|
|
+ last_row = N % uint32_t(tot_weight);
|
|
|
|
|
|
for (auto &routing_node: g_teems_config.routing_nodes) {
|
|
|
- uint8_t weight =
|
|
|
- g_teems_config.weights[routing_node].weight;
|
|
|
+ uint8_t weight = g_teems_config.weights[routing_node].weight;
|
|
|
+
|
|
|
if (weight == 0) {
|
|
|
// This shouldn't happen, but just in case
|
|
|
continue;
|
|
@@ -432,13 +610,13 @@ static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
|
|
|
uint8_t *buf = round1.buf + start * msg_size;
|
|
|
|
|
|
for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
- const UidKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ const T *idxp = indices + i*tot_weight + start_weight;
|
|
|
for (uint32_t j=0; j<weight; ++j) {
|
|
|
memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
buf += msg_size;
|
|
|
}
|
|
|
}
|
|
|
- const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ const T *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
buf += msg_size;
|
|
@@ -453,12 +631,209 @@ static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
|
|
|
NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
nodecom.message_start(num_msgs * msg_size);
|
|
|
for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
- const UidKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ const T *idxp = indices + i*tot_weight + start_weight;
|
|
|
for (uint32_t j=0; j<weight; ++j) {
|
|
|
nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
}
|
|
|
}
|
|
|
- const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ const T *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
+ nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Send the round 1a messages from the round 1 buffer, which only occurs in public-channel routing.
|
|
|
+// msgs points to the message buffer, indices points to the the sorted indices, and N is the number
|
|
|
+// of non-padding items.
|
|
|
+static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+
|
|
|
+ uint32_t min_msgs_per_node = route_state.max_round1a_msgs / num_routing_nodes;
|
|
|
+ uint32_t extra_msgs = route_state.max_round1a_msgs % num_routing_nodes;
|
|
|
+ for (auto &routing_node: g_teems_config.routing_nodes) {
|
|
|
+ // In this unweighted setting, start_weight represents the position among routing nodes
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[routing_node].startweight;
|
|
|
+ uint32_t start_msg, num_msgs;
|
|
|
+ if (prev_nodes >= extra_msgs) {
|
|
|
+ start_msg = min_msgs_per_node * prev_nodes + extra_msgs;
|
|
|
+ num_msgs = min_msgs_per_node;
|
|
|
+ } else {
|
|
|
+ start_msg = min_msgs_per_node * prev_nodes + prev_nodes;
|
|
|
+ num_msgs = min_msgs_per_node + 1;
|
|
|
+ }
|
|
|
+ // take number of messages into account
|
|
|
+ if (start_msg >= N) {
|
|
|
+ num_msgs = 0;
|
|
|
+ } else if (start_msg + num_msgs > N) {
|
|
|
+ num_msgs = N - start_msg;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (routing_node == my_node_num) {
|
|
|
+ // Special case: we're sending to ourselves; just put the
|
|
|
+ // messages in our own buffer
|
|
|
+ MsgBuffer &round1a = route_state.round1a;
|
|
|
+ pthread_mutex_lock(&round1a.mutex);
|
|
|
+ uint32_t start = round1a.reserved;
|
|
|
+ if (start + num_msgs > round1a.bufsize) {
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ printf("Max %u messages exceeded in round 1a\n", round1a.bufsize);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ round1a.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ uint8_t *buf = round1a.buf + start * msg_size;
|
|
|
+ for (uint32_t i=0; i<num_msgs; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + start_msg + i;
|
|
|
+ memmove(buf, msgs + idxp->index()*msg_size, msg_size);
|
|
|
+ buf += msg_size;
|
|
|
+ }
|
|
|
+ pthread_mutex_lock(&round1a.mutex);
|
|
|
+ round1a.inserted += num_msgs;
|
|
|
+ round1a.nodes_received += 1;
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ } else {
|
|
|
+ NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
+ nodecom.message_start(num_msgs * msg_size);
|
|
|
+ for (uint32_t i=0; i<num_msgs; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + start_msg + i;
|
|
|
+ nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Send the round 1b messages from the round 1a buffer, which only occurs in public-channel routing.
|
|
|
+// msgs points to the message buffer, indices points to the the sorted indices, and N is the number
|
|
|
+// of non-padding items.
|
|
|
+static void send_round1b_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
+
|
|
|
+ // send to previous node
|
|
|
+ if (prev_nodes > 0) {
|
|
|
+ nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
|
|
|
+ NodeCommState &nodecom = g_commstates[prev_node];
|
|
|
+ uint32_t num_msgs = min(route_state.max_round1a_msgs,
|
|
|
+ route_state.max_round1b_msgs_to_adj_rtr);
|
|
|
+ nodecom.message_start(num_msgs * msg_size);
|
|
|
+ for (uint32_t i=0; i<num_msgs; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + i;
|
|
|
+ nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // send to next node
|
|
|
+ if (prev_nodes < num_routing_nodes-1) {
|
|
|
+ nodenum_t next_node = g_teems_config.routing_nodes[1];
|
|
|
+ NodeCommState &nodecom = g_commstates[next_node];
|
|
|
+ if (N <= route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr) {
|
|
|
+ // No messages to exchange with next node
|
|
|
+ nodecom.message_start(0);
|
|
|
+ // No need to call message_data()
|
|
|
+ } else {
|
|
|
+ uint32_t start_msg =
|
|
|
+ route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr;
|
|
|
+ uint32_t num_msgs = N - start_msg;
|
|
|
+ nodecom.message_start(num_msgs * msg_size);
|
|
|
+ for (uint32_t i=0; i<num_msgs; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + start_msg + i;
|
|
|
+ nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Send the round 1c messages. Note that N here is not private.
|
|
|
+// FIX: combine with send_round1_msgs(), which is similar
|
|
|
+static void send_round1c_msgs(const uint8_t *msgs, const UidPriorityKey *indices,
|
|
|
+ uint32_t N)
|
|
|
+{
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ uint16_t tot_weight;
|
|
|
+ tot_weight = g_teems_config.tot_weight;
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+
|
|
|
+ uint32_t full_rows;
|
|
|
+ uint32_t last_row;
|
|
|
+ full_rows = N / uint32_t(tot_weight);
|
|
|
+ last_row = N % uint32_t(tot_weight);
|
|
|
+
|
|
|
+ for (auto &routing_node: g_teems_config.routing_nodes) {
|
|
|
+ uint8_t weight = g_teems_config.weights[routing_node].weight;
|
|
|
+
|
|
|
+ if (weight == 0) {
|
|
|
+ // This shouldn't happen, but just in case
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ uint16_t start_weight =
|
|
|
+ g_teems_config.weights[routing_node].startweight;
|
|
|
+
|
|
|
+ // The number of messages headed for this routing node from the
|
|
|
+ // full rows
|
|
|
+ uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
|
|
|
+ // The number of messages headed for this routing node from the
|
|
|
+ // incomplete last row is:
|
|
|
+ // 0 if last_row < start_weight
|
|
|
+ // last_row-start_weight if start_weight <= last_row < start_weight + weight
|
|
|
+ // weight if start_weight + weight <= last_row
|
|
|
+ uint32_t num_msgs_last_row = 0;
|
|
|
+ if (start_weight <= last_row && last_row < start_weight + weight) {
|
|
|
+ num_msgs_last_row = last_row-start_weight;
|
|
|
+ } else if (start_weight + weight <= last_row) {
|
|
|
+ num_msgs_last_row = weight;
|
|
|
+ }
|
|
|
+ // The total number of messages headed for this routing node
|
|
|
+ uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
|
|
|
+
|
|
|
+ if (routing_node == my_node_num) {
|
|
|
+ // Special case: we're sending to ourselves; just put the
|
|
|
+ // messages in our own round1 buffer
|
|
|
+ MsgBuffer &round1c = route_state.round1c;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round1c.mutex);
|
|
|
+ uint32_t start = round1c.reserved;
|
|
|
+ if (start + num_msgs > round1c.bufsize) {
|
|
|
+ pthread_mutex_unlock(&round1c.mutex);
|
|
|
+ printf("Max %u messages exceeded\n", round1c.bufsize);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ round1c.reserved += num_msgs;
|
|
|
+ pthread_mutex_unlock(&round1c.mutex);
|
|
|
+ uint8_t *buf = round1c.buf + start * msg_size;
|
|
|
+
|
|
|
+ for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<weight; ++j) {
|
|
|
+ memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
+ buf += msg_size;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
+ memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
+ buf += msg_size;
|
|
|
+ }
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round1c.mutex);
|
|
|
+ round1c.inserted += num_msgs;
|
|
|
+ round1c.nodes_received += 1;
|
|
|
+ pthread_mutex_unlock(&round1c.mutex);
|
|
|
+
|
|
|
+ } else {
|
|
|
+ NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
+ nodecom.message_start(num_msgs * msg_size);
|
|
|
+ for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
+ const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
+ for (uint32_t j=0; j<weight; ++j) {
|
|
|
+ nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
}
|
|
@@ -526,6 +901,283 @@ static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+static void round1a_processing(void *cbpointer) {
|
|
|
+ uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ MsgBuffer &round1 = route_state.round1;
|
|
|
+
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (round1.reserved != round1.inserted) {
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+ pthread_mutex_lock(&round1.mutex);
|
|
|
+ }
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = round1.inserted;
|
|
|
+ unsigned long start_round1a = printf_with_rtclock("begin round1a processing (%u)\n", inserted);
|
|
|
+ // Sort the messages we've received
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
|
|
|
+#endif
|
|
|
+ // Sort received messages by increasing user ID and
|
|
|
+ // priority. Smaller priority number indicates higher priority.
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1.buf,
|
|
|
+ g_teems_config.msg_size, round1.inserted, route_state.max_round1_msgs,
|
|
|
+ send_round1a_msgs);
|
|
|
+
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
|
|
|
+ printf_with_rtclock_diff(start_round1a, "end round1a processing (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+ round1.reset();
|
|
|
+ pthread_mutex_unlock(&round1.mutex);
|
|
|
+
|
|
|
+ MsgBuffer &round1a = route_state.round1a;
|
|
|
+ pthread_mutex_lock(&round1a.mutex);
|
|
|
+ round1a.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round1a.nodes_received;
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1A;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1A);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_1A;
|
|
|
+ route_state.round1a.completed_prev_round = true;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1A);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+static void round1b_processing(void *cbpointer) {
|
|
|
+ uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+ MsgBuffer &round1a = route_state.round1a;
|
|
|
+
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ pthread_mutex_lock(&round1a.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (round1a.reserved != round1a.inserted) {
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ pthread_mutex_lock(&round1a.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = round1a.inserted;
|
|
|
+ unsigned long start_round1b = printf_with_rtclock("begin round1b processing (%u)\n", inserted);
|
|
|
+ // Sort the messages we've received
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
+#endif
|
|
|
+ // Sort received messages by increasing user ID and
|
|
|
+ // priority. Smaller priority number indicates higher priority.
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
|
|
|
+ g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
|
+ send_round1b_msgs);
|
|
|
+
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
+ printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+
|
|
|
+ //round1a.reset(); // Don't reset until end of round 1c
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ MsgBuffer &round1b = route_state.round1b;
|
|
|
+ pthread_mutex_lock(&round1b.mutex);
|
|
|
+ round1b.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round1b.nodes_received;
|
|
|
+ nodenum_t adjacent_nodes =
|
|
|
+ (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);
|
|
|
+ pthread_mutex_unlock(&round1b.mutex);
|
|
|
+ if (nodes_received == adjacent_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1B;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_1B;
|
|
|
+ route_state.round1b.completed_prev_round = true;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+//FIX: adjust the message totals based on the sorts
|
|
|
+static void round1c_processing(void *cbpointer) {
|
|
|
+ uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
+ MsgBuffer &round1a = route_state.round1a;
|
|
|
+ MsgBuffer &round1b = route_state.round1b;
|
|
|
+
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ pthread_mutex_lock(&round1b.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (round1b.reserved != round1b.inserted) {
|
|
|
+ pthread_mutex_unlock(&round1b.mutex);
|
|
|
+ pthread_mutex_lock(&round1b.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = round1a.inserted;
|
|
|
+ unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", inserted);
|
|
|
+ // Sort the messages we've received
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
+#endif
|
|
|
+
|
|
|
+ // Sort received messages by increasing user ID and
|
|
|
+ // priority. Smaller priority number indicates higher priority.
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
|
|
|
+ g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
|
+ send_round1c_msgs);
|
|
|
+
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
+ printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+
|
|
|
+ round1a.reset();
|
|
|
+ pthread_mutex_unlock(&round1a.mutex);
|
|
|
+ pthread_mutex_lock(&round1b.mutex);
|
|
|
+ round1b.reset();
|
|
|
+ pthread_mutex_unlock(&round1b.mutex);
|
|
|
+
|
|
|
+ MsgBuffer &round1c = route_state.round1c;
|
|
|
+ pthread_mutex_lock(&round1c.mutex);
|
|
|
+ round1c.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round1c.nodes_received;
|
|
|
+ pthread_mutex_unlock(&round1c.mutex);
|
|
|
+ if (nodes_received == num_routing_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_1C;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1C);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_1C;
|
|
|
+ route_state.round1b.completed_prev_round = true;
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1C);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Process messages in round 2
|
|
|
+static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prevround) {
|
|
|
+ if (my_roles & ROLE_ROUTING) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&prevround.mutex);
|
|
|
+ // Ensure there are no pending messages currently being inserted
|
|
|
+ // into the buffer
|
|
|
+ while (prevround.reserved != prevround.inserted) {
|
|
|
+ pthread_mutex_unlock(&prevround.mutex);
|
|
|
+ pthread_mutex_lock(&prevround.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+ // If the _total_ number of messages we received in round 1
|
|
|
+ // is less than the max number of messages we could send to
|
|
|
+ // _each_ storage node, then cap the number of messages we
|
|
|
+ // will send to each storage node to that number.
|
|
|
+ uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
|
|
|
+ if (prevround.inserted < msgs_per_stg) {
|
|
|
+ msgs_per_stg = prevround.inserted;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Note: at this point, it is required that each message in
|
|
|
+ // the prevround buffer have a _valid_ storage node id field.
|
|
|
+
|
|
|
+ // Obliviously tally the number of messages we received in
|
|
|
+ // the previous round destined for each storage node
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ uint32_t inserted = prevround.inserted;
|
|
|
+ unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, prevround.bufsize);
|
|
|
+ unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
+ nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
|
|
|
+ std::vector<uint32_t> tally = obliv_tally_stg(
|
|
|
+ prevround.buf, msg_size, prevround.inserted, num_storage_nodes);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
|
|
|
+#endif
|
|
|
+ // Note: tally contains private values! It's OK to
|
|
|
+ // non-obliviously check for an error condition, though.
|
|
|
+ // While we're at it, obliviously change the tally of
|
|
|
+ // messages received to a tally of padding messages
|
|
|
+ // required.
|
|
|
+ uint32_t tot_padding = 0;
|
|
|
+ for (nodenum_t i=0; i<num_storage_nodes; ++i) {
|
|
|
+ if (tally[i] > msgs_per_stg) {
|
|
|
+ printf("Received too many messages for storage node %u\n", i);
|
|
|
+ assert(tally[i] <= msgs_per_stg);
|
|
|
+ }
|
|
|
+ tally[i] = msgs_per_stg - tally[i];
|
|
|
+ tot_padding += tally[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ prevround.reserved += tot_padding;
|
|
|
+ assert(prevround.reserved <= prevround.bufsize);
|
|
|
+
|
|
|
+ // Obliviously add padding for each storage node according
|
|
|
+ // to the (private) padding tally.
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
|
|
|
+#endif
|
|
|
+ obliv_pad_stg(prevround.buf + prevround.inserted * msg_size,
|
|
|
+ msg_size, tally, tot_padding);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
|
|
|
+#endif
|
|
|
+
|
|
|
+ prevround.inserted += tot_padding;
|
|
|
+
|
|
|
+ // Obliviously shuffle the messages
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
|
|
|
+#endif
|
|
|
+ uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
|
|
|
+ prevround.buf, msg_size, prevround.inserted, prevround.bufsize);
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
+ printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
|
|
|
+ printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, prevround.bufsize);
|
|
|
+#endif
|
|
|
+
|
|
|
+ // Now we can handle the messages non-obliviously, since we
|
|
|
+ // know there will be exactly msgs_per_stg messages to each
|
|
|
+ // storage node, and the oblivious shuffle broke the
|
|
|
+ // connection between where each message came from and where
|
|
|
+ // it's going.
|
|
|
+ send_round2_msgs(num_shuffled, msgs_per_stg);
|
|
|
+
|
|
|
+ prevround.reset();
|
|
|
+ pthread_mutex_unlock(&prevround.mutex);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (my_roles & ROLE_STORAGE) {
|
|
|
+ route_state.cbpointer = cbpointer;
|
|
|
+ MsgBuffer &round2 = route_state.round2;
|
|
|
+
|
|
|
+ pthread_mutex_lock(&round2.mutex);
|
|
|
+ round2.completed_prev_round = true;
|
|
|
+ nodenum_t nodes_received = round2.nodes_received;
|
|
|
+ pthread_mutex_unlock(&round2.mutex);
|
|
|
+
|
|
|
+ if (nodes_received == g_teems_config.num_routing_nodes) {
|
|
|
+ route_state.step = ROUTE_ROUND_2;
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
+ ocall_routing_round_complete(cbpointer, 2);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ route_state.step = ROUTE_ROUND_2;
|
|
|
+ route_state.round2.completed_prev_round = true;
|
|
|
+ ocall_routing_round_complete(cbpointer, 2);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// Perform the next round of routing. The callback pointer will be
|
|
|
// passed to ocall_routing_round_complete when the round is complete.
|
|
|
void ecall_routing_proceed(void *cbpointer)
|
|
@@ -551,9 +1203,17 @@ void ecall_routing_proceed(void *cbpointer)
|
|
|
unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
|
|
|
unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
#endif
|
|
|
- sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
|
- g_teems_config.msg_size, ingbuf.inserted,
|
|
|
- route_state.tot_msg_per_ing, send_round1_msgs);
|
|
|
+ if (g_teems_config.private_routing) {
|
|
|
+ sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
|
+ g_teems_config.msg_size, ingbuf.inserted,
|
|
|
+ route_state.tot_msg_per_ing, send_round1_msgs<UidKey>);
|
|
|
+ } else {
|
|
|
+ // Sort received messages by increasing user ID and
|
|
|
+ // priority. Smaller priority number indicates higher priority.
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
|
+ g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing,
|
|
|
+ send_round1_msgs<UidPriorityKey>);
|
|
|
+ }
|
|
|
#ifdef PROFILE_ROUTING
|
|
|
printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
|
printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
|
|
@@ -580,137 +1240,15 @@ void ecall_routing_proceed(void *cbpointer)
|
|
|
ocall_routing_round_complete(cbpointer, 1);
|
|
|
}
|
|
|
} else if (route_state.step == ROUTE_ROUND_1) {
|
|
|
- if (my_roles & ROLE_ROUTING) {
|
|
|
- route_state.cbpointer = cbpointer;
|
|
|
- MsgBuffer &round1 = route_state.round1;
|
|
|
-
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
- // Ensure there are no pending messages currently being inserted
|
|
|
- // into the buffer
|
|
|
- while (round1.reserved != round1.inserted) {
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
- }
|
|
|
-
|
|
|
- // If the _total_ number of messages we received in round 1
|
|
|
- // is less than the max number of messages we could send to
|
|
|
- // _each_ storage node, then cap the number of messages we
|
|
|
- // will send to each storage node to that number.
|
|
|
- uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
|
|
|
- if (round1.inserted < msgs_per_stg) {
|
|
|
- msgs_per_stg = round1.inserted;
|
|
|
- }
|
|
|
-
|
|
|
- // Note: at this point, it is required that each message in
|
|
|
- // the round1 buffer have a _valid_ storage node id field.
|
|
|
-
|
|
|
- // Obliviously tally the number of messages we received in
|
|
|
- // round1 destined for each storage node
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- uint32_t inserted = round1.inserted;
|
|
|
- unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
|
|
|
- unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
|
|
|
-#endif
|
|
|
- uint16_t msg_size = g_teems_config.msg_size;
|
|
|
- nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
|
|
|
- std::vector<uint32_t> tally = obliv_tally_stg(
|
|
|
- round1.buf, msg_size, round1.inserted, num_storage_nodes);
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
|
|
|
-#endif
|
|
|
-
|
|
|
- // For public routing, convert excess messages to padding destined
|
|
|
- // for other storage nodes with fewer messages than the maximum.
|
|
|
- if (!g_teems_config.private_routing) {
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- unsigned long start_convert_excess = printf_with_rtclock("begin converting excess messages (%u)\n", round1.inserted);
|
|
|
-#endif
|
|
|
- // Sort received messages by increasing storage node and
|
|
|
- // priority. Smaller priority number indicates higher priority.
|
|
|
- // Sorted messages are put back into source buffer.
|
|
|
- sort_mtobliv<NidPriorityKey>(g_teems_config.nthreads,
|
|
|
- round1.buf, g_teems_config.msg_size, round1.inserted,
|
|
|
- round1.bufsize);
|
|
|
- // Convert excess messages into padding
|
|
|
- obliv_excess_to_padding(round1.buf, msg_size, round1.inserted,
|
|
|
- tally, msgs_per_stg);
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- printf_with_rtclock_diff(start_convert_excess, "end converting excess messages (%u)\n", round1.inserted);
|
|
|
-#endif
|
|
|
- }
|
|
|
-
|
|
|
- // Note: tally contains private values! It's OK to
|
|
|
- // non-obliviously check for an error condition, though.
|
|
|
- // While we're at it, obliviously change the tally of
|
|
|
- // messages received to a tally of padding messages
|
|
|
- // required.
|
|
|
- uint32_t tot_padding = 0;
|
|
|
- for (nodenum_t i=0; i<num_storage_nodes; ++i) {
|
|
|
- if (tally[i] > msgs_per_stg) {
|
|
|
- printf("Received too many messages for storage node %u\n", i);
|
|
|
- assert(tally[i] <= msgs_per_stg);
|
|
|
- }
|
|
|
- tally[i] = msgs_per_stg - tally[i];
|
|
|
- tot_padding += tally[i];
|
|
|
- }
|
|
|
-
|
|
|
- round1.reserved += tot_padding;
|
|
|
- assert(round1.reserved <= round1.bufsize);
|
|
|
-
|
|
|
- // Obliviously add padding for each storage node according
|
|
|
- // to the (private) padding tally.
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
|
|
|
-#endif
|
|
|
- obliv_pad_stg(round1.buf + round1.inserted * msg_size,
|
|
|
- msg_size, tally, tot_padding);
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
|
|
|
-#endif
|
|
|
-
|
|
|
- round1.inserted += tot_padding;
|
|
|
-
|
|
|
- // Obliviously shuffle the messages
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
|
|
|
-#endif
|
|
|
- uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
|
|
|
- round1.buf, msg_size, round1.inserted, round1.bufsize);
|
|
|
-#ifdef PROFILE_ROUTING
|
|
|
- printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
|
|
|
- printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
|
|
|
-#endif
|
|
|
-
|
|
|
- // Now we can handle the messages non-obliviously, since we
|
|
|
- // know there will be exactly msgs_per_stg messages to each
|
|
|
- // storage node, and the oblivious shuffle broke the
|
|
|
- // connection between where each message came from and where
|
|
|
- // it's going.
|
|
|
- send_round2_msgs(num_shuffled, msgs_per_stg);
|
|
|
-
|
|
|
- round1.reset();
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
-
|
|
|
- }
|
|
|
- if (my_roles & ROLE_STORAGE) {
|
|
|
- route_state.cbpointer = cbpointer;
|
|
|
- MsgBuffer &round2 = route_state.round2;
|
|
|
-
|
|
|
- pthread_mutex_lock(&round2.mutex);
|
|
|
- round2.completed_prev_round = true;
|
|
|
- nodenum_t nodes_received = round2.nodes_received;
|
|
|
- pthread_mutex_unlock(&round2.mutex);
|
|
|
-
|
|
|
- if (nodes_received == g_teems_config.num_routing_nodes) {
|
|
|
- route_state.step = ROUTE_ROUND_2;
|
|
|
- route_state.cbpointer = NULL;
|
|
|
- ocall_routing_round_complete(cbpointer, 2);
|
|
|
- }
|
|
|
- } else {
|
|
|
- route_state.step = ROUTE_ROUND_2;
|
|
|
- route_state.round2.completed_prev_round = true;
|
|
|
- ocall_routing_round_complete(cbpointer, 2);
|
|
|
+ if (g_teems_config.private_routing) { // private routing next round
|
|
|
+ round2_processing(my_roles, cbpointer, route_state.round1);
|
|
|
+ } else { // public routing next round
|
|
|
+ round1a_processing(cbpointer);
|
|
|
}
|
|
|
+ } else if (route_state.step == ROUTE_ROUND_1A) {
|
|
|
+ round1b_processing(cbpointer);
|
|
|
+ } else if (route_state.step == ROUTE_ROUND_1B) {
|
|
|
+ round1c_processing(cbpointer);
|
|
|
} else if (route_state.step == ROUTE_ROUND_2) {
|
|
|
if (my_roles & ROLE_STORAGE) {
|
|
|
MsgBuffer &round2 = route_state.round2;
|