#include "Enclave_t.h" #include "config.hpp" #include "utils.hpp" #include "sort.hpp" #include "comms.hpp" #include "obliv.hpp" #include "storage.hpp" #include "route.hpp" #define PROFILE_ROUTING enum RouteStep { ROUTE_NOT_STARTED, ROUTE_ROUND_1, ROUTE_ROUND_2 }; // The ingbuf MsgBuffer stores messages an ingestion node ingests while // waiting for round 1 to start, which will be sorted and sent out in // round 1. The round1 MsgBuffer stores messages a routing node // receives in round 1, which will be padded, sorted, and sent out in // round 2. The round2 MsgBuffer stores messages a storage node // receives in round 2. static struct RouteState { MsgBuffer ingbuf; MsgBuffer round1; MsgBuffer round2; RouteStep step; uint32_t tot_msg_per_ing; uint32_t max_msg_to_each_stg; uint32_t max_round2_msgs; void *cbpointer; } route_state; // Computes ceil(x/y) where x and y are integers, x>=0, y>0. #define CEILDIV(x,y) (((x)+(y)-1)/(y)) // Call this near the end of ecall_config_load, but before // comms_init_nodestate. Returns true on success, false on failure. bool route_init() { // Compute the maximum number of messages we could receive by direct // ingestion // Each ingestion node will have at most // ceil(user_count/num_ingestion_nodes) users, and each user will // send at most m_priv_out messages. uint32_t users_per_ing = CEILDIV(g_teems_config.user_count, g_teems_config.num_ingestion_nodes); uint32_t tot_msg_per_ing = users_per_ing * g_teems_config.m_priv_out; // Compute the maximum number of messages we could receive in round 1 // Each ingestion node will send us an our_weight/tot_weight // fraction of the messages they hold uint32_t max_msg_from_each_ing = CEILDIV(tot_msg_per_ing, g_teems_config.tot_weight) * g_teems_config.my_weight; // And the maximum number we can receive in total is that times the // number of ingestion nodes uint32_t max_round1_msgs = max_msg_from_each_ing * g_teems_config.num_ingestion_nodes; // Compute the maximum number of messages we could send in round 2 // Each storage node has at most this many users uint32_t users_per_stg = CEILDIV(g_teems_config.user_count, g_teems_config.num_storage_nodes); // And so can receive at most this many messages uint32_t tot_msg_per_stg = users_per_stg * g_teems_config.m_priv_in; // Which will be at most this many from us uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg, g_teems_config.tot_weight) * g_teems_config.my_weight; // But we can't send more messages to each storage server than we // could receive in total if (max_msg_to_each_stg > max_round1_msgs) { max_msg_to_each_stg = max_round1_msgs; } // And the max total number of outgoing messages in round 2 is then uint32_t max_round2_msgs = max_msg_to_each_stg * g_teems_config.num_storage_nodes; // In case we have a weird configuration where users can send more // messages per epoch than they can receive, ensure the round 2 // buffer is large enough to hold the incoming messages as well if (max_round2_msgs < max_round1_msgs) { max_round2_msgs = max_round1_msgs; } /* printf("round1_msgs = %u, round2_msgs = %u\n", max_round1_msgs, max_round2_msgs); */ // Create the route state uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; try { if (my_roles & ROLE_INGESTION) { route_state.ingbuf.alloc(tot_msg_per_ing); } if (my_roles & ROLE_ROUTING) { route_state.round1.alloc(max_round2_msgs); } if (my_roles & ROLE_STORAGE) { route_state.round2.alloc(tot_msg_per_stg + g_teems_config.tot_weight); } } catch (std::bad_alloc&) { printf("Memory allocation failed in route_init\n"); return false; } route_state.step = ROUTE_NOT_STARTED; route_state.tot_msg_per_ing = tot_msg_per_ing; route_state.max_msg_to_each_stg = max_msg_to_each_stg; route_state.max_round2_msgs = max_round2_msgs; route_state.cbpointer = NULL; threadid_t nthreads = g_teems_config.nthreads; #ifdef PROFILE_ROUTING unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads); #endif sort_precompute_evalplan(tot_msg_per_ing, nthreads); sort_precompute_evalplan(max_round2_msgs, nthreads); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end precompute evalplans\n"); #endif return true; } // Precompute the WaksmanNetworks needed for the sorts. If you pass -1, // it will return the number of different sizes it needs. If you pass // [0,sizes-1], it will compute one WaksmanNetwork with that size index // and return the number of available WaksmanNetworks of that size. size_t ecall_precompute_sort(int sizeidx) { size_t ret = 0; switch(sizeidx) { case 0: #ifdef PROFILE_ROUTING {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.tot_msg_per_ing); #endif ret = sort_precompute(route_state.tot_msg_per_ing); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.tot_msg_per_ing);} #endif break; case 1: #ifdef PROFILE_ROUTING {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.max_round2_msgs); #endif ret = sort_precompute(route_state.max_round2_msgs); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_round2_msgs);} #endif break; default: ret = 2; break; } return ret; } static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf, NodeCommState &, uint32_t tot_enc_chunk_size) { uint16_t msg_size = g_teems_config.msg_size; // Chunks will be encrypted and have a MAC tag attached which will // not correspond to plaintext bytes, so we can trim them. // The minimum number of chunks needed to transmit this message uint32_t min_num_chunks = (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE; // The number of plaintext bytes this message could contain uint32_t plaintext_bytes = tot_enc_chunk_size - SGX_AESGCM_MAC_SIZE * min_num_chunks; assert ((plaintext_bytes % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size); pthread_mutex_lock(&msgbuf.mutex); uint32_t start = msgbuf.reserved; if (start + num_msgs > msgbuf.bufsize) { pthread_mutex_unlock(&msgbuf.mutex); printf("Max %u messages exceeded\n", msgbuf.bufsize); return NULL; } msgbuf.reserved += num_msgs; pthread_mutex_unlock(&msgbuf.mutex); return msgbuf.buf + start * msg_size; } static void round2_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t); // A round 1 message was received by a routing node from an ingestion // node; we put it into the round 2 buffer for processing in round 2 static void round1_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round1.mutex); route_state.round1.inserted += num_msgs; route_state.round1.nodes_received += 1; nodenum_t nodes_received = route_state.round1.nodes_received; bool completed_prev_round = route_state.round1.completed_prev_round; pthread_mutex_unlock(&route_state.round1.mutex); // What is the next message we expect from this node? if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round2, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round2_received; } // Otherwise, it's just the next round 1 message, so don't change // the handlers. if (nodes_received == g_teems_config.num_ingestion_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_1; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 1); } } // A round 2 message was received by a storage node from a routing node static void round2_received(NodeCommState &nodest, uint8_t *data, uint32_t plaintext_len, uint32_t) { uint16_t msg_size = g_teems_config.msg_size; assert((plaintext_len % uint32_t(msg_size)) == 0); uint32_t num_msgs = plaintext_len / uint32_t(msg_size); uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[nodest.node_num]; pthread_mutex_lock(&route_state.round2.mutex); route_state.round2.inserted += num_msgs; route_state.round2.nodes_received += 1; nodenum_t nodes_received = route_state.round2.nodes_received; bool completed_prev_round = route_state.round2.completed_prev_round; pthread_mutex_unlock(&route_state.round2.mutex); // What is the next message we expect from this node? if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1_received; } // Otherwise, it's just the next round 2 message, so don't change // the handlers. if (nodes_received == g_teems_config.num_routing_nodes && completed_prev_round) { route_state.step = ROUTE_ROUND_2; void *cbpointer = route_state.cbpointer; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 2); } } // For a given other node, set the received message handler to the first // message we would expect from them, given their roles and our roles. void route_init_msg_handler(nodenum_t node_num) { // Our roles and their roles uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num]; uint8_t their_roles = g_teems_config.roles[node_num]; // The node communication state NodeCommState &nodest = g_commstates[node_num]; // If we are a routing node (possibly among other roles) and they // are an ingestion node (possibly among other roles), a round 1 // routing message is the first thing we expect from them. We put // these messages into the round1 buffer for processing. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round1, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round1_received; } // Otherwise, if we are a storage node (possibly among other roles) // and they are a routing node (possibly among other roles), a round // 2 routing message is the first thing we expect from them else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) { nodest.in_msg_get_buf = [&](NodeCommState &commst, uint32_t tot_enc_chunk_size) { return msgbuffer_get_buf(route_state.round2, commst, tot_enc_chunk_size); }; nodest.in_msg_received = round2_received; } // Otherwise, we don't expect a message from this node. Set the // unknown message handler. else { nodest.in_msg_get_buf = default_in_msg_get_buf; nodest.in_msg_received = unknown_in_msg_received; } } // Directly ingest a buffer of num_msgs messages into the ingbuf buffer. // Return true on success, false on failure. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs) { uint16_t msg_size = g_teems_config.msg_size; MsgBuffer &ingbuf = route_state.ingbuf; pthread_mutex_lock(&ingbuf.mutex); uint32_t start = ingbuf.reserved; if (start + num_msgs > route_state.tot_msg_per_ing) { pthread_mutex_unlock(&ingbuf.mutex); printf("Max %u messages exceeded\n", route_state.tot_msg_per_ing); return false; } ingbuf.reserved += num_msgs; pthread_mutex_unlock(&ingbuf.mutex); memmove(ingbuf.buf + start * msg_size, msgs, num_msgs * msg_size); pthread_mutex_lock(&ingbuf.mutex); ingbuf.inserted += num_msgs; pthread_mutex_unlock(&ingbuf.mutex); return true; } // Send the round 1 messages. Note that N here is not private. static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices, uint32_t N) { uint16_t msg_size = g_teems_config.msg_size; uint16_t tot_weight = g_teems_config.tot_weight; nodenum_t my_node_num = g_teems_config.my_node_num; uint32_t full_rows = N / uint32_t(tot_weight); uint32_t last_row = N % uint32_t(tot_weight); for (auto &routing_node: g_teems_config.routing_nodes) { uint8_t weight = g_teems_config.weights[routing_node].weight; if (weight == 0) { // This shouldn't happen, but just in case continue; } uint16_t start_weight = g_teems_config.weights[routing_node].startweight; // The number of messages headed for this routing node from the // full rows uint32_t num_msgs_full_rows = full_rows * uint32_t(weight); // The number of messages headed for this routing node from the // incomplete last row is: // 0 if last_row < start_weight // last_row-start_weight if start_weight <= last_row < start_weight + weight // weight if start_weight + weight <= last_row uint32_t num_msgs_last_row = 0; if (start_weight <= last_row && last_row < start_weight + weight) { num_msgs_last_row = last_row-start_weight; } else if (start_weight + weight <= last_row) { num_msgs_last_row = weight; } // The total number of messages headed for this routing node uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row; if (routing_node == my_node_num) { // Special case: we're sending to ourselves; just put the // messages in our own round1 buffer MsgBuffer &round1 = route_state.round1; pthread_mutex_lock(&round1.mutex); uint32_t start = round1.reserved; if (start + num_msgs > round1.bufsize) { pthread_mutex_unlock(&round1.mutex); printf("Max %u messages exceeded\n", round1.bufsize); return; } round1.reserved += num_msgs; pthread_mutex_unlock(&round1.mutex); uint8_t *buf = round1.buf + start * msg_size; for (uint32_t i=0; i round2.bufsize) { pthread_mutex_unlock(&round2.mutex); printf("Max %u messages exceeded\n", round2.bufsize); return; } round2.reserved += num_msgs_per_stg; pthread_mutex_unlock(&round2.mutex); myself_buf = round2.buf + start * msg_size; } } while (tot_msgs) { nodenum_t storage_node_id = nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS); if (storage_node_id < num_storage_nodes) { nodenum_t node = g_teems_config.storage_map[storage_node_id]; if (node == my_node_num) { memmove(myself_buf, buf, msg_size); myself_buf += msg_size; } else { g_commstates[node].message_data(buf, msg_size); } } buf += msg_size; --tot_msgs; } if (myself_buf) { MsgBuffer &round2 = route_state.round2; pthread_mutex_lock(&round2.mutex); round2.inserted += num_msgs_per_stg; round2.nodes_received += 1; pthread_mutex_unlock(&round2.mutex); } } // Perform the next round of routing. The callback pointer will be // passed to ocall_routing_round_complete when the round is complete. void ecall_routing_proceed(void *cbpointer) { uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num]; if (route_state.step == ROUTE_NOT_STARTED) { if (my_roles & ROLE_INGESTION) { route_state.cbpointer = cbpointer; MsgBuffer &ingbuf = route_state.ingbuf; MsgBuffer &round1 = route_state.round1; pthread_mutex_lock(&ingbuf.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (ingbuf.reserved != ingbuf.inserted) { pthread_mutex_unlock(&ingbuf.mutex); pthread_mutex_lock(&ingbuf.mutex); } // Sort the messages we've received #ifdef PROFILE_ROUTING uint32_t inserted = ingbuf.inserted; unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted); unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing); #endif sort_mtobliv(g_teems_config.nthreads, ingbuf.buf, g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing, send_round1_msgs); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing); printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted); #endif ingbuf.reset(); pthread_mutex_unlock(&ingbuf.mutex); pthread_mutex_lock(&round1.mutex); round1.completed_prev_round = true; nodenum_t nodes_received = round1.nodes_received; pthread_mutex_unlock(&round1.mutex); if (nodes_received == g_teems_config.num_ingestion_nodes) { route_state.step = ROUTE_ROUND_1; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 1); } } else { route_state.step = ROUTE_ROUND_1; ocall_routing_round_complete(cbpointer, 1); } } else if (route_state.step == ROUTE_ROUND_1) { if (my_roles & ROLE_ROUTING) { route_state.cbpointer = cbpointer; MsgBuffer &round1 = route_state.round1; MsgBuffer &round2 = route_state.round2; pthread_mutex_lock(&round1.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (round1.reserved != round1.inserted) { pthread_mutex_unlock(&round1.mutex); pthread_mutex_lock(&round1.mutex); } // If the _total_ number of messages we received in round 1 // is less than the max number of messages we could send to // _each_ storage node, then cap the number of messages we // will send to each storage node to that number. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg; if (round1.inserted < msgs_per_stg) { msgs_per_stg = round1.inserted; } // Note: at this point, it is required that each message in // the round1 buffer have a _valid_ storage node id field. // Obliviously tally the number of messages we received in // round1 destined for each storage node #ifdef PROFILE_ROUTING uint32_t inserted = round1.inserted; unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize); unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted); #endif uint16_t msg_size = g_teems_config.msg_size; nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes; std::vector tally = obliv_tally_stg( round1.buf, msg_size, round1.inserted, num_storage_nodes); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted); #endif // Note: tally contains private values! It's OK to // non-obliviously check for an error condition, though. // While we're at it, obliviously change the tally of // messages received to a tally of padding messages // required. uint32_t tot_padding = 0; for (nodenum_t i=0; i msgs_per_stg) { printf("Received too many messages for storage node %u\n", i); assert(tally[i] <= msgs_per_stg); } tally[i] = msgs_per_stg - tally[i]; tot_padding += tally[i]; } round1.reserved += tot_padding; assert(round1.reserved <= round1.bufsize); // Obliviously add padding for each storage node according // to the (private) padding tally. #ifdef PROFILE_ROUTING unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding); #endif obliv_pad_stg(round1.buf + round1.inserted * msg_size, msg_size, tally, tot_padding); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding); #endif round1.inserted += tot_padding; // Obliviously shuffle the messages #ifdef PROFILE_ROUTING unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize); #endif uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads, round1.buf, msg_size, round1.inserted, round1.bufsize); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize); printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize); #endif // Now we can handle the messages non-obliviously, since we // know there will be exactly msgs_per_stg messages to each // storage node, and the oblivious shuffle broke the // connection between where each message came from and where // it's going. send_round2_msgs(num_shuffled, msgs_per_stg); round1.reset(); pthread_mutex_unlock(&round1.mutex); pthread_mutex_lock(&round2.mutex); round2.completed_prev_round = true; nodenum_t nodes_received = round2.nodes_received; pthread_mutex_unlock(&round2.mutex); if (nodes_received == g_teems_config.num_routing_nodes) { route_state.step = ROUTE_ROUND_2; route_state.cbpointer = NULL; ocall_routing_round_complete(cbpointer, 2); } } else { route_state.step = ROUTE_ROUND_2; ocall_routing_round_complete(cbpointer, 2); } } else if (route_state.step == ROUTE_ROUND_2) { if (my_roles & ROLE_STORAGE) { MsgBuffer &round2 = route_state.round2; pthread_mutex_lock(&round2.mutex); // Ensure there are no pending messages currently being inserted // into the buffer while (round2.reserved != round2.inserted) { pthread_mutex_unlock(&round2.mutex); pthread_mutex_lock(&round2.mutex); } #ifdef PROFILE_ROUTING unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted); #endif storage_received(round2.buf, round2.inserted); #ifdef PROFILE_ROUTING printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted); #endif round2.reset(); pthread_mutex_unlock(&round2.mutex); // We're done route_state.step = ROUTE_NOT_STARTED; ocall_routing_round_complete(cbpointer, 0); } else { // We're done route_state.step = ROUTE_NOT_STARTED; ocall_routing_round_complete(cbpointer, 0); } } }