route.cpp 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364
  1. #include "Enclave_t.h"
  2. #include "config.hpp"
  3. #include "utils.hpp"
  4. #include "sort.hpp"
  5. #include "comms.hpp"
  6. #include "obliv.hpp"
  7. #include "storage.hpp"
  8. #include "route.hpp"
  9. #define PROFILE_ROUTING
  10. RouteState route_state;
  11. // Computes ceil(x/y) where x and y are integers, x>=0, y>0.
  12. #define CEILDIV(x,y) (((x)+(y)-1)/(y))
  13. // Call this near the end of ecall_config_load, but before
  14. // comms_init_nodestate. Returns true on success, false on failure.
  15. bool route_init()
  16. {
  17. // Compute the maximum number of messages we could receive by direct
  18. // ingestion
  19. // Each ingestion node will have at most
  20. // ceil(user_count/num_ingestion_nodes) users, and each user will
  21. // send at most m_priv_out messages.
  22. uint32_t users_per_ing = CEILDIV(g_teems_config.user_count,
  23. g_teems_config.num_ingestion_nodes);
  24. uint32_t tot_msg_per_ing;
  25. if (g_teems_config.private_routing) {
  26. tot_msg_per_ing = users_per_ing * g_teems_config.m_priv_out;
  27. } else {
  28. tot_msg_per_ing = users_per_ing * g_teems_config.m_pub_out;
  29. }
  30. // Compute the maximum number of messages we could receive in round 1
  31. // In private routing, each ingestion node will send us an
  32. // our_weight/tot_weight fraction of the messages they hold
  33. uint32_t max_msg_from_each_ing;
  34. max_msg_from_each_ing = CEILDIV(tot_msg_per_ing, g_teems_config.tot_weight) *
  35. g_teems_config.my_weight;
  36. // And the maximum number we can receive in total is that times the
  37. // number of ingestion nodes
  38. uint32_t max_round1_msgs = max_msg_from_each_ing *
  39. g_teems_config.num_ingestion_nodes;
  40. // Compute the maximum number of messages we could send in round 2
  41. // Each storage node has at most this many users
  42. uint32_t users_per_stg = CEILDIV(g_teems_config.user_count,
  43. g_teems_config.num_storage_nodes);
  44. // And so can receive at most this many messages
  45. uint32_t tot_msg_per_stg;
  46. if (g_teems_config.private_routing) {
  47. tot_msg_per_stg = users_per_stg * g_teems_config.m_priv_in;
  48. } else {
  49. tot_msg_per_stg = users_per_stg * g_teems_config.m_pub_in;
  50. }
  51. // Which will be at most this many from us
  52. uint32_t max_msg_to_each_stg;
  53. max_msg_to_each_stg = CEILDIV(tot_msg_per_stg, g_teems_config.tot_weight) *
  54. g_teems_config.my_weight;
  55. // But we can't send more messages to each storage server than we
  56. // could receive in total
  57. if (max_msg_to_each_stg > max_round1_msgs) {
  58. max_msg_to_each_stg = max_round1_msgs;
  59. }
  60. // And the max total number of outgoing messages in round 2 is then
  61. uint32_t max_round2_msgs = max_msg_to_each_stg *
  62. g_teems_config.num_storage_nodes;
  63. // In case we have a weird configuration where users can send more
  64. // messages per epoch than they can receive, ensure the round 2
  65. // buffer is large enough to hold the incoming messages as well
  66. if (max_round2_msgs < max_round1_msgs) {
  67. max_round2_msgs = max_round1_msgs;
  68. }
  69. // The max number of messages that can arrive at a storage server
  70. uint32_t max_stg_msgs;
  71. max_stg_msgs = tot_msg_per_stg + g_teems_config.tot_weight;
  72. // Calculating public-routing buffer sizes
  73. // Weights are not used in public routing
  74. uint32_t max_round1b_msgs_to_adj_rtr =
  75. (g_teems_config.num_routing_nodes-1)*(g_teems_config.num_routing_nodes-1);
  76. // Ensure columnroute constraint that column height is >= 2*(num_routing_nodes-1)^2
  77. uint32_t max_round1a_msgs = std::max(max_round1_msgs, 2*max_round1b_msgs_to_adj_rtr);
  78. uint32_t max_round1c_msgs = max_round1a_msgs;
  79. /*
  80. printf("users_per_ing=%u, tot_msg_per_ing=%u, max_msg_from_each_ing=%u, max_round1_msgs=%u, users_per_stg=%u, tot_msg_per_stg=%u, max_msg_to_each_stg=%u, max_round2_msgs=%u, max_stg_msgs=%u\n", users_per_ing, tot_msg_per_ing, max_msg_from_each_ing, max_round1_msgs, users_per_stg, tot_msg_per_stg, max_msg_to_each_stg, max_round2_msgs, max_stg_msgs);
  81. */
  82. // Create the route state
  83. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  84. try {
  85. if (my_roles & ROLE_INGESTION) {
  86. route_state.ingbuf.alloc(tot_msg_per_ing);
  87. }
  88. if (my_roles & ROLE_ROUTING) {
  89. route_state.round1.alloc(max_round2_msgs);
  90. if (!g_teems_config.private_routing) {
  91. route_state.round1a.alloc(max_round1a_msgs);
  92. route_state.round1a_sorted.alloc(max_round1a_msgs);
  93. // double round 1b buffers to sort with some round 1a messages
  94. route_state.round1b_prev.alloc(2*max_round1b_msgs_to_adj_rtr);
  95. route_state.round1b_next.alloc(2*max_round1b_msgs_to_adj_rtr);
  96. route_state.round1c.alloc(max_round1c_msgs);
  97. }
  98. }
  99. if (my_roles & ROLE_STORAGE) {
  100. route_state.round2.alloc(max_stg_msgs);
  101. if (!storage_init(users_per_stg, max_stg_msgs)) {
  102. return false;
  103. }
  104. }
  105. } catch (std::bad_alloc&) {
  106. printf("Memory allocation failed in route_init\n");
  107. return false;
  108. }
  109. route_state.step = ROUTE_NOT_STARTED;
  110. route_state.tot_msg_per_ing = tot_msg_per_ing;
  111. route_state.max_round1_msgs = max_round1_msgs;
  112. route_state.max_round1a_msgs = max_round1a_msgs;
  113. route_state.max_round1b_msgs_to_adj_rtr = max_round1b_msgs_to_adj_rtr;
  114. route_state.max_round1c_msgs = max_round1c_msgs;
  115. route_state.max_msg_to_each_stg = max_msg_to_each_stg;
  116. route_state.max_round2_msgs = max_round2_msgs;
  117. route_state.max_stg_msgs = max_stg_msgs;
  118. route_state.cbpointer = NULL;
  119. threadid_t nthreads = g_teems_config.nthreads;
  120. #ifdef PROFILE_ROUTING
  121. unsigned long start = printf_with_rtclock("begin precompute evalplans (%u,%hu) (%u,%hu)\n", tot_msg_per_ing, nthreads, max_round2_msgs, nthreads);
  122. #endif
  123. if (my_roles & ROLE_INGESTION) {
  124. sort_precompute_evalplan(tot_msg_per_ing, nthreads);
  125. }
  126. if (my_roles & ROLE_ROUTING) {
  127. sort_precompute_evalplan(max_round2_msgs, nthreads);
  128. if(!g_teems_config.private_routing) {
  129. sort_precompute_evalplan(max_round1a_msgs, nthreads);
  130. sort_precompute_evalplan(2*max_round1b_msgs_to_adj_rtr, nthreads);
  131. }
  132. }
  133. if (my_roles & ROLE_STORAGE) {
  134. sort_precompute_evalplan(max_stg_msgs, nthreads);
  135. }
  136. #ifdef PROFILE_ROUTING
  137. printf_with_rtclock_diff(start, "end precompute evalplans\n");
  138. #endif
  139. return true;
  140. }
  141. // Call when shutting system down to deallocate routing state
  142. void route_close() {
  143. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  144. if (my_roles & ROLE_STORAGE) {
  145. storage_close();
  146. }
  147. }
  148. // Precompute the WaksmanNetworks needed for the sorts. If you pass -1,
  149. // it will return the number of different sizes it needs to regenerate.
  150. // If you pass [0,sizes-1], it will compute one WaksmanNetwork with that
  151. // size index and return the number of available WaksmanNetworks of that
  152. // size. If you pass anything else, it will return the number of
  153. // different sizes it needs at all.
  154. // The list of sizes that need refilling, updated when you pass -1
  155. static std::vector<uint32_t> used_sizes;
  156. size_t ecall_precompute_sort(int sizeidx)
  157. {
  158. size_t ret = 0;
  159. if (sizeidx == -1) {
  160. used_sizes = sort_get_used();
  161. ret = used_sizes.size();
  162. } else if (sizeidx >= 0 && sizeidx < used_sizes.size()) {
  163. uint32_t size = used_sizes[sizeidx];
  164. #ifdef PROFILE_ROUTING
  165. unsigned long start = printf_with_rtclock("begin precompute WaksmanNetwork (%u)\n", size);
  166. #endif
  167. ret = sort_precompute(size);
  168. #ifdef PROFILE_ROUTING
  169. printf_with_rtclock_diff(start, "end precompute Waksman Network (%u)\n", size);
  170. #endif
  171. } else {
  172. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  173. if (my_roles & ROLE_INGESTION) {
  174. used_sizes.push_back(route_state.tot_msg_per_ing);
  175. }
  176. if (my_roles & ROLE_ROUTING) {
  177. used_sizes.push_back(route_state.max_round2_msgs);
  178. if(!g_teems_config.private_routing) {
  179. used_sizes.push_back(route_state.max_round1a_msgs);
  180. used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
  181. used_sizes.push_back(2*route_state.max_round1b_msgs_to_adj_rtr);
  182. used_sizes.push_back(route_state.max_round1c_msgs);
  183. }
  184. }
  185. if (my_roles & ROLE_STORAGE) {
  186. used_sizes.push_back(route_state.max_stg_msgs);
  187. if(!g_teems_config.private_routing) {
  188. used_sizes.push_back(route_state.max_stg_msgs);
  189. }
  190. }
  191. ret = used_sizes.size();
  192. }
  193. return ret;
  194. }
  195. static uint8_t* msgbuffer_get_buf(MsgBuffer &msgbuf,
  196. NodeCommState &, uint32_t tot_enc_chunk_size)
  197. {
  198. uint16_t msg_size = g_teems_config.msg_size;
  199. // Chunks will be encrypted and have a MAC tag attached which will
  200. // not correspond to plaintext bytes, so we can trim them.
  201. // The minimum number of chunks needed to transmit this message
  202. uint32_t min_num_chunks =
  203. (tot_enc_chunk_size + (FRAME_SIZE-1)) / FRAME_SIZE;
  204. // The number of plaintext bytes this message could contain
  205. uint32_t plaintext_bytes = tot_enc_chunk_size -
  206. SGX_AESGCM_MAC_SIZE * min_num_chunks;
  207. assert ((plaintext_bytes % uint32_t(msg_size)) == 0);
  208. uint32_t num_msgs = plaintext_bytes/uint32_t(msg_size);
  209. pthread_mutex_lock(&msgbuf.mutex);
  210. uint32_t start = msgbuf.reserved;
  211. if (start + num_msgs > msgbuf.bufsize) {
  212. pthread_mutex_unlock(&msgbuf.mutex);
  213. printf("Max %u messages exceeded\n", msgbuf.bufsize);
  214. return NULL;
  215. }
  216. msgbuf.reserved += num_msgs;
  217. pthread_mutex_unlock(&msgbuf.mutex);
  218. return msgbuf.buf + start * msg_size;
  219. }
  220. static void round1a_received(NodeCommState &nodest,
  221. uint8_t *data, uint32_t plaintext_len, uint32_t);
  222. static void round1b_prev_received(NodeCommState &nodest,
  223. uint8_t *data, uint32_t plaintext_len, uint32_t);
  224. static void round1b_next_received(NodeCommState &nodest,
  225. uint8_t *data, uint32_t plaintext_len, uint32_t);
  226. static void round1c_received(NodeCommState &nodest, uint8_t *data,
  227. uint32_t plaintext_len, uint32_t);
  228. static void round2_received(NodeCommState &nodest,
  229. uint8_t *data, uint32_t plaintext_len, uint32_t);
  230. // A round 1 message was received by a routing node from an ingestion
  231. // node; we put it into the round 2 buffer for processing in round 2
  232. static void round1_received(NodeCommState &nodest,
  233. uint8_t *data, uint32_t plaintext_len, uint32_t)
  234. {
  235. uint16_t msg_size = g_teems_config.msg_size;
  236. assert((plaintext_len % uint32_t(msg_size)) == 0);
  237. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  238. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  239. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  240. pthread_mutex_lock(&route_state.round1.mutex);
  241. route_state.round1.inserted += num_msgs;
  242. route_state.round1.nodes_received += 1;
  243. nodenum_t nodes_received = route_state.round1.nodes_received;
  244. bool completed_prev_round = route_state.round1.completed_prev_round;
  245. pthread_mutex_unlock(&route_state.round1.mutex);
  246. // What is the next message we expect from this node?
  247. if (g_teems_config.private_routing) {
  248. if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  249. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  250. uint32_t tot_enc_chunk_size) {
  251. return msgbuffer_get_buf(route_state.round2, commst,
  252. tot_enc_chunk_size);
  253. };
  254. nodest.in_msg_received = round2_received;
  255. }
  256. // Otherwise, it's just the next round 1 message, so don't change
  257. // the handlers.
  258. } else {
  259. if (their_roles & ROLE_ROUTING) {
  260. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  261. uint32_t tot_enc_chunk_size) {
  262. return msgbuffer_get_buf(route_state.round1a, commst,
  263. tot_enc_chunk_size);
  264. };
  265. nodest.in_msg_received = round1a_received;
  266. }
  267. // Otherwise, it's just the next round 1 message, so don't change
  268. // the handlers.
  269. }
  270. if (nodes_received == g_teems_config.num_ingestion_nodes &&
  271. completed_prev_round) {
  272. route_state.step = ROUTE_ROUND_1;
  273. void *cbpointer = route_state.cbpointer;
  274. route_state.cbpointer = NULL;
  275. ocall_routing_round_complete(cbpointer, 1);
  276. }
  277. }
  278. // A round 1a message was received by a routing node
  279. static void round1a_received(NodeCommState &nodest,
  280. uint8_t *data, uint32_t plaintext_len, uint32_t)
  281. {
  282. uint16_t msg_size = g_teems_config.msg_size;
  283. assert((plaintext_len % uint32_t(msg_size)) == 0);
  284. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  285. pthread_mutex_lock(&route_state.round1a.mutex);
  286. route_state.round1a.inserted += num_msgs;
  287. route_state.round1a.nodes_received += 1;
  288. nodenum_t nodes_received = route_state.round1a.nodes_received;
  289. bool completed_prev_round = route_state.round1a.completed_prev_round;
  290. pthread_mutex_unlock(&route_state.round1a.mutex);
  291. // Both are routing nodes
  292. // We only expect a message from the previous and next nodes (if they exist)
  293. nodenum_t my_node_num = g_teems_config.my_node_num;
  294. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  295. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  296. if ((prev_nodes > 0) &&
  297. (nodest.node_num == g_teems_config.routing_nodes[num_routing_nodes-1])) {
  298. // Node is previous routing node
  299. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  300. uint32_t tot_enc_chunk_size) {
  301. return msgbuffer_get_buf(route_state.round1b_prev, commst,
  302. tot_enc_chunk_size);
  303. };
  304. nodest.in_msg_received = round1b_prev_received;
  305. } else if ((prev_nodes < num_routing_nodes-1) &&
  306. (nodest.node_num == g_teems_config.routing_nodes[1])) {
  307. // Node is next routing node
  308. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  309. uint32_t tot_enc_chunk_size) {
  310. return msgbuffer_get_buf(route_state.round1b_next, commst,
  311. tot_enc_chunk_size);
  312. };
  313. nodest.in_msg_received = round1b_next_received;
  314. } else {
  315. // other routing nodes will not send to this node until round 1c
  316. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  317. uint32_t tot_enc_chunk_size) {
  318. return msgbuffer_get_buf(route_state.round1c, commst,
  319. tot_enc_chunk_size);
  320. };
  321. nodest.in_msg_received = round1c_received;
  322. }
  323. if (nodes_received == g_teems_config.num_routing_nodes &&
  324. completed_prev_round) {
  325. route_state.step = ROUTE_ROUND_1A;
  326. void *cbpointer = route_state.cbpointer;
  327. route_state.cbpointer = NULL;
  328. ocall_routing_round_complete(cbpointer, ROUND_1A);
  329. }
  330. }
  331. // A round 1b message was received from the previous routing node
  332. static void round1b_prev_received(NodeCommState &nodest,
  333. uint8_t *data, uint32_t plaintext_len, uint32_t)
  334. {
  335. uint16_t msg_size = g_teems_config.msg_size;
  336. assert((plaintext_len % uint32_t(msg_size)) == 0);
  337. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  338. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  339. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  340. pthread_mutex_lock(&route_state.round1b_prev.mutex);
  341. route_state.round1b_prev.inserted += num_msgs;
  342. route_state.round1b_prev.nodes_received += 1;
  343. nodenum_t nodes_received = route_state.round1b_prev.nodes_received;
  344. bool completed_prev_round = route_state.round1b_prev.completed_prev_round;
  345. pthread_mutex_unlock(&route_state.round1b_prev.mutex);
  346. pthread_mutex_lock(&route_state.round1b_next.mutex);
  347. nodes_received += route_state.round1b_next.nodes_received;
  348. pthread_mutex_unlock(&route_state.round1b_next.mutex);
  349. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  350. uint32_t tot_enc_chunk_size) {
  351. return msgbuffer_get_buf(route_state.round1c, commst,
  352. tot_enc_chunk_size);
  353. };
  354. nodest.in_msg_received = round1c_received;
  355. nodenum_t my_node_num = g_teems_config.my_node_num;
  356. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  357. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  358. nodenum_t adjacent_nodes;
  359. if (num_routing_nodes == 1) {
  360. adjacent_nodes = 0;
  361. } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
  362. adjacent_nodes = 1;
  363. } else {
  364. adjacent_nodes = 2;
  365. }
  366. if (nodes_received == adjacent_nodes && completed_prev_round) {
  367. route_state.step = ROUTE_ROUND_1B;
  368. void *cbpointer = route_state.cbpointer;
  369. route_state.cbpointer = NULL;
  370. ocall_routing_round_complete(cbpointer, ROUND_1B);
  371. }
  372. }
  373. // A round 1b message was received from the next routing node
  374. static void round1b_next_received(NodeCommState &nodest,
  375. uint8_t *data, uint32_t plaintext_len, uint32_t)
  376. {
  377. uint16_t msg_size = g_teems_config.msg_size;
  378. assert((plaintext_len % uint32_t(msg_size)) == 0);
  379. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  380. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  381. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  382. pthread_mutex_lock(&route_state.round1b_next.mutex);
  383. route_state.round1b_next.inserted += num_msgs;
  384. route_state.round1b_next.nodes_received += 1;
  385. nodenum_t nodes_received = route_state.round1b_next.nodes_received;
  386. bool completed_prev_round = route_state.round1b_next.completed_prev_round;
  387. pthread_mutex_unlock(&route_state.round1b_next.mutex);
  388. pthread_mutex_lock(&route_state.round1b_prev.mutex);
  389. nodes_received += route_state.round1b_prev.nodes_received;
  390. pthread_mutex_unlock(&route_state.round1b_prev.mutex);
  391. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  392. uint32_t tot_enc_chunk_size) {
  393. return msgbuffer_get_buf(route_state.round1c, commst,
  394. tot_enc_chunk_size);
  395. };
  396. nodest.in_msg_received = round1c_received;
  397. nodenum_t my_node_num = g_teems_config.my_node_num;
  398. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  399. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  400. nodenum_t adjacent_nodes;
  401. if (num_routing_nodes == 1) {
  402. adjacent_nodes = 0;
  403. } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
  404. adjacent_nodes = 1;
  405. } else {
  406. adjacent_nodes = 2;
  407. }
  408. if (nodes_received == adjacent_nodes && completed_prev_round) {
  409. route_state.step = ROUTE_ROUND_1B;
  410. void *cbpointer = route_state.cbpointer;
  411. route_state.cbpointer = NULL;
  412. ocall_routing_round_complete(cbpointer, ROUND_1B);
  413. }
  414. }
  415. // Message received in round 1c
  416. static void round1c_received(NodeCommState &nodest, uint8_t *data,
  417. uint32_t plaintext_len, uint32_t)
  418. {
  419. uint16_t msg_size = g_teems_config.msg_size;
  420. assert((plaintext_len % uint32_t(msg_size)) == 0);
  421. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  422. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  423. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  424. pthread_mutex_lock(&route_state.round1c.mutex);
  425. route_state.round1c.inserted += num_msgs;
  426. route_state.round1c.nodes_received += 1;
  427. nodenum_t nodes_received = route_state.round1c.nodes_received;
  428. bool completed_prev_round = route_state.round1c.completed_prev_round;
  429. pthread_mutex_unlock(&route_state.round1c.mutex);
  430. // What is the next message we expect from this node?
  431. if (our_roles & ROLE_STORAGE) {
  432. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  433. uint32_t tot_enc_chunk_size) {
  434. return msgbuffer_get_buf(route_state.round2, commst,
  435. tot_enc_chunk_size);
  436. };
  437. nodest.in_msg_received = round2_received;
  438. } else if (their_roles & ROLE_INGESTION) {
  439. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  440. uint32_t tot_enc_chunk_size) {
  441. return msgbuffer_get_buf(route_state.round1, commst,
  442. tot_enc_chunk_size);
  443. };
  444. nodest.in_msg_received = round1_received;
  445. } else {
  446. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  447. uint32_t tot_enc_chunk_size) {
  448. return msgbuffer_get_buf(route_state.round1a, commst,
  449. tot_enc_chunk_size);
  450. };
  451. nodest.in_msg_received = round1a_received;
  452. }
  453. if (nodes_received == g_teems_config.num_routing_nodes &&
  454. completed_prev_round) {
  455. route_state.step = ROUTE_ROUND_1C;
  456. void *cbpointer = route_state.cbpointer;
  457. route_state.cbpointer = NULL;
  458. ocall_routing_round_complete(cbpointer, ROUND_1C);
  459. }
  460. }
  461. // A round 2 message was received by a storage node from a routing node
  462. static void round2_received(NodeCommState &nodest,
  463. uint8_t *data, uint32_t plaintext_len, uint32_t)
  464. {
  465. uint16_t msg_size = g_teems_config.msg_size;
  466. assert((plaintext_len % uint32_t(msg_size)) == 0);
  467. uint32_t num_msgs = plaintext_len / uint32_t(msg_size);
  468. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  469. uint8_t their_roles = g_teems_config.roles[nodest.node_num];
  470. pthread_mutex_lock(&route_state.round2.mutex);
  471. route_state.round2.inserted += num_msgs;
  472. route_state.round2.nodes_received += 1;
  473. nodenum_t nodes_received = route_state.round2.nodes_received;
  474. bool completed_prev_round = route_state.round2.completed_prev_round;
  475. pthread_mutex_unlock(&route_state.round2.mutex);
  476. // What is the next message we expect from this node?
  477. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  478. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  479. uint32_t tot_enc_chunk_size) {
  480. return msgbuffer_get_buf(route_state.round1, commst,
  481. tot_enc_chunk_size);
  482. };
  483. nodest.in_msg_received = round1_received;
  484. }
  485. // Otherwise, it's just the next round 2 message, so don't change
  486. // the handlers.
  487. if (nodes_received == g_teems_config.num_routing_nodes &&
  488. completed_prev_round) {
  489. route_state.step = ROUTE_ROUND_2;
  490. void *cbpointer = route_state.cbpointer;
  491. route_state.cbpointer = NULL;
  492. ocall_routing_round_complete(cbpointer, 2);
  493. }
  494. }
  495. // For a given other node, set the received message handler to the first
  496. // message we would expect from them, given their roles and our roles.
  497. void route_init_msg_handler(nodenum_t node_num)
  498. {
  499. // Our roles and their roles
  500. uint8_t our_roles = g_teems_config.roles[g_teems_config.my_node_num];
  501. uint8_t their_roles = g_teems_config.roles[node_num];
  502. // The node communication state
  503. NodeCommState &nodest = g_commstates[node_num];
  504. // If we are a routing node (possibly among other roles) and they
  505. // are an ingestion node (possibly among other roles), a round 1
  506. // routing message is the first thing we expect from them. We put
  507. // these messages into the round1 buffer for processing.
  508. if ((our_roles & ROLE_ROUTING) && (their_roles & ROLE_INGESTION)) {
  509. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  510. uint32_t tot_enc_chunk_size) {
  511. return msgbuffer_get_buf(route_state.round1, commst,
  512. tot_enc_chunk_size);
  513. };
  514. nodest.in_msg_received = round1_received;
  515. }
  516. // Otherwise, if we are a storage node (possibly among other roles)
  517. // and they are a routing node (possibly among other roles), a round
  518. // 2 routing message is the first thing we expect from them
  519. else if ((our_roles & ROLE_STORAGE) && (their_roles & ROLE_ROUTING)) {
  520. nodest.in_msg_get_buf = [&](NodeCommState &commst,
  521. uint32_t tot_enc_chunk_size) {
  522. return msgbuffer_get_buf(route_state.round2, commst,
  523. tot_enc_chunk_size);
  524. };
  525. nodest.in_msg_received = round2_received;
  526. }
  527. // Otherwise, we don't expect a message from this node. Set the
  528. // unknown message handler.
  529. else {
  530. nodest.in_msg_get_buf = default_in_msg_get_buf;
  531. nodest.in_msg_received = unknown_in_msg_received;
  532. }
  533. }
  534. // Directly ingest a buffer of num_msgs messages into the ingbuf buffer.
  535. // Return true on success, false on failure.
  536. bool ecall_ingest_raw(uint8_t *msgs, uint32_t num_msgs)
  537. {
  538. uint16_t msg_size = g_teems_config.msg_size;
  539. MsgBuffer &ingbuf = route_state.ingbuf;
  540. pthread_mutex_lock(&ingbuf.mutex);
  541. uint32_t start = ingbuf.reserved;
  542. if (start + num_msgs > route_state.tot_msg_per_ing) {
  543. pthread_mutex_unlock(&ingbuf.mutex);
  544. printf("Max %u messages exceeded\n",
  545. route_state.tot_msg_per_ing);
  546. return false;
  547. }
  548. ingbuf.reserved += num_msgs;
  549. pthread_mutex_unlock(&ingbuf.mutex);
  550. memmove(ingbuf.buf + start * msg_size,
  551. msgs, num_msgs * msg_size);
  552. pthread_mutex_lock(&ingbuf.mutex);
  553. ingbuf.inserted += num_msgs;
  554. pthread_mutex_unlock(&ingbuf.mutex);
  555. return true;
  556. }
  557. // Send messages round-robin, used in rounds 1 and 1c. Note that N here is not private.
  558. template<typename T>
  559. static void send_round_robin_msgs(MsgBuffer &round, const uint8_t *msgs, const T *indices,
  560. uint32_t N)
  561. {
  562. uint16_t msg_size = g_teems_config.msg_size;
  563. uint16_t tot_weight;
  564. tot_weight = g_teems_config.tot_weight;
  565. nodenum_t my_node_num = g_teems_config.my_node_num;
  566. uint32_t full_rows;
  567. uint32_t last_row;
  568. full_rows = N / uint32_t(tot_weight);
  569. last_row = N % uint32_t(tot_weight);
  570. for (auto &routing_node: g_teems_config.routing_nodes) {
  571. uint8_t weight = g_teems_config.weights[routing_node].weight;
  572. if (weight == 0) {
  573. // This shouldn't happen, but just in case
  574. continue;
  575. }
  576. uint16_t start_weight =
  577. g_teems_config.weights[routing_node].startweight;
  578. // The number of messages headed for this routing node from the
  579. // full rows
  580. uint32_t num_msgs_full_rows = full_rows * uint32_t(weight);
  581. // The number of messages headed for this routing node from the
  582. // incomplete last row is:
  583. // 0 if last_row < start_weight
  584. // last_row-start_weight if start_weight <= last_row < start_weight + weight
  585. // weight if start_weight + weight <= last_row
  586. uint32_t num_msgs_last_row = 0;
  587. if (start_weight <= last_row && last_row < start_weight + weight) {
  588. num_msgs_last_row = last_row-start_weight;
  589. } else if (start_weight + weight <= last_row) {
  590. num_msgs_last_row = weight;
  591. }
  592. // The total number of messages headed for this routing node
  593. uint32_t num_msgs = num_msgs_full_rows + num_msgs_last_row;
  594. if (routing_node == my_node_num) {
  595. // Special case: we're sending to ourselves; just put the
  596. // messages in our own buffer
  597. pthread_mutex_lock(&round.mutex);
  598. uint32_t start = round.reserved;
  599. if (start + num_msgs > round.bufsize) {
  600. pthread_mutex_unlock(&round.mutex);
  601. printf("Max %u messages exceeded\n", round.bufsize);
  602. return;
  603. }
  604. round.reserved += num_msgs;
  605. pthread_mutex_unlock(&round.mutex);
  606. uint8_t *buf = round.buf + start * msg_size;
  607. for (uint32_t i=0; i<full_rows; ++i) {
  608. const T *idxp = indices + i*tot_weight + start_weight;
  609. for (uint32_t j=0; j<weight; ++j) {
  610. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  611. buf += msg_size;
  612. }
  613. }
  614. const T *idxp = indices + full_rows*tot_weight + start_weight;
  615. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  616. memmove(buf, msgs + idxp[j].index()*msg_size, msg_size);
  617. buf += msg_size;
  618. }
  619. pthread_mutex_lock(&round.mutex);
  620. round.inserted += num_msgs;
  621. round.nodes_received += 1;
  622. pthread_mutex_unlock(&round.mutex);
  623. } else {
  624. NodeCommState &nodecom = g_commstates[routing_node];
  625. nodecom.message_start(num_msgs * msg_size);
  626. for (uint32_t i=0; i<full_rows; ++i) {
  627. const T *idxp = indices + i*tot_weight + start_weight;
  628. for (uint32_t j=0; j<weight; ++j) {
  629. nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
  630. }
  631. }
  632. const T *idxp = indices + full_rows*tot_weight + start_weight;
  633. for (uint32_t j=0; j<num_msgs_last_row; ++j) {
  634. nodecom.message_data(msgs + idxp[j].index()*msg_size, msg_size);
  635. }
  636. }
  637. }
  638. }
  639. // Send the round 1a messages from the round 1 buffer, which only occurs in public-channel routing.
  640. // msgs points to the message buffer, indices points to the the sorted indices, and N is the number
  641. // of non-padding items.
  642. static void send_round1a_msgs(const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
  643. uint16_t msg_size = g_teems_config.msg_size;
  644. nodenum_t my_node_num = g_teems_config.my_node_num;
  645. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  646. uint32_t min_msgs_per_node = route_state.max_round1a_msgs / num_routing_nodes;
  647. uint32_t extra_msgs = route_state.max_round1a_msgs % num_routing_nodes;
  648. for (auto &routing_node: g_teems_config.routing_nodes) {
  649. // In this unweighted setting, start_weight represents the position among routing nodes
  650. uint16_t prev_nodes = g_teems_config.weights[routing_node].startweight;
  651. uint32_t start_msg, num_msgs;
  652. if (prev_nodes >= extra_msgs) {
  653. start_msg = min_msgs_per_node * prev_nodes + extra_msgs;
  654. num_msgs = min_msgs_per_node;
  655. } else {
  656. start_msg = min_msgs_per_node * prev_nodes + prev_nodes;
  657. num_msgs = min_msgs_per_node + 1;
  658. }
  659. // take number of messages into account
  660. if (start_msg >= N) {
  661. num_msgs = 0;
  662. } else if (start_msg + num_msgs > N) {
  663. num_msgs = N - start_msg;
  664. }
  665. if (routing_node == my_node_num) {
  666. // Special case: we're sending to ourselves; just put the
  667. // messages in our own buffer
  668. MsgBuffer &round1a = route_state.round1a;
  669. pthread_mutex_lock(&round1a.mutex);
  670. uint32_t start = round1a.reserved;
  671. if (start + num_msgs > round1a.bufsize) {
  672. pthread_mutex_unlock(&round1a.mutex);
  673. printf("Max %u messages exceeded in round 1a\n", round1a.bufsize);
  674. return;
  675. }
  676. round1a.reserved += num_msgs;
  677. pthread_mutex_unlock(&round1a.mutex);
  678. uint8_t *buf = round1a.buf + start * msg_size;
  679. for (uint32_t i=0; i<num_msgs; ++i) {
  680. const UidPriorityKey *idxp = indices + start_msg + i;
  681. memmove(buf, msgs + idxp->index()*msg_size, msg_size);
  682. buf += msg_size;
  683. }
  684. pthread_mutex_lock(&round1a.mutex);
  685. round1a.inserted += num_msgs;
  686. round1a.nodes_received += 1;
  687. pthread_mutex_unlock(&round1a.mutex);
  688. } else {
  689. NodeCommState &nodecom = g_commstates[routing_node];
  690. nodecom.message_start(num_msgs * msg_size);
  691. for (uint32_t i=0; i<num_msgs; ++i) {
  692. const UidPriorityKey *idxp = indices + start_msg + i;
  693. nodecom.message_data(msgs + idxp->index()*msg_size, msg_size);
  694. }
  695. }
  696. }
  697. }
  698. // Send the round 1b messages from the round 1a buffer, which only occurs in public-channel routing.
  699. // msgs points to the message buffer, and N is the number of non-padding items.
  700. static void send_round1b_msgs(const uint8_t *msgs, uint32_t N) {
  701. uint16_t msg_size = g_teems_config.msg_size;
  702. nodenum_t my_node_num = g_teems_config.my_node_num;
  703. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  704. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  705. // send to previous node
  706. if (prev_nodes > 0) {
  707. nodenum_t prev_node = g_teems_config.routing_nodes[num_routing_nodes-1];
  708. NodeCommState &nodecom = g_commstates[prev_node];
  709. uint32_t num_msgs = min(N, route_state.max_round1b_msgs_to_adj_rtr);
  710. nodecom.message_start(num_msgs * msg_size);
  711. for (uint32_t i=0; i<num_msgs; ++i) {
  712. nodecom.message_data(msgs + i*msg_size, msg_size);
  713. }
  714. }
  715. // send to next node
  716. if (prev_nodes < num_routing_nodes-1) {
  717. nodenum_t next_node = g_teems_config.routing_nodes[1];
  718. NodeCommState &nodecom = g_commstates[next_node];
  719. if (N <= route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr) {
  720. // No messages to exchange with next node
  721. nodecom.message_start(0);
  722. // No need to call message_data()
  723. } else {
  724. uint32_t start_msg =
  725. route_state.max_round1a_msgs - route_state.max_round1b_msgs_to_adj_rtr;
  726. uint32_t num_msgs = N - start_msg;
  727. nodecom.message_start(num_msgs * msg_size);
  728. for (uint32_t i=0; i<num_msgs; ++i) {
  729. nodecom.message_data(msgs + i*msg_size, msg_size);
  730. }
  731. }
  732. }
  733. }
  734. // Send the round 2 messages from the previous-round buffer, which are already
  735. // padded and shuffled, so this can be done non-obliviously. tot_msgs
  736. // is the total number of messages in the input buffer, which may
  737. // include padding messages added by the shuffle. Those messages are
  738. // not sent anywhere. There are num_msgs_per_stg messages for each
  739. // storage node labelled for that node.
  740. static void send_round2_msgs(uint32_t tot_msgs, uint32_t num_msgs_per_stg, MsgBuffer &prevround)
  741. {
  742. uint16_t msg_size = g_teems_config.msg_size;
  743. const uint8_t* buf = prevround.buf;
  744. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  745. nodenum_t my_node_num = g_teems_config.my_node_num;
  746. uint8_t *myself_buf = NULL;
  747. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  748. nodenum_t node = g_teems_config.storage_nodes[i];
  749. if (node != my_node_num) {
  750. g_commstates[node].message_start(msg_size * num_msgs_per_stg);
  751. } else {
  752. MsgBuffer &round2 = route_state.round2;
  753. pthread_mutex_lock(&round2.mutex);
  754. uint32_t start = round2.reserved;
  755. if (start + num_msgs_per_stg > round2.bufsize) {
  756. pthread_mutex_unlock(&round2.mutex);
  757. printf("Max %u messages exceeded\n", round2.bufsize);
  758. return;
  759. }
  760. round2.reserved += num_msgs_per_stg;
  761. pthread_mutex_unlock(&round2.mutex);
  762. myself_buf = round2.buf + start * msg_size;
  763. }
  764. }
  765. while (tot_msgs) {
  766. nodenum_t storage_node_id =
  767. nodenum_t((*(const uint32_t *)buf)>>DEST_UID_BITS);
  768. if (storage_node_id < num_storage_nodes) {
  769. nodenum_t node = g_teems_config.storage_map[storage_node_id];
  770. if (node == my_node_num) {
  771. memmove(myself_buf, buf, msg_size);
  772. myself_buf += msg_size;
  773. } else {
  774. g_commstates[node].message_data(buf, msg_size);
  775. }
  776. }
  777. buf += msg_size;
  778. --tot_msgs;
  779. }
  780. if (myself_buf) {
  781. MsgBuffer &round2 = route_state.round2;
  782. pthread_mutex_lock(&round2.mutex);
  783. round2.inserted += num_msgs_per_stg;
  784. round2.nodes_received += 1;
  785. pthread_mutex_unlock(&round2.mutex);
  786. }
  787. }
  788. static void round1a_processing(void *cbpointer) {
  789. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  790. MsgBuffer &round1 = route_state.round1;
  791. if (my_roles & ROLE_ROUTING) {
  792. route_state.cbpointer = cbpointer;
  793. pthread_mutex_lock(&round1.mutex);
  794. // Ensure there are no pending messages currently being inserted
  795. // into the buffer
  796. while (round1.reserved != round1.inserted) {
  797. pthread_mutex_unlock(&round1.mutex);
  798. pthread_mutex_lock(&round1.mutex);
  799. }
  800. #ifdef PROFILE_ROUTING
  801. uint32_t inserted = round1.inserted;
  802. unsigned long start_round1a = printf_with_rtclock("begin round1a processing (%u)\n", inserted);
  803. // Sort the messages we've received
  804. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
  805. #endif
  806. // Sort received messages by increasing user ID and
  807. // priority. Smaller priority number indicates higher priority.
  808. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1.buf,
  809. g_teems_config.msg_size, round1.inserted, route_state.max_round1_msgs,
  810. send_round1a_msgs);
  811. #ifdef PROFILE_ROUTING
  812. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1_msgs);
  813. printf_with_rtclock_diff(start_round1a, "end round1a processing (%u)\n", inserted);
  814. #endif
  815. round1.reset();
  816. pthread_mutex_unlock(&round1.mutex);
  817. MsgBuffer &round1a = route_state.round1a;
  818. pthread_mutex_lock(&round1a.mutex);
  819. round1a.completed_prev_round = true;
  820. nodenum_t nodes_received = round1a.nodes_received;
  821. pthread_mutex_unlock(&round1a.mutex);
  822. if (nodes_received == g_teems_config.num_routing_nodes) {
  823. route_state.step = ROUTE_ROUND_1A;
  824. route_state.cbpointer = NULL;
  825. ocall_routing_round_complete(cbpointer, ROUND_1A);
  826. }
  827. } else {
  828. route_state.step = ROUTE_ROUND_1A;
  829. route_state.round1a.completed_prev_round = true;
  830. ocall_routing_round_complete(cbpointer, ROUND_1A);
  831. }
  832. }
  833. static void round1b_processing(void *cbpointer) {
  834. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  835. nodenum_t my_node_num = g_teems_config.my_node_num;
  836. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  837. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  838. MsgBuffer &round1a = route_state.round1a;
  839. MsgBuffer &round1a_sorted = route_state.round1a_sorted;
  840. if (my_roles & ROLE_ROUTING) {
  841. route_state.cbpointer = cbpointer;
  842. pthread_mutex_lock(&round1a.mutex);
  843. // Ensure there are no pending messages currently being inserted
  844. // into the buffer
  845. while (round1a.reserved != round1a.inserted) {
  846. pthread_mutex_unlock(&round1a.mutex);
  847. pthread_mutex_lock(&round1a.mutex);
  848. }
  849. pthread_mutex_lock(&round1a_sorted.mutex);
  850. #ifdef PROFILE_ROUTING
  851. uint32_t inserted = round1a.inserted;
  852. unsigned long start_round1b = printf_with_rtclock("begin round1b processing (%u)\n", inserted);
  853. // Sort the messages we've received
  854. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
  855. #endif
  856. // Sort received messages by increasing user ID and
  857. // priority. Smaller priority number indicates higher priority.
  858. if (inserted > 0) {
  859. // copy items in sorted order into round1a_sorted
  860. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a.buf,
  861. g_teems_config.msg_size, round1a.inserted, route_state.max_round1a_msgs,
  862. route_state.round1a_sorted.buf);
  863. send_round1b_msgs(round1a_sorted.buf, round1a.inserted);
  864. } else {
  865. send_round1b_msgs(NULL, 0);
  866. }
  867. #ifdef PROFILE_ROUTING
  868. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.max_round1a_msgs);
  869. printf_with_rtclock_diff(start_round1b, "end round1b processing (%u)\n", inserted);
  870. #endif
  871. pthread_mutex_unlock(&round1a_sorted.mutex);
  872. pthread_mutex_unlock(&round1a.mutex);
  873. MsgBuffer &round1b_prev = route_state.round1b_prev;
  874. pthread_mutex_lock(&round1b_prev.mutex);
  875. round1b_prev.completed_prev_round = true;
  876. nodenum_t nodes_received = round1b_prev.nodes_received;
  877. pthread_mutex_unlock(&round1b_prev.mutex);
  878. MsgBuffer &round1b_next = route_state.round1b_next;
  879. pthread_mutex_lock(&round1b_next.mutex);
  880. round1b_next.completed_prev_round = true;
  881. nodes_received += round1b_next.nodes_received;
  882. pthread_mutex_unlock(&round1b_next.mutex);
  883. nodenum_t adjacent_nodes;
  884. if (num_routing_nodes == 1) {
  885. adjacent_nodes = 0;
  886. } else if ((prev_nodes == 0) || (prev_nodes == num_routing_nodes-1)) {
  887. adjacent_nodes = 1;
  888. } else {
  889. adjacent_nodes = 2;
  890. }
  891. if (nodes_received == adjacent_nodes) {
  892. route_state.step = ROUTE_ROUND_1B;
  893. route_state.cbpointer = NULL;
  894. ocall_routing_round_complete(cbpointer, ROUND_1B);
  895. }
  896. } else {
  897. route_state.step = ROUTE_ROUND_1B;
  898. route_state.round1b_prev.completed_prev_round = true;
  899. route_state.round1b_next.completed_prev_round = true;
  900. ocall_routing_round_complete(cbpointer, ROUND_1B);
  901. }
  902. }
  903. static void copy_msgs(uint8_t *dst, uint32_t start_msg, uint32_t num_copy, const uint8_t *src,
  904. const UidPriorityKey *indices)
  905. {
  906. uint16_t msg_size = g_teems_config.msg_size;
  907. const UidPriorityKey *idxp = indices + start_msg;
  908. uint8_t *buf = dst;
  909. for (uint32_t i=0; i<num_copy; i++) {
  910. memmove(buf, src + idxp[i].index()*msg_size, msg_size);
  911. buf += msg_size;
  912. }
  913. }
  914. static void round1c_processing(void *cbpointer) {
  915. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  916. nodenum_t my_node_num = g_teems_config.my_node_num;
  917. nodenum_t num_routing_nodes = g_teems_config.num_routing_nodes;
  918. uint16_t prev_nodes = g_teems_config.weights[my_node_num].startweight;
  919. uint16_t msg_size = g_teems_config.msg_size;
  920. uint32_t max_round1b_msgs_to_adj_rtr = route_state.max_round1b_msgs_to_adj_rtr;
  921. uint32_t max_round1a_msgs = route_state.max_round1a_msgs;
  922. MsgBuffer &round1a = route_state.round1a;
  923. MsgBuffer &round1a_sorted = route_state.round1a_sorted;
  924. MsgBuffer &round1b_prev = route_state.round1b_prev;
  925. MsgBuffer &round1b_next = route_state.round1b_next;
  926. if (my_roles & ROLE_ROUTING) {
  927. route_state.cbpointer = cbpointer;
  928. pthread_mutex_lock(&round1b_prev.mutex);
  929. pthread_mutex_lock(&round1b_next.mutex);
  930. // Ensure there are no pending messages currently being inserted
  931. // into the round 1b buffers
  932. while (round1b_prev.reserved != round1b_prev.inserted) {
  933. pthread_mutex_unlock(&round1b_prev.mutex);
  934. pthread_mutex_lock(&round1b_prev.mutex);
  935. }
  936. while (round1b_next.reserved != round1b_next.inserted) {
  937. pthread_mutex_unlock(&round1b_next.mutex);
  938. pthread_mutex_lock(&round1b_next.mutex);
  939. }
  940. pthread_mutex_lock(&round1a.mutex);
  941. pthread_mutex_lock(&round1a_sorted.mutex);
  942. #ifdef PROFILE_ROUTING
  943. unsigned long start_round1c = printf_with_rtclock("begin round1c processing (%u)\n", round1a.inserted);
  944. #endif
  945. // sort round1b_prev msgs with initial msgs in round1a_sorted
  946. if (prev_nodes > 0) {
  947. // Copy initial msgs in round1a_sorted to round1b_prev buffer for sorting
  948. // Note that all inserted values and buffer sizes are non-secret
  949. uint32_t num_init_round1a = min(round1a.inserted,
  950. max_round1b_msgs_to_adj_rtr);
  951. uint32_t num_round1b_prev = round1b_prev.inserted;
  952. if (num_round1b_prev + num_init_round1a <= max_round1b_msgs_to_adj_rtr) {
  953. // all our round 1a messages "belong" to previous router and can be removed here
  954. round1a.inserted = 0;
  955. } else {
  956. // copy initial round1a msgs after round1b_prev msgs
  957. memmove(round1b_prev.buf+num_round1b_prev*msg_size, round1a_sorted.buf,
  958. num_init_round1a*msg_size);
  959. // sort and take final msgs as initial round1a msgs
  960. #ifdef PROFILE_ROUTING
  961. unsigned long start_sort = printf_with_rtclock("begin round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
  962. #endif
  963. uint32_t num_copy = num_round1b_prev+num_init_round1a-max_round1b_msgs_to_adj_rtr;
  964. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_prev.buf,
  965. msg_size, num_round1b_prev + num_init_round1a,
  966. 2*max_round1b_msgs_to_adj_rtr,
  967. [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
  968. return copy_msgs(round1a_sorted.buf, max_round1b_msgs_to_adj_rtr,
  969. num_copy, src, indices);
  970. }
  971. );
  972. round1a.inserted -= (max_round1b_msgs_to_adj_rtr-num_round1b_prev);
  973. #ifdef PROFILE_ROUTING
  974. printf_with_rtclock_diff(start_sort, "end round1b_prev oblivious sort (%u,%u)\n", num_round1b_prev + num_init_round1a, 2*max_round1b_msgs_to_adj_rtr);
  975. #endif
  976. }
  977. }
  978. // sort round1b_next msgs with final msgs in round1a_sorted
  979. if ((prev_nodes < num_routing_nodes-1) && (round1b_next.inserted > 0)) {
  980. // Copy final msgs in round1a_sorted to round1b_next buffer for sorting
  981. // Note that all inserted values and buffer sizes are non-secret
  982. // round1b_next.inserted>0, so round1a >= max_round1a_msgs-max_round1b_msgs_to_adj_rtr
  983. uint32_t round1a_msg_start = max_round1a_msgs-max_round1b_msgs_to_adj_rtr;
  984. uint32_t num_final_round1a = round1a.inserted - round1a_msg_start;
  985. uint32_t num_round1b_next = round1b_next.inserted;
  986. memmove(round1b_next.buf+num_round1b_next*msg_size,
  987. round1a_sorted.buf + round1a_msg_start*msg_size,
  988. num_final_round1a*msg_size);
  989. // sort and take initial msgs as final round1a msgs
  990. #ifdef PROFILE_ROUTING
  991. unsigned long start_sort = printf_with_rtclock("begin round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
  992. #endif
  993. uint32_t num_copy = min(num_final_round1a+num_round1b_next,
  994. max_round1b_msgs_to_adj_rtr);
  995. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1b_next.buf,
  996. msg_size, num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr,
  997. [&](const uint8_t *src, const UidPriorityKey *indices, uint32_t Nr) {
  998. return copy_msgs(round1a_sorted.buf + round1a_msg_start*msg_size, 0,
  999. num_copy, src, indices);
  1000. }
  1001. );
  1002. round1a.inserted += (num_copy - num_final_round1a);
  1003. #ifdef PROFILE_ROUTING
  1004. printf_with_rtclock_diff(start_sort, "end round1b_next oblivious sort (%u,%u)\n", num_round1b_next + num_final_round1a, 2*max_round1b_msgs_to_adj_rtr);
  1005. #endif
  1006. }
  1007. #ifdef PROFILE_ROUTING
  1008. unsigned long start_sort = printf_with_rtclock("begin full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
  1009. #endif
  1010. // Sort received messages by increasing user ID and
  1011. // priority. Smaller priority number indicates higher priority.
  1012. if (round1a.inserted > 0) {
  1013. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, round1a_sorted.buf,
  1014. msg_size, round1a.inserted, route_state.max_round1a_msgs,
  1015. [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
  1016. send_round_robin_msgs<UidPriorityKey>(route_state.round1c, msgs, indices, N);
  1017. });
  1018. } else {
  1019. send_round_robin_msgs<UidPriorityKey>(route_state.round1c, NULL, NULL, 0);
  1020. }
  1021. #ifdef PROFILE_ROUTING
  1022. printf_with_rtclock_diff(start_sort, "end full oblivious sort (%u,%u)\n", round1a.inserted, route_state.max_round1a_msgs);
  1023. printf_with_rtclock_diff(start_round1c, "end round1c processing (%u)\n", round1a.inserted);
  1024. #endif
  1025. round1a.reset();
  1026. round1a_sorted.reset();
  1027. round1b_prev.reset();
  1028. round1b_next.reset();
  1029. pthread_mutex_unlock(&round1a_sorted.mutex);
  1030. pthread_mutex_unlock(&round1a.mutex);
  1031. pthread_mutex_unlock(&round1b_next.mutex);
  1032. pthread_mutex_unlock(&round1b_prev.mutex);
  1033. MsgBuffer &round1c = route_state.round1c;
  1034. pthread_mutex_lock(&round1c.mutex);
  1035. round1c.completed_prev_round = true;
  1036. nodenum_t nodes_received = round1c.nodes_received;
  1037. pthread_mutex_unlock(&round1c.mutex);
  1038. if (nodes_received == num_routing_nodes) {
  1039. route_state.step = ROUTE_ROUND_1C;
  1040. route_state.cbpointer = NULL;
  1041. ocall_routing_round_complete(cbpointer, ROUND_1C);
  1042. }
  1043. } else {
  1044. route_state.step = ROUTE_ROUND_1C;
  1045. route_state.round1c.completed_prev_round = true;
  1046. ocall_routing_round_complete(cbpointer, ROUND_1C);
  1047. }
  1048. }
  1049. // Process messages in round 2
  1050. static void round2_processing(uint8_t my_roles, void *cbpointer, MsgBuffer &prevround) {
  1051. if (my_roles & ROLE_ROUTING) {
  1052. route_state.cbpointer = cbpointer;
  1053. pthread_mutex_lock(&prevround.mutex);
  1054. // Ensure there are no pending messages currently being inserted
  1055. // into the buffer
  1056. while (prevround.reserved != prevround.inserted) {
  1057. pthread_mutex_unlock(&prevround.mutex);
  1058. pthread_mutex_lock(&prevround.mutex);
  1059. }
  1060. // If the _total_ number of messages we received in round 1
  1061. // is less than the max number of messages we could send to
  1062. // _each_ storage node, then cap the number of messages we
  1063. // will send to each storage node to that number.
  1064. uint32_t msgs_per_stg = route_state.max_msg_to_each_stg;
  1065. if (prevround.inserted < msgs_per_stg) {
  1066. msgs_per_stg = prevround.inserted;
  1067. }
  1068. // Note: at this point, it is required that each message in
  1069. // the prevround buffer have a _valid_ storage node id field.
  1070. // Obliviously tally the number of messages we received in
  1071. // the previous round destined for each storage node
  1072. #ifdef PROFILE_ROUTING
  1073. unsigned long start_round2 = printf_with_rtclock("begin round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1074. unsigned long start_tally = printf_with_rtclock("begin tally (%u)\n", prevround.inserted);
  1075. #endif
  1076. uint16_t msg_size = g_teems_config.msg_size;
  1077. nodenum_t num_storage_nodes = g_teems_config.num_storage_nodes;
  1078. std::vector<uint32_t> tally = obliv_tally_stg(
  1079. prevround.buf, msg_size, prevround.inserted, num_storage_nodes);
  1080. #ifdef PROFILE_ROUTING
  1081. printf_with_rtclock_diff(start_tally, "end tally (%u)\n", prevround.inserted);
  1082. #endif
  1083. // Note: tally contains private values! It's OK to
  1084. // non-obliviously check for an error condition, though.
  1085. // While we're at it, obliviously change the tally of
  1086. // messages received to a tally of padding messages
  1087. // required.
  1088. uint32_t tot_padding = 0;
  1089. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  1090. if (tally[i] > msgs_per_stg) {
  1091. printf("Received too many messages for storage node %u\n", i);
  1092. assert(tally[i] <= msgs_per_stg);
  1093. }
  1094. tally[i] = msgs_per_stg - tally[i];
  1095. tot_padding += tally[i];
  1096. }
  1097. prevround.reserved += tot_padding;
  1098. assert(prevround.reserved <= prevround.bufsize);
  1099. // Obliviously add padding for each storage node according
  1100. // to the (private) padding tally.
  1101. #ifdef PROFILE_ROUTING
  1102. unsigned long start_pad = printf_with_rtclock("begin pad (%u)\n", tot_padding);
  1103. #endif
  1104. obliv_pad_stg(prevround.buf + prevround.inserted * msg_size,
  1105. msg_size, tally, tot_padding);
  1106. #ifdef PROFILE_ROUTING
  1107. printf_with_rtclock_diff(start_pad, "end pad (%u)\n", tot_padding);
  1108. #endif
  1109. prevround.inserted += tot_padding;
  1110. // Obliviously shuffle the messages
  1111. #ifdef PROFILE_ROUTING
  1112. unsigned long start_shuffle = printf_with_rtclock("begin shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1113. #endif
  1114. uint32_t num_shuffled = shuffle_mtobliv(g_teems_config.nthreads,
  1115. prevround.buf, msg_size, prevround.inserted, prevround.bufsize);
  1116. #ifdef PROFILE_ROUTING
  1117. printf_with_rtclock_diff(start_shuffle, "end shuffle (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1118. printf_with_rtclock_diff(start_round2, "end round2 processing (%u,%u)\n", prevround.inserted, prevround.bufsize);
  1119. #endif
  1120. // Now we can handle the messages non-obliviously, since we
  1121. // know there will be exactly msgs_per_stg messages to each
  1122. // storage node, and the oblivious shuffle broke the
  1123. // connection between where each message came from and where
  1124. // it's going.
  1125. send_round2_msgs(num_shuffled, msgs_per_stg, prevround);
  1126. prevround.reset();
  1127. pthread_mutex_unlock(&prevround.mutex);
  1128. }
  1129. if (my_roles & ROLE_STORAGE) {
  1130. route_state.cbpointer = cbpointer;
  1131. MsgBuffer &round2 = route_state.round2;
  1132. pthread_mutex_lock(&round2.mutex);
  1133. round2.completed_prev_round = true;
  1134. nodenum_t nodes_received = round2.nodes_received;
  1135. pthread_mutex_unlock(&round2.mutex);
  1136. if (nodes_received == g_teems_config.num_routing_nodes) {
  1137. route_state.step = ROUTE_ROUND_2;
  1138. route_state.cbpointer = NULL;
  1139. ocall_routing_round_complete(cbpointer, 2);
  1140. }
  1141. } else {
  1142. route_state.step = ROUTE_ROUND_2;
  1143. route_state.round2.completed_prev_round = true;
  1144. ocall_routing_round_complete(cbpointer, 2);
  1145. }
  1146. }
  1147. // Perform the next round of routing. The callback pointer will be
  1148. // passed to ocall_routing_round_complete when the round is complete.
  1149. void ecall_routing_proceed(void *cbpointer)
  1150. {
  1151. uint8_t my_roles = g_teems_config.roles[g_teems_config.my_node_num];
  1152. if (route_state.step == ROUTE_NOT_STARTED) {
  1153. if (my_roles & ROLE_INGESTION) {
  1154. ingestion_epoch++;
  1155. route_state.cbpointer = cbpointer;
  1156. MsgBuffer &ingbuf = route_state.ingbuf;
  1157. pthread_mutex_lock(&ingbuf.mutex);
  1158. // Ensure there are no pending messages currently being inserted
  1159. // into the buffer
  1160. while (ingbuf.reserved != ingbuf.inserted) {
  1161. pthread_mutex_unlock(&ingbuf.mutex);
  1162. pthread_mutex_lock(&ingbuf.mutex);
  1163. }
  1164. // Sort the messages we've received
  1165. #ifdef PROFILE_ROUTING
  1166. uint32_t inserted = ingbuf.inserted;
  1167. unsigned long start_round1 = printf_with_rtclock("begin round1 processing (%u)\n", inserted);
  1168. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  1169. #endif
  1170. if (g_teems_config.private_routing) {
  1171. sort_mtobliv<UidKey>(g_teems_config.nthreads, ingbuf.buf,
  1172. g_teems_config.msg_size, ingbuf.inserted,
  1173. route_state.tot_msg_per_ing,
  1174. [&](const uint8_t *msgs, const UidKey *indices, uint32_t N) {
  1175. send_round_robin_msgs<UidKey>(route_state.round1, msgs, indices, N);
  1176. });
  1177. } else {
  1178. // Sort received messages by increasing user ID and
  1179. // priority. Smaller priority number indicates higher priority.
  1180. sort_mtobliv<UidPriorityKey>(g_teems_config.nthreads, ingbuf.buf,
  1181. g_teems_config.msg_size, ingbuf.inserted, route_state.tot_msg_per_ing,
  1182. [&](const uint8_t *msgs, const UidPriorityKey *indices, uint32_t N) {
  1183. send_round_robin_msgs<UidPriorityKey>(route_state.round1, msgs, indices, N);
  1184. });
  1185. }
  1186. #ifdef PROFILE_ROUTING
  1187. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u,%u)\n", inserted, route_state.tot_msg_per_ing);
  1188. printf_with_rtclock_diff(start_round1, "end round1 processing (%u)\n", inserted);
  1189. #endif
  1190. ingbuf.reset();
  1191. pthread_mutex_unlock(&ingbuf.mutex);
  1192. }
  1193. if (my_roles & ROLE_ROUTING) {
  1194. MsgBuffer &round1 = route_state.round1;
  1195. pthread_mutex_lock(&round1.mutex);
  1196. round1.completed_prev_round = true;
  1197. nodenum_t nodes_received = round1.nodes_received;
  1198. pthread_mutex_unlock(&round1.mutex);
  1199. if (nodes_received == g_teems_config.num_ingestion_nodes) {
  1200. route_state.step = ROUTE_ROUND_1;
  1201. route_state.cbpointer = NULL;
  1202. ocall_routing_round_complete(cbpointer, 1);
  1203. }
  1204. } else {
  1205. route_state.step = ROUTE_ROUND_1;
  1206. route_state.round1.completed_prev_round = true;
  1207. ocall_routing_round_complete(cbpointer, 1);
  1208. }
  1209. } else if (route_state.step == ROUTE_ROUND_1) {
  1210. if (g_teems_config.private_routing) { // private routing next round
  1211. round2_processing(my_roles, cbpointer, route_state.round1);
  1212. } else { // public routing next round
  1213. round1a_processing(cbpointer);
  1214. }
  1215. } else if (route_state.step == ROUTE_ROUND_1A) {
  1216. round1b_processing(cbpointer);
  1217. } else if (route_state.step == ROUTE_ROUND_1B) {
  1218. round1c_processing(cbpointer);
  1219. } else if (route_state.step == ROUTE_ROUND_1C) {
  1220. round2_processing(my_roles, cbpointer, route_state.round1c);
  1221. } else if (route_state.step == ROUTE_ROUND_2) {
  1222. if (my_roles & ROLE_STORAGE) {
  1223. MsgBuffer &round2 = route_state.round2;
  1224. pthread_mutex_lock(&round2.mutex);
  1225. // Ensure there are no pending messages currently being inserted
  1226. // into the buffer
  1227. while (round2.reserved != round2.inserted) {
  1228. pthread_mutex_unlock(&round2.mutex);
  1229. pthread_mutex_lock(&round2.mutex);
  1230. }
  1231. unsigned long start = printf_with_rtclock("begin storage processing (%u)\n", round2.inserted);
  1232. storage_received(round2);
  1233. printf_with_rtclock_diff(start, "end storage processing (%u)\n", round2.inserted);
  1234. // We're done
  1235. route_state.step = ROUTE_NOT_STARTED;
  1236. ocall_routing_round_complete(cbpointer, 0);
  1237. } else {
  1238. // We're done
  1239. route_state.step = ROUTE_NOT_STARTED;
  1240. ocall_routing_round_complete(cbpointer, 0);
  1241. }
  1242. }
  1243. }