#include "Enclave_t.h" #include "config.hpp" #include "utils.hpp" #include "sort.hpp" #include "comms.hpp" #include "obliv.hpp" #include "storage.hpp" #include "route.hpp" #define PROFILE_ROUTING // #define TRACE_ROUTING RouteState route_state; // Computes ceil(x/y) where x and y are integers, x>=0, y>0. #define CEILDIV(x,y) (((x)+(y)-1)/(y)) #ifdef TRACE_ROUTING // Show (the first 300 of, if there are more) the headers and the first // few bytes of the body of each message in the buffer static void show_messages(const char *label, const unsigned char *buffer, size_t num) { if (label) { printf("%s\n", label); } for (size_t i=0; i max_round1_msgs) { max_msg_to_each_stg = max_round1_msgs; } // And the max total number of outgoing messages in round 2 is then uint32_t max_round2_msgs = max_msg_to_each_stg * g_teems_config.num_storage_nodes; // In case we have a weird configuration where users can send more // messages per epoch than they can receive, ensure the round 2 // buffer is large enough to hold the incoming messages as well if (max_round2_msgs < max_round1_msgs) { max_round2_msgs = max_round1_msgs; } // The max number of messages that can arrive at a storage server uint32_t max_stg_msgs; max_stg_msgs = (tot_msg_per_stg/g_teems_config.tot_weight + g_teems_config.tot_weight) * g_teems_config.tot_weight; // Calculating ID channel buffer sizes // Weights are not used in ID channel routing // Round up to a multiple of num_routing_nodes uint32_t max_round1b_msgs_to_adj_rtr = CEILDIV( (g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1), g_teems_config.num_routing_nodes) * g_teems_config.num_routing_nodes; // Ensure columnroute constraint that column height is >= 2*(num_routing_nodes-1)^2 // and a multiple of num_routing_nodes uint32_t max_round1a_msgs = CEILDIV( std::max(max_round1_msgs, 2*max_round1b_msgs_to_adj_rtr), g_teems_config.num_routing_nodes) * g_teems_config.num_routing_nodes; uint32_t max_round1c_msgs = std::max(max_round1a_msgs, max_round2_msgs); /* printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs); */ #ifdef TRACK_HEAP_USAGE printf("route_init H1 heap %u\n", g_peak_heap_used); #endif // Create the route state uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; try { if (my_roles & ROLE_INGESTION) { route_state.ingbuf.alloc(tot_msg_per_ing); } #ifdef TRACK_HEAP_USAGE printf("route_init alloc %u msgs\n", tot_msg_per_ing); printf("route_init H2 heap %u\n", g_peak_heap_used); #endif if (my_roles & ROLE_ROUTING) { route_state.round1.alloc(max_round2_msgs); #ifdef TRACK_HEAP_USAGE printf("route_init alloc %u msgs\n", max_round2_msgs); printf("route_init H3 heap %u\n", g_peak_heap_used); #endif if (!g_teems_config.token_channel) { route_state.round1a.alloc(max_round1a_msgs); route_state.round1a_sorted.alloc(max_round1a_msgs + max_round1b_msgs_to_adj_rtr); // double round 1b buffers to sort with some round 1a messages route_state.round1b_next.alloc(2*max_round1b_msgs_to_adj_rtr); route_state.round1c.alloc(max_round1c_msgs); #ifdef TRACK_HEAP_USAGE printf("route_init alloc %u msgs\n", max_round1a_msgs); printf("route_init alloc %u msgs\n", max_round1a_msgs + max_round1b_msgs_to_adj_rtr); printf("route_init alloc %u msgs\n", 2*max_round1b_msgs_to_adj_rtr); printf("route_init alloc %u msgs\n", max_round1c_msgs); printf("route_init H4 heap %u\n", g_peak_heap_used); #endif } } if (my_roles & ROLE_STORAGE) { route_state.round2.alloc(max_stg_msgs); #ifdef TRACK_HEAP_USAGE printf("route_init alloc %u msgs\n", max_stg_msgs); printf("route_init H5 heap %u\n", g_peak_heap_used); #endif if (!storage_init(users_per_stg, max_stg_msgs)) { return false; } #ifdef TRACK_HEAP_USAGE printf("storage_init(%u,%u)\n", users_per_stg, max_stg_msgs); printf("route_init H6 heap %u\n", g_peak_heap_used); #endif } } catch (std::bad_alloc&) { printf("Memory allocation failed in route_init\n"); return false; } route_state.step = ROUTE_NOT_STARTED; route_state.tot_msg_per_ing = tot_msg_per_ing; route_state.max_round1_msgs = max_round1_msgs; route_state.max_round1a_msgs = max_round1a_msgs; route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr; route_state.max_round1c_msgs = max_round1c_msgs; route_state.max_msg_to_each_stg = max_msg_to_each_stg; route_state.max_round2_msgs = max_round2_msgs; route_state.max_stg_msgs = max_stg_msgs; route_state.cbpointer = NULL; threadid_t nthreads = g_teems_config.nthreads; #ifdef PROFILE_ROUTING unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads); #endif if (my_roles & ROLE_INGESTION) { #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) start heap %u\n", tot_msg_per_ing, g_peak_heap_used); #endif sort_precompute_evalplan(tot_msg_per_ing, nthreads); #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) end heap %u\n", tot_msg_per_ing, g_peak_heap_used); #endif } if (my_roles & ROLE_ROUTING) { #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) start heap %u\n", max_round2_msgs, g_peak_heap_used); #endif sort_precompute_evalplan(max_round2_msgs, nthreads); #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) end heap %u\n", max_round2_msgs, g_peak_heap_used); #endif if(!g_teems_config.token_channel) { #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) start heap %u\n", max_round1a_msgs, g_peak_heap_used); #endif sort_precompute_evalplan(max_round1a_msgs, nthreads); #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) end heap %u\n", max_round1a_msgs, g_peak_heap_used); printf("sort_precompute_evalplan(%u) start heap %u\n", max_round1b_msgs_to_adj_rtr, g_peak_heap_used); #endif sort_precompute_evalplan(2*max_round1b_msgs_to_adj_rtr, nthreads); #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) end heap %u\n", max_round1b_msgs_to_adj_rtr, g_peak_heap_used); #endif } } if (my_roles & ROLE_STORAGE) { #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) start heap %u\n", max_stg_msgs, g_peak_heap_used); #endif sort_precompute_evalplan(max_stg_msgs, nthreads); #ifdef TRACK_HEAP_USAGE printf("sort_precompute_evalplan(%u) end heap %u\n", max_stg_msgs, g_peak_heap_used); #endif } #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end precompute evalplans\n"); #endif #ifdef TRACK_HEAP_USAGE printf("route_init end heap %u\n", g_peak_heap_used); #endif return true; } // Call when shutting system down to deallocate routing state void route_close() { uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; if (my_roles & ROLE_STORAGE) { storage_close(); } } // Precompute the WaksmanNetworks needed for the sorts. If you pass -1, // it will return the number of different sizes it needs to regenerate. // If you pass [0,sizes-1], it will compute one WaksmanNetwork with that // size index and return the number of available WaksmanNetworks of that // size. If you pass anything else, it will return the number of // different sizes it needs at all. // The list of sizes that need refilling, updated when you pass -1 static std::vector used_sizes; size_t ecall_precompute_sort(int sizeidx) { size_t ret = 0; if (sizeidx == -1) { used_sizes = sort_get_used(); ret = used_sizes.size(); } else if (sizeidx >= 0 && sizeidx < used_sizes.size()) { uint32_t size = used_sizes[sizeidx]; #ifdef TRACK_HEAP_USAGE printf("ecall_precompute_sort start heap %u\n", g_peak_heap_used); #endif #ifdef PROFILE_ROUTING unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", size); #endif ret = sort_precompute(size); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", size); #endif #ifdef TRACK_HEAP_USAGE printf("ecall_precompute_sort end heap %u\n", g_peak_heap_used); #endif } else { uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; if (my_roles & ROLE_INGESTION) { used_sizes.push_back(route_state.tot_msg_per_ing); } if (my_roles & ROLE_ROUTING) { used_sizes.push_back(route_state.max_round2_msgs); if(!g_teems_config.token_channel) { used_sizes.push_back(route_state.max_round1a_msgs); used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr); } } if (my_roles & ROLE_STORAGE) { used_sizes.push_back(route_state.max_stg_msgs); if(!g_teems_config.token_channel) { used_sizes.push_back(route_state.max_stg_msgs); } } ret = used_sizes.size(); } return ret; } static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf, NodeCommState &, uint32_t tot_enc_chunk_size) { uint16_t msg_size = g_teems_config.msg_size; // Chunks will be encrypted and have a MAC tag attached which will // not correspond to plaintext bytes, so we can trim them. // The minimum number of chunks needed to transmit this message uint32_t min_num_chunks = (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE; // The number of plaintext bytes this message could contain uint32_t plaintext_bytes = tot_enc_chunk_size - SGX_AESGCM_MAC_SIZE * min_num_chunks; uint32_t num_msgs = CEILDIV(plaintext_bytes, uint32_t(msg_size)); pthread_mutex_lock(&msgbuf.mutex); uint32_t start = msgbuf.reserved; if (start + num_msgs > msgbuf.bufsize) { pthread_mutex_unlock(&msgbuf.mutex); printf("Max %u messages exceeded (msgbuffer_get_buf)\n", msgbuf.bufsize); return NULL; } msgbuf.reserved += num_msgs; pthread_mutex_unlock(&msgbuf.mutex); return msgbuf.buf + start * msg_size; } static void round1a_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t); static void round1b_next_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t); static void round1c_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t); static void round2_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t); // A round 1 message was received by a routing node from an ingestion // node; we put it into the round 2 buffer for processing in round 2 static void round1_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round1.mutex); route_state.round1.inserted += num_msgs; route_state.round1.nodes_received += 1; nodenum_t nodes_received = route_state.round1.nodes_received; bool completed_prev_round = route_state.round1.completed_prev_round; pthread_mutex_unlock(&route_state.round1.mutex); // What is the next message we expect from this node? if (g_teems_config.token_channel) { if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round2, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round2_received; } // Otherwise, it's just the next round 1 message, so don't change // the handlers. } else { if (their_roles & ROLE_ROUTING) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1a, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1a_received; } // Otherwise, it's just the next round 1 message, so don't change // the handlers. } if (nodes_received == g_teems_config.num_ingestion_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_1; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 1); } } // A round 1a message was received by a routing node static void round1a_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); pthread_mutex_lock(&route_state.round1a.mutex); route_state.round1a.inserted += num_msgs; route_state.round1a.nodes_received += 1; nodenum_t nodes_received = route_state.round1a.nodes_received; bool completed_prev_round = route_state.round1a.completed_prev_round; pthread_mutex_unlock(&route_state.round1a.mutex); // Both are routing nodes // We only expect a message from the next node (if it exists) nodenum_t my_node_num = g_teems_config.my_node_num; nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes; uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight; if ((prev_nodes < num_routing_nodes-1) && (nodest.node_num == g_teems_config.routing_nodes[1])) { // Node is next routing node nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1b_next, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1b_next_received; } else { // other routing nodes will not send to this node until round 1c nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1c, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1c_received; } if (nodes_received == g_teems_config.num_routing_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_1A; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1A); } } // A round 1b message was received from the next routing node static void round1b_next_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; // There are an extra 5 bytes at the end of this message, containing // the next receiver id (4 bytes) and the count of messages with // that receiver id (1 byte) assert((plaintext_len % uint32_t(msg_size)) == 5); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round1b_next.mutex); // Add an extra 1 for the message space taken up by the above 5 // bytes route_state.round1b_next.inserted += num_msgs + 1; route_state.round1b_next.nodes_received += 1; nodenum_t nodes_received = route_state.round1b_next.nodes_received; bool completed_prev_round = route_state.round1b_next.completed_prev_round; pthread_mutex_unlock(&route_state.round1b_next.mutex); nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1c, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1c_received; nodenum_t my_node_num = g_teems_config.my_node_num; uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight; nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes; nodenum_t adjacent_nodes; if (num_routing_nodes == 1) { adjacent_nodes = 0; } else if (prev_nodes == num_routing_nodes-1) { adjacent_nodes = 0; } else { adjacent_nodes = 1; } if (nodes_received == adjacent_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_1B; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1B); } } // Message received in round 1c static void round1c_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round1c.mutex); route_state.round1c.inserted += num_msgs; route_state.round1c.nodes_received += 1; nodenum_t nodes_received = route_state.round1c.nodes_received; bool completed_prev_round = route_state.round1c.completed_prev_round; pthread_mutex_unlock(&route_state.round1c.mutex); // What is the next message we expect from this node? if (our_roles & ROLE_STORAGE) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round2, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round2_received; } else if (their_roles & ROLE_INGESTION) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1_received; } else { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1a, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1a_received; } if (nodes_received == g_teems_config.num_routing_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_1C; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1C); } } // A round 2 message was received by a storage node from a routing node static void round2_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round2.mutex); route_state.round2.inserted += num_msgs; route_state.round2.nodes_received += 1; nodenum_t nodes_received = route_state.round2.nodes_received; bool completed_prev_round = route_state.round2.completed_prev_round; pthread_mutex_unlock(&route_state.round2.mutex); // What is the next message we expect from this node? if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1_received; } // Otherwise, it's just the next round 2 message, so don't change // the handlers. if (nodes_received == g_teems_config.num_routing_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_2; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 2); } } // For a given other node, set the received message handler to the first // message we would expect from them, given their roles and our roles. void route_init_msg_handler(nodenum_t node_num) { // Our roles and their roles uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[node_num]; // The node communication state NodeCommState &nodest = g_commstates[node_num]; // If we are a routing node (possibly among other roles) and they // are an ingestion node (possibly among other roles), a round 1 // routing message is the first thing we expect from them. We put // these messages into the round1 buffer for processing. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1_received; } // Otherwise, if we are a storage node (possibly among other roles) // and they are a routing node (possibly among other roles), a round // 2 routing message is the first thing we expect from them else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round2, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round2_received; } // Otherwise, we don't expect a message from this node. Set the // unknown message handler. else { nodest.in_msg_get_buf = default_in_msg_get_buf; nodest.in_msg_received = unknown_in_msg_received; } } // Directly ingest a buffer of num_msgs messages into the ingbuf buffer. // Return true on success, false on failure. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs) { uint16_t msg_size = g_teems_config.msg_size; MsgBuffer &ingbuf = route_state.ingbuf; pthread_mutex_lock(&ingbuf.mutex); uint32_t start = ingbuf.reserved; if (start + num_msgs > route_state.tot_msg_per_ing) { pthread_mutex_unlock(&ingbuf.mutex); printf("Max %u messages exceeded (ecall_ingest_raw)\n", route_state.tot_msg_per_ing); return false; } ingbuf.reserved += num_msgs; pthread_mutex_unlock(&ingbuf.mutex); memmove(ingbuf.buf + start * msg_size, msgs, num_msgs * msg_size); pthread_mutex_lock(&ingbuf.mutex); ingbuf.inserted += num_msgs; pthread_mutex_unlock(&ingbuf.mutex); return true; } // Send messages round-robin, used in rounds 1 and 1c. Note that N here // is not private. Pass indices = nullptr to just send the messages in // the order they appear in the msgs buffer. template static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T *indices, uint32_t N) { uint16_t msg_size = g_teems_config.msg_size; uint16_t tot_weight; tot_weight = g_teems_config.tot_weight; nodenum_t my_node_num = g_teems_config.my_node_num; uint32_t full_rows; uint32_t last_row; full_rows = N / uint32_t(tot_weight); last_row = N % uint32_t(tot_weight); for (auto &routing_node: g_teems_config.routing_nodes) { uint8_t weight = g_teems_config.weights[routing_node].weight; if (weight == 0) { // This shouldn't happen, but just in case continue; } uint16_t start_weight = g_teems_config.weights[routing_node].startweight; // The number of messages headed for this routing node from the // full rows uint32_t num_msgs_full_rows = full_rows * uint32_t(weight); // The number of messages headed for this routing node from the // incomplete last row is: // 0 if last_row < start_weight // last_row-start_weight if start_weight <= last_row < start_weight + weight // weight if start_weight + weight <= last_row uint32_t num_msgs_last_row = 0; if (start_weight <= last_row && last_row < start_weight + weight) { num_msgs_last_row = last_row-start_weight; } else if (start_weight + weight <= last_row) { num_msgs_last_row = weight; } // The total number of messages headed for this routing node uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row; if (routing_node == my_node_num) { // Special case: we're sending to ourselves; just put the // messages in our own buffer pthread_mutex_lock(&round.mutex); uint32_t start = round.reserved; if (start + num_msgs > round.bufsize) { pthread_mutex_unlock(&round.mutex); printf("Max %u messages exceeded (send_round_robin_msgs)\n", round.bufsize); return; } round.reserved += num_msgs; pthread_mutex_unlock(&round.mutex); uint8_t *buf = round.buf + start * msg_size; for (uint32_t i=0; i= extra_msgs) { start_msg = min_msgs_per_node * prev_nodes + extra_msgs; num_msgs = min_msgs_per_node; } else { start_msg = min_msgs_per_node * prev_nodes + prev_nodes; num_msgs = min_msgs_per_node + 1; } // take number of messages into account if (start_msg >= N) { num_msgs = 0; } else if (start_msg + num_msgs > N) { num_msgs = N - start_msg; } if (routing_node == my_node_num) { // Special case: we're sending to ourselves; just put the // messages in our own buffer MsgBuffer &round1a = route_state.round1a; pthread_mutex_lock(&round1a.mutex); uint32_t start = round1a.reserved; if (start + num_msgs > round1a.bufsize) { pthread_mutex_unlock(&round1a.mutex); printf("Max %u messages exceeded in round 1a\n", round1a.bufsize); return; } round1a.reserved += num_msgs; pthread_mutex_unlock(&round1a.mutex); uint8_t *buf = round1a.buf + start * msg_size; for (uint32_t i=0; iindex()*msg_size, msg_size); buf += msg_size; } pthread_mutex_lock(&round1a.mutex); round1a.inserted += num_msgs; round1a.nodes_received += 1; pthread_mutex_unlock(&round1a.mutex); } else { NodeCommState &nodecom = g_commstates[routing_node]; nodecom.message_start(num_msgs * msg_size); for (uint32_t i=0; iindex()*msg_size, msg_size); } } } } // Send the round 1b messages from the round 1a buffer, which only occurs in ID-channel routing. // msgs points to the message buffer, and N is the number of non-padding items. // Return the number of messages sent static uint32_t send_round1b_msgs(const uint8_t *msgs, uint32_t N) { uint16_t msg_size = g_teems_config.msg_size; nodenum_t my_node_num = g_teems_config.my_node_num; nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes; uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight; // send to previous node if (prev_nodes > 0) { nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1]; NodeCommState &nodecom = g_commstates[prev_node]; uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr); // There are an extra 5 bytes at the end of this message: 4 // bytes for the receiver id in the next message we _didn't_ // send, and 1 byte for the number of messages we have at the // beginning of the buffer of messages we didn't send (max // id_in) with the same receiver id nodecom.message_start(num_msgs * msg_size + 5); nodecom.message_data(msgs, num_msgs * msg_size); uint32_t next_receiver_id = 0xffffffff; uint8_t next_rid_count = 0; // num_msgs and N are not private, but the contents of the // buffer are. if (num_msgs < N) { next_receiver_id = *(const uint32_t *)(msgs + num_msgs * msg_size); next_rid_count = 1; // If id_in > 1, obliviously scan messages num_msgs+1 .. // num_msgs+(id_in-1) and as long as they have the same // receiver id as next_receiver_id, add 1 to next_rid_count (but // don't go past message N of course) // This count _includes_ the first message already scanned // above. It is not private. uint8_t num_to_scan = uint8_t(std::min(N - num_msgs, uint32_t(g_teems_config.m_id_in))); const unsigned char *scan_msg = msgs + (num_msgs + 1) * msg_size; for (uint8_t i=1; i round2.bufsize) { pthread_mutex_unlock(&round2.mutex); printf("Max %u messages exceeded (send_round2_msgs)\n", round2.bufsize); return; } round2.reserved += num_msgs_per_stg; pthread_mutex_unlock(&round2.mutex); myself_buf = round2.buf + start * msg_size; } } while (tot_msgs) { nodenum_t storage_node_id = nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS); if (storage_node_id < num_storage_nodes) { nodenum_t node = g_teems_config.storage_map[storage_node_id]; if (node == my_node_num) { memmove(myself_buf, buf, msg_size); myself_buf += msg_size; } else { g_commstates[node].message_data(buf, msg_size); } } buf += msg_size; --tot_msgs; } if (myself_buf) { MsgBuffer &round2 = route_state.round2; pthread_mutex_lock(&round2.mutex); round2.inserted += num_msgs_per_stg; round2.nodes_received += 1; pthread_mutex_unlock(&round2.mutex); } } static void round1a_processing(void *cbpointer) { uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; MsgBuffer &round1 = route_state.round1; if (my_roles & ROLE_ROUTING) { route_state.cbpointer = cbpointer; pthread_mutex_lock(&round1.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (round1.reserved != round1.inserted) { pthread_mutex_unlock(&round1.mutex); pthread_mutex_lock(&round1.mutex); } #ifdef TRACE_ROUTING show_messages("Start of round 1a", round1.buf, round1.inserted); #endif #ifdef PROFILE_ROUTING uint32_t inserted = round1.inserted; unsigned long start_round1a = printf_with_rtclock("begin round1a processing (%u)\n", inserted); // Sort the messages we've received unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs); #endif // Sort received messages by increasing user ID and // priority. Larger priority number indicates higher priority. sort_mtobliv(g_teems_config.nthreads, round1.buf, g_teems_config.msg_size, round1.inserted, route_state.max_round1_msgs, send_round1a_msgs); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs); printf_with_rtclock_diff(start_round1a, "end round1a processing (%u)\n", inserted); #endif round1.reset(); pthread_mutex_unlock(&round1.mutex); MsgBuffer &round1a = route_state.round1a; pthread_mutex_lock(&round1a.mutex); round1a.completed_prev_round = true; nodenum_t nodes_received = round1a.nodes_received; pthread_mutex_unlock(&round1a.mutex); if (nodes_received == g_teems_config.num_routing_nodes) { route_state.step = ROUTE_ROUND_1A; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1A); } } else { route_state.step = ROUTE_ROUND_1A; route_state.round1a.completed_prev_round = true; ocall_routing_round_complete(cbpointer, ROUND_1A); } } static void round1b_processing(void *cbpointer) { uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; nodenum_t my_node_num = g_teems_config.my_node_num; uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight; nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes; MsgBuffer &round1a = route_state.round1a; MsgBuffer &round1a_sorted = route_state.round1a_sorted; if (my_roles & ROLE_ROUTING) { route_state.cbpointer = cbpointer; pthread_mutex_lock(&round1a.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (round1a.reserved != round1a.inserted) { pthread_mutex_unlock(&round1a.mutex); pthread_mutex_lock(&round1a.mutex); } pthread_mutex_lock(&round1a_sorted.mutex); #ifdef TRACE_ROUTING show_messages("Start of round 1b", round1a.buf, round1a.inserted); #endif #ifdef PROFILE_ROUTING uint32_t inserted = round1a.inserted; unsigned long start_round1b = printf_with_rtclock("begin round1b processing (%u)\n", inserted); #endif // Sort the messages we've received // Sort received messages by increasing user ID and // priority. Larger priority number indicates higher priority. if (inserted > 0) { // copy items in sorted order into round1a_sorted #ifdef PROFILE_ROUTING unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs); #endif route_state.round1a_sorted.reserved = round1a.inserted; sort_mtobliv(g_teems_config.nthreads, round1a.buf, g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs, route_state.round1a_sorted.buf); route_state.round1a_sorted.inserted = round1a.inserted; #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs); #endif #ifdef TRACE_ROUTING show_messages("In round 1b (sorted)", round1a_sorted.buf, round1a_sorted.inserted); #endif uint32_t num_sent = send_round1b_msgs(round1a_sorted.buf, round1a_sorted.inserted); // Remove the sent messages from the buffer memmove(round1a_sorted.buf, round1a_sorted.buf + num_sent * g_teems_config.msg_size, (round1a_sorted.inserted - num_sent) * g_teems_config.msg_size); round1a_sorted.inserted -= num_sent; round1a_sorted.reserved -= num_sent; #ifdef TRACE_ROUTING show_messages("In round 1b (after sending initial block)", round1a_sorted.buf, round1a_sorted.inserted); #endif } else { send_round1b_msgs(NULL, 0); } #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n", inserted); #endif pthread_mutex_unlock(&round1a_sorted.mutex); pthread_mutex_unlock(&round1a.mutex); MsgBuffer &round1b_next = route_state.round1b_next; pthread_mutex_lock(&round1b_next.mutex); round1b_next.completed_prev_round = true; nodenum_t nodes_received = round1b_next.nodes_received; pthread_mutex_unlock(&round1b_next.mutex); nodenum_t adjacent_nodes; if (num_routing_nodes == 1) { adjacent_nodes = 0; } else if (prev_nodes == num_routing_nodes-1) { adjacent_nodes = 0; } else { adjacent_nodes = 1; } if (nodes_received == adjacent_nodes) { route_state.step = ROUTE_ROUND_1B; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1B); } } else { route_state.step = ROUTE_ROUND_1B; route_state.round1b_next.completed_prev_round = true; ocall_routing_round_complete(cbpointer, ROUND_1B); } } static void copy_msgs(uint8_t *dst, uint32_t start_msg, uint32_t num_copy, const uint8_t *src, const UidPriorityKey *indices) { uint16_t msg_size = g_teems_config.msg_size; const UidPriorityKey *idxp = indices + start_msg; uint8_t *buf = dst; for (uint32_t i=0; i= 1) { next_receiver_id = *(uint32_t *)(round1b_next.buf + (round1b_next.inserted-1)*msg_size); next_rid_count = *(round1b_next.buf + (round1b_next.inserted-1)*msg_size + 4); round1b_next.inserted -= 1; round1b_next.reserved -= 1; } pthread_mutex_lock(&round1a.mutex); pthread_mutex_lock(&round1a_sorted.mutex); #ifdef TRACE_ROUTING show_messages("Start of round 1c", round1a_sorted.buf, round1a_sorted.inserted); #endif #ifdef PROFILE_ROUTING unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", round1a.inserted); #endif // sort round1b_next msgs with final msgs in round1a_sorted if ((prev_nodes < num_routing_nodes-1) && (round1b_next.inserted > 0)) { // Copy final msgs in round1a_sorted to round1b_next buffer for sorting // Note that all inserted values and buffer sizes are non-secret uint32_t num_final_round1a = std::min( round1a_sorted.inserted, max_round1b_msgs_to_adj_rtr); uint32_t round1a_msg_start = round1a_sorted.inserted-num_final_round1a; uint32_t num_round1b_next = round1b_next.inserted; round1b_next.reserved += num_final_round1a; memmove(round1b_next.buf+num_round1b_next*msg_size, round1a_sorted.buf + round1a_msg_start*msg_size, num_final_round1a*msg_size); round1b_next.inserted += num_final_round1a; round1a_sorted.reserved -= num_final_round1a; round1a_sorted.inserted -= num_final_round1a; #ifdef TRACE_ROUTING show_messages("In round 1c (after setting aside)", round1a_sorted.buf, round1a_sorted.inserted); show_messages("In round 1c (the aside)", round1b_next.buf, round1b_next.inserted); #endif // sort and append to round1a msgs uint32_t num_copy = round1b_next.inserted; #ifdef PROFILE_ROUTING unsigned long start_sort = printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n", num_copy, 2*max_round1b_msgs_to_adj_rtr); #endif round1a_sorted.reserved += num_copy; sort_mtobliv(g_teems_config.nthreads, round1b_next.buf, msg_size, num_copy, 2*max_round1b_msgs_to_adj_rtr, [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) { return copy_msgs(round1a_sorted.buf + round1a_sorted.inserted * msg_size, 0, num_copy, src, indices); } ); round1a_sorted.inserted += num_copy; #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_sort, "end round1b_next oblivious sort (%u,%u)\n", num_copy, 2*max_round1b_msgs_to_adj_rtr); #endif } #ifdef TRACE_ROUTING printf("round1a_sorted.inserted = %lu, reserved = %lu, bufsize = %lu\n", round1a_sorted.inserted, round1a_sorted.reserved, round1a_sorted.bufsize); show_messages("In round 1c before padding pass", round1a_sorted.buf, round1a_sorted.inserted); #endif // The round1a_sorted buffer is now sorted by receiver uid and // priority. Going from the end of the buffer to the beginning // (so as to encounter and keep the highest-priority messages // for any given receiver first), obliviously turn any messages // over the limit of id_in for any given receiver into padding. // Also keep track of which messages are not padding for use in // later compaction. bool *is_not_padding = new bool[round1a_sorted.inserted]; for (uint32_t i=0; i= id_in: // next_receiver_id = receiver_id // next_rid_count = next_rid_count // become_padding = 1 bool same_receiver_id = (receiver_id == next_receiver_id); // If same_receiver_id is 0, reset next_rid_count to 0. // If same_receiver_id is 1, don't change next_rid_count. // This method (AND with -same_receiver_id) is more likely // to be constant time than multiplying by same_receiver_id. next_rid_count &= (-(uint32_t(same_receiver_id))); bool become_padding = (next_rid_count >= id_in); next_rid_count += !become_padding; next_receiver_id = receiver_id; // Obliviously change the receiver id to 0xffffffff // (padding) if become-padding is 1 receiver_id |= (-(uint32_t(become_padding))); *(uint32_t*)header = receiver_id; is_not_padding[round1a_sorted.inserted - 1 - i] = (receiver_id != 0xffffffff); } #ifdef TRACE_ROUTING show_messages("In round 1c after padding pass", round1a_sorted.buf, round1a_sorted.inserted); #endif // Oblivious compaction to move the padding messages to the end, // preserving the (already-sorted) order of the non-padding // messages TightCompact_parallel( (unsigned char *) round1a_sorted.buf, round1a_sorted.inserted, msg_size, is_not_padding, g_teems_config.nthreads); delete[] is_not_padding; #ifdef TRACE_ROUTING show_messages("In round 1c after compaction", round1a_sorted.buf, round1a_sorted.inserted); #endif send_round_robin_msgs(route_state.round1c, round1a_sorted.buf, NULL, round1a_sorted.inserted); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", round1a.inserted); #endif round1a.reset(); round1a_sorted.reset(); round1b_next.reset(); pthread_mutex_unlock(&round1a_sorted.mutex); pthread_mutex_unlock(&round1a.mutex); pthread_mutex_unlock(&round1b_next.mutex); MsgBuffer &round1c = route_state.round1c; pthread_mutex_lock(&round1c.mutex); round1c.completed_prev_round = true; nodenum_t nodes_received = round1c.nodes_received; pthread_mutex_unlock(&round1c.mutex); if (nodes_received == num_routing_nodes) { route_state.step = ROUTE_ROUND_1C; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, ROUND_1C); } } else { route_state.step = ROUTE_ROUND_1C; route_state.round1c.completed_prev_round = true; ocall_routing_round_complete(cbpointer, ROUND_1C); } } // Process messages in round 2 static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prevround) { if (my_roles & ROLE_ROUTING) { route_state.cbpointer = cbpointer; pthread_mutex_lock(&prevround.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (prevround.reserved != prevround.inserted) { pthread_mutex_unlock(&prevround.mutex); pthread_mutex_lock(&prevround.mutex); } // If the _total_ number of messages we received in round 1 // is less than the max number of messages we could send to // _each_ storage node, then cap the number of messages we // will send to each storage node to that number. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg; if (prevround.inserted < msgs_per_stg) { msgs_per_stg = prevround.inserted; } // Note: at this point, it is required that each message in // the prevround buffer have a _valid_ storage node id field. // Obliviously tally the number of messages we received in // the previous round destined for each storage node #ifdef TRACE_ROUTING show_messages("Start of round 2", prevround.buf, prevround.inserted); #endif #ifdef PROFILE_ROUTING unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize); unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", prevround.inserted); #endif uint16_t msg_size = g_teems_config.msg_size; nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes; std::vector tally = obliv_tally_stg( prevround.buf, msg_size, prevround.inserted, num_storage_nodes); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_tally, "end tally (%u)\n", prevround.inserted); #endif // Note: tally contains private values! It's OK to // non-obliviously check for an error condition, though. // While we're at it, obliviously change the tally of // messages received to a tally of padding messages // required. uint32_t tot_messages = msgs_per_stg * num_storage_nodes; for (nodenum_t i=0; i msgs_per_stg) { printf("Received too many messages for storage node " "%u (%u > %u)\n", i, tally[i], msgs_per_stg); assert(tally[i] <= msgs_per_stg); } tally[i] = msgs_per_stg - tally[i]; } // Allocate extra padding messages (not yet destined for a // particular storage node) assert(prevround.inserted <= tot_messages && tot_messages <= prevround.bufsize); prevround.reserved = tot_messages; for (uint32_t i=prevround.inserted; i(g_teems_config.nthreads, ingbuf.buf, g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing, [&](const uint8_t *msgs, const UidKey *indices, uint32_t N) { send_round_robin_msgs(route_state.round1, msgs, indices, N); }); } else { // Sort received messages by increasing user ID and // priority. Larger priority number indicates higher priority. sort_mtobliv(g_teems_config.nthreads, ingbuf.buf, g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing, [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) { send_round_robin_msgs(route_state.round1, msgs, indices, N); }); } #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing); printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted); #endif ingbuf.reset(); pthread_mutex_unlock(&ingbuf.mutex); } if (my_roles & ROLE_ROUTING) { MsgBuffer &round1 = route_state.round1; pthread_mutex_lock(&round1.mutex); round1.completed_prev_round = true; nodenum_t nodes_received = round1.nodes_received; pthread_mutex_unlock(&round1.mutex); if (nodes_received == g_teems_config.num_ingestion_nodes) { route_state.step = ROUTE_ROUND_1; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 1); } } else { route_state.step = ROUTE_ROUND_1; route_state.round1.completed_prev_round = true; ocall_routing_round_complete(cbpointer, 1); } } else if (route_state.step == ROUTE_ROUND_1) { if (g_teems_config.token_channel) { // Token channel routing next round round2_processing(my_roles, cbpointer, route_state.round1); } else { // ID channel routing next round round1a_processing(cbpointer); } } else if (route_state.step == ROUTE_ROUND_1A) { round1b_processing(cbpointer); } else if (route_state.step == ROUTE_ROUND_1B) { round1c_processing(cbpointer); } else if (route_state.step == ROUTE_ROUND_1C) { round2_processing(my_roles, cbpointer, route_state.round1c); } else if (route_state.step == ROUTE_ROUND_2) { if (my_roles & ROLE_STORAGE) { MsgBuffer &round2 = route_state.round2; pthread_mutex_lock(&round2.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (round2.reserved != round2.inserted) { pthread_mutex_unlock(&round2.mutex); pthread_mutex_lock(&round2.mutex); } unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted); storage_received(round2); printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted); // We're done route_state.step = ROUTE_NOT_STARTED; ocall_routing_round_complete(cbpointer, 0); } else { // We're done route_state.step = ROUTE_NOT_STARTED; ocall_routing_round_complete(cbpointer, 0); } } #ifdef TRACK_HEAP_USAGE printf("ecall_routing_proceed end heap %u\n", g_peak_heap_used); #endif }