storage.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. #include "utils.hpp"
  2. #include "config.hpp"
  3. #include "ORExpand.hpp"
  4. #include "sort.hpp"
  5. #include "storage.hpp"
  6. #define PROFILE_STORAGE
  7. static struct {
  8. uint32_t max_users;
  9. uint32_t my_storage_node_id;
  10. // A local storage buffer, used when we need to do non-in-place
  11. // sorts of the messages that have arrived
  12. MsgBuffer stg_buf;
  13. // The destination vector for ORExpand
  14. std::vector<uint32_t> dest;
  15. // The selected array for compaction during public routing
  16. // Need an bool array for compaction, and std:vector<bool> lacks .data()
  17. bool *pub_selected;
  18. } storage_state;
  19. // route_init will call this function; no one else should call it
  20. // explicitly. The parameter is the number of messages that can fit in
  21. // the storage-side MsgBuffer. Returns true on success, false on
  22. // failure.
  23. bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
  24. {
  25. storage_state.max_users = max_users;
  26. storage_state.stg_buf.alloc(msg_buf_size);
  27. storage_state.dest.resize(msg_buf_size);
  28. storage_state.pub_selected = new bool[msg_buf_size];
  29. uint32_t my_storage_node_id = 0;
  30. for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
  31. if (g_teems_config.roles[i] & ROLE_STORAGE) {
  32. if (i == g_teems_config.my_node_num) {
  33. storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
  34. } else {
  35. ++my_storage_node_id;
  36. }
  37. }
  38. }
  39. return true;
  40. }
  41. void storage_close() {
  42. delete[] storage_state.pub_selected;
  43. }
  44. // Handle the messages received by a storage node. Pass a _locked_
  45. // MsgBuffer. This function will itself reset and unlock it when it's
  46. // done with it.
  47. void storage_received(MsgBuffer &storage_buf)
  48. {
  49. uint16_t msg_size = g_teems_config.msg_size;
  50. nodenum_t my_node_num = g_teems_config.my_node_num;
  51. const uint8_t *msgs = storage_buf.buf;
  52. uint32_t num_msgs = storage_buf.inserted;
  53. uint32_t real = 0, padding = 0;
  54. uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
  55. uint32_t nid_mask = ~uid_mask;
  56. #ifdef PROFILE_STORAGE
  57. unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
  58. #endif
  59. // It's OK to test for errors in a way that's non-oblivous if
  60. // there's an error (but it should be oblivious if there are no
  61. // errors)
  62. for (uint32_t i=0; i<num_msgs; ++i) {
  63. uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
  64. bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
  65. & ((uid & uid_mask) < storage_state.max_users))
  66. | ((uid & uid_mask) == uid_mask));
  67. if (!ok) {
  68. printf("Received bad uid: %08x\n", uid);
  69. assert(ok);
  70. }
  71. }
  72. // Testing: report how many real and dummy messages arrived
  73. printf("Storage server received %u messages:\n", num_msgs);
  74. for (uint32_t i=0; i<num_msgs; ++i) {
  75. uint32_t dest_addr = *(const uint32_t*)msgs;
  76. nodenum_t dest_node =
  77. g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
  78. if (dest_node != my_node_num) {
  79. char hexbuf[2*msg_size + 1];
  80. for (uint32_t j=0;j<msg_size;++j) {
  81. snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
  82. }
  83. printf("Misrouted message: %s\n", hexbuf);
  84. } else if ((dest_addr & uid_mask) == uid_mask) {
  85. ++padding;
  86. } else {
  87. ++real;
  88. }
  89. msgs += msg_size;
  90. }
  91. printf("%u real, %u padding\n", real, padding);
  92. /*
  93. for (uint32_t i=0;i<num_msgs; ++i) {
  94. printf("%3d: %08x %08x\n", i,
  95. *(uint32_t*)(storage_buf.buf+(i*msg_size)),
  96. *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
  97. }
  98. */
  99. // Sort the received messages by userid into the
  100. // storage_state.stg_buf MsgBuffer.
  101. #ifdef PROFILE_STORAGE
  102. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
  103. #endif
  104. sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
  105. msg_size, storage_buf.inserted, storage_buf.bufsize,
  106. storage_state.stg_buf.buf);
  107. #ifdef PROFILE_STORAGE
  108. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
  109. #endif
  110. // For public routing, remove excess per-user messages by making them
  111. // padding, and then compact non-padding messages.
  112. if (!g_teems_config.private_routing) {
  113. uint8_t *msg = storage_state.stg_buf.buf;
  114. uint32_t uid;
  115. uint32_t prev_uid = uid_mask; // initialization technically unnecessary
  116. uint32_t num_user_msgs = 0; // number of messages to the user
  117. uint8_t sel;
  118. for (uint32_t i=0; i<num_msgs; ++i) {
  119. uid = (*(uint32_t*) msg) &= uid_mask;
  120. num_user_msgs = oselect_uint32_t(1, num_user_msgs+1,
  121. uid == prev_uid);
  122. // Select if messages per user not exceeded and msg is not padding
  123. sel = ((uint8_t) ((num_user_msgs <= g_teems_config.m_pub_in))) &
  124. ((uint8_t) uid != uid_mask);
  125. storage_state.pub_selected[i] = (bool) sel;
  126. // Make padding if not selected
  127. *(uint32_t *) msg = (*(uint32_t *) msg) & nid_mask;
  128. *(uint32_t *) msg += oselect_uint32_t(uid_mask, uid, sel);
  129. msg += msg_size;
  130. prev_uid = uid;
  131. }
  132. #ifdef PROFILE_STORAGE
  133. unsigned long start_compaction = printf_with_rtclock("begin public-channel compaction (%u)\n", num_msgs);
  134. #endif
  135. TightCompact_parallel<OSWAP_16X>(
  136. (unsigned char *) storage_state.stg_buf.buf,
  137. num_msgs, msg_size, storage_state.pub_selected,
  138. g_teems_config.nthreads);
  139. #ifdef PROFILE_STORAGE
  140. printf_with_rtclock_diff(start_compaction, "end public-channel compaction (%u)\n", num_msgs);
  141. #endif
  142. }
  143. /*
  144. for (uint32_t i=0;i<num_msgs; ++i) {
  145. printf("%3d: %08x %08x\n", i,
  146. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  147. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
  148. }
  149. */
  150. #ifdef PROFILE_STORAGE
  151. unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
  152. #endif
  153. // Obliviously set the dest array
  154. uint32_t *dests = storage_state.dest.data();
  155. uint32_t stg_size = storage_state.stg_buf.bufsize;
  156. const uint8_t *buf = storage_state.stg_buf.buf;
  157. uint32_t m_priv_in = g_teems_config.m_priv_in;
  158. uint32_t uid = *(uint32_t*)(buf);
  159. uid &= uid_mask;
  160. // num_msgs is not a private value
  161. if (num_msgs > 0) {
  162. dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
  163. uid == uid_mask);
  164. }
  165. uint32_t prev_uid = uid;
  166. for (uint32_t i=1; i<num_msgs; ++i) {
  167. uid = *(uint32_t*)(buf + i*msg_size);
  168. uid &= uid_mask;
  169. uint32_t next = oselect_uint32_t(uid * m_priv_in, dests[i-1]+1,
  170. uid == prev_uid);
  171. dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
  172. prev_uid = uid;
  173. }
  174. for (uint32_t i=num_msgs; i<stg_size; ++i) {
  175. dests[i] = 0xffffffff;
  176. *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
  177. }
  178. #ifdef PROFILE_STORAGE
  179. printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
  180. #endif
  181. /*
  182. for (uint32_t i=0;i<stg_size; ++i) {
  183. printf("%3d: %08x %08x %u\n", i,
  184. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  185. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  186. dests[i]);
  187. }
  188. */
  189. #ifdef PROFILE_STORAGE
  190. unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
  191. #endif
  192. ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
  193. msg_size, stg_size, g_teems_config.nthreads);
  194. #ifdef PROFILE_STORAGE
  195. printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
  196. #endif
  197. /*
  198. for (uint32_t i=0;i<stg_size; ++i) {
  199. printf("%3d: %08x %08x %u\n", i,
  200. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  201. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  202. dests[i]);
  203. }
  204. */
  205. // You can do more processing after these lines, as long as they
  206. // don't touch storage_buf. They _can_ touch the backing buffer
  207. // storage_state.stg_buf.
  208. storage_buf.reset();
  209. pthread_mutex_unlock(&storage_buf.mutex);
  210. storage_state.stg_buf.reset();
  211. #ifdef PROFILE_STORAGE
  212. printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
  213. #endif
  214. }