route.cpp 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571
  1. #include "Enclave_t.h"
  2. #include "config.hpp"
  3. #include "utils.hpp"
  4. #include "sort.hpp"
  5. #include "comms.hpp"
  6. #include "obliv.hpp"
  7. #include "storage.hpp"
  8. #include "route.hpp"
  9. #define PROFILE_ROUTING
  10. // #define TRACE_ROUTING
  11. RouteState route_state;
  12. // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
  13. #define CEILDIV(x,y) (((x)+(y)-1)/(y))
  14. #ifdef TRACE_ROUTING
  15. // Show (the first 300 of, if there are more) the headers and the first
  16. // few bytes of the body of each message in the buffer
  17. static void show_messages(const char *label, const unsigned char *buffer,
  18. size_t num)
  19. {
  20. if (label) {
  21. printf("%s\n", label);
  22. }
  23. for (size_t i=0; i<num && i<300; ++i) {
  24. const uint32_t *ibuf = (const uint32_t *)buffer;
  25. if (g_teems_config.token_channel) {
  26. printf("%3d R:%08x S:%08x [%08x]\n", i, ibuf[0], ibuf[1],
  27. ibuf[2]);
  28. } else {
  29. printf("%3d R:%08x P:%08x S:%08x [%08x]\n", i, ibuf[0], ibuf[1],
  30. ibuf[2], ibuf[3]);
  31. }
  32. buffer += g_teems_config.msg_size;
  33. }
  34. printf("\n");
  35. }
  36. #endif
  37. // Call this near the end of ecall_config_load, but before
  38. // comms_init_nodestate. Returns true on success, false on failure.
  39. bool route_init()
  40. {
  41. // Compute the maximum number of messages we could receive by direct
  42. // ingestion
  43. // Each ingestion node will have at most
  44. // ceil(user_count/num_ingestion_nodes) users, and each user will
  45. // send at most m_token_out messages.
  46. uint32_t users_per_ing = CEILDIV(g_teems_config.user_count,
  47. g_teems_config.num_ingestion_nodes);
  48. uint32_t tot_msg_per_ing;
  49. if (g_teems_config.token_channel) {
  50. tot_msg_per_ing = users_per_ing * g_teems_config.m_token_out;
  51. } else {
  52. tot_msg_per_ing = users_per_ing * g_teems_config.m_id_out;
  53. }
  54. // Compute the maximum number of messages we could receive in round 1
  55. // In token channel routing, each ingestion node will send us an
  56. // our_weight/tot_weight fraction of the messages they hold
  57. uint32_t max_msg_from_each_ing;
  58. max_msg_from_each_ing = CEILDIV(tot_msg_per_ing, g_teems_config.tot_weight) *
  59. g_teems_config.my_weight;
  60. // And the maximum number we can receive in total is that times the
  61. // number of ingestion nodes
  62. uint32_t max_round1_msgs = max_msg_from_each_ing *
  63. g_teems_config.num_ingestion_nodes;
  64. // Compute the maximum number of messages we could send in round 2
  65. // Each storage node has at most this many users
  66. uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
  67. g_teems_config.num_storage_nodes);
  68. // And so can receive at most this many messages
  69. uint32_t tot_msg_per_stg;
  70. if (g_teems_config.token_channel) {
  71. tot_msg_per_stg = users_per_stg * g_teems_config.m_token_in;
  72. } else {
  73. tot_msg_per_stg = users_per_stg * g_teems_config.m_id_in;
  74. }
  75. // Which will be at most this many from us
  76. uint32_t max_msg_to_each_stg;
  77. max_msg_to_each_stg = (tot_msg_per_stg/g_teems_config.tot_weight
  78. + g_teems_config.tot_weight) * g_teems_config.my_weight;
  79. // But we can't send more messages to each storage server than we
  80. // could receive in total
  81. if (max_msg_to_each_stg > max_round1_msgs) {
  82. max_msg_to_each_stg = max_round1_msgs;
  83. }
  84. // And the max total number of outgoing messages in round 2 is then
  85. uint32_t max_round2_msgs = max_msg_to_each_stg *
  86. g_teems_config.num_storage_nodes;
  87. // In case we have a weird configuration where users can send more
  88. // messages per epoch than they can receive, ensure the round 2
  89. // buffer is large enough to hold the incoming messages as well
  90. if (max_round2_msgs < max_round1_msgs) {
  91. max_round2_msgs = max_round1_msgs;
  92. }
  93. // The max number of messages that can arrive at a storage server
  94. uint32_t max_stg_msgs;
  95. max_stg_msgs = (tot_msg_per_stg/g_teems_config.tot_weight
  96. + g_teems_config.tot_weight) * g_teems_config.tot_weight;
  97. // Calculating ID channel buffer sizes
  98. // Weights are not used in ID channel routing
  99. // Round up to a multiple of num_routing_nodes
  100. uint32_t max_round1b_msgs_to_adj_rtr = CEILDIV(
  101. (g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1),
  102. g_teems_config.num_routing_nodes) *
  103. g_teems_config.num_routing_nodes;
  104. // Ensure columnroute constraint that column height is >= 2*(num_routing_nodes-1)^2
  105. // and a multiple of num_routing_nodes
  106. uint32_t max_round1a_msgs = CEILDIV(
  107. std::max(max_round1_msgs, 2*max_round1b_msgs_to_adj_rtr),
  108. g_teems_config.num_routing_nodes) *
  109. g_teems_config.num_routing_nodes;
  110. uint32_t max_round1c_msgs = std::max(max_round1a_msgs,
  111. max_round2_msgs);
  112. /*
  113. printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
  114. */
  115. #ifdef TRACK_HEAP_USAGE
  116. printf("route_init H1 heap %u\n", g_peak_heap_used);
  117. #endif
  118. // Create the route state
  119. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  120. try {
  121. if (my_roles & ROLE_INGESTION) {
  122. route_state.ingbuf.alloc(tot_msg_per_ing);
  123. }
  124. #ifdef TRACK_HEAP_USAGE
  125. printf("route_init alloc %u msgs\n", tot_msg_per_ing);
  126. printf("route_init H2 heap %u\n", g_peak_heap_used);
  127. #endif
  128. if (my_roles & ROLE_ROUTING) {
  129. route_state.round1.alloc(max_round2_msgs);
  130. #ifdef TRACK_HEAP_USAGE
  131. printf("route_init alloc %u msgs\n", max_round2_msgs);
  132. printf("route_init H3 heap %u\n", g_peak_heap_used);
  133. #endif
  134. if (!g_teems_config.token_channel) {
  135. route_state.round1a.alloc(max_round1a_msgs);
  136. route_state.round1a_sorted.alloc(max_round1a_msgs +
  137. max_round1b_msgs_to_adj_rtr);
  138. // double round 1b buffers to sort with some round 1a messages
  139. route_state.round1b_next.alloc(2*max_round1b_msgs_to_adj_rtr);
  140. route_state.round1c.alloc(max_round1c_msgs);
  141. #ifdef TRACK_HEAP_USAGE
  142. printf("route_init alloc %u msgs\n", max_round1a_msgs);
  143. printf("route_init alloc %u msgs\n", max_round1a_msgs +
  144. max_round1b_msgs_to_adj_rtr);
  145. printf("route_init alloc %u msgs\n", 2*max_round1b_msgs_to_adj_rtr);
  146. printf("route_init alloc %u msgs\n", max_round1c_msgs);
  147. printf("route_init H4 heap %u\n", g_peak_heap_used);
  148. #endif
  149. }
  150. }
  151. if (my_roles & ROLE_STORAGE) {
  152. route_state.round2.alloc(max_stg_msgs);
  153. #ifdef TRACK_HEAP_USAGE
  154. printf("route_init alloc %u msgs\n", max_stg_msgs);
  155. printf("route_init H5 heap %u\n", g_peak_heap_used);
  156. #endif
  157. if (!storage_init(users_per_stg, max_stg_msgs)) {
  158. return false;
  159. }
  160. #ifdef TRACK_HEAP_USAGE
  161. printf("storage_init(%u,%u)\n", users_per_stg, max_stg_msgs);
  162. printf("route_init H6 heap %u\n", g_peak_heap_used);
  163. #endif
  164. }
  165. } catch (std::bad_alloc&) {
  166. printf("Memory allocation failed in route_init\n");
  167. return false;
  168. }
  169. route_state.step = ROUTE_NOT_STARTED;
  170. route_state.tot_msg_per_ing = tot_msg_per_ing;
  171. route_state.max_round1_msgs = max_round1_msgs;
  172. route_state.max_round1a_msgs = max_round1a_msgs;
  173. route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
  174. route_state.max_round1c_msgs = max_round1c_msgs;
  175. route_state.max_msg_to_each_stg = max_msg_to_each_stg;
  176. route_state.max_round2_msgs = max_round2_msgs;
  177. route_state.max_stg_msgs = max_stg_msgs;
  178. route_state.cbpointer = NULL;
  179. threadid_t nthreads = g_teems_config.nthreads;
  180. #ifdef PROFILE_ROUTING
  181. unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
  182. #endif
  183. if (my_roles & ROLE_INGESTION) {
  184. #ifdef TRACK_HEAP_USAGE
  185. printf("sort_precompute_evalplan(%u) start heap %u\n", tot_msg_per_ing, g_peak_heap_used);
  186. #endif
  187. sort_precompute_evalplan(tot_msg_per_ing, nthreads);
  188. #ifdef TRACK_HEAP_USAGE
  189. printf("sort_precompute_evalplan(%u) end heap %u\n", tot_msg_per_ing, g_peak_heap_used);
  190. #endif
  191. }
  192. if (my_roles & ROLE_ROUTING) {
  193. #ifdef TRACK_HEAP_USAGE
  194. printf("sort_precompute_evalplan(%u) start heap %u\n", max_round2_msgs, g_peak_heap_used);
  195. #endif
  196. sort_precompute_evalplan(max_round2_msgs, nthreads);
  197. #ifdef TRACK_HEAP_USAGE
  198. printf("sort_precompute_evalplan(%u) end heap %u\n", max_round2_msgs, g_peak_heap_used);
  199. #endif
  200. if(!g_teems_config.token_channel) {
  201. #ifdef TRACK_HEAP_USAGE
  202. printf("sort_precompute_evalplan(%u) start heap %u\n", max_round1a_msgs, g_peak_heap_used);
  203. #endif
  204. sort_precompute_evalplan(max_round1a_msgs, nthreads);
  205. #ifdef TRACK_HEAP_USAGE
  206. printf("sort_precompute_evalplan(%u) end heap %u\n", max_round1a_msgs, g_peak_heap_used);
  207. printf("sort_precompute_evalplan(%u) start heap %u\n", max_round1b_msgs_to_adj_rtr, g_peak_heap_used);
  208. #endif
  209. sort_precompute_evalplan(2*max_round1b_msgs_to_adj_rtr, nthreads);
  210. #ifdef TRACK_HEAP_USAGE
  211. printf("sort_precompute_evalplan(%u) end heap %u\n", max_round1b_msgs_to_adj_rtr, g_peak_heap_used);
  212. #endif
  213. }
  214. }
  215. if (my_roles & ROLE_STORAGE) {
  216. #ifdef TRACK_HEAP_USAGE
  217. printf("sort_precompute_evalplan(%u) start heap %u\n", max_stg_msgs, g_peak_heap_used);
  218. #endif
  219. sort_precompute_evalplan(max_stg_msgs, nthreads);
  220. #ifdef TRACK_HEAP_USAGE
  221. printf("sort_precompute_evalplan(%u) end heap %u\n", max_stg_msgs, g_peak_heap_used);
  222. #endif
  223. }
  224. #ifdef PROFILE_ROUTING
  225. printf_with_rtclock_diff(start, "end precompute evalplans\n");
  226. #endif
  227. #ifdef TRACK_HEAP_USAGE
  228. printf("route_init end heap %u\n", g_peak_heap_used);
  229. #endif
  230. return true;
  231. }
  232. // Call when shutting system down to deallocate routing state
  233. void route_close() {
  234. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  235. if (my_roles & ROLE_STORAGE) {
  236. storage_close();
  237. }
  238. }
  239. // Precompute the WaksmanNetworks needed for the sorts. If you pass -1,
  240. // it will return the number of different sizes it needs to regenerate.
  241. // If you pass [0,sizes-1], it will compute one WaksmanNetwork with that
  242. // size index and return the number of available WaksmanNetworks of that
  243. // size. If you pass anything else, it will return the number of
  244. // different sizes it needs at all.
  245. // The list of sizes that need refilling, updated when you pass -1
  246. static std::vector<uint32_t> used_sizes;
  247. size_t ecall_precompute_sort(int sizeidx)
  248. {
  249. size_t ret = 0;
  250. if (sizeidx == -1) {
  251. used_sizes = sort_get_used();
  252. ret = used_sizes.size();
  253. } else if (sizeidx >= 0 && sizeidx < used_sizes.size()) {
  254. uint32_t size = used_sizes[sizeidx];
  255. #ifdef TRACK_HEAP_USAGE
  256. printf("ecall_precompute_sort start heap %u\n", g_peak_heap_used);
  257. #endif
  258. #ifdef PROFILE_ROUTING
  259. unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", size);
  260. #endif
  261. ret = sort_precompute(size);
  262. #ifdef PROFILE_ROUTING
  263. printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", size);
  264. #endif
  265. #ifdef TRACK_HEAP_USAGE
  266. printf("ecall_precompute_sort end heap %u\n", g_peak_heap_used);
  267. #endif
  268. } else {
  269. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  270. if (my_roles & ROLE_INGESTION) {
  271. used_sizes.push_back(route_state.tot_msg_per_ing);
  272. }
  273. if (my_roles & ROLE_ROUTING) {
  274. used_sizes.push_back(route_state.max_round2_msgs);
  275. if(!g_teems_config.token_channel) {
  276. used_sizes.push_back(route_state.max_round1a_msgs);
  277. used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
  278. }
  279. }
  280. if (my_roles & ROLE_STORAGE) {
  281. used_sizes.push_back(route_state.max_stg_msgs);
  282. if(!g_teems_config.token_channel) {
  283. used_sizes.push_back(route_state.max_stg_msgs);
  284. }
  285. }
  286. ret = used_sizes.size();
  287. }
  288. return ret;
  289. }
  290. static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
  291. NodeCommState &, uint32_t tot_enc_chunk_size)
  292. {
  293. uint16_t msg_size = g_teems_config.msg_size;
  294. // Chunks will be encrypted and have a MAC tag attached which will
  295. // not correspond to plaintext bytes, so we can trim them.
  296. // The minimum number of chunks needed to transmit this message
  297. uint32_t min_num_chunks =
  298. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  299. // The number of plaintext bytes this message could contain
  300. uint32_t plaintext_bytes = tot_enc_chunk_size -
  301. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  302. uint32_t num_msgs = CEILDIV(plaintext_bytes, uint32_t(msg_size));
  303. pthread_mutex_lock(&msgbuf.mutex);
  304. uint32_t start = msgbuf.reserved;
  305. if (start + num_msgs > msgbuf.bufsize) {
  306. pthread_mutex_unlock(&msgbuf.mutex);
  307. printf("Max %u messages exceeded (msgbuffer_get_buf)\n",
  308. msgbuf.bufsize);
  309. return NULL;
  310. }
  311. msgbuf.reserved += num_msgs;
  312. pthread_mutex_unlock(&msgbuf.mutex);
  313. return msgbuf.buf + start * msg_size;
  314. }
  315. static void round1a_received(NodeCommState &nodest,
  316. uint8_t *data, uint32_t plaintext_len, uint32_t);
  317. static void round1b_next_received(NodeCommState &nodest,
  318. uint8_t *data, uint32_t plaintext_len, uint32_t);
  319. static void round1c_received(NodeCommState &nodest, uint8_t *data,
  320. uint32_t plaintext_len, uint32_t);
  321. static void round2_received(NodeCommState &nodest,
  322. uint8_t *data, uint32_t plaintext_len, uint32_t);
  323. // A round 1 message was received by a routing node from an ingestion
  324. // node; we put it into the round 2 buffer for processing in round 2
  325. static void round1_received(NodeCommState &nodest,
  326. uint8_t *data, uint32_t plaintext_len, uint32_t)
  327. {
  328. uint16_t msg_size = g_teems_config.msg_size;
  329. assert((plaintext_len % uint32_t(msg_size)) == 0);
  330. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  331. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  332. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  333. pthread_mutex_lock(&route_state.round1.mutex);
  334. route_state.round1.inserted += num_msgs;
  335. route_state.round1.nodes_received += 1;
  336. nodenum_t nodes_received = route_state.round1.nodes_received;
  337. bool completed_prev_round = route_state.round1.completed_prev_round;
  338. pthread_mutex_unlock(&route_state.round1.mutex);
  339. // What is the next message we expect from this node?
  340. if (g_teems_config.token_channel) {
  341. if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  342. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  343. uint32_t tot_enc_chunk_size) {
  344. return msgbuffer_get_buf(route_state.round2, commst,
  345. tot_enc_chunk_size);
  346. };
  347. nodest.in_msg_received = round2_received;
  348. }
  349. // Otherwise, it's just the next round 1 message, so don't change
  350. // the handlers.
  351. } else {
  352. if (their_roles & ROLE_ROUTING) {
  353. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  354. uint32_t tot_enc_chunk_size) {
  355. return msgbuffer_get_buf(route_state.round1a, commst,
  356. tot_enc_chunk_size);
  357. };
  358. nodest.in_msg_received = round1a_received;
  359. }
  360. // Otherwise, it's just the next round 1 message, so don't change
  361. // the handlers.
  362. }
  363. if (nodes_received == g_teems_config.num_ingestion_nodes &&
  364. completed_prev_round) {
  365. route_state.step = ROUTE_ROUND_1;
  366. void *cbpointer = route_state.cbpointer;
  367. route_state.cbpointer = NULL;
  368. ocall_routing_round_complete(cbpointer, 1);
  369. }
  370. }
  371. // A round 1a message was received by a routing node
  372. static void round1a_received(NodeCommState &nodest,
  373. uint8_t *data, uint32_t plaintext_len, uint32_t)
  374. {
  375. uint16_t msg_size = g_teems_config.msg_size;
  376. assert((plaintext_len % uint32_t(msg_size)) == 0);
  377. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  378. pthread_mutex_lock(&route_state.round1a.mutex);
  379. route_state.round1a.inserted += num_msgs;
  380. route_state.round1a.nodes_received += 1;
  381. nodenum_t nodes_received = route_state.round1a.nodes_received;
  382. bool completed_prev_round = route_state.round1a.completed_prev_round;
  383. pthread_mutex_unlock(&route_state.round1a.mutex);
  384. // Both are routing nodes
  385. // We only expect a message from the next node (if it exists)
  386. nodenum_t my_node_num = g_teems_config.my_node_num;
  387. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  388. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  389. if ((prev_nodes < num_routing_nodes-1) &&
  390. (nodest.node_num == g_teems_config.routing_nodes[1])) {
  391. // Node is next routing node
  392. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  393. uint32_t tot_enc_chunk_size) {
  394. return msgbuffer_get_buf(route_state.round1b_next, commst,
  395. tot_enc_chunk_size);
  396. };
  397. nodest.in_msg_received = round1b_next_received;
  398. } else {
  399. // other routing nodes will not send to this node until round 1c
  400. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  401. uint32_t tot_enc_chunk_size) {
  402. return msgbuffer_get_buf(route_state.round1c, commst,
  403. tot_enc_chunk_size);
  404. };
  405. nodest.in_msg_received = round1c_received;
  406. }
  407. if (nodes_received == g_teems_config.num_routing_nodes &&
  408. completed_prev_round) {
  409. route_state.step = ROUTE_ROUND_1A;
  410. void *cbpointer = route_state.cbpointer;
  411. route_state.cbpointer = NULL;
  412. ocall_routing_round_complete(cbpointer, ROUND_1A);
  413. }
  414. }
  415. // A round 1b message was received from the next routing node
  416. static void round1b_next_received(NodeCommState &nodest,
  417. uint8_t *data, uint32_t plaintext_len, uint32_t)
  418. {
  419. uint16_t msg_size = g_teems_config.msg_size;
  420. // There are an extra 5 bytes at the end of this message, containing
  421. // the next receiver id (4 bytes) and the count of messages with
  422. // that receiver id (1 byte)
  423. assert((plaintext_len % uint32_t(msg_size)) == 5);
  424. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  425. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  426. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  427. pthread_mutex_lock(&route_state.round1b_next.mutex);
  428. // Add an extra 1 for the message space taken up by the above 5
  429. // bytes
  430. route_state.round1b_next.inserted += num_msgs + 1;
  431. route_state.round1b_next.nodes_received += 1;
  432. nodenum_t nodes_received = route_state.round1b_next.nodes_received;
  433. bool completed_prev_round = route_state.round1b_next.completed_prev_round;
  434. pthread_mutex_unlock(&route_state.round1b_next.mutex);
  435. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  436. uint32_t tot_enc_chunk_size) {
  437. return msgbuffer_get_buf(route_state.round1c, commst,
  438. tot_enc_chunk_size);
  439. };
  440. nodest.in_msg_received = round1c_received;
  441. nodenum_t my_node_num = g_teems_config.my_node_num;
  442. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  443. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  444. nodenum_t adjacent_nodes;
  445. if (num_routing_nodes == 1) {
  446. adjacent_nodes = 0;
  447. } else if (prev_nodes == num_routing_nodes-1) {
  448. adjacent_nodes = 0;
  449. } else {
  450. adjacent_nodes = 1;
  451. }
  452. if (nodes_received == adjacent_nodes && completed_prev_round) {
  453. route_state.step = ROUTE_ROUND_1B;
  454. void *cbpointer = route_state.cbpointer;
  455. route_state.cbpointer = NULL;
  456. ocall_routing_round_complete(cbpointer, ROUND_1B);
  457. }
  458. }
  459. // Message received in round 1c
  460. static void round1c_received(NodeCommState &nodest, uint8_t *data,
  461. uint32_t plaintext_len, uint32_t)
  462. {
  463. uint16_t msg_size = g_teems_config.msg_size;
  464. assert((plaintext_len % uint32_t(msg_size)) == 0);
  465. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  466. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  467. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  468. pthread_mutex_lock(&route_state.round1c.mutex);
  469. route_state.round1c.inserted += num_msgs;
  470. route_state.round1c.nodes_received += 1;
  471. nodenum_t nodes_received = route_state.round1c.nodes_received;
  472. bool completed_prev_round = route_state.round1c.completed_prev_round;
  473. pthread_mutex_unlock(&route_state.round1c.mutex);
  474. // What is the next message we expect from this node?
  475. if (our_roles & ROLE_STORAGE) {
  476. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  477. uint32_t tot_enc_chunk_size) {
  478. return msgbuffer_get_buf(route_state.round2, commst,
  479. tot_enc_chunk_size);
  480. };
  481. nodest.in_msg_received = round2_received;
  482. } else if (their_roles & ROLE_INGESTION) {
  483. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  484. uint32_t tot_enc_chunk_size) {
  485. return msgbuffer_get_buf(route_state.round1, commst,
  486. tot_enc_chunk_size);
  487. };
  488. nodest.in_msg_received = round1_received;
  489. } else {
  490. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  491. uint32_t tot_enc_chunk_size) {
  492. return msgbuffer_get_buf(route_state.round1a, commst,
  493. tot_enc_chunk_size);
  494. };
  495. nodest.in_msg_received = round1a_received;
  496. }
  497. if (nodes_received == g_teems_config.num_routing_nodes &&
  498. completed_prev_round) {
  499. route_state.step = ROUTE_ROUND_1C;
  500. void *cbpointer = route_state.cbpointer;
  501. route_state.cbpointer = NULL;
  502. ocall_routing_round_complete(cbpointer, ROUND_1C);
  503. }
  504. }
  505. // A round 2 message was received by a storage node from a routing node
  506. static void round2_received(NodeCommState &nodest,
  507. uint8_t *data, uint32_t plaintext_len, uint32_t)
  508. {
  509. uint16_t msg_size = g_teems_config.msg_size;
  510. assert((plaintext_len % uint32_t(msg_size)) == 0);
  511. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  512. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  513. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  514. pthread_mutex_lock(&route_state.round2.mutex);
  515. route_state.round2.inserted += num_msgs;
  516. route_state.round2.nodes_received += 1;
  517. nodenum_t nodes_received = route_state.round2.nodes_received;
  518. bool completed_prev_round = route_state.round2.completed_prev_round;
  519. pthread_mutex_unlock(&route_state.round2.mutex);
  520. // What is the next message we expect from this node?
  521. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  522. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  523. uint32_t tot_enc_chunk_size) {
  524. return msgbuffer_get_buf(route_state.round1, commst,
  525. tot_enc_chunk_size);
  526. };
  527. nodest.in_msg_received = round1_received;
  528. }
  529. // Otherwise, it's just the next round 2 message, so don't change
  530. // the handlers.
  531. if (nodes_received == g_teems_config.num_routing_nodes &&
  532. completed_prev_round) {
  533. route_state.step = ROUTE_ROUND_2;
  534. void *cbpointer = route_state.cbpointer;
  535. route_state.cbpointer = NULL;
  536. ocall_routing_round_complete(cbpointer, 2);
  537. }
  538. }
  539. // For a given other node, set the received message handler to the first
  540. // message we would expect from them, given their roles and our roles.
  541. void route_init_msg_handler(nodenum_t node_num)
  542. {
  543. // Our roles and their roles
  544. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  545. uint8_t their_roles = g_teems_config.roles[node_num];
  546. // The node communication state
  547. NodeCommState &nodest = g_commstates[node_num];
  548. // If we are a routing node (possibly among other roles) and they
  549. // are an ingestion node (possibly among other roles), a round 1
  550. // routing message is the first thing we expect from them. We put
  551. // these messages into the round1 buffer for processing.
  552. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  553. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  554. uint32_t tot_enc_chunk_size) {
  555. return msgbuffer_get_buf(route_state.round1, commst,
  556. tot_enc_chunk_size);
  557. };
  558. nodest.in_msg_received = round1_received;
  559. }
  560. // Otherwise, if we are a storage node (possibly among other roles)
  561. // and they are a routing node (possibly among other roles), a round
  562. // 2 routing message is the first thing we expect from them
  563. else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  564. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  565. uint32_t tot_enc_chunk_size) {
  566. return msgbuffer_get_buf(route_state.round2, commst,
  567. tot_enc_chunk_size);
  568. };
  569. nodest.in_msg_received = round2_received;
  570. }
  571. // Otherwise, we don't expect a message from this node. Set the
  572. // unknown message handler.
  573. else {
  574. nodest.in_msg_get_buf = default_in_msg_get_buf;
  575. nodest.in_msg_received = unknown_in_msg_received;
  576. }
  577. }
  578. // Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
  579. // Return true on success, false on failure.
  580. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
  581. {
  582. uint16_t msg_size = g_teems_config.msg_size;
  583. MsgBuffer &ingbuf = route_state.ingbuf;
  584. pthread_mutex_lock(&ingbuf.mutex);
  585. uint32_t start = ingbuf.reserved;
  586. if (start + num_msgs > route_state.tot_msg_per_ing) {
  587. pthread_mutex_unlock(&ingbuf.mutex);
  588. printf("Max %u messages exceeded (ecall_ingest_raw)\n",
  589. route_state.tot_msg_per_ing);
  590. return false;
  591. }
  592. ingbuf.reserved += num_msgs;
  593. pthread_mutex_unlock(&ingbuf.mutex);
  594. memmove(ingbuf.buf + start * msg_size,
  595. msgs, num_msgs * msg_size);
  596. pthread_mutex_lock(&ingbuf.mutex);
  597. ingbuf.inserted += num_msgs;
  598. pthread_mutex_unlock(&ingbuf.mutex);
  599. return true;
  600. }
  601. // Send messages round-robin, used in rounds 1 and 1c. Note that N here
  602. // is not private. Pass indices = nullptr to just send the messages in
  603. // the order they appear in the msgs buffer.
  604. template<typename T>
  605. static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs,
  606. const T *indices, uint32_t N)
  607. {
  608. uint16_t msg_size = g_teems_config.msg_size;
  609. uint16_t tot_weight;
  610. tot_weight = g_teems_config.tot_weight;
  611. nodenum_t my_node_num = g_teems_config.my_node_num;
  612. uint32_t full_rows;
  613. uint32_t last_row;
  614. full_rows = N / uint32_t(tot_weight);
  615. last_row = N % uint32_t(tot_weight);
  616. for (auto &routing_node: g_teems_config.routing_nodes) {
  617. uint8_t weight = g_teems_config.weights[routing_node].weight;
  618. if (weight == 0) {
  619. // This shouldn't happen, but just in case
  620. continue;
  621. }
  622. uint16_t start_weight =
  623. g_teems_config.weights[routing_node].startweight;
  624. // The number of messages headed for this routing node from the
  625. // full rows
  626. uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
  627. // The number of messages headed for this routing node from the
  628. // incomplete last row is:
  629. // 0 if last_row < start_weight
  630. // last_row-start_weight if start_weight <= last_row < start_weight + weight
  631. // weight if start_weight + weight <= last_row
  632. uint32_t num_msgs_last_row = 0;
  633. if (start_weight <= last_row && last_row < start_weight + weight) {
  634. num_msgs_last_row = last_row-start_weight;
  635. } else if (start_weight + weight <= last_row) {
  636. num_msgs_last_row = weight;
  637. }
  638. // The total number of messages headed for this routing node
  639. uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
  640. if (routing_node == my_node_num) {
  641. // Special case: we're sending to ourselves; just put the
  642. // messages in our own buffer
  643. pthread_mutex_lock(&round.mutex);
  644. uint32_t start = round.reserved;
  645. if (start + num_msgs > round.bufsize) {
  646. pthread_mutex_unlock(&round.mutex);
  647. printf("Max %u messages exceeded (send_round_robin_msgs)\n",
  648. round.bufsize);
  649. return;
  650. }
  651. round.reserved += num_msgs;
  652. pthread_mutex_unlock(&round.mutex);
  653. uint8_t *buf = round.buf + start * msg_size;
  654. for (uint32_t i=0; i<full_rows; ++i) {
  655. if (indices) {
  656. const T *idxp = indices + i*tot_weight + start_weight;
  657. for (uint32_t j=0; j<weight; ++j) {
  658. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  659. buf += msg_size;
  660. }
  661. } else {
  662. size_t idx = i*tot_weight + start_weight;
  663. memmove(buf, msgs + idx*msg_size, weight*msg_size);
  664. buf += weight*msg_size;
  665. }
  666. }
  667. if (indices) {
  668. const T *idxp = indices + full_rows*tot_weight + start_weight;
  669. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  670. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  671. buf += msg_size;
  672. }
  673. } else {
  674. size_t idx = full_rows*tot_weight + start_weight;
  675. memmove(buf, msgs + idx*msg_size, num_msgs_last_row*msg_size);
  676. buf += num_msgs_last_row*msg_size;
  677. }
  678. pthread_mutex_lock(&round.mutex);
  679. round.inserted += num_msgs;
  680. round.nodes_received += 1;
  681. pthread_mutex_unlock(&round.mutex);
  682. } else {
  683. NodeCommState &nodecom = g_commstates[routing_node];
  684. nodecom.message_start(num_msgs * msg_size);
  685. for (uint32_t i=0; i<full_rows; ++i) {
  686. if (indices) {
  687. const T *idxp = indices + i*tot_weight + start_weight;
  688. for (uint32_t j=0; j<weight; ++j) {
  689. nodecom.message_data(msgs + idxp[j].index()*msg_size,
  690. msg_size);
  691. }
  692. } else {
  693. size_t idx = i*tot_weight + start_weight;
  694. nodecom.message_data(msgs + idx*msg_size, weight*msg_size);
  695. }
  696. }
  697. if (indices) {
  698. const T *idxp = indices + full_rows*tot_weight + start_weight;
  699. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  700. nodecom.message_data(msgs + idxp[j].index()*msg_size,
  701. msg_size);
  702. }
  703. } else {
  704. size_t idx = full_rows*tot_weight + start_weight;
  705. nodecom.message_data(msgs + idx*msg_size,
  706. num_msgs_last_row*msg_size);
  707. }
  708. }
  709. }
  710. }
  711. // Send the round 1a messages from the round 1 buffer, which only occurs in ID-channel routing.
  712. // msgs points to the message buffer, indices points to the the sorted indices, and N is the number
  713. // of non-padding items.
  714. static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
  715. uint16_t msg_size = g_teems_config.msg_size;
  716. nodenum_t my_node_num = g_teems_config.my_node_num;
  717. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  718. uint32_t min_msgs_per_node = route_state.max_round1a_msgs / num_routing_nodes;
  719. uint32_t extra_msgs = route_state.max_round1a_msgs % num_routing_nodes;
  720. for (auto &routing_node: g_teems_config.routing_nodes) {
  721. // In this unweighted setting, start_weight represents the position among routing nodes
  722. uint16_t prev_nodes = g_teems_config.weights[routing_node].startweight;
  723. uint32_t start_msg, num_msgs;
  724. if (prev_nodes >= extra_msgs) {
  725. start_msg = min_msgs_per_node * prev_nodes + extra_msgs;
  726. num_msgs = min_msgs_per_node;
  727. } else {
  728. start_msg = min_msgs_per_node * prev_nodes + prev_nodes;
  729. num_msgs = min_msgs_per_node + 1;
  730. }
  731. // take number of messages into account
  732. if (start_msg >= N) {
  733. num_msgs = 0;
  734. } else if (start_msg + num_msgs > N) {
  735. num_msgs = N - start_msg;
  736. }
  737. if (routing_node == my_node_num) {
  738. // Special case: we're sending to ourselves; just put the
  739. // messages in our own buffer
  740. MsgBuffer &round1a = route_state.round1a;
  741. pthread_mutex_lock(&round1a.mutex);
  742. uint32_t start = round1a.reserved;
  743. if (start + num_msgs > round1a.bufsize) {
  744. pthread_mutex_unlock(&round1a.mutex);
  745. printf("Max %u messages exceeded in round 1a\n", round1a.bufsize);
  746. return;
  747. }
  748. round1a.reserved += num_msgs;
  749. pthread_mutex_unlock(&round1a.mutex);
  750. uint8_t *buf = round1a.buf + start * msg_size;
  751. for (uint32_t i=0; i<num_msgs; ++i) {
  752. const UidPriorityKey *idxp = indices + start_msg + i;
  753. memmove(buf, msgs + idxp->index()*msg_size, msg_size);
  754. buf += msg_size;
  755. }
  756. pthread_mutex_lock(&round1a.mutex);
  757. round1a.inserted += num_msgs;
  758. round1a.nodes_received += 1;
  759. pthread_mutex_unlock(&round1a.mutex);
  760. } else {
  761. NodeCommState &nodecom = g_commstates[routing_node];
  762. nodecom.message_start(num_msgs * msg_size);
  763. for (uint32_t i=0; i<num_msgs; ++i) {
  764. const UidPriorityKey *idxp = indices + start_msg + i;
  765. nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
  766. }
  767. }
  768. }
  769. }
  770. // Send the round 1b messages from the round 1a buffer, which only occurs in ID-channel routing.
  771. // msgs points to the message buffer, and N is the number of non-padding items.
  772. // Return the number of messages sent
  773. static uint32_t send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
  774. uint16_t msg_size = g_teems_config.msg_size;
  775. nodenum_t my_node_num = g_teems_config.my_node_num;
  776. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  777. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  778. // send to previous node
  779. if (prev_nodes > 0) {
  780. nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
  781. NodeCommState &nodecom = g_commstates[prev_node];
  782. uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr);
  783. // There are an extra 5 bytes at the end of this message: 4
  784. // bytes for the receiver id in the next message we _didn't_
  785. // send, and 1 byte for the number of messages we have at the
  786. // beginning of the buffer of messages we didn't send (max
  787. // id_in) with the same receiver id
  788. nodecom.message_start(num_msgs * msg_size + 5);
  789. nodecom.message_data(msgs, num_msgs * msg_size);
  790. uint32_t next_receiver_id = 0xffffffff;
  791. uint8_t next_rid_count = 0;
  792. // num_msgs and N are not private, but the contents of the
  793. // buffer are.
  794. if (num_msgs < N) {
  795. next_receiver_id = *(const uint32_t *)(msgs +
  796. num_msgs * msg_size);
  797. next_rid_count = 1;
  798. // If id_in > 1, obliviously scan messages num_msgs+1 ..
  799. // num_msgs+(id_in-1) and as long as they have the same
  800. // receiver id as next_receiver_id, add 1 to next_rid_count (but
  801. // don't go past message N of course)
  802. // This count _includes_ the first message already scanned
  803. // above. It is not private.
  804. uint8_t num_to_scan = uint8_t(std::min(N - num_msgs,
  805. uint32_t(g_teems_config.m_id_in)));
  806. const unsigned char *scan_msg = msgs +
  807. (num_msgs + 1) * msg_size;
  808. for (uint8_t i=1; i<num_to_scan; ++i) {
  809. uint32_t receiver_id = *(const uint32_t *)scan_msg;
  810. next_rid_count += (receiver_id == next_receiver_id);
  811. scan_msg += msg_size;
  812. }
  813. }
  814. nodecom.message_data((const unsigned char *)&next_receiver_id, 4);
  815. nodecom.message_data(&next_rid_count, 1);
  816. return num_msgs;
  817. }
  818. return 0;
  819. }
  820. // Send the round 2 messages from the previous-round buffer, which are already
  821. // padded and shuffled, so this can be done non-obliviously. tot_msgs
  822. // is the total number of messages in the input buffer, which may
  823. // include padding messages added by the shuffle. Those messages are
  824. // not sent anywhere. There are num_msgs_per_stg messages for each
  825. // storage node labelled for that node.
  826. static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg, MsgBuffer &prevround)
  827. {
  828. uint16_t msg_size = g_teems_config.msg_size;
  829. const uint8_t* buf = prevround.buf;
  830. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  831. nodenum_t my_node_num = g_teems_config.my_node_num;
  832. uint8_t *myself_buf = NULL;
  833. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  834. nodenum_t node = g_teems_config.storage_nodes[i];
  835. if (node != my_node_num) {
  836. g_commstates[node].message_start(msg_size * num_msgs_per_stg);
  837. } else {
  838. MsgBuffer &round2 = route_state.round2;
  839. pthread_mutex_lock(&round2.mutex);
  840. uint32_t start = round2.reserved;
  841. if (start + num_msgs_per_stg > round2.bufsize) {
  842. pthread_mutex_unlock(&round2.mutex);
  843. printf("Max %u messages exceeded (send_round2_msgs)\n",
  844. round2.bufsize);
  845. return;
  846. }
  847. round2.reserved += num_msgs_per_stg;
  848. pthread_mutex_unlock(&round2.mutex);
  849. myself_buf = round2.buf + start * msg_size;
  850. }
  851. }
  852. while (tot_msgs) {
  853. nodenum_t storage_node_id =
  854. nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
  855. if (storage_node_id < num_storage_nodes) {
  856. nodenum_t node = g_teems_config.storage_map[storage_node_id];
  857. if (node == my_node_num) {
  858. memmove(myself_buf, buf, msg_size);
  859. myself_buf += msg_size;
  860. } else {
  861. g_commstates[node].message_data(buf, msg_size);
  862. }
  863. }
  864. buf += msg_size;
  865. --tot_msgs;
  866. }
  867. if (myself_buf) {
  868. MsgBuffer &round2 = route_state.round2;
  869. pthread_mutex_lock(&round2.mutex);
  870. round2.inserted += num_msgs_per_stg;
  871. round2.nodes_received += 1;
  872. pthread_mutex_unlock(&round2.mutex);
  873. }
  874. }
  875. static void round1a_processing(void *cbpointer) {
  876. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  877. MsgBuffer &round1 = route_state.round1;
  878. if (my_roles & ROLE_ROUTING) {
  879. route_state.cbpointer = cbpointer;
  880. pthread_mutex_lock(&round1.mutex);
  881. // Ensure there are no pending messages currently being inserted
  882. // into the buffer
  883. while (round1.reserved != round1.inserted) {
  884. pthread_mutex_unlock(&round1.mutex);
  885. pthread_mutex_lock(&round1.mutex);
  886. }
  887. #ifdef TRACE_ROUTING
  888. show_messages("Start of round 1a", round1.buf, round1.inserted);
  889. #endif
  890. #ifdef PROFILE_ROUTING
  891. uint32_t inserted = round1.inserted;
  892. unsigned long start_round1a =
  893. printf_with_rtclock("begin round1a processing (%u)\n", inserted);
  894. // Sort the messages we've received
  895. unsigned long start_sort =
  896. printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted,
  897. route_state.max_round1_msgs);
  898. #endif
  899. // Sort received messages by increasing user ID and
  900. // priority. Larger priority number indicates higher priority.
  901. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1.buf,
  902. g_teems_config.msg_size, round1.inserted, route_state.max_round1_msgs,
  903. send_round1a_msgs);
  904. #ifdef PROFILE_ROUTING
  905. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n",
  906. inserted, route_state.max_round1_msgs);
  907. printf_with_rtclock_diff(start_round1a, "end round1a processing (%u)\n",
  908. inserted);
  909. #endif
  910. round1.reset();
  911. pthread_mutex_unlock(&round1.mutex);
  912. MsgBuffer &round1a = route_state.round1a;
  913. pthread_mutex_lock(&round1a.mutex);
  914. round1a.completed_prev_round = true;
  915. nodenum_t nodes_received = round1a.nodes_received;
  916. pthread_mutex_unlock(&round1a.mutex);
  917. if (nodes_received == g_teems_config.num_routing_nodes) {
  918. route_state.step = ROUTE_ROUND_1A;
  919. route_state.cbpointer = NULL;
  920. ocall_routing_round_complete(cbpointer, ROUND_1A);
  921. }
  922. } else {
  923. route_state.step = ROUTE_ROUND_1A;
  924. route_state.round1a.completed_prev_round = true;
  925. ocall_routing_round_complete(cbpointer, ROUND_1A);
  926. }
  927. }
  928. static void round1b_processing(void *cbpointer) {
  929. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  930. nodenum_t my_node_num = g_teems_config.my_node_num;
  931. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  932. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  933. MsgBuffer &round1a = route_state.round1a;
  934. MsgBuffer &round1a_sorted = route_state.round1a_sorted;
  935. if (my_roles & ROLE_ROUTING) {
  936. route_state.cbpointer = cbpointer;
  937. pthread_mutex_lock(&round1a.mutex);
  938. // Ensure there are no pending messages currently being inserted
  939. // into the buffer
  940. while (round1a.reserved != round1a.inserted) {
  941. pthread_mutex_unlock(&round1a.mutex);
  942. pthread_mutex_lock(&round1a.mutex);
  943. }
  944. pthread_mutex_lock(&round1a_sorted.mutex);
  945. #ifdef TRACE_ROUTING
  946. show_messages("Start of round 1b", round1a.buf, round1a.inserted);
  947. #endif
  948. #ifdef PROFILE_ROUTING
  949. uint32_t inserted = round1a.inserted;
  950. unsigned long start_round1b =
  951. printf_with_rtclock("begin round1b processing (%u)\n", inserted);
  952. #endif
  953. // Sort the messages we've received
  954. // Sort received messages by increasing user ID and
  955. // priority. Larger priority number indicates higher priority.
  956. if (inserted > 0) {
  957. // copy items in sorted order into round1a_sorted
  958. #ifdef PROFILE_ROUTING
  959. unsigned long start_sort =
  960. printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted,
  961. route_state.max_round1a_msgs);
  962. #endif
  963. route_state.round1a_sorted.reserved = round1a.inserted;
  964. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
  965. g_teems_config.msg_size, round1a.inserted,
  966. route_state.max_round1a_msgs,
  967. route_state.round1a_sorted.buf);
  968. route_state.round1a_sorted.inserted = round1a.inserted;
  969. #ifdef PROFILE_ROUTING
  970. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n",
  971. inserted, route_state.max_round1a_msgs);
  972. #endif
  973. #ifdef TRACE_ROUTING
  974. show_messages("In round 1b (sorted)", round1a_sorted.buf,
  975. round1a_sorted.inserted);
  976. #endif
  977. uint32_t num_sent = send_round1b_msgs(round1a_sorted.buf,
  978. round1a_sorted.inserted);
  979. // Remove the sent messages from the buffer
  980. memmove(round1a_sorted.buf,
  981. round1a_sorted.buf + num_sent * g_teems_config.msg_size,
  982. (round1a_sorted.inserted - num_sent) * g_teems_config.msg_size);
  983. round1a_sorted.inserted -= num_sent;
  984. round1a_sorted.reserved -= num_sent;
  985. #ifdef TRACE_ROUTING
  986. show_messages("In round 1b (after sending initial block)",
  987. round1a_sorted.buf, round1a_sorted.inserted);
  988. #endif
  989. } else {
  990. send_round1b_msgs(NULL, 0);
  991. }
  992. #ifdef PROFILE_ROUTING
  993. printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n",
  994. inserted);
  995. #endif
  996. pthread_mutex_unlock(&round1a_sorted.mutex);
  997. pthread_mutex_unlock(&round1a.mutex);
  998. MsgBuffer &round1b_next = route_state.round1b_next;
  999. pthread_mutex_lock(&round1b_next.mutex);
  1000. round1b_next.completed_prev_round = true;
  1001. nodenum_t nodes_received = round1b_next.nodes_received;
  1002. pthread_mutex_unlock(&round1b_next.mutex);
  1003. nodenum_t adjacent_nodes;
  1004. if (num_routing_nodes == 1) {
  1005. adjacent_nodes = 0;
  1006. } else if (prev_nodes == num_routing_nodes-1) {
  1007. adjacent_nodes = 0;
  1008. } else {
  1009. adjacent_nodes = 1;
  1010. }
  1011. if (nodes_received == adjacent_nodes) {
  1012. route_state.step = ROUTE_ROUND_1B;
  1013. route_state.cbpointer = NULL;
  1014. ocall_routing_round_complete(cbpointer, ROUND_1B);
  1015. }
  1016. } else {
  1017. route_state.step = ROUTE_ROUND_1B;
  1018. route_state.round1b_next.completed_prev_round = true;
  1019. ocall_routing_round_complete(cbpointer, ROUND_1B);
  1020. }
  1021. }
  1022. static void copy_msgs(uint8_t *dst, uint32_t start_msg, uint32_t num_copy, const uint8_t *src,
  1023. const UidPriorityKey *indices)
  1024. {
  1025. uint16_t msg_size = g_teems_config.msg_size;
  1026. const UidPriorityKey *idxp = indices + start_msg;
  1027. uint8_t *buf = dst;
  1028. for (uint32_t i=0; i<num_copy; i++) {
  1029. memmove(buf, src + idxp[i].index()*msg_size, msg_size);
  1030. buf += msg_size;
  1031. }
  1032. }
  1033. static void round1c_processing(void *cbpointer) {
  1034. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  1035. nodenum_t my_node_num = g_teems_config.my_node_num;
  1036. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  1037. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  1038. uint16_t msg_size = g_teems_config.msg_size;
  1039. uint32_t max_round1b_msgs_to_adj_rtr = route_state.max_round1b_msgs_to_adj_rtr;
  1040. uint32_t max_round1a_msgs = route_state.max_round1a_msgs;
  1041. MsgBuffer &round1a = route_state.round1a;
  1042. MsgBuffer &round1a_sorted = route_state.round1a_sorted;
  1043. MsgBuffer &round1b_next = route_state.round1b_next;
  1044. if (my_roles & ROLE_ROUTING) {
  1045. route_state.cbpointer = cbpointer;
  1046. pthread_mutex_lock(&round1b_next.mutex);
  1047. // Ensure there are no pending messages currently being inserted
  1048. // into the round 1b buffer
  1049. while (round1b_next.reserved != round1b_next.inserted) {
  1050. pthread_mutex_unlock(&round1b_next.mutex);
  1051. pthread_mutex_lock(&round1b_next.mutex);
  1052. }
  1053. uint32_t next_receiver_id = 0xffffffff;
  1054. uint8_t next_rid_count = 0;
  1055. // Extract the trailing 5 bytes if we received a round1b message
  1056. if (round1b_next.inserted >= 1) {
  1057. next_receiver_id = *(uint32_t *)(round1b_next.buf +
  1058. (round1b_next.inserted-1)*msg_size);
  1059. next_rid_count = *(round1b_next.buf +
  1060. (round1b_next.inserted-1)*msg_size + 4);
  1061. round1b_next.inserted -= 1;
  1062. round1b_next.reserved -= 1;
  1063. }
  1064. pthread_mutex_lock(&round1a.mutex);
  1065. pthread_mutex_lock(&round1a_sorted.mutex);
  1066. #ifdef TRACE_ROUTING
  1067. show_messages("Start of round 1c", round1a_sorted.buf,
  1068. round1a_sorted.inserted);
  1069. #endif
  1070. #ifdef PROFILE_ROUTING
  1071. unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", round1a.inserted);
  1072. #endif
  1073. // sort round1b_next msgs with final msgs in round1a_sorted
  1074. if ((prev_nodes < num_routing_nodes-1) && (round1b_next.inserted > 0)) {
  1075. // Copy final msgs in round1a_sorted to round1b_next buffer for sorting
  1076. // Note that all inserted values and buffer sizes are non-secret
  1077. uint32_t num_final_round1a = std::min(
  1078. round1a_sorted.inserted, max_round1b_msgs_to_adj_rtr);
  1079. uint32_t round1a_msg_start =
  1080. round1a_sorted.inserted-num_final_round1a;
  1081. uint32_t num_round1b_next = round1b_next.inserted;
  1082. round1b_next.reserved += num_final_round1a;
  1083. memmove(round1b_next.buf+num_round1b_next*msg_size,
  1084. round1a_sorted.buf + round1a_msg_start*msg_size,
  1085. num_final_round1a*msg_size);
  1086. round1b_next.inserted += num_final_round1a;
  1087. round1a_sorted.reserved -= num_final_round1a;
  1088. round1a_sorted.inserted -= num_final_round1a;
  1089. #ifdef TRACE_ROUTING
  1090. show_messages("In round 1c (after setting aside)",
  1091. round1a_sorted.buf, round1a_sorted.inserted);
  1092. show_messages("In round 1c (the aside)", round1b_next.buf,
  1093. round1b_next.inserted);
  1094. #endif
  1095. // sort and append to round1a msgs
  1096. uint32_t num_copy = round1b_next.inserted;
  1097. #ifdef PROFILE_ROUTING
  1098. unsigned long start_sort =
  1099. printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n",
  1100. num_copy, 2*max_round1b_msgs_to_adj_rtr);
  1101. #endif
  1102. round1a_sorted.reserved += num_copy;
  1103. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads,
  1104. round1b_next.buf, msg_size, num_copy,
  1105. 2*max_round1b_msgs_to_adj_rtr,
  1106. [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
  1107. return copy_msgs(round1a_sorted.buf +
  1108. round1a_sorted.inserted * msg_size, 0, num_copy,
  1109. src, indices);
  1110. }
  1111. );
  1112. round1a_sorted.inserted += num_copy;
  1113. #ifdef PROFILE_ROUTING
  1114. printf_with_rtclock_diff(start_sort,
  1115. "end round1b_next oblivious sort (%u,%u)\n",
  1116. num_copy, 2*max_round1b_msgs_to_adj_rtr);
  1117. #endif
  1118. }
  1119. #ifdef TRACE_ROUTING
  1120. printf("round1a_sorted.inserted = %lu, reserved = %lu, bufsize = %lu\n",
  1121. round1a_sorted.inserted, round1a_sorted.reserved,
  1122. round1a_sorted.bufsize);
  1123. show_messages("In round 1c before padding pass", round1a_sorted.buf,
  1124. round1a_sorted.inserted);
  1125. #endif
  1126. // The round1a_sorted buffer is now sorted by receiver uid and
  1127. // priority. Going from the end of the buffer to the beginning
  1128. // (so as to encounter and keep the highest-priority messages
  1129. // for any given receiver first), obliviously turn any messages
  1130. // over the limit of id_in for any given receiver into padding.
  1131. // Also keep track of which messages are not padding for use in
  1132. // later compaction.
  1133. bool *is_not_padding = new bool[round1a_sorted.inserted];
  1134. for (uint32_t i=0; i<round1a_sorted.inserted; ++i) {
  1135. uint8_t *header = round1a_sorted.buf +
  1136. msg_size * (round1a_sorted.inserted - 1 - i);
  1137. uint32_t receiver_id = *(uint32_t*)header;
  1138. uint32_t id_in = uint32_t(g_teems_config.m_id_in);
  1139. // These are the possible cases and what we need to do in
  1140. // each case, but we have to evaluate them obliviously
  1141. // receiver_id != next_receiver_id:
  1142. // next_receiver_id = receiver_id
  1143. // next_rid_count = 1
  1144. // become_padding = 0
  1145. // receiver_id == next_receiver_id && next_rid_count < id_in:
  1146. // next_receiver_id = receiver_id
  1147. // next_rid_count = next_rid_count + 1
  1148. // become_padding = 0
  1149. // receiver_id == next_receiver_id && next_rid_count >= id_in:
  1150. // next_receiver_id = receiver_id
  1151. // next_rid_count = next_rid_count
  1152. // become_padding = 1
  1153. bool same_receiver_id = (receiver_id == next_receiver_id);
  1154. // If same_receiver_id is 0, reset next_rid_count to 0.
  1155. // If same_receiver_id is 1, don't change next_rid_count.
  1156. // This method (AND with -same_receiver_id) is more likely
  1157. // to be constant time than multiplying by same_receiver_id.
  1158. next_rid_count &= (-(uint32_t(same_receiver_id)));
  1159. bool become_padding = (next_rid_count >= id_in);
  1160. next_rid_count += !become_padding;
  1161. next_receiver_id = receiver_id;
  1162. // Obliviously change the receiver id to 0xffffffff
  1163. // (padding) if become-padding is 1
  1164. receiver_id |= (-(uint32_t(become_padding)));
  1165. *(uint32_t*)header = receiver_id;
  1166. is_not_padding[round1a_sorted.inserted - 1 - i] =
  1167. (receiver_id != 0xffffffff);
  1168. }
  1169. #ifdef TRACE_ROUTING
  1170. show_messages("In round 1c after padding pass", round1a_sorted.buf,
  1171. round1a_sorted.inserted);
  1172. #endif
  1173. // Oblivious compaction to move the padding messages to the end,
  1174. // preserving the (already-sorted) order of the non-padding
  1175. // messages
  1176. TightCompact_parallel<OSWAP_16X>(
  1177. (unsigned char *) round1a_sorted.buf,
  1178. round1a_sorted.inserted, msg_size, is_not_padding,
  1179. g_teems_config.nthreads);
  1180. delete[] is_not_padding;
  1181. #ifdef TRACE_ROUTING
  1182. show_messages("In round 1c after compaction", round1a_sorted.buf,
  1183. round1a_sorted.inserted);
  1184. #endif
  1185. send_round_robin_msgs<UidPriorityKey>(route_state.round1c,
  1186. round1a_sorted.buf, NULL, round1a_sorted.inserted);
  1187. #ifdef PROFILE_ROUTING
  1188. printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n",
  1189. round1a.inserted);
  1190. #endif
  1191. round1a.reset();
  1192. round1a_sorted.reset();
  1193. round1b_next.reset();
  1194. pthread_mutex_unlock(&round1a_sorted.mutex);
  1195. pthread_mutex_unlock(&round1a.mutex);
  1196. pthread_mutex_unlock(&round1b_next.mutex);
  1197. MsgBuffer &round1c = route_state.round1c;
  1198. pthread_mutex_lock(&round1c.mutex);
  1199. round1c.completed_prev_round = true;
  1200. nodenum_t nodes_received = round1c.nodes_received;
  1201. pthread_mutex_unlock(&round1c.mutex);
  1202. if (nodes_received == num_routing_nodes) {
  1203. route_state.step = ROUTE_ROUND_1C;
  1204. route_state.cbpointer = NULL;
  1205. ocall_routing_round_complete(cbpointer, ROUND_1C);
  1206. }
  1207. } else {
  1208. route_state.step = ROUTE_ROUND_1C;
  1209. route_state.round1c.completed_prev_round = true;
  1210. ocall_routing_round_complete(cbpointer, ROUND_1C);
  1211. }
  1212. }
  1213. // Process messages in round 2
  1214. static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prevround) {
  1215. if (my_roles & ROLE_ROUTING) {
  1216. route_state.cbpointer = cbpointer;
  1217. pthread_mutex_lock(&prevround.mutex);
  1218. // Ensure there are no pending messages currently being inserted
  1219. // into the buffer
  1220. while (prevround.reserved != prevround.inserted) {
  1221. pthread_mutex_unlock(&prevround.mutex);
  1222. pthread_mutex_lock(&prevround.mutex);
  1223. }
  1224. // If the _total_ number of messages we received in round 1
  1225. // is less than the max number of messages we could send to
  1226. // _each_ storage node, then cap the number of messages we
  1227. // will send to each storage node to that number.
  1228. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
  1229. if (prevround.inserted < msgs_per_stg) {
  1230. msgs_per_stg = prevround.inserted;
  1231. }
  1232. // Note: at this point, it is required that each message in
  1233. // the prevround buffer have a _valid_ storage node id field.
  1234. // Obliviously tally the number of messages we received in
  1235. // the previous round destined for each storage node
  1236. #ifdef TRACE_ROUTING
  1237. show_messages("Start of round 2", prevround.buf, prevround.inserted);
  1238. #endif
  1239. #ifdef PROFILE_ROUTING
  1240. unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1241. unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", prevround.inserted);
  1242. #endif
  1243. uint16_t msg_size = g_teems_config.msg_size;
  1244. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  1245. std::vector<uint32_t> tally = obliv_tally_stg(
  1246. prevround.buf, msg_size, prevround.inserted, num_storage_nodes);
  1247. #ifdef PROFILE_ROUTING
  1248. printf_with_rtclock_diff(start_tally, "end tally (%u)\n", prevround.inserted);
  1249. #endif
  1250. // Note: tally contains private values! It's OK to
  1251. // non-obliviously check for an error condition, though.
  1252. // While we're at it, obliviously change the tally of
  1253. // messages received to a tally of padding messages
  1254. // required.
  1255. uint32_t tot_messages = msgs_per_stg * num_storage_nodes;
  1256. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  1257. if (tally[i] > msgs_per_stg) {
  1258. printf("Received too many messages for storage node "
  1259. "%u (%u > %u)\n", i, tally[i], msgs_per_stg);
  1260. assert(tally[i] <= msgs_per_stg);
  1261. }
  1262. tally[i] = msgs_per_stg - tally[i];
  1263. }
  1264. // Allocate extra padding messages (not yet destined for a
  1265. // particular storage node)
  1266. assert(prevround.inserted <= tot_messages &&
  1267. tot_messages <= prevround.bufsize);
  1268. prevround.reserved = tot_messages;
  1269. for (uint32_t i=prevround.inserted; i<tot_messages; ++i) {
  1270. uint8_t *header = prevround.buf + i*msg_size;
  1271. *(uint32_t*)header = 0xffffffff;
  1272. }
  1273. // Obliviously convert global padding (receiver id 0xffffffff)
  1274. // into padding for each storage node according to the (private)
  1275. // padding tally.
  1276. #ifdef PROFILE_ROUTING
  1277. unsigned long start_pad =
  1278. printf_with_rtclock("begin pad (%u)\n", tot_messages);
  1279. #endif
  1280. obliv_stg_padding(prevround.buf, msg_size, tally, tot_messages);
  1281. #ifdef PROFILE_ROUTING
  1282. printf_with_rtclock_diff(start_pad, "end pad (%u)\n",
  1283. tot_messages);
  1284. #endif
  1285. prevround.inserted = tot_messages;
  1286. #ifdef TRACE_ROUTING
  1287. show_messages("In round 2 after padding", prevround.buf, prevround.inserted);
  1288. #endif
  1289. // Obliviously shuffle the messages
  1290. #ifdef PROFILE_ROUTING
  1291. unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1292. #endif
  1293. uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
  1294. prevround.buf, msg_size, prevround.inserted, prevround.bufsize);
  1295. #ifdef PROFILE_ROUTING
  1296. printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1297. printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1298. #endif
  1299. // Now we can handle the messages non-obliviously, since we
  1300. // know there will be exactly msgs_per_stg messages to each
  1301. // storage node, and the oblivious shuffle broke the
  1302. // connection between where each message came from and where
  1303. // it's going.
  1304. send_round2_msgs(num_shuffled, msgs_per_stg, prevround);
  1305. prevround.reset();
  1306. pthread_mutex_unlock(&prevround.mutex);
  1307. }
  1308. if (my_roles & ROLE_STORAGE) {
  1309. route_state.cbpointer = cbpointer;
  1310. MsgBuffer &round2 = route_state.round2;
  1311. pthread_mutex_lock(&round2.mutex);
  1312. round2.completed_prev_round = true;
  1313. nodenum_t nodes_received = round2.nodes_received;
  1314. pthread_mutex_unlock(&round2.mutex);
  1315. if (nodes_received == g_teems_config.num_routing_nodes) {
  1316. route_state.step = ROUTE_ROUND_2;
  1317. route_state.cbpointer = NULL;
  1318. ocall_routing_round_complete(cbpointer, 2);
  1319. }
  1320. } else {
  1321. route_state.step = ROUTE_ROUND_2;
  1322. route_state.round2.completed_prev_round = true;
  1323. ocall_routing_round_complete(cbpointer, 2);
  1324. }
  1325. }
  1326. // Perform the next round of routing. The callback pointer will be
  1327. // passed to ocall_routing_round_complete when the round is complete.
  1328. void ecall_routing_proceed(void *cbpointer)
  1329. {
  1330. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  1331. if (route_state.step == ROUTE_NOT_STARTED) {
  1332. if (my_roles & ROLE_INGESTION) {
  1333. ingestion_epoch++;
  1334. route_state.cbpointer = cbpointer;
  1335. MsgBuffer &ingbuf = route_state.ingbuf;
  1336. pthread_mutex_lock(&ingbuf.mutex);
  1337. // Ensure there are no pending messages currently being inserted
  1338. // into the buffer
  1339. while (ingbuf.reserved != ingbuf.inserted) {
  1340. pthread_mutex_unlock(&ingbuf.mutex);
  1341. pthread_mutex_lock(&ingbuf.mutex);
  1342. }
  1343. // Sort the messages we've received
  1344. #ifdef TRACE_ROUTING
  1345. show_messages("Start of round 1", ingbuf.buf, ingbuf.inserted);
  1346. #endif
  1347. #ifdef PROFILE_ROUTING
  1348. uint32_t inserted = ingbuf.inserted;
  1349. unsigned long start_round1 =
  1350. printf_with_rtclock("begin round1 processing (%u)\n", inserted);
  1351. unsigned long start_sort =
  1352. printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted,
  1353. route_state.tot_msg_per_ing);
  1354. #endif
  1355. if (g_teems_config.token_channel) {
  1356. sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
  1357. g_teems_config.msg_size, ingbuf.inserted,
  1358. route_state.tot_msg_per_ing,
  1359. [&](const uint8_t *msgs, const UidKey *indices, uint32_t N) {
  1360. send_round_robin_msgs<UidKey>(route_state.round1,
  1361. msgs, indices, N);
  1362. });
  1363. } else {
  1364. // Sort received messages by increasing user ID and
  1365. // priority. Larger priority number indicates higher priority.
  1366. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
  1367. g_teems_config.msg_size, ingbuf.inserted,
  1368. route_state.tot_msg_per_ing,
  1369. [&](const uint8_t *msgs,
  1370. const UidPriorityKey *indices, uint32_t N) {
  1371. send_round_robin_msgs<UidPriorityKey>(route_state.round1,
  1372. msgs, indices, N);
  1373. });
  1374. }
  1375. #ifdef PROFILE_ROUTING
  1376. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n",
  1377. inserted, route_state.tot_msg_per_ing);
  1378. printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n",
  1379. inserted);
  1380. #endif
  1381. ingbuf.reset();
  1382. pthread_mutex_unlock(&ingbuf.mutex);
  1383. }
  1384. if (my_roles & ROLE_ROUTING) {
  1385. MsgBuffer &round1 = route_state.round1;
  1386. pthread_mutex_lock(&round1.mutex);
  1387. round1.completed_prev_round = true;
  1388. nodenum_t nodes_received = round1.nodes_received;
  1389. pthread_mutex_unlock(&round1.mutex);
  1390. if (nodes_received == g_teems_config.num_ingestion_nodes) {
  1391. route_state.step = ROUTE_ROUND_1;
  1392. route_state.cbpointer = NULL;
  1393. ocall_routing_round_complete(cbpointer, 1);
  1394. }
  1395. } else {
  1396. route_state.step = ROUTE_ROUND_1;
  1397. route_state.round1.completed_prev_round = true;
  1398. ocall_routing_round_complete(cbpointer, 1);
  1399. }
  1400. } else if (route_state.step == ROUTE_ROUND_1) {
  1401. if (g_teems_config.token_channel) { // Token channel routing next round
  1402. round2_processing(my_roles, cbpointer, route_state.round1);
  1403. } else { // ID channel routing next round
  1404. round1a_processing(cbpointer);
  1405. }
  1406. } else if (route_state.step == ROUTE_ROUND_1A) {
  1407. round1b_processing(cbpointer);
  1408. } else if (route_state.step == ROUTE_ROUND_1B) {
  1409. round1c_processing(cbpointer);
  1410. } else if (route_state.step == ROUTE_ROUND_1C) {
  1411. round2_processing(my_roles, cbpointer, route_state.round1c);
  1412. } else if (route_state.step == ROUTE_ROUND_2) {
  1413. if (my_roles & ROLE_STORAGE) {
  1414. MsgBuffer &round2 = route_state.round2;
  1415. pthread_mutex_lock(&round2.mutex);
  1416. // Ensure there are no pending messages currently being inserted
  1417. // into the buffer
  1418. while (round2.reserved != round2.inserted) {
  1419. pthread_mutex_unlock(&round2.mutex);
  1420. pthread_mutex_lock(&round2.mutex);
  1421. }
  1422. unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
  1423. storage_received(round2);
  1424. printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
  1425. // We're done
  1426. route_state.step = ROUTE_NOT_STARTED;
  1427. ocall_routing_round_complete(cbpointer, 0);
  1428. } else {
  1429. // We're done
  1430. route_state.step = ROUTE_NOT_STARTED;
  1431. ocall_routing_round_complete(cbpointer, 0);
  1432. }
  1433. }
  1434. #ifdef TRACK_HEAP_USAGE
  1435. printf("ecall_routing_proceed end heap %u\n", g_peak_heap_used);
  1436. #endif
  1437. }