route.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. #include "Enclave_t.h"
  2. #include "config.hpp"
  3. #include "utils.hpp"
  4. #include "sort.hpp"
  5. #include "comms.hpp"
  6. #include "obliv.hpp"
  7. #include "storage.hpp"
  8. #include "route.hpp"
  9. #define PROFILE_ROUTING
  10. enum RouteStep {
  11. ROUTE_NOT_STARTED,
  12. ROUTE_ROUND_1,
  13. ROUTE_ROUND_2
  14. };
  15. // The ingbuf MsgBuffer stores messages an ingestion node ingests while
  16. // waiting for round 1 to start, which will be sorted and sent out in
  17. // round 1. The round1 MsgBuffer stores messages a routing node
  18. // receives in round 1, which will be padded, sorted, and sent out in
  19. // round 2. The round2 MsgBuffer stores messages a storage node
  20. // receives in round 2.
  21. static struct RouteState {
  22. MsgBuffer ingbuf;
  23. MsgBuffer round1;
  24. MsgBuffer round2;
  25. RouteStep step;
  26. uint32_t tot_msg_per_ing;
  27. uint32_t max_msg_to_each_stg;
  28. uint32_t max_round2_msgs;
  29. void *cbpointer;
  30. } route_state;
  31. // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
  32. #define CEILDIV(x,y) (((x)+(y)-1)/(y))
  33. // Call this near the end of ecall_config_load, but before
  34. // comms_init_nodestate. Returns true on success, false on failure.
  35. bool route_init()
  36. {
  37. // Compute the maximum number of messages we could receive by direct
  38. // ingestion
  39. // Each ingestion node will have at most
  40. // ceil(user_count/num_ingestion_nodes) users, and each user will
  41. // send at most m_priv_out messages.
  42. uint32_t users_per_ing = CEILDIV(g_teems_config.user_count,
  43. g_teems_config.num_ingestion_nodes);
  44. uint32_t tot_msg_per_ing = users_per_ing * g_teems_config.m_priv_out;
  45. // Compute the maximum number of messages we could receive in round 1
  46. // Each ingestion node will send us an our_weight/tot_weight
  47. // fraction of the messages they hold
  48. uint32_t max_msg_from_each_ing = CEILDIV(tot_msg_per_ing,
  49. g_teems_config.tot_weight) * g_teems_config.my_weight;
  50. // And the maximum number we can receive in total is that times the
  51. // number of ingestion nodes
  52. uint32_t max_round1_msgs = max_msg_from_each_ing *
  53. g_teems_config.num_ingestion_nodes;
  54. // Compute the maximum number of messages we could send in round 2
  55. // Each storage node has at most this many users
  56. uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
  57. g_teems_config.num_storage_nodes);
  58. // And so can receive at most this many messages
  59. uint32_t tot_msg_per_stg = users_per_stg *
  60. g_teems_config.m_priv_in;
  61. // Which will be at most this many from us
  62. uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
  63. g_teems_config.tot_weight) * g_teems_config.my_weight;
  64. // But we can't send more messages to each storage server than we
  65. // could receive in total
  66. if (max_msg_to_each_stg > max_round1_msgs) {
  67. max_msg_to_each_stg = max_round1_msgs;
  68. }
  69. // And the max total number of outgoing messages in round 2 is then
  70. uint32_t max_round2_msgs = max_msg_to_each_stg *
  71. g_teems_config.num_storage_nodes;
  72. // In case we have a weird configuration where users can send more
  73. // messages per epoch than they can receive, ensure the round 2
  74. // buffer is large enough to hold the incoming messages as well
  75. if (max_round2_msgs < max_round1_msgs) {
  76. max_round2_msgs = max_round1_msgs;
  77. }
  78. /*
  79. printf("round1_msgs = %u, round2_msgs = %u\n",
  80. max_round1_msgs, max_round2_msgs);
  81. */
  82. // Create the route state
  83. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  84. try {
  85. if (my_roles & ROLE_INGESTION) {
  86. route_state.ingbuf.alloc(tot_msg_per_ing);
  87. }
  88. if (my_roles & ROLE_ROUTING) {
  89. route_state.round1.alloc(max_round2_msgs);
  90. }
  91. if (my_roles & ROLE_STORAGE) {
  92. route_state.round2.alloc(tot_msg_per_stg +
  93. g_teems_config.tot_weight);
  94. }
  95. } catch (std::bad_alloc&) {
  96. printf("Memory allocation failed in route_init\n");
  97. return false;
  98. }
  99. route_state.step = ROUTE_NOT_STARTED;
  100. route_state.tot_msg_per_ing = tot_msg_per_ing;
  101. route_state.max_msg_to_each_stg = max_msg_to_each_stg;
  102. route_state.max_round2_msgs = max_round2_msgs;
  103. route_state.cbpointer = NULL;
  104. threadid_t nthreads = g_teems_config.nthreads;
  105. #ifdef PROFILE_ROUTING
  106. unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
  107. #endif
  108. sort_precompute_evalplan(tot_msg_per_ing, nthreads);
  109. sort_precompute_evalplan(max_round2_msgs, nthreads);
  110. #ifdef PROFILE_ROUTING
  111. printf_with_rtclock_diff(start, "end precompute evalplans\n");
  112. #endif
  113. return true;
  114. }
  115. // Precompute the WaksmanNetworks needed for the sorts. If you pass -1,
  116. // it will return the number of different sizes it needs. If you pass
  117. // [0,sizes-1], it will compute one WaksmanNetwork with that size index
  118. // and return the number of available WaksmanNetworks of that size.
  119. size_t ecall_precompute_sort(int sizeidx)
  120. {
  121. size_t ret = 0;
  122. switch(sizeidx) {
  123. case 0:
  124. #ifdef PROFILE_ROUTING
  125. {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.tot_msg_per_ing);
  126. #endif
  127. ret = sort_precompute(route_state.tot_msg_per_ing);
  128. #ifdef PROFILE_ROUTING
  129. printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.tot_msg_per_ing);}
  130. #endif
  131. break;
  132. case 1:
  133. #ifdef PROFILE_ROUTING
  134. {unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", route_state.max_round2_msgs);
  135. #endif
  136. ret = sort_precompute(route_state.max_round2_msgs);
  137. #ifdef PROFILE_ROUTING
  138. printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", route_state.max_round2_msgs);}
  139. #endif
  140. break;
  141. default:
  142. ret = 2;
  143. break;
  144. }
  145. return ret;
  146. }
  147. static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
  148. NodeCommState &, uint32_t tot_enc_chunk_size)
  149. {
  150. uint16_t msg_size = g_teems_config.msg_size;
  151. // Chunks will be encrypted and have a MAC tag attached which will
  152. // not correspond to plaintext bytes, so we can trim them.
  153. // The minimum number of chunks needed to transmit this message
  154. uint32_t min_num_chunks =
  155. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  156. // The number of plaintext bytes this message could contain
  157. uint32_t plaintext_bytes = tot_enc_chunk_size -
  158. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  159. assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
  160. uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
  161. pthread_mutex_lock(&msgbuf.mutex);
  162. uint32_t start = msgbuf.reserved;
  163. if (start + num_msgs > msgbuf.bufsize) {
  164. pthread_mutex_unlock(&msgbuf.mutex);
  165. printf("Max %u messages exceeded\n", msgbuf.bufsize);
  166. return NULL;
  167. }
  168. msgbuf.reserved += num_msgs;
  169. pthread_mutex_unlock(&msgbuf.mutex);
  170. return msgbuf.buf + start * msg_size;
  171. }
  172. static void round2_received(NodeCommState &nodest,
  173. uint8_t *data, uint32_t plaintext_len, uint32_t);
  174. // A round 1 message was received by a routing node from an ingestion
  175. // node; we put it into the round 2 buffer for processing in round 2
  176. static void round1_received(NodeCommState &nodest,
  177. uint8_t *data, uint32_t plaintext_len, uint32_t)
  178. {
  179. uint16_t msg_size = g_teems_config.msg_size;
  180. assert((plaintext_len % uint32_t(msg_size)) == 0);
  181. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  182. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  183. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  184. pthread_mutex_lock(&route_state.round1.mutex);
  185. route_state.round1.inserted += num_msgs;
  186. route_state.round1.nodes_received += 1;
  187. nodenum_t nodes_received = route_state.round1.nodes_received;
  188. bool completed_prev_round = route_state.round1.completed_prev_round;
  189. pthread_mutex_unlock(&route_state.round1.mutex);
  190. // What is the next message we expect from this node?
  191. if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  192. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  193. uint32_t tot_enc_chunk_size) {
  194. return msgbuffer_get_buf(route_state.round2, commst,
  195. tot_enc_chunk_size);
  196. };
  197. nodest.in_msg_received = round2_received;
  198. }
  199. // Otherwise, it's just the next round 1 message, so don't change
  200. // the handlers.
  201. if (nodes_received == g_teems_config.num_ingestion_nodes &&
  202. completed_prev_round) {
  203. route_state.step = ROUTE_ROUND_1;
  204. void *cbpointer = route_state.cbpointer;
  205. route_state.cbpointer = NULL;
  206. ocall_routing_round_complete(cbpointer, 1);
  207. }
  208. }
  209. // A round 2 message was received by a storage node from a routing node
  210. static void round2_received(NodeCommState &nodest,
  211. uint8_t *data, uint32_t plaintext_len, uint32_t)
  212. {
  213. uint16_t msg_size = g_teems_config.msg_size;
  214. assert((plaintext_len % uint32_t(msg_size)) == 0);
  215. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  216. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  217. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  218. pthread_mutex_lock(&route_state.round2.mutex);
  219. route_state.round2.inserted += num_msgs;
  220. route_state.round2.nodes_received += 1;
  221. nodenum_t nodes_received = route_state.round2.nodes_received;
  222. bool completed_prev_round = route_state.round2.completed_prev_round;
  223. pthread_mutex_unlock(&route_state.round2.mutex);
  224. // What is the next message we expect from this node?
  225. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  226. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  227. uint32_t tot_enc_chunk_size) {
  228. return msgbuffer_get_buf(route_state.round1, commst,
  229. tot_enc_chunk_size);
  230. };
  231. nodest.in_msg_received = round1_received;
  232. }
  233. // Otherwise, it's just the next round 2 message, so don't change
  234. // the handlers.
  235. if (nodes_received == g_teems_config.num_routing_nodes &&
  236. completed_prev_round) {
  237. route_state.step = ROUTE_ROUND_2;
  238. void *cbpointer = route_state.cbpointer;
  239. route_state.cbpointer = NULL;
  240. ocall_routing_round_complete(cbpointer, 2);
  241. }
  242. }
  243. // For a given other node, set the received message handler to the first
  244. // message we would expect from them, given their roles and our roles.
  245. void route_init_msg_handler(nodenum_t node_num)
  246. {
  247. // Our roles and their roles
  248. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  249. uint8_t their_roles = g_teems_config.roles[node_num];
  250. // The node communication state
  251. NodeCommState &nodest = g_commstates[node_num];
  252. // If we are a routing node (possibly among other roles) and they
  253. // are an ingestion node (possibly among other roles), a round 1
  254. // routing message is the first thing we expect from them. We put
  255. // these messages into the round1 buffer for processing.
  256. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  257. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  258. uint32_t tot_enc_chunk_size) {
  259. return msgbuffer_get_buf(route_state.round1, commst,
  260. tot_enc_chunk_size);
  261. };
  262. nodest.in_msg_received = round1_received;
  263. }
  264. // Otherwise, if we are a storage node (possibly among other roles)
  265. // and they are a routing node (possibly among other roles), a round
  266. // 2 routing message is the first thing we expect from them
  267. else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  268. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  269. uint32_t tot_enc_chunk_size) {
  270. return msgbuffer_get_buf(route_state.round2, commst,
  271. tot_enc_chunk_size);
  272. };
  273. nodest.in_msg_received = round2_received;
  274. }
  275. // Otherwise, we don't expect a message from this node. Set the
  276. // unknown message handler.
  277. else {
  278. nodest.in_msg_get_buf = default_in_msg_get_buf;
  279. nodest.in_msg_received = unknown_in_msg_received;
  280. }
  281. }
  282. // Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
  283. // Return true on success, false on failure.
  284. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
  285. {
  286. uint16_t msg_size = g_teems_config.msg_size;
  287. MsgBuffer &ingbuf = route_state.ingbuf;
  288. pthread_mutex_lock(&ingbuf.mutex);
  289. uint32_t start = ingbuf.reserved;
  290. if (start + num_msgs > route_state.tot_msg_per_ing) {
  291. pthread_mutex_unlock(&ingbuf.mutex);
  292. printf("Max %u messages exceeded\n",
  293. route_state.tot_msg_per_ing);
  294. return false;
  295. }
  296. ingbuf.reserved += num_msgs;
  297. pthread_mutex_unlock(&ingbuf.mutex);
  298. memmove(ingbuf.buf + start * msg_size,
  299. msgs, num_msgs * msg_size);
  300. pthread_mutex_lock(&ingbuf.mutex);
  301. ingbuf.inserted += num_msgs;
  302. pthread_mutex_unlock(&ingbuf.mutex);
  303. return true;
  304. }
  305. // Send the round 1 messages. Note that N here is not private.
  306. static void send_round1_msgs(const uint8_t *msgs, const uint64_t *indices,
  307. uint32_t N)
  308. {
  309. uint16_t msg_size = g_teems_config.msg_size;
  310. uint16_t tot_weight = g_teems_config.tot_weight;
  311. nodenum_t my_node_num = g_teems_config.my_node_num;
  312. uint32_t full_rows = N / uint32_t(tot_weight);
  313. uint32_t last_row = N % uint32_t(tot_weight);
  314. for (auto &routing_node: g_teems_config.routing_nodes) {
  315. uint8_t weight =
  316. g_teems_config.weights[routing_node].weight;
  317. if (weight == 0) {
  318. // This shouldn't happen, but just in case
  319. continue;
  320. }
  321. uint16_t start_weight =
  322. g_teems_config.weights[routing_node].startweight;
  323. // The number of messages headed for this routing node from the
  324. // full rows
  325. uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
  326. // The number of messages headed for this routing node from the
  327. // incomplete last row is:
  328. // 0 if last_row < start_weight
  329. // last_row-start_weight if start_weight <= last_row < start_weight + weight
  330. // weight if start_weight + weight <= last_row
  331. uint32_t num_msgs_last_row = 0;
  332. if (start_weight <= last_row && last_row < start_weight + weight) {
  333. num_msgs_last_row = last_row-start_weight;
  334. } else if (start_weight + weight <= last_row) {
  335. num_msgs_last_row = weight;
  336. }
  337. // The total number of messages headed for this routing node
  338. uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
  339. if (routing_node == my_node_num) {
  340. // Special case: we're sending to ourselves; just put the
  341. // messages in our own round1 buffer
  342. MsgBuffer &round1 = route_state.round1;
  343. pthread_mutex_lock(&round1.mutex);
  344. uint32_t start = round1.reserved;
  345. if (start + num_msgs > round1.bufsize) {
  346. pthread_mutex_unlock(&round1.mutex);
  347. printf("Max %u messages exceeded\n", round1.bufsize);
  348. return;
  349. }
  350. round1.reserved += num_msgs;
  351. pthread_mutex_unlock(&round1.mutex);
  352. uint8_t *buf = round1.buf + start * msg_size;
  353. for (uint32_t i=0; i<full_rows; ++i) {
  354. const uint64_t *idxp = indices + i*tot_weight + start_weight;
  355. for (uint32_t j=0; j<weight; ++j) {
  356. memmove(buf, msgs + idxp[j]*msg_size, msg_size);
  357. buf += msg_size;
  358. }
  359. }
  360. const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
  361. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  362. memmove(buf, msgs + idxp[j]*msg_size, msg_size);
  363. buf += msg_size;
  364. }
  365. pthread_mutex_lock(&round1.mutex);
  366. round1.inserted += num_msgs;
  367. round1.nodes_received += 1;
  368. pthread_mutex_unlock(&round1.mutex);
  369. } else {
  370. NodeCommState &nodecom = g_commstates[routing_node];
  371. nodecom.message_start(num_msgs * msg_size);
  372. for (uint32_t i=0; i<full_rows; ++i) {
  373. const uint64_t *idxp = indices + i*tot_weight + start_weight;
  374. for (uint32_t j=0; j<weight; ++j) {
  375. nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
  376. }
  377. }
  378. const uint64_t *idxp = indices + full_rows*tot_weight + start_weight;
  379. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  380. nodecom.message_data(msgs + idxp[j]*msg_size, msg_size);
  381. }
  382. }
  383. }
  384. }
  385. // Send the round 2 messages from the round 1 buffer, which are already
  386. // padded and shuffled, so this can be done non-obliviously. tot_msgs
  387. // is the total number of messages in the input buffer, which may
  388. // include padding messages added by the shuffle. Those messages are
  389. // not sent anywhere. There are num_msgs_per_stg messages for each
  390. // storage node labelled for that node.
  391. static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
  392. {
  393. uint16_t msg_size = g_teems_config.msg_size;
  394. MsgBuffer &round1 = route_state.round1;
  395. const uint8_t* buf = round1.buf;
  396. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  397. nodenum_t my_node_num = g_teems_config.my_node_num;
  398. uint8_t *myself_buf = NULL;
  399. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  400. nodenum_t node = g_teems_config.storage_nodes[i];
  401. if (node != my_node_num) {
  402. g_commstates[node].message_start(msg_size * num_msgs_per_stg);
  403. } else {
  404. MsgBuffer &round2 = route_state.round2;
  405. pthread_mutex_lock(&round2.mutex);
  406. uint32_t start = round2.reserved;
  407. if (start + num_msgs_per_stg > round2.bufsize) {
  408. pthread_mutex_unlock(&round2.mutex);
  409. printf("Max %u messages exceeded\n", round2.bufsize);
  410. return;
  411. }
  412. round2.reserved += num_msgs_per_stg;
  413. pthread_mutex_unlock(&round2.mutex);
  414. myself_buf = round2.buf + start * msg_size;
  415. }
  416. }
  417. while (tot_msgs) {
  418. nodenum_t storage_node_id =
  419. nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
  420. if (storage_node_id < num_storage_nodes) {
  421. nodenum_t node = g_teems_config.storage_map[storage_node_id];
  422. if (node == my_node_num) {
  423. memmove(myself_buf, buf, msg_size);
  424. myself_buf += msg_size;
  425. } else {
  426. g_commstates[node].message_data(buf, msg_size);
  427. }
  428. }
  429. buf += msg_size;
  430. --tot_msgs;
  431. }
  432. if (myself_buf) {
  433. MsgBuffer &round2 = route_state.round2;
  434. pthread_mutex_lock(&round2.mutex);
  435. round2.inserted += num_msgs_per_stg;
  436. round2.nodes_received += 1;
  437. pthread_mutex_unlock(&round2.mutex);
  438. }
  439. }
  440. // Perform the next round of routing. The callback pointer will be
  441. // passed to ocall_routing_round_complete when the round is complete.
  442. void ecall_routing_proceed(void *cbpointer)
  443. {
  444. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  445. if (route_state.step == ROUTE_NOT_STARTED) {
  446. if (my_roles & ROLE_INGESTION) {
  447. route_state.cbpointer = cbpointer;
  448. MsgBuffer &ingbuf = route_state.ingbuf;
  449. MsgBuffer &round1 = route_state.round1;
  450. pthread_mutex_lock(&ingbuf.mutex);
  451. // Ensure there are no pending messages currently being inserted
  452. // into the buffer
  453. while (ingbuf.reserved != ingbuf.inserted) {
  454. pthread_mutex_unlock(&ingbuf.mutex);
  455. pthread_mutex_lock(&ingbuf.mutex);
  456. }
  457. // Sort the messages we've received
  458. #ifdef PROFILE_ROUTING
  459. uint32_t inserted = ingbuf.inserted;
  460. unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
  461. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  462. #endif
  463. sort_mtobliv(g_teems_config.nthreads, ingbuf.buf,
  464. g_teems_config.msg_size, ingbuf.inserted,
  465. route_state.tot_msg_per_ing, send_round1_msgs);
  466. #ifdef PROFILE_ROUTING
  467. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  468. printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
  469. #endif
  470. ingbuf.reset();
  471. pthread_mutex_unlock(&ingbuf.mutex);
  472. pthread_mutex_lock(&round1.mutex);
  473. round1.completed_prev_round = true;
  474. nodenum_t nodes_received = round1.nodes_received;
  475. pthread_mutex_unlock(&round1.mutex);
  476. if (nodes_received == g_teems_config.num_ingestion_nodes) {
  477. route_state.step = ROUTE_ROUND_1;
  478. route_state.cbpointer = NULL;
  479. ocall_routing_round_complete(cbpointer, 1);
  480. }
  481. } else {
  482. route_state.step = ROUTE_ROUND_1;
  483. ocall_routing_round_complete(cbpointer, 1);
  484. }
  485. } else if (route_state.step == ROUTE_ROUND_1) {
  486. if (my_roles & ROLE_ROUTING) {
  487. route_state.cbpointer = cbpointer;
  488. MsgBuffer &round1 = route_state.round1;
  489. MsgBuffer &round2 = route_state.round2;
  490. pthread_mutex_lock(&round1.mutex);
  491. // Ensure there are no pending messages currently being inserted
  492. // into the buffer
  493. while (round1.reserved != round1.inserted) {
  494. pthread_mutex_unlock(&round1.mutex);
  495. pthread_mutex_lock(&round1.mutex);
  496. }
  497. // If the _total_ number of messages we received in round 1
  498. // is less than the max number of messages we could send to
  499. // _each_ storage node, then cap the number of messages we
  500. // will send to each storage node to that number.
  501. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
  502. if (round1.inserted < msgs_per_stg) {
  503. msgs_per_stg = round1.inserted;
  504. }
  505. // Note: at this point, it is required that each message in
  506. // the round1 buffer have a _valid_ storage node id field.
  507. // Obliviously tally the number of messages we received in
  508. // round1 destined for each storage node
  509. #ifdef PROFILE_ROUTING
  510. uint32_t inserted = round1.inserted;
  511. unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
  512. unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
  513. #endif
  514. uint16_t msg_size = g_teems_config.msg_size;
  515. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  516. std::vector<uint32_t> tally = obliv_tally_stg(
  517. round1.buf, msg_size, round1.inserted, num_storage_nodes);
  518. #ifdef PROFILE_ROUTING
  519. printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
  520. #endif
  521. // Note: tally contains private values! It's OK to
  522. // non-obliviously check for an error condition, though.
  523. // While we're at it, obliviously change the tally of
  524. // messages received to a tally of padding messages
  525. // required.
  526. uint32_t tot_padding = 0;
  527. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  528. if (tally[i] > msgs_per_stg) {
  529. printf("Received too many messages for storage node %u\n", i);
  530. assert(tally[i] <= msgs_per_stg);
  531. }
  532. tally[i] = msgs_per_stg - tally[i];
  533. tot_padding += tally[i];
  534. }
  535. round1.reserved += tot_padding;
  536. assert(round1.reserved <= round1.bufsize);
  537. // Obliviously add padding for each storage node according
  538. // to the (private) padding tally.
  539. #ifdef PROFILE_ROUTING
  540. unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
  541. #endif
  542. obliv_pad_stg(round1.buf + round1.inserted * msg_size,
  543. msg_size, tally, tot_padding);
  544. #ifdef PROFILE_ROUTING
  545. printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
  546. #endif
  547. round1.inserted += tot_padding;
  548. // Obliviously shuffle the messages
  549. #ifdef PROFILE_ROUTING
  550. unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
  551. #endif
  552. uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
  553. round1.buf, msg_size, round1.inserted, round1.bufsize);
  554. #ifdef PROFILE_ROUTING
  555. printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
  556. printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
  557. #endif
  558. // Now we can handle the messages non-obliviously, since we
  559. // know there will be exactly msgs_per_stg messages to each
  560. // storage node, and the oblivious shuffle broke the
  561. // connection between where each message came from and where
  562. // it's going.
  563. send_round2_msgs(num_shuffled, msgs_per_stg);
  564. round1.reset();
  565. pthread_mutex_unlock(&round1.mutex);
  566. pthread_mutex_lock(&round2.mutex);
  567. round2.completed_prev_round = true;
  568. nodenum_t nodes_received = round2.nodes_received;
  569. pthread_mutex_unlock(&round2.mutex);
  570. if (nodes_received == g_teems_config.num_routing_nodes) {
  571. route_state.step = ROUTE_ROUND_2;
  572. route_state.cbpointer = NULL;
  573. ocall_routing_round_complete(cbpointer, 2);
  574. }
  575. } else {
  576. route_state.step = ROUTE_ROUND_2;
  577. ocall_routing_round_complete(cbpointer, 2);
  578. }
  579. } else if (route_state.step == ROUTE_ROUND_2) {
  580. if (my_roles & ROLE_STORAGE) {
  581. MsgBuffer &round2 = route_state.round2;
  582. pthread_mutex_lock(&round2.mutex);
  583. // Ensure there are no pending messages currently being inserted
  584. // into the buffer
  585. while (round2.reserved != round2.inserted) {
  586. pthread_mutex_unlock(&round2.mutex);
  587. pthread_mutex_lock(&round2.mutex);
  588. }
  589. #ifdef PROFILE_ROUTING
  590. unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
  591. #endif
  592. storage_received(round2.buf, round2.inserted);
  593. #ifdef PROFILE_ROUTING
  594. printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
  595. #endif
  596. round2.reset();
  597. pthread_mutex_unlock(&round2.mutex);
  598. // We're done
  599. route_state.step = ROUTE_NOT_STARTED;
  600. ocall_routing_round_complete(cbpointer, 0);
  601. } else {
  602. // We're done
  603. route_state.step = ROUTE_NOT_STARTED;
  604. ocall_routing_round_complete(cbpointer, 0);
  605. }
  606. }
  607. }