|
@@ -4,6 +4,7 @@
|
|
|
#include "utils.hpp"
|
|
|
#include "sort.hpp"
|
|
|
#include "comms.hpp"
|
|
|
+#include "obliv.hpp"
|
|
|
#include "route.hpp"
|
|
|
|
|
|
#define PROFILE_ROUTING
|
|
@@ -75,7 +76,7 @@ static struct RouteState {
|
|
|
MsgBuffer round2;
|
|
|
RouteStep step;
|
|
|
uint32_t tot_msg_per_ing;
|
|
|
- uint32_t max_msg_to_each_str;
|
|
|
+ uint32_t max_msg_to_each_stg;
|
|
|
uint32_t max_round2_msgs;
|
|
|
void *cbpointer;
|
|
|
} route_state;
|
|
@@ -112,25 +113,25 @@ bool route_init()
|
|
|
// Compute the maximum number of messages we could send in round 2
|
|
|
|
|
|
// Each storage node has at most this many users
|
|
|
- uint32_t users_per_str = CEILDIV(g_teems_config.user_count,
|
|
|
+ uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
|
|
|
g_teems_config.num_storage_nodes);
|
|
|
|
|
|
// And so can receive at most this many messages
|
|
|
- uint32_t tot_msg_per_str = users_per_str *
|
|
|
+ uint32_t tot_msg_per_stg = users_per_stg *
|
|
|
g_teems_config.m_priv_in;
|
|
|
|
|
|
// Which will be at most this many from us
|
|
|
- uint32_t max_msg_to_each_str = CEILDIV(tot_msg_per_str,
|
|
|
+ uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
|
|
|
g_teems_config.tot_weight) * g_teems_config.my_weight;
|
|
|
|
|
|
// But we can't send more messages to each storage server than we
|
|
|
// could receive in total
|
|
|
- if (max_msg_to_each_str > max_round1_msgs) {
|
|
|
- max_msg_to_each_str = max_round1_msgs;
|
|
|
+ if (max_msg_to_each_stg > max_round1_msgs) {
|
|
|
+ max_msg_to_each_stg = max_round1_msgs;
|
|
|
}
|
|
|
|
|
|
// And the max total number of outgoing messages in round 2 is then
|
|
|
- uint32_t max_round2_msgs = max_msg_to_each_str *
|
|
|
+ uint32_t max_round2_msgs = max_msg_to_each_stg *
|
|
|
g_teems_config.num_storage_nodes;
|
|
|
|
|
|
// In case we have a weird configuration where users can send more
|
|
@@ -155,7 +156,7 @@ bool route_init()
|
|
|
route_state.round1.alloc(max_round2_msgs);
|
|
|
}
|
|
|
if (my_roles & ROLE_STORAGE) {
|
|
|
- route_state.round2.alloc(tot_msg_per_str +
|
|
|
+ route_state.round2.alloc(tot_msg_per_stg +
|
|
|
g_teems_config.tot_weight);
|
|
|
}
|
|
|
} catch (std::bad_alloc&) {
|
|
@@ -164,7 +165,7 @@ bool route_init()
|
|
|
}
|
|
|
route_state.step = ROUTE_NOT_STARTED;
|
|
|
route_state.tot_msg_per_ing = tot_msg_per_ing;
|
|
|
- route_state.max_msg_to_each_str = max_msg_to_each_str;
|
|
|
+ route_state.max_msg_to_each_stg = max_msg_to_each_stg;
|
|
|
route_state.max_round2_msgs = max_round2_msgs;
|
|
|
route_state.cbpointer = NULL;
|
|
|
|
|
@@ -489,8 +490,61 @@ void ecall_routing_proceed(void *cbpointer)
|
|
|
pthread_mutex_lock(&round1.mutex);
|
|
|
}
|
|
|
|
|
|
+ // If the _total_ number of messages we received in round 1
|
|
|
+ // is less than the max number of messages we could send to
|
|
|
+ // _each_ storage node, then cap the number of messages we
|
|
|
+ // will send to each storage node to that number.
|
|
|
+ uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
|
|
|
+ if (round1.inserted < msgs_per_stg) {
|
|
|
+ msgs_per_stg = round1.inserted;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Note: at this point, it is required that each message in
|
|
|
+ // the round1 buffer have a _valid_ storage node id field.
|
|
|
+
|
|
|
+ // Obliviously tally the number of messages we received in
|
|
|
+ // round1 destined for each storage node
|
|
|
uint32_t msg_size = g_teems_config.msg_size;
|
|
|
- for(uint32_t i=0;i<round1.inserted;++i) {
|
|
|
+ nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
|
|
|
+ std::vector<uint32_t> tally = obliv_tally_stg(
|
|
|
+ round1.buf, msg_size, round1.inserted, num_storage_nodes);
|
|
|
+
|
|
|
+ // Note: tally contains private values! It's OK to
|
|
|
+ // non-obliviously check for an error condition, though.
|
|
|
+ // While we're at it, obliviously change the tally of
|
|
|
+ // messages received to a tally of padding messages
|
|
|
+ // required.
|
|
|
+ uint32_t tot_padding = 0;
|
|
|
+ for (nodenum_t i=0; i<num_storage_nodes; ++i) {
|
|
|
+ if (tally[i] > msgs_per_stg) {
|
|
|
+ printf("Received too many messages for storage node %u\n", i);
|
|
|
+ assert(tally[i] <= msgs_per_stg);
|
|
|
+ }
|
|
|
+ tally[i] = msgs_per_stg - tally[i];
|
|
|
+ tot_padding += tally[i];
|
|
|
+ }
|
|
|
+
|
|
|
+ round1.reserved += tot_padding;
|
|
|
+ assert(round1.reserved <= round1.bufsize);
|
|
|
+
|
|
|
+ // Obliviously add padding for each storage node according
|
|
|
+ // to the (private) padding tally.
|
|
|
+ obliv_pad_stg(round1.buf + round1.inserted * msg_size,
|
|
|
+ msg_size, tally, tot_padding);
|
|
|
+
|
|
|
+ round1.inserted += tot_padding;
|
|
|
+
|
|
|
+ // Obliviously shuffle the messages
|
|
|
+ uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
|
|
|
+ round1.buf, msg_size, round1.inserted, round1.bufsize);
|
|
|
+
|
|
|
+ // Now we can handle the messages non-obliviously, since we
|
|
|
+ // know there will be exactly msgs_per_stg messages to each
|
|
|
+ // storage node, and the oblivious shuffle broke the
|
|
|
+ // connection between where each message came from and where
|
|
|
+ // it's going.
|
|
|
+
|
|
|
+ for(uint32_t i=0;i<num_shuffled;++i) {
|
|
|
uint32_t destaddr = *(uint32_t*)(round1.buf+i*msg_size);
|
|
|
printf("%08x\n", destaddr);
|
|
|
}
|