route.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. #include "Enclave_t.h"
  2. #include "config.hpp"
  3. #include "utils.hpp"
  4. #include "sort.hpp"
  5. #include "comms.hpp"
  6. #include "obliv.hpp"
  7. #include "storage.hpp"
  8. #include "route.hpp"
  9. #define PROFILE_ROUTING
  10. enum RouteStep {
  11. ROUTE_NOT_STARTED,
  12. ROUTE_ROUND_1,
  13. ROUTE_ROUND_2
  14. };
  15. // The ingbuf MsgBuffer stores messages an ingestion node ingests while
  16. // waiting for round 1 to start, which will be sorted and sent out in
  17. // round 1. The round1 MsgBuffer stores messages a routing node
  18. // receives in round 1, which will be padded, sorted, and sent out in
  19. // round 2. The round2 MsgBuffer stores messages a storage node
  20. // receives in round 2.
  21. static struct RouteState {
  22. MsgBuffer ingbuf;
  23. MsgBuffer round1;
  24. MsgBuffer round2;
  25. RouteStep step;
  26. uint32_t tot_msg_per_ing;
  27. uint32_t max_msg_to_each_stg;
  28. uint32_t max_round2_msgs;
  29. uint32_t max_stg_msgs;
  30. void *cbpointer;
  31. } route_state;
  32. // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
  33. #define CEILDIV(x,y) (((x)+(y)-1)/(y))
  34. // Call this near the end of ecall_config_load, but before
  35. // comms_init_nodestate. Returns true on success, false on failure.
  36. bool route_init()
  37. {
  38. // Compute the maximum number of messages we could receive by direct
  39. // ingestion
  40. // Each ingestion node will have at most
  41. // ceil(user_count/num_ingestion_nodes) users, and each user will
  42. // send at most m_priv_out messages.
  43. uint32_t users_per_ing = CEILDIV(g_teems_config.user_count,
  44. g_teems_config.num_ingestion_nodes);
  45. uint32_t tot_msg_per_ing = users_per_ing * g_teems_config.m_priv_out;
  46. // Compute the maximum number of messages we could receive in round 1
  47. // Each ingestion node will send us an our_weight/tot_weight
  48. // fraction of the messages they hold
  49. uint32_t max_msg_from_each_ing = CEILDIV(tot_msg_per_ing,
  50. g_teems_config.tot_weight) * g_teems_config.my_weight;
  51. // And the maximum number we can receive in total is that times the
  52. // number of ingestion nodes
  53. uint32_t max_round1_msgs = max_msg_from_each_ing *
  54. g_teems_config.num_ingestion_nodes;
  55. // Compute the maximum number of messages we could send in round 2
  56. // Each storage node has at most this many users
  57. uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
  58. g_teems_config.num_storage_nodes);
  59. // And so can receive at most this many messages
  60. uint32_t tot_msg_per_stg = users_per_stg *
  61. g_teems_config.m_priv_in;
  62. // Which will be at most this many from us
  63. uint32_t max_msg_to_each_stg = CEILDIV(tot_msg_per_stg,
  64. g_teems_config.tot_weight) * g_teems_config.my_weight;
  65. // But we can't send more messages to each storage server than we
  66. // could receive in total
  67. if (max_msg_to_each_stg > max_round1_msgs) {
  68. max_msg_to_each_stg = max_round1_msgs;
  69. }
  70. // And the max total number of outgoing messages in round 2 is then
  71. uint32_t max_round2_msgs = max_msg_to_each_stg *
  72. g_teems_config.num_storage_nodes;
  73. // In case we have a weird configuration where users can send more
  74. // messages per epoch than they can receive, ensure the round 2
  75. // buffer is large enough to hold the incoming messages as well
  76. if (max_round2_msgs < max_round1_msgs) {
  77. max_round2_msgs = max_round1_msgs;
  78. }
  79. // The max number of messages that can arrive at a storage server
  80. uint32_t max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
  81. /*
  82. printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
  83. */
  84. // Create the route state
  85. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  86. try {
  87. if (my_roles & ROLE_INGESTION) {
  88. route_state.ingbuf.alloc(tot_msg_per_ing);
  89. }
  90. if (my_roles & ROLE_ROUTING) {
  91. route_state.round1.alloc(max_round2_msgs);
  92. }
  93. if (my_roles & ROLE_STORAGE) {
  94. route_state.round2.alloc(max_stg_msgs);
  95. if (!storage_init(users_per_stg, max_stg_msgs)) {
  96. return false;
  97. }
  98. }
  99. } catch (std::bad_alloc&) {
  100. printf("Memory allocation failed in route_init\n");
  101. return false;
  102. }
  103. route_state.step = ROUTE_NOT_STARTED;
  104. route_state.tot_msg_per_ing = tot_msg_per_ing;
  105. route_state.max_msg_to_each_stg = max_msg_to_each_stg;
  106. route_state.max_round2_msgs = max_round2_msgs;
  107. route_state.max_stg_msgs = max_stg_msgs;
  108. route_state.cbpointer = NULL;
  109. threadid_t nthreads = g_teems_config.nthreads;
  110. #ifdef PROFILE_ROUTING
  111. unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
  112. #endif
  113. if (my_roles & ROLE_INGESTION) {
  114. sort_precompute_evalplan(tot_msg_per_ing, nthreads);
  115. }
  116. if (my_roles & ROLE_ROUTING) {
  117. sort_precompute_evalplan(max_round2_msgs, nthreads);
  118. }
  119. if (my_roles & ROLE_STORAGE) {
  120. sort_precompute_evalplan(max_stg_msgs, nthreads);
  121. }
  122. #ifdef PROFILE_ROUTING
  123. printf_with_rtclock_diff(start, "end precompute evalplans\n");
  124. #endif
  125. return true;
  126. }
  127. // Precompute the WaksmanNetworks needed for the sorts. If you pass -1,
  128. // it will return the number of different sizes it needs to regenerate.
  129. // If you pass [0,sizes-1], it will compute one WaksmanNetwork with that
  130. // size index and return the number of available WaksmanNetworks of that
  131. // size. If you pass anything else, it will return the number of
  132. // different sizes it needs at all.
  133. // The list of sizes that need refilling, updated when you pass -1
  134. static std::vector<uint32_t> used_sizes;
  135. size_t ecall_precompute_sort(int sizeidx)
  136. {
  137. size_t ret = 0;
  138. if (sizeidx == -1) {
  139. used_sizes = sort_get_used();
  140. ret = used_sizes.size();
  141. } else if (sizeidx >= 0 && sizeidx < used_sizes.size()) {
  142. uint32_t size = used_sizes[sizeidx];
  143. #ifdef PROFILE_ROUTING
  144. unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", size);
  145. #endif
  146. ret = sort_precompute(size);
  147. #ifdef PROFILE_ROUTING
  148. printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", size);
  149. #endif
  150. } else {
  151. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  152. if (my_roles & ROLE_INGESTION) {
  153. used_sizes.push_back(route_state.tot_msg_per_ing);
  154. }
  155. if (my_roles & ROLE_ROUTING) {
  156. used_sizes.push_back(route_state.max_round2_msgs);
  157. }
  158. if (my_roles & ROLE_STORAGE) {
  159. used_sizes.push_back(route_state.max_stg_msgs);
  160. }
  161. ret = used_sizes.size();
  162. }
  163. return ret;
  164. }
  165. static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
  166. NodeCommState &, uint32_t tot_enc_chunk_size)
  167. {
  168. uint16_t msg_size = g_teems_config.msg_size;
  169. // Chunks will be encrypted and have a MAC tag attached which will
  170. // not correspond to plaintext bytes, so we can trim them.
  171. // The minimum number of chunks needed to transmit this message
  172. uint32_t min_num_chunks =
  173. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  174. // The number of plaintext bytes this message could contain
  175. uint32_t plaintext_bytes = tot_enc_chunk_size -
  176. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  177. assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
  178. uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
  179. pthread_mutex_lock(&msgbuf.mutex);
  180. uint32_t start = msgbuf.reserved;
  181. if (start + num_msgs > msgbuf.bufsize) {
  182. pthread_mutex_unlock(&msgbuf.mutex);
  183. printf("Max %u messages exceeded\n", msgbuf.bufsize);
  184. return NULL;
  185. }
  186. msgbuf.reserved += num_msgs;
  187. pthread_mutex_unlock(&msgbuf.mutex);
  188. return msgbuf.buf + start * msg_size;
  189. }
  190. static void round2_received(NodeCommState &nodest,
  191. uint8_t *data, uint32_t plaintext_len, uint32_t);
  192. // A round 1 message was received by a routing node from an ingestion
  193. // node; we put it into the round 2 buffer for processing in round 2
  194. static void round1_received(NodeCommState &nodest,
  195. uint8_t *data, uint32_t plaintext_len, uint32_t)
  196. {
  197. uint16_t msg_size = g_teems_config.msg_size;
  198. assert((plaintext_len % uint32_t(msg_size)) == 0);
  199. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  200. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  201. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  202. pthread_mutex_lock(&route_state.round1.mutex);
  203. route_state.round1.inserted += num_msgs;
  204. route_state.round1.nodes_received += 1;
  205. nodenum_t nodes_received = route_state.round1.nodes_received;
  206. bool completed_prev_round = route_state.round1.completed_prev_round;
  207. pthread_mutex_unlock(&route_state.round1.mutex);
  208. // What is the next message we expect from this node?
  209. if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  210. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  211. uint32_t tot_enc_chunk_size) {
  212. return msgbuffer_get_buf(route_state.round2, commst,
  213. tot_enc_chunk_size);
  214. };
  215. nodest.in_msg_received = round2_received;
  216. }
  217. // Otherwise, it's just the next round 1 message, so don't change
  218. // the handlers.
  219. if (nodes_received == g_teems_config.num_ingestion_nodes &&
  220. completed_prev_round) {
  221. route_state.step = ROUTE_ROUND_1;
  222. void *cbpointer = route_state.cbpointer;
  223. route_state.cbpointer = NULL;
  224. ocall_routing_round_complete(cbpointer, 1);
  225. }
  226. }
  227. // A round 2 message was received by a storage node from a routing node
  228. static void round2_received(NodeCommState &nodest,
  229. uint8_t *data, uint32_t plaintext_len, uint32_t)
  230. {
  231. uint16_t msg_size = g_teems_config.msg_size;
  232. assert((plaintext_len % uint32_t(msg_size)) == 0);
  233. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  234. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  235. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  236. pthread_mutex_lock(&route_state.round2.mutex);
  237. route_state.round2.inserted += num_msgs;
  238. route_state.round2.nodes_received += 1;
  239. nodenum_t nodes_received = route_state.round2.nodes_received;
  240. bool completed_prev_round = route_state.round2.completed_prev_round;
  241. pthread_mutex_unlock(&route_state.round2.mutex);
  242. // What is the next message we expect from this node?
  243. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  244. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  245. uint32_t tot_enc_chunk_size) {
  246. return msgbuffer_get_buf(route_state.round1, commst,
  247. tot_enc_chunk_size);
  248. };
  249. nodest.in_msg_received = round1_received;
  250. }
  251. // Otherwise, it's just the next round 2 message, so don't change
  252. // the handlers.
  253. if (nodes_received == g_teems_config.num_routing_nodes &&
  254. completed_prev_round) {
  255. route_state.step = ROUTE_ROUND_2;
  256. void *cbpointer = route_state.cbpointer;
  257. route_state.cbpointer = NULL;
  258. ocall_routing_round_complete(cbpointer, 2);
  259. }
  260. }
  261. // For a given other node, set the received message handler to the first
  262. // message we would expect from them, given their roles and our roles.
  263. void route_init_msg_handler(nodenum_t node_num)
  264. {
  265. // Our roles and their roles
  266. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  267. uint8_t their_roles = g_teems_config.roles[node_num];
  268. // The node communication state
  269. NodeCommState &nodest = g_commstates[node_num];
  270. // If we are a routing node (possibly among other roles) and they
  271. // are an ingestion node (possibly among other roles), a round 1
  272. // routing message is the first thing we expect from them. We put
  273. // these messages into the round1 buffer for processing.
  274. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  275. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  276. uint32_t tot_enc_chunk_size) {
  277. return msgbuffer_get_buf(route_state.round1, commst,
  278. tot_enc_chunk_size);
  279. };
  280. nodest.in_msg_received = round1_received;
  281. }
  282. // Otherwise, if we are a storage node (possibly among other roles)
  283. // and they are a routing node (possibly among other roles), a round
  284. // 2 routing message is the first thing we expect from them
  285. else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  286. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  287. uint32_t tot_enc_chunk_size) {
  288. return msgbuffer_get_buf(route_state.round2, commst,
  289. tot_enc_chunk_size);
  290. };
  291. nodest.in_msg_received = round2_received;
  292. }
  293. // Otherwise, we don't expect a message from this node. Set the
  294. // unknown message handler.
  295. else {
  296. nodest.in_msg_get_buf = default_in_msg_get_buf;
  297. nodest.in_msg_received = unknown_in_msg_received;
  298. }
  299. }
  300. // Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
  301. // Return true on success, false on failure.
  302. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
  303. {
  304. uint16_t msg_size = g_teems_config.msg_size;
  305. MsgBuffer &ingbuf = route_state.ingbuf;
  306. pthread_mutex_lock(&ingbuf.mutex);
  307. uint32_t start = ingbuf.reserved;
  308. if (start + num_msgs > route_state.tot_msg_per_ing) {
  309. pthread_mutex_unlock(&ingbuf.mutex);
  310. printf("Max %u messages exceeded\n",
  311. route_state.tot_msg_per_ing);
  312. return false;
  313. }
  314. ingbuf.reserved += num_msgs;
  315. pthread_mutex_unlock(&ingbuf.mutex);
  316. memmove(ingbuf.buf + start * msg_size,
  317. msgs, num_msgs * msg_size);
  318. pthread_mutex_lock(&ingbuf.mutex);
  319. ingbuf.inserted += num_msgs;
  320. pthread_mutex_unlock(&ingbuf.mutex);
  321. return true;
  322. }
  323. // Send the round 1 messages. Note that N here is not private.
  324. static void send_round1_msgs(const uint8_t *msgs, const UidKey *indices,
  325. uint32_t N)
  326. {
  327. uint16_t msg_size = g_teems_config.msg_size;
  328. uint16_t tot_weight = g_teems_config.tot_weight;
  329. nodenum_t my_node_num = g_teems_config.my_node_num;
  330. uint32_t full_rows = N / uint32_t(tot_weight);
  331. uint32_t last_row = N % uint32_t(tot_weight);
  332. for (auto &routing_node: g_teems_config.routing_nodes) {
  333. uint8_t weight =
  334. g_teems_config.weights[routing_node].weight;
  335. if (weight == 0) {
  336. // This shouldn't happen, but just in case
  337. continue;
  338. }
  339. uint16_t start_weight =
  340. g_teems_config.weights[routing_node].startweight;
  341. // The number of messages headed for this routing node from the
  342. // full rows
  343. uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
  344. // The number of messages headed for this routing node from the
  345. // incomplete last row is:
  346. // 0 if last_row < start_weight
  347. // last_row-start_weight if start_weight <= last_row < start_weight + weight
  348. // weight if start_weight + weight <= last_row
  349. uint32_t num_msgs_last_row = 0;
  350. if (start_weight <= last_row && last_row < start_weight + weight) {
  351. num_msgs_last_row = last_row-start_weight;
  352. } else if (start_weight + weight <= last_row) {
  353. num_msgs_last_row = weight;
  354. }
  355. // The total number of messages headed for this routing node
  356. uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
  357. if (routing_node == my_node_num) {
  358. // Special case: we're sending to ourselves; just put the
  359. // messages in our own round1 buffer
  360. MsgBuffer &round1 = route_state.round1;
  361. pthread_mutex_lock(&round1.mutex);
  362. uint32_t start = round1.reserved;
  363. if (start + num_msgs > round1.bufsize) {
  364. pthread_mutex_unlock(&round1.mutex);
  365. printf("Max %u messages exceeded\n", round1.bufsize);
  366. return;
  367. }
  368. round1.reserved += num_msgs;
  369. pthread_mutex_unlock(&round1.mutex);
  370. uint8_t *buf = round1.buf + start * msg_size;
  371. for (uint32_t i=0; i<full_rows; ++i) {
  372. const UidKey *idxp = indices + i*tot_weight + start_weight;
  373. for (uint32_t j=0; j<weight; ++j) {
  374. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  375. buf += msg_size;
  376. }
  377. }
  378. const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
  379. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  380. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  381. buf += msg_size;
  382. }
  383. pthread_mutex_lock(&round1.mutex);
  384. round1.inserted += num_msgs;
  385. round1.nodes_received += 1;
  386. pthread_mutex_unlock(&round1.mutex);
  387. } else {
  388. NodeCommState &nodecom = g_commstates[routing_node];
  389. nodecom.message_start(num_msgs * msg_size);
  390. for (uint32_t i=0; i<full_rows; ++i) {
  391. const UidKey *idxp = indices + i*tot_weight + start_weight;
  392. for (uint32_t j=0; j<weight; ++j) {
  393. nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
  394. }
  395. }
  396. const UidKey *idxp = indices + full_rows*tot_weight + start_weight;
  397. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  398. nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
  399. }
  400. }
  401. }
  402. }
  403. // Send the round 2 messages from the round 1 buffer, which are already
  404. // padded and shuffled, so this can be done non-obliviously. tot_msgs
  405. // is the total number of messages in the input buffer, which may
  406. // include padding messages added by the shuffle. Those messages are
  407. // not sent anywhere. There are num_msgs_per_stg messages for each
  408. // storage node labelled for that node.
  409. static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg)
  410. {
  411. uint16_t msg_size = g_teems_config.msg_size;
  412. MsgBuffer &round1 = route_state.round1;
  413. const uint8_t* buf = round1.buf;
  414. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  415. nodenum_t my_node_num = g_teems_config.my_node_num;
  416. uint8_t *myself_buf = NULL;
  417. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  418. nodenum_t node = g_teems_config.storage_nodes[i];
  419. if (node != my_node_num) {
  420. g_commstates[node].message_start(msg_size * num_msgs_per_stg);
  421. } else {
  422. MsgBuffer &round2 = route_state.round2;
  423. pthread_mutex_lock(&round2.mutex);
  424. uint32_t start = round2.reserved;
  425. if (start + num_msgs_per_stg > round2.bufsize) {
  426. pthread_mutex_unlock(&round2.mutex);
  427. printf("Max %u messages exceeded\n", round2.bufsize);
  428. return;
  429. }
  430. round2.reserved += num_msgs_per_stg;
  431. pthread_mutex_unlock(&round2.mutex);
  432. myself_buf = round2.buf + start * msg_size;
  433. }
  434. }
  435. while (tot_msgs) {
  436. nodenum_t storage_node_id =
  437. nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
  438. if (storage_node_id < num_storage_nodes) {
  439. nodenum_t node = g_teems_config.storage_map[storage_node_id];
  440. if (node == my_node_num) {
  441. memmove(myself_buf, buf, msg_size);
  442. myself_buf += msg_size;
  443. } else {
  444. g_commstates[node].message_data(buf, msg_size);
  445. }
  446. }
  447. buf += msg_size;
  448. --tot_msgs;
  449. }
  450. if (myself_buf) {
  451. MsgBuffer &round2 = route_state.round2;
  452. pthread_mutex_lock(&round2.mutex);
  453. round2.inserted += num_msgs_per_stg;
  454. round2.nodes_received += 1;
  455. pthread_mutex_unlock(&round2.mutex);
  456. }
  457. }
  458. // Perform the next round of routing. The callback pointer will be
  459. // passed to ocall_routing_round_complete when the round is complete.
  460. void ecall_routing_proceed(void *cbpointer)
  461. {
  462. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  463. if (route_state.step == ROUTE_NOT_STARTED) {
  464. if (my_roles & ROLE_INGESTION) {
  465. route_state.cbpointer = cbpointer;
  466. MsgBuffer &ingbuf = route_state.ingbuf;
  467. pthread_mutex_lock(&ingbuf.mutex);
  468. // Ensure there are no pending messages currently being inserted
  469. // into the buffer
  470. while (ingbuf.reserved != ingbuf.inserted) {
  471. pthread_mutex_unlock(&ingbuf.mutex);
  472. pthread_mutex_lock(&ingbuf.mutex);
  473. }
  474. // Sort the messages we've received
  475. #ifdef PROFILE_ROUTING
  476. uint32_t inserted = ingbuf.inserted;
  477. unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
  478. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  479. #endif
  480. sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
  481. g_teems_config.msg_size, ingbuf.inserted,
  482. route_state.tot_msg_per_ing, send_round1_msgs);
  483. #ifdef PROFILE_ROUTING
  484. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  485. printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
  486. #endif
  487. ingbuf.reset();
  488. pthread_mutex_unlock(&ingbuf.mutex);
  489. }
  490. if (my_roles & ROLE_ROUTING) {
  491. MsgBuffer &round1 = route_state.round1;
  492. pthread_mutex_lock(&round1.mutex);
  493. round1.completed_prev_round = true;
  494. nodenum_t nodes_received = round1.nodes_received;
  495. pthread_mutex_unlock(&round1.mutex);
  496. if (nodes_received == g_teems_config.num_ingestion_nodes) {
  497. route_state.step = ROUTE_ROUND_1;
  498. route_state.cbpointer = NULL;
  499. ocall_routing_round_complete(cbpointer, 1);
  500. }
  501. } else {
  502. route_state.step = ROUTE_ROUND_1;
  503. route_state.round1.completed_prev_round = true;
  504. ocall_routing_round_complete(cbpointer, 1);
  505. }
  506. } else if (route_state.step == ROUTE_ROUND_1) {
  507. if (my_roles & ROLE_ROUTING) {
  508. route_state.cbpointer = cbpointer;
  509. MsgBuffer &round1 = route_state.round1;
  510. pthread_mutex_lock(&round1.mutex);
  511. // Ensure there are no pending messages currently being inserted
  512. // into the buffer
  513. while (round1.reserved != round1.inserted) {
  514. pthread_mutex_unlock(&round1.mutex);
  515. pthread_mutex_lock(&round1.mutex);
  516. }
  517. // If the _total_ number of messages we received in round 1
  518. // is less than the max number of messages we could send to
  519. // _each_ storage node, then cap the number of messages we
  520. // will send to each storage node to that number.
  521. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
  522. if (round1.inserted < msgs_per_stg) {
  523. msgs_per_stg = round1.inserted;
  524. }
  525. // Note: at this point, it is required that each message in
  526. // the round1 buffer have a _valid_ storage node id field.
  527. // Obliviously tally the number of messages we received in
  528. // round1 destined for each storage node
  529. #ifdef PROFILE_ROUTING
  530. uint32_t inserted = round1.inserted;
  531. unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", inserted, round1.bufsize);
  532. unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", inserted);
  533. #endif
  534. uint16_t msg_size = g_teems_config.msg_size;
  535. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  536. std::vector<uint32_t> tally = obliv_tally_stg(
  537. round1.buf, msg_size, round1.inserted, num_storage_nodes);
  538. #ifdef PROFILE_ROUTING
  539. printf_with_rtclock_diff(start_tally, "end tally (%u)\n", inserted);
  540. #endif
  541. // Note: tally contains private values! It's OK to
  542. // non-obliviously check for an error condition, though.
  543. // While we're at it, obliviously change the tally of
  544. // messages received to a tally of padding messages
  545. // required.
  546. uint32_t tot_padding = 0;
  547. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  548. if (tally[i] > msgs_per_stg) {
  549. printf("Received too many messages for storage node %u\n", i);
  550. assert(tally[i] <= msgs_per_stg);
  551. }
  552. tally[i] = msgs_per_stg - tally[i];
  553. tot_padding += tally[i];
  554. }
  555. round1.reserved += tot_padding;
  556. assert(round1.reserved <= round1.bufsize);
  557. // Obliviously add padding for each storage node according
  558. // to the (private) padding tally.
  559. #ifdef PROFILE_ROUTING
  560. unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
  561. #endif
  562. obliv_pad_stg(round1.buf + round1.inserted * msg_size,
  563. msg_size, tally, tot_padding);
  564. #ifdef PROFILE_ROUTING
  565. printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
  566. #endif
  567. round1.inserted += tot_padding;
  568. // Obliviously shuffle the messages
  569. #ifdef PROFILE_ROUTING
  570. unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
  571. #endif
  572. uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
  573. round1.buf, msg_size, round1.inserted, round1.bufsize);
  574. #ifdef PROFILE_ROUTING
  575. printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", round1.inserted, round1.bufsize);
  576. printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", inserted, round1.bufsize);
  577. #endif
  578. // Now we can handle the messages non-obliviously, since we
  579. // know there will be exactly msgs_per_stg messages to each
  580. // storage node, and the oblivious shuffle broke the
  581. // connection between where each message came from and where
  582. // it's going.
  583. send_round2_msgs(num_shuffled, msgs_per_stg);
  584. round1.reset();
  585. pthread_mutex_unlock(&round1.mutex);
  586. }
  587. if (my_roles & ROLE_STORAGE) {
  588. route_state.cbpointer = cbpointer;
  589. MsgBuffer &round2 = route_state.round2;
  590. pthread_mutex_lock(&round2.mutex);
  591. round2.completed_prev_round = true;
  592. nodenum_t nodes_received = round2.nodes_received;
  593. pthread_mutex_unlock(&round2.mutex);
  594. if (nodes_received == g_teems_config.num_routing_nodes) {
  595. route_state.step = ROUTE_ROUND_2;
  596. route_state.cbpointer = NULL;
  597. ocall_routing_round_complete(cbpointer, 2);
  598. }
  599. } else {
  600. route_state.step = ROUTE_ROUND_2;
  601. route_state.round2.completed_prev_round = true;
  602. ocall_routing_round_complete(cbpointer, 2);
  603. }
  604. } else if (route_state.step == ROUTE_ROUND_2) {
  605. if (my_roles & ROLE_STORAGE) {
  606. MsgBuffer &round2 = route_state.round2;
  607. pthread_mutex_lock(&round2.mutex);
  608. // Ensure there are no pending messages currently being inserted
  609. // into the buffer
  610. while (round2.reserved != round2.inserted) {
  611. pthread_mutex_unlock(&round2.mutex);
  612. pthread_mutex_lock(&round2.mutex);
  613. }
  614. #ifdef PROFILE_ROUTING
  615. unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
  616. #endif
  617. storage_received(round2);
  618. #ifdef PROFILE_ROUTING
  619. printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
  620. #endif
  621. // We're done
  622. route_state.step = ROUTE_NOT_STARTED;
  623. ocall_routing_round_complete(cbpointer, 0);
  624. } else {
  625. // We're done
  626. route_state.step = ROUTE_NOT_STARTED;
  627. ocall_routing_round_complete(cbpointer, 0);
  628. }
  629. }
  630. }