|
@@ -88,10 +88,10 @@ bool route_init()
|
|
|
|
|
|
// Calculating public-routing buffer sizes
|
|
// Calculating public-routing buffer sizes
|
|
// Weights are not used in public routing
|
|
// Weights are not used in public routing
|
|
- uint32_t max_round1a_msgs = max_round1_msgs;
|
|
|
|
uint32_t max_round1b_msgs_to_adj_rtr =
|
|
uint32_t max_round1b_msgs_to_adj_rtr =
|
|
(g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1);
|
|
(g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1);
|
|
- uint32_t max_round1b_msgs = 2*max_round1b_msgs_to_adj_rtr;
|
|
|
|
|
|
+ // Ensure columnroute constraint that column height is >= 2*(num_routing_nodes-1)^2
|
|
|
|
+ uint32_t max_round1a_msgs = std::max(max_round1_msgs, 2*max_round1b_msgs_to_adj_rtr);
|
|
uint32_t max_round1c_msgs = max_round1a_msgs;
|
|
uint32_t max_round1c_msgs = max_round1a_msgs;
|
|
|
|
|
|
/*
|
|
/*
|
|
@@ -109,7 +109,9 @@ bool route_init()
|
|
if (!g_teems_config.private_routing) {
|
|
if (!g_teems_config.private_routing) {
|
|
route_state.round1a.alloc(max_round1a_msgs);
|
|
route_state.round1a.alloc(max_round1a_msgs);
|
|
route_state.round1a_sorted.alloc(max_round1a_msgs);
|
|
route_state.round1a_sorted.alloc(max_round1a_msgs);
|
|
- route_state.round1b.alloc(2*max_round1b_msgs); // double to sort with 1a msgs
|
|
|
|
|
|
+ // double round 1b buffers to sort with some round 1a messages
|
|
|
|
+ route_state.round1b_prev.alloc(2*max_round1b_msgs_to_adj_rtr);
|
|
|
|
+ route_state.round1b_next.alloc(2*max_round1b_msgs_to_adj_rtr);
|
|
route_state.round1c.alloc(max_round1c_msgs);
|
|
route_state.round1c.alloc(max_round1c_msgs);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -128,7 +130,6 @@ bool route_init()
|
|
route_state.max_round1_msgs = max_round1_msgs;
|
|
route_state.max_round1_msgs = max_round1_msgs;
|
|
route_state.max_round1a_msgs = max_round1a_msgs;
|
|
route_state.max_round1a_msgs = max_round1a_msgs;
|
|
route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
|
|
route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
|
|
- route_state.max_round1b_msgs = max_round1b_msgs;
|
|
|
|
route_state.max_round1c_msgs = max_round1c_msgs;
|
|
route_state.max_round1c_msgs = max_round1c_msgs;
|
|
route_state.max_msg_to_each_stg = max_msg_to_each_stg;
|
|
route_state.max_msg_to_each_stg = max_msg_to_each_stg;
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
@@ -146,7 +147,7 @@ bool route_init()
|
|
sort_precompute_evalplan(max_round2_msgs, nthreads);
|
|
sort_precompute_evalplan(max_round2_msgs, nthreads);
|
|
if(!g_teems_config.private_routing) {
|
|
if(!g_teems_config.private_routing) {
|
|
sort_precompute_evalplan(max_round1a_msgs, nthreads);
|
|
sort_precompute_evalplan(max_round1a_msgs, nthreads);
|
|
- sort_precompute_evalplan(max_round1b_msgs, nthreads);
|
|
|
|
|
|
+ sort_precompute_evalplan(2*max_round1b_msgs_to_adj_rtr, nthreads);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (my_roles & ROLE_STORAGE) {
|
|
if (my_roles & ROLE_STORAGE) {
|
|
@@ -202,8 +203,8 @@ size_t ecall_precompute_sort(int sizeidx)
|
|
used_sizes.push_back(route_state.max_round2_msgs);
|
|
used_sizes.push_back(route_state.max_round2_msgs);
|
|
if(!g_teems_config.private_routing) {
|
|
if(!g_teems_config.private_routing) {
|
|
used_sizes.push_back(route_state.max_round1a_msgs);
|
|
used_sizes.push_back(route_state.max_round1a_msgs);
|
|
- used_sizes.push_back(route_state.max_round1b_msgs);
|
|
|
|
- used_sizes.push_back(route_state.max_round1b_msgs);
|
|
|
|
|
|
+ used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
|
|
|
|
+ used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
|
|
used_sizes.push_back(route_state.max_round1c_msgs);
|
|
used_sizes.push_back(route_state.max_round1c_msgs);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -253,7 +254,10 @@ static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
|
|
static void round1a_received(NodeCommState &nodest,
|
|
static void round1a_received(NodeCommState &nodest,
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
|
|
|
-static void round1b_received(NodeCommState &nodest,
|
|
|
|
|
|
+static void round1b_prev_received(NodeCommState &nodest,
|
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
|
+
|
|
|
|
+static void round1b_next_received(NodeCommState &nodest,
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t);
|
|
|
|
|
|
static void round1c_received(NodeCommState &nodest, uint8_t *data,
|
|
static void round1c_received(NodeCommState &nodest, uint8_t *data,
|
|
@@ -330,7 +334,6 @@ static void round1a_received(NodeCommState &nodest,
|
|
|
|
|
|
// Both are routing nodes
|
|
// Both are routing nodes
|
|
// We only expect a message from the previous and next nodes (if they exist)
|
|
// We only expect a message from the previous and next nodes (if they exist)
|
|
- //FIX: replace handlers for previous and next nodes, to put messages in correct locations
|
|
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
@@ -339,19 +342,19 @@ static void round1a_received(NodeCommState &nodest,
|
|
// Node is previous routing node
|
|
// Node is previous routing node
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
uint32_t tot_enc_chunk_size) {
|
|
uint32_t tot_enc_chunk_size) {
|
|
- return msgbuffer_get_buf(route_state.round1b, commst,
|
|
|
|
|
|
+ return msgbuffer_get_buf(route_state.round1b_prev, commst,
|
|
tot_enc_chunk_size);
|
|
tot_enc_chunk_size);
|
|
};
|
|
};
|
|
- nodest.in_msg_received = round1b_received;
|
|
|
|
|
|
+ nodest.in_msg_received = round1b_prev_received;
|
|
} else if ((prev_nodes < num_routing_nodes-1) &&
|
|
} else if ((prev_nodes < num_routing_nodes-1) &&
|
|
(nodest.node_num == g_teems_config.routing_nodes[1])) {
|
|
(nodest.node_num == g_teems_config.routing_nodes[1])) {
|
|
// Node is next routing node
|
|
// Node is next routing node
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
uint32_t tot_enc_chunk_size) {
|
|
uint32_t tot_enc_chunk_size) {
|
|
- return msgbuffer_get_buf(route_state.round1b, commst,
|
|
|
|
|
|
+ return msgbuffer_get_buf(route_state.round1b_next, commst,
|
|
tot_enc_chunk_size);
|
|
tot_enc_chunk_size);
|
|
- };
|
|
|
|
- nodest.in_msg_received = round1b_received;
|
|
|
|
|
|
+ };
|
|
|
|
+ nodest.in_msg_received = round1b_next_received;
|
|
} else {
|
|
} else {
|
|
// other routing nodes will not send to this node until round 1c
|
|
// other routing nodes will not send to this node until round 1c
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
@@ -371,8 +374,51 @@ static void round1a_received(NodeCommState &nodest,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-// A round 1b message was received by a routing node
|
|
|
|
-static void round1b_received(NodeCommState &nodest,
|
|
|
|
|
|
+// A round 1b message was received from the previous routing node
|
|
|
|
+static void round1b_prev_received(NodeCommState &nodest,
|
|
|
|
+ uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
|
|
+{
|
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
|
+ assert((plaintext_len % uint32_t(msg_size)) == 0);
|
|
|
|
+ uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
|
|
+ uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
|
|
+ uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
|
|
+ pthread_mutex_lock(&route_state.round1b_prev.mutex);
|
|
|
|
+ route_state.round1b_prev.inserted += num_msgs;
|
|
|
|
+ route_state.round1b_prev.nodes_received += 1;
|
|
|
|
+ nodenum_t nodes_received = route_state.round1b_prev.nodes_received;
|
|
|
|
+ bool completed_prev_round = route_state.round1b_prev.completed_prev_round;
|
|
|
|
+ pthread_mutex_unlock(&route_state.round1b_prev.mutex);
|
|
|
|
+ pthread_mutex_lock(&route_state.round1b_next.mutex);
|
|
|
|
+ nodes_received += route_state.round1b_next.nodes_received;
|
|
|
|
+ pthread_mutex_unlock(&route_state.round1b_next.mutex);
|
|
|
|
+ nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
|
|
+ uint32_t tot_enc_chunk_size) {
|
|
|
|
+ return msgbuffer_get_buf(route_state.round1c, commst,
|
|
|
|
+ tot_enc_chunk_size);
|
|
|
|
+ };
|
|
|
|
+ nodest.in_msg_received = round1c_received;
|
|
|
|
+ nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
|
+ nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
|
+ nodenum_t adjacent_nodes;
|
|
|
|
+ if (num_routing_nodes == 1) {
|
|
|
|
+ adjacent_nodes = 0;
|
|
|
|
+ } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
|
|
|
|
+ adjacent_nodes = 1;
|
|
|
|
+ } else {
|
|
|
|
+ adjacent_nodes = 2;
|
|
|
|
+ }
|
|
|
|
+ if (nodes_received == adjacent_nodes && completed_prev_round) {
|
|
|
|
+ route_state.step = ROUTE_ROUND_1B;
|
|
|
|
+ void *cbpointer = route_state.cbpointer;
|
|
|
|
+ route_state.cbpointer = NULL;
|
|
|
|
+ ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// A round 1b message was received from the next routing node
|
|
|
|
+static void round1b_next_received(NodeCommState &nodest,
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
uint8_t *data, uint32_t plaintext_len, uint32_t)
|
|
{
|
|
{
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
@@ -380,13 +426,15 @@ static void round1b_received(NodeCommState &nodest,
|
|
uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
|
|
uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
uint8_t their_roles = g_teems_config.roles[nodest.node_num];
|
|
- pthread_mutex_lock(&route_state.round1b.mutex);
|
|
|
|
- route_state.round1b.inserted += num_msgs;
|
|
|
|
- route_state.round1b.nodes_received += 1;
|
|
|
|
- nodenum_t nodes_received = route_state.round1b.nodes_received;
|
|
|
|
- bool completed_prev_round = route_state.round1b.completed_prev_round;
|
|
|
|
- pthread_mutex_unlock(&route_state.round1b.mutex);
|
|
|
|
- // Set handler back to standard encrypted message handler
|
|
|
|
|
|
+ pthread_mutex_lock(&route_state.round1b_next.mutex);
|
|
|
|
+ route_state.round1b_next.inserted += num_msgs;
|
|
|
|
+ route_state.round1b_next.nodes_received += 1;
|
|
|
|
+ nodenum_t nodes_received = route_state.round1b_next.nodes_received;
|
|
|
|
+ bool completed_prev_round = route_state.round1b_next.completed_prev_round;
|
|
|
|
+ pthread_mutex_unlock(&route_state.round1b_next.mutex);
|
|
|
|
+ pthread_mutex_lock(&route_state.round1b_prev.mutex);
|
|
|
|
+ nodes_received += route_state.round1b_prev.nodes_received;
|
|
|
|
+ pthread_mutex_unlock(&route_state.round1b_prev.mutex);
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
nodest.in_msg_get_buf = [&](NodeCommState &commst,
|
|
uint32_t tot_enc_chunk_size) {
|
|
uint32_t tot_enc_chunk_size) {
|
|
return msgbuffer_get_buf(route_state.round1c, commst,
|
|
return msgbuffer_get_buf(route_state.round1c, commst,
|
|
@@ -396,10 +444,13 @@ static void round1b_received(NodeCommState &nodest,
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
- nodenum_t adjacent_nodes =
|
|
|
|
- (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);
|
|
|
|
- if(num_routing_nodes==1) {
|
|
|
|
|
|
+ nodenum_t adjacent_nodes;
|
|
|
|
+ if (num_routing_nodes == 1) {
|
|
adjacent_nodes = 0;
|
|
adjacent_nodes = 0;
|
|
|
|
+ } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
|
|
|
|
+ adjacent_nodes = 1;
|
|
|
|
+ } else {
|
|
|
|
+ adjacent_nodes = 2;
|
|
}
|
|
}
|
|
if (nodes_received == adjacent_nodes && completed_prev_round) {
|
|
if (nodes_received == adjacent_nodes && completed_prev_round) {
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
@@ -565,9 +616,9 @@ bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
|
|
return true;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
-// Send the round 1 messages. Note that N here is not private.
|
|
|
|
|
|
+// Send messages round-robin, used in rounds 1 and 1c. Note that N here is not private.
|
|
template<typename T>
|
|
template<typename T>
|
|
-static void send_round1_msgs(const uint8_t *msgs, const T *indices,
|
|
|
|
|
|
+static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T *indices,
|
|
uint32_t N)
|
|
uint32_t N)
|
|
{
|
|
{
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
uint16_t msg_size = g_teems_config.msg_size;
|
|
@@ -609,19 +660,18 @@ static void send_round1_msgs(const uint8_t *msgs, const T *indices,
|
|
|
|
|
|
if (routing_node == my_node_num) {
|
|
if (routing_node == my_node_num) {
|
|
// Special case: we're sending to ourselves; just put the
|
|
// Special case: we're sending to ourselves; just put the
|
|
- // messages in our own round1 buffer
|
|
|
|
- MsgBuffer &round1 = route_state.round1;
|
|
|
|
|
|
+ // messages in our own buffer
|
|
|
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
|
- uint32_t start = round1.reserved;
|
|
|
|
- if (start + num_msgs > round1.bufsize) {
|
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
|
- printf("Max %u messages exceeded\n", round1.bufsize);
|
|
|
|
|
|
+ pthread_mutex_lock(&round.mutex);
|
|
|
|
+ uint32_t start = round.reserved;
|
|
|
|
+ if (start + num_msgs > round.bufsize) {
|
|
|
|
+ pthread_mutex_unlock(&round.mutex);
|
|
|
|
+ printf("Max %u messages exceeded\n", round.bufsize);
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
- round1.reserved += num_msgs;
|
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
|
- uint8_t *buf = round1.buf + start * msg_size;
|
|
|
|
|
|
+ round.reserved += num_msgs;
|
|
|
|
+ pthread_mutex_unlock(&round.mutex);
|
|
|
|
+ uint8_t *buf = round.buf + start * msg_size;
|
|
|
|
|
|
for (uint32_t i=0; i<full_rows; ++i) {
|
|
for (uint32_t i=0; i<full_rows; ++i) {
|
|
const T *idxp = indices + i*tot_weight + start_weight;
|
|
const T *idxp = indices + i*tot_weight + start_weight;
|
|
@@ -636,10 +686,10 @@ static void send_round1_msgs(const uint8_t *msgs, const T *indices,
|
|
buf += msg_size;
|
|
buf += msg_size;
|
|
}
|
|
}
|
|
|
|
|
|
- pthread_mutex_lock(&round1.mutex);
|
|
|
|
- round1.inserted += num_msgs;
|
|
|
|
- round1.nodes_received += 1;
|
|
|
|
- pthread_mutex_unlock(&round1.mutex);
|
|
|
|
|
|
+ pthread_mutex_lock(&round.mutex);
|
|
|
|
+ round.inserted += num_msgs;
|
|
|
|
+ round.nodes_received += 1;
|
|
|
|
+ pthread_mutex_unlock(&round.mutex);
|
|
|
|
|
|
} else {
|
|
} else {
|
|
NodeCommState &nodecom = g_commstates[routing_node];
|
|
NodeCommState &nodecom = g_commstates[routing_node];
|
|
@@ -758,99 +808,6 @@ static void send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-// Send the round 1c messages. Note that N here is not private.
|
|
|
|
-// FIX: combine with send_round1_msgs(), which is similar
|
|
|
|
-static void send_round1c_msgs(const uint8_t *msgs, const UidPriorityKey *indices,
|
|
|
|
- uint32_t N)
|
|
|
|
-{
|
|
|
|
- uint16_t msg_size = g_teems_config.msg_size;
|
|
|
|
- uint16_t tot_weight;
|
|
|
|
- tot_weight = g_teems_config.tot_weight;
|
|
|
|
- nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
|
|
-
|
|
|
|
- uint32_t full_rows;
|
|
|
|
- uint32_t last_row;
|
|
|
|
- full_rows = N / uint32_t(tot_weight);
|
|
|
|
- last_row = N % uint32_t(tot_weight);
|
|
|
|
-
|
|
|
|
- for (auto &routing_node: g_teems_config.routing_nodes) {
|
|
|
|
- uint8_t weight = g_teems_config.weights[routing_node].weight;
|
|
|
|
-
|
|
|
|
- if (weight == 0) {
|
|
|
|
- // This shouldn't happen, but just in case
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
- uint16_t start_weight =
|
|
|
|
- g_teems_config.weights[routing_node].startweight;
|
|
|
|
-
|
|
|
|
- // The number of messages headed for this routing node from the
|
|
|
|
- // full rows
|
|
|
|
- uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
|
|
|
|
- // The number of messages headed for this routing node from the
|
|
|
|
- // incomplete last row is:
|
|
|
|
- // 0 if last_row < start_weight
|
|
|
|
- // last_row-start_weight if start_weight <= last_row < start_weight + weight
|
|
|
|
- // weight if start_weight + weight <= last_row
|
|
|
|
- uint32_t num_msgs_last_row = 0;
|
|
|
|
- if (start_weight <= last_row && last_row < start_weight + weight) {
|
|
|
|
- num_msgs_last_row = last_row-start_weight;
|
|
|
|
- } else if (start_weight + weight <= last_row) {
|
|
|
|
- num_msgs_last_row = weight;
|
|
|
|
- }
|
|
|
|
- // The total number of messages headed for this routing node
|
|
|
|
- uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
|
|
|
|
-
|
|
|
|
- if (routing_node == my_node_num) {
|
|
|
|
- // Special case: we're sending to ourselves; just put the
|
|
|
|
- // messages in our own round1 buffer
|
|
|
|
- MsgBuffer &round1c = route_state.round1c;
|
|
|
|
-
|
|
|
|
- pthread_mutex_lock(&round1c.mutex);
|
|
|
|
- uint32_t start = round1c.reserved;
|
|
|
|
- if (start + num_msgs > round1c.bufsize) {
|
|
|
|
- pthread_mutex_unlock(&round1c.mutex);
|
|
|
|
- printf("Max %u messages exceeded\n", round1c.bufsize);
|
|
|
|
- return;
|
|
|
|
- }
|
|
|
|
- round1c.reserved += num_msgs;
|
|
|
|
- pthread_mutex_unlock(&round1c.mutex);
|
|
|
|
- uint8_t *buf = round1c.buf + start * msg_size;
|
|
|
|
-
|
|
|
|
- for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
|
- const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
|
- for (uint32_t j=0; j<weight; ++j) {
|
|
|
|
- memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
|
- buf += msg_size;
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
|
- for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
|
- memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
|
- buf += msg_size;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- pthread_mutex_lock(&round1c.mutex);
|
|
|
|
- round1c.inserted += num_msgs;
|
|
|
|
- round1c.nodes_received += 1;
|
|
|
|
- pthread_mutex_unlock(&round1c.mutex);
|
|
|
|
-
|
|
|
|
- } else {
|
|
|
|
- NodeCommState &nodecom = g_commstates[routing_node];
|
|
|
|
- nodecom.message_start(num_msgs * msg_size);
|
|
|
|
- for (uint32_t i=0; i<full_rows; ++i) {
|
|
|
|
- const UidPriorityKey *idxp = indices + i*tot_weight + start_weight;
|
|
|
|
- for (uint32_t j=0; j<weight; ++j) {
|
|
|
|
- nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- const UidPriorityKey *idxp = indices + full_rows*tot_weight + start_weight;
|
|
|
|
- for (uint32_t j=0; j<num_msgs_last_row; ++j) {
|
|
|
|
- nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
// Send the round 2 messages from the previous-round buffer, which are already
|
|
// Send the round 2 messages from the previous-round buffer, which are already
|
|
// padded and shuffled, so this can be done non-obliviously. tot_msgs
|
|
// padded and shuffled, so this can be done non-obliviously. tot_msgs
|
|
// is the total number of messages in the input buffer, which may
|
|
// is the total number of messages in the input buffer, which may
|
|
@@ -992,7 +949,7 @@ static void round1b_processing(void *cbpointer) {
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
|
|
g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
route_state.round1a_sorted.buf);
|
|
route_state.round1a_sorted.buf);
|
|
- send_round1b_msgs(round1a.buf, round1a.inserted);
|
|
|
|
|
|
+ send_round1b_msgs(round1a_sorted.buf, round1a.inserted);
|
|
} else {
|
|
} else {
|
|
send_round1b_msgs(NULL, 0);
|
|
send_round1b_msgs(NULL, 0);
|
|
}
|
|
}
|
|
@@ -1004,16 +961,24 @@ static void round1b_processing(void *cbpointer) {
|
|
|
|
|
|
pthread_mutex_unlock(&round1a_sorted.mutex);
|
|
pthread_mutex_unlock(&round1a_sorted.mutex);
|
|
pthread_mutex_unlock(&round1a.mutex);
|
|
pthread_mutex_unlock(&round1a.mutex);
|
|
- MsgBuffer &round1b = route_state.round1b;
|
|
|
|
- pthread_mutex_lock(&round1b.mutex);
|
|
|
|
- round1b.completed_prev_round = true;
|
|
|
|
- nodenum_t nodes_received = round1b.nodes_received;
|
|
|
|
- nodenum_t adjacent_nodes =
|
|
|
|
- (((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) ? 1 : 2);
|
|
|
|
- if(num_routing_nodes==1) {
|
|
|
|
|
|
+ MsgBuffer &round1b_prev = route_state.round1b_prev;
|
|
|
|
+ pthread_mutex_lock(&round1b_prev.mutex);
|
|
|
|
+ round1b_prev.completed_prev_round = true;
|
|
|
|
+ nodenum_t nodes_received = round1b_prev.nodes_received;
|
|
|
|
+ pthread_mutex_unlock(&round1b_prev.mutex);
|
|
|
|
+ MsgBuffer &round1b_next = route_state.round1b_next;
|
|
|
|
+ pthread_mutex_lock(&round1b_next.mutex);
|
|
|
|
+ round1b_next.completed_prev_round = true;
|
|
|
|
+ nodes_received += round1b_next.nodes_received;
|
|
|
|
+ pthread_mutex_unlock(&round1b_next.mutex);
|
|
|
|
+ nodenum_t adjacent_nodes;
|
|
|
|
+ if (num_routing_nodes == 1) {
|
|
adjacent_nodes = 0;
|
|
adjacent_nodes = 0;
|
|
|
|
+ } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
|
|
|
|
+ adjacent_nodes = 1;
|
|
|
|
+ } else {
|
|
|
|
+ adjacent_nodes = 2;
|
|
}
|
|
}
|
|
- pthread_mutex_unlock(&round1b.mutex);
|
|
|
|
if (nodes_received == adjacent_nodes) {
|
|
if (nodes_received == adjacent_nodes) {
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
route_state.cbpointer = NULL;
|
|
route_state.cbpointer = NULL;
|
|
@@ -1021,59 +986,148 @@ static void round1b_processing(void *cbpointer) {
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
route_state.step = ROUTE_ROUND_1B;
|
|
- route_state.round1b.completed_prev_round = true;
|
|
|
|
|
|
+ route_state.round1b_prev.completed_prev_round = true;
|
|
|
|
+ route_state.round1b_next.completed_prev_round = true;
|
|
ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
ocall_routing_round_complete(cbpointer, ROUND_1B);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-//FIX: adjust the message totals based on the sorts
|
|
|
|
|
|
+static void copy_msgs(uint8_t *dst, uint32_t start_msg, uint32_t num_copy, const uint8_t *src,
|
|
|
|
+ const UidPriorityKey *indices)
|
|
|
|
+{
|
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
|
+ const UidPriorityKey *idxp = indices + start_msg;
|
|
|
|
+ uint8_t *buf = dst;
|
|
|
|
+ for (uint32_t i=0; i<num_copy; i++) {
|
|
|
|
+ memmove(buf, src + idxp[i].index()*msg_size, msg_size);
|
|
|
|
+ buf += msg_size;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
static void round1c_processing(void *cbpointer) {
|
|
static void round1c_processing(void *cbpointer) {
|
|
uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
nodenum_t my_node_num = g_teems_config.my_node_num;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
|
|
|
|
+ uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
|
|
|
|
+ uint16_t msg_size = g_teems_config.msg_size;
|
|
|
|
+ uint32_t max_round1b_msgs_to_adj_rtr = route_state.max_round1b_msgs_to_adj_rtr;
|
|
|
|
+ uint32_t max_round1a_msgs = route_state.max_round1a_msgs;
|
|
MsgBuffer &round1a = route_state.round1a;
|
|
MsgBuffer &round1a = route_state.round1a;
|
|
MsgBuffer &round1a_sorted = route_state.round1a_sorted;
|
|
MsgBuffer &round1a_sorted = route_state.round1a_sorted;
|
|
- MsgBuffer &round1b = route_state.round1b;
|
|
|
|
|
|
+ MsgBuffer &round1b_prev = route_state.round1b_prev;
|
|
|
|
+ MsgBuffer &round1b_next = route_state.round1b_next;
|
|
|
|
|
|
if (my_roles & ROLE_ROUTING) {
|
|
if (my_roles & ROLE_ROUTING) {
|
|
route_state.cbpointer = cbpointer;
|
|
route_state.cbpointer = cbpointer;
|
|
- pthread_mutex_lock(&round1b.mutex);
|
|
|
|
|
|
+ pthread_mutex_lock(&round1b_prev.mutex);
|
|
|
|
+ pthread_mutex_lock(&round1b_next.mutex);
|
|
// Ensure there are no pending messages currently being inserted
|
|
// Ensure there are no pending messages currently being inserted
|
|
- // into the buffer
|
|
|
|
- while (round1b.reserved != round1b.inserted) {
|
|
|
|
- pthread_mutex_unlock(&round1b.mutex);
|
|
|
|
- pthread_mutex_lock(&round1b.mutex);
|
|
|
|
|
|
+ // into the round 1b buffers
|
|
|
|
+ while (round1b_prev.reserved != round1b_prev.inserted) {
|
|
|
|
+ pthread_mutex_unlock(&round1b_prev.mutex);
|
|
|
|
+ pthread_mutex_lock(&round1b_prev.mutex);
|
|
|
|
+ }
|
|
|
|
+ while (round1b_next.reserved != round1b_next.inserted) {
|
|
|
|
+ pthread_mutex_unlock(&round1b_next.mutex);
|
|
|
|
+ pthread_mutex_lock(&round1b_next.mutex);
|
|
}
|
|
}
|
|
pthread_mutex_lock(&round1a.mutex);
|
|
pthread_mutex_lock(&round1a.mutex);
|
|
pthread_mutex_lock(&round1a_sorted.mutex);
|
|
pthread_mutex_lock(&round1a_sorted.mutex);
|
|
|
|
+
|
|
#ifdef PROFILE_ROUTING
|
|
#ifdef PROFILE_ROUTING
|
|
- uint32_t inserted = round1a.inserted;
|
|
|
|
- unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", inserted);
|
|
|
|
- // Sort the messages we've received
|
|
|
|
- unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
|
|
|
+ unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", round1a.inserted);
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
+ // sort round1b_prev msgs with initial msgs in round1a_sorted
|
|
|
|
+ if (prev_nodes > 0) {
|
|
|
|
+ // Copy initial msgs in round1a_sorted to round1b_prev buffer for sorting
|
|
|
|
+ // Note that all inserted values and buffer sizes are non-secret
|
|
|
|
+ uint32_t num_init_round1a = min(round1a.inserted,
|
|
|
|
+ max_round1b_msgs_to_adj_rtr);
|
|
|
|
+ uint32_t num_round1b_prev = round1b_prev.inserted;
|
|
|
|
+ if (num_round1b_prev + num_init_round1a <= max_round1b_msgs_to_adj_rtr) {
|
|
|
|
+ // all our round 1a messages "belong" to previous router and can be removed here
|
|
|
|
+ round1a.inserted = 0;
|
|
|
|
+ } else {
|
|
|
|
+ // copy initial round1a msgs after round1b_prev msgs
|
|
|
|
+ memmove(round1b_prev.buf+num_round1b_prev*msg_size, round1a_sorted.buf,
|
|
|
|
+ num_init_round1a*msg_size);
|
|
|
|
+ // sort and take final msgs as initial round1a msgs
|
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
|
|
|
|
+#endif
|
|
|
|
+ uint32_t num_copy = num_round1b_prev+num_init_round1a-max_round1b_msgs_to_adj_rtr;
|
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_prev.buf,
|
|
|
|
+ msg_size, num_round1b_prev + num_init_round1a,
|
|
|
|
+ 2*max_round1b_msgs_to_adj_rtr,
|
|
|
|
+ [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
|
|
|
|
+ return copy_msgs(round1a_sorted.buf, max_round1b_msgs_to_adj_rtr,
|
|
|
|
+ num_copy, src, indices);
|
|
|
|
+ }
|
|
|
|
+ );
|
|
|
|
+ round1a.inserted -= (max_round1b_msgs_to_adj_rtr-num_round1b_prev);
|
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
|
+ printf_with_rtclock_diff(start_sort, "end round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
|
|
|
|
+#endif
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // sort round1b_next msgs with final msgs in round1a_sorted
|
|
|
|
+ if ((prev_nodes < num_routing_nodes-1) && (round1b_next.inserted > 0)) {
|
|
|
|
+ // Copy final msgs in round1a_sorted to round1b_next buffer for sorting
|
|
|
|
+ // Note that all inserted values and buffer sizes are non-secret
|
|
|
|
+ // round1b_next.inserted>0, so round1a >= max_round1a_msgs-max_round1b_msgs_to_adj_rtr
|
|
|
|
+ uint32_t round1a_msg_start = max_round1a_msgs-max_round1b_msgs_to_adj_rtr;
|
|
|
|
+ uint32_t num_final_round1a = round1a.inserted - round1a_msg_start;
|
|
|
|
+ uint32_t num_round1b_next = round1b_next.inserted;
|
|
|
|
+ memmove(round1b_next.buf+num_round1b_next*msg_size,
|
|
|
|
+ round1a_sorted.buf + round1a_msg_start*msg_size,
|
|
|
|
+ num_final_round1a*msg_size);
|
|
|
|
+ // sort and take initial msgs as final round1a msgs
|
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
|
|
|
|
+#endif
|
|
|
|
+ uint32_t num_copy = min(num_final_round1a+num_round1b_next,
|
|
|
|
+ max_round1b_msgs_to_adj_rtr);
|
|
|
|
+ sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_next.buf,
|
|
|
|
+ msg_size, num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr,
|
|
|
|
+ [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
|
|
|
|
+ return copy_msgs(round1a_sorted.buf + round1a_msg_start*msg_size, 0,
|
|
|
|
+ num_copy, src, indices);
|
|
|
|
+ }
|
|
|
|
+ );
|
|
|
|
+ round1a.inserted += (num_copy - num_final_round1a);
|
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
|
+ printf_with_rtclock_diff(start_sort, "end round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
|
|
|
|
+#endif
|
|
|
|
+ }
|
|
|
|
+#ifdef PROFILE_ROUTING
|
|
|
|
+ unsigned long start_sort = printf_with_rtclock("begin full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
|
|
|
|
+#endif
|
|
// Sort received messages by increasing user ID and
|
|
// Sort received messages by increasing user ID and
|
|
// priority. Smaller priority number indicates higher priority.
|
|
// priority. Smaller priority number indicates higher priority.
|
|
- if (inserted > 0) {
|
|
|
|
|
|
+ if (round1a.inserted > 0) {
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a_sorted.buf,
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a_sorted.buf,
|
|
- g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
|
|
- send_round1c_msgs);
|
|
|
|
|
|
+ msg_size, round1a.inserted, route_state.max_round1a_msgs,
|
|
|
|
+ [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
|
|
|
|
+ send_round_robin_msgs<UidPriorityKey>(route_state.round1c, msgs, indices, N);
|
|
|
|
+ });
|
|
} else {
|
|
} else {
|
|
- send_round1c_msgs(NULL, NULL, 0);
|
|
|
|
|
|
+ send_round_robin_msgs<UidPriorityKey>(route_state.round1c, NULL, NULL, 0);
|
|
}
|
|
}
|
|
|
|
|
|
#ifdef PROFILE_ROUTING
|
|
#ifdef PROFILE_ROUTING
|
|
- printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
|
|
|
|
- printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", inserted);
|
|
|
|
|
|
+ printf_with_rtclock_diff(start_sort, "end full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
|
|
|
|
+ printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", round1a.inserted);
|
|
#endif
|
|
#endif
|
|
|
|
|
|
round1a.reset();
|
|
round1a.reset();
|
|
round1a_sorted.reset();
|
|
round1a_sorted.reset();
|
|
- round1b.reset();
|
|
|
|
|
|
+ round1b_prev.reset();
|
|
|
|
+ round1b_next.reset();
|
|
pthread_mutex_unlock(&round1a_sorted.mutex);
|
|
pthread_mutex_unlock(&round1a_sorted.mutex);
|
|
pthread_mutex_unlock(&round1a.mutex);
|
|
pthread_mutex_unlock(&round1a.mutex);
|
|
- pthread_mutex_unlock(&round1b.mutex);
|
|
|
|
|
|
+ pthread_mutex_unlock(&round1b_next.mutex);
|
|
|
|
+ pthread_mutex_unlock(&round1b_prev.mutex);
|
|
|
|
|
|
MsgBuffer &round1c = route_state.round1c;
|
|
MsgBuffer &round1c = route_state.round1c;
|
|
pthread_mutex_lock(&round1c.mutex);
|
|
pthread_mutex_lock(&round1c.mutex);
|
|
@@ -1087,7 +1141,7 @@ static void round1c_processing(void *cbpointer) {
|
|
}
|
|
}
|
|
} else {
|
|
} else {
|
|
route_state.step = ROUTE_ROUND_1C;
|
|
route_state.step = ROUTE_ROUND_1C;
|
|
- route_state.round1b.completed_prev_round = true;
|
|
|
|
|
|
+ route_state.round1c.completed_prev_round = true;
|
|
ocall_routing_round_complete(cbpointer, ROUND_1C);
|
|
ocall_routing_round_complete(cbpointer, ROUND_1C);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -1232,13 +1286,18 @@ void ecall_routing_proceed(void *cbpointer)
|
|
if (g_teems_config.private_routing) {
|
|
if (g_teems_config.private_routing) {
|
|
sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
g_teems_config.msg_size, ingbuf.inserted,
|
|
g_teems_config.msg_size, ingbuf.inserted,
|
|
- route_state.tot_msg_per_ing, send_round1_msgs<UidKey>);
|
|
|
|
|
|
+ route_state.tot_msg_per_ing,
|
|
|
|
+ [&](const uint8_t *msgs, const UidKey *indices, uint32_t N) {
|
|
|
|
+ send_round_robin_msgs<UidKey>(route_state.round1, msgs, indices, N);
|
|
|
|
+ });
|
|
} else {
|
|
} else {
|
|
// Sort received messages by increasing user ID and
|
|
// Sort received messages by increasing user ID and
|
|
// priority. Smaller priority number indicates higher priority.
|
|
// priority. Smaller priority number indicates higher priority.
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
|
|
g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing,
|
|
g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing,
|
|
- send_round1_msgs<UidPriorityKey>);
|
|
|
|
|
|
+ [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
|
|
|
|
+ send_round_robin_msgs<UidPriorityKey>(route_state.round1, msgs, indices, N);
|
|
|
|
+ });
|
|
}
|
|
}
|
|
#ifdef PROFILE_ROUTING
|
|
#ifdef PROFILE_ROUTING
|
|
printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|
|
printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
|