storage.cpp 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #include "utils.hpp"
  2. #include "config.hpp"
  3. #include "ORExpand.hpp"
  4. #include "sort.hpp"
  5. #include "storage.hpp"
  6. #include "client.hpp"
  7. #define PROFILE_STORAGE
  8. StgClient *clients;
  9. static struct {
  10. uint32_t max_users;
  11. uint32_t my_storage_node_id;
  12. // A local storage buffer, used when we need to do non-in-place
  13. // sorts of the messages that have arrived
  14. MsgBuffer stg_buf;
  15. // The destination vector for ORExpand
  16. std::vector<uint32_t> dest;
  17. } storage_state;
  18. // route_init will call this function; no one else should call it
  19. // explicitly. The parameter is the number of messages that can fit in
  20. // the storage-side MsgBuffer. Returns true on success, false on
  21. // failure.
  22. bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
  23. {
  24. storage_state.max_users = max_users;
  25. storage_state.stg_buf.alloc(msg_buf_size);
  26. storage_state.dest.resize(msg_buf_size);
  27. uint32_t my_storage_node_id = 0;
  28. uint32_t my_stg_pos = 0;
  29. for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
  30. if (g_teems_config.roles[i] & ROLE_STORAGE) {
  31. if (i == g_teems_config.my_node_num) {
  32. storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
  33. my_stg_pos = my_storage_node_id;
  34. } else {
  35. ++my_storage_node_id;
  36. }
  37. }
  38. }
  39. storage_generateClientKeys(max_users, my_stg_pos);
  40. return true;
  41. }
  42. // Handle the messages received by a storage node. Pass a _locked_
  43. // MsgBuffer. This function will itself reset and unlock it when it's
  44. // done with it.
  45. void storage_received(MsgBuffer &storage_buf)
  46. {
  47. uint16_t msg_size = g_teems_config.msg_size;
  48. nodenum_t my_node_num = g_teems_config.my_node_num;
  49. const uint8_t *msgs = storage_buf.buf;
  50. uint32_t num_msgs = storage_buf.inserted;
  51. uint32_t real = 0, padding = 0;
  52. uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
  53. uint32_t nid_mask = ~uid_mask;
  54. #ifdef PROFILE_STORAGE
  55. unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
  56. #endif
  57. // It's OK to test for errors in a way that's non-oblivous if
  58. // there's an error (but it should be oblivious if there are no
  59. // errors)
  60. for (uint32_t i=0; i<num_msgs; ++i) {
  61. uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
  62. bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
  63. & ((uid & uid_mask) < storage_state.max_users))
  64. | ((uid & uid_mask) == uid_mask));
  65. if (!ok) {
  66. printf("Received bad uid: %08x\n", uid);
  67. assert(ok);
  68. }
  69. }
  70. // Testing: report how many real and dummy messages arrived
  71. printf("Storage server received %u messages:\n", num_msgs);
  72. for (uint32_t i=0; i<num_msgs; ++i) {
  73. uint32_t dest_addr = *(const uint32_t*)msgs;
  74. nodenum_t dest_node =
  75. g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
  76. if (dest_node != my_node_num) {
  77. char hexbuf[2*msg_size + 1];
  78. for (uint32_t j=0;j<msg_size;++j) {
  79. snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
  80. }
  81. printf("Misrouted message: %s\n", hexbuf);
  82. } else if ((dest_addr & uid_mask) == uid_mask) {
  83. ++padding;
  84. } else {
  85. ++real;
  86. }
  87. msgs += msg_size;
  88. }
  89. printf("%u real, %u padding\n", real, padding);
  90. /*
  91. for (uint32_t i=0;i<num_msgs; ++i) {
  92. printf("%3d: %08x %08x\n", i,
  93. *(uint32_t*)(storage_buf.buf+(i*msg_size)),
  94. *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
  95. }
  96. */
  97. // Sort the received messages by userid into the
  98. // storage_state.stg_buf MsgBuffer.
  99. #ifdef PROFILE_STORAGE
  100. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
  101. #endif
  102. sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
  103. msg_size, storage_buf.inserted, storage_buf.bufsize,
  104. storage_state.stg_buf.buf);
  105. #ifdef PROFILE_STORAGE
  106. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
  107. #endif
  108. /*
  109. for (uint32_t i=0;i<num_msgs; ++i) {
  110. printf("%3d: %08x %08x\n", i,
  111. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  112. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
  113. }
  114. */
  115. #ifdef PROFILE_STORAGE
  116. unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
  117. #endif
  118. // Obliviously set the dest array
  119. uint32_t *dests = storage_state.dest.data();
  120. uint32_t stg_size = storage_state.stg_buf.bufsize;
  121. const uint8_t *buf = storage_state.stg_buf.buf;
  122. uint32_t m_priv_in = g_teems_config.m_priv_in;
  123. uint32_t uid = *(uint32_t*)(buf);
  124. uid &= uid_mask;
  125. // num_msgs is not a private value
  126. if (num_msgs > 0) {
  127. dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
  128. uid == uid_mask);
  129. }
  130. uint32_t prev_uid = uid;
  131. for (uint32_t i=1; i<num_msgs; ++i) {
  132. uid = *(uint32_t*)(buf + i*msg_size);
  133. uid &= uid_mask;
  134. uint32_t next = oselect_uint32_t(uid * m_priv_in, dests[i-1]+1,
  135. uid == prev_uid);
  136. dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
  137. prev_uid = uid;
  138. }
  139. for (uint32_t i=num_msgs; i<stg_size; ++i) {
  140. dests[i] = 0xffffffff;
  141. *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
  142. }
  143. #ifdef PROFILE_STORAGE
  144. printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
  145. #endif
  146. /*
  147. for (uint32_t i=0;i<stg_size; ++i) {
  148. printf("%3d: %08x %08x %u\n", i,
  149. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  150. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  151. dests[i]);
  152. }
  153. */
  154. #ifdef PROFILE_STORAGE
  155. unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
  156. #endif
  157. ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
  158. msg_size, stg_size, g_teems_config.nthreads);
  159. #ifdef PROFILE_STORAGE
  160. printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
  161. #endif
  162. /*
  163. for (uint32_t i=0;i<stg_size; ++i) {
  164. printf("%3d: %08x %08x %u\n", i,
  165. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  166. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  167. dests[i]);
  168. }
  169. */
  170. // You can do more processing after these lines, as long as they
  171. // don't touch storage_buf. They _can_ touch the backing buffer
  172. // storage_state.stg_buf.
  173. storage_buf.reset();
  174. pthread_mutex_unlock(&storage_buf.mutex);
  175. storage_state.stg_buf.reset();
  176. #ifdef PROFILE_STORAGE
  177. printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
  178. #endif
  179. }
  180. bool storage_generateClientKeys(uint32_t num_clients, uint32_t my_stg_no) {
  181. clients = new StgClient[num_clients];
  182. uint32_t num_stg_nodes = g_teems_config.num_storage_nodes;
  183. uint32_t c_simid = my_stg_no;
  184. //printf("In Ingestion::genCK, num_clients = %d\n", num_clients);
  185. for (uint32_t i=0; i<num_clients; i++) {
  186. const sgx_aes_gcm_128bit_key_t *pESK = &(g_teems_config.ESK);
  187. unsigned char zeroes[SGX_AESGCM_KEY_SIZE];
  188. unsigned char iv[SGX_AESGCM_IV_SIZE];
  189. sgx_aes_gcm_128bit_tag_t tag;
  190. memset(zeroes, 0, SGX_AESGCM_KEY_SIZE);
  191. memset(iv, 0, SGX_AESGCM_IV_SIZE);
  192. memcpy(iv, (uint8_t*) (&c_simid), sizeof(c_simid));
  193. memcpy(iv + sizeof(c_simid), "STG", sizeof("STG"));
  194. sgx_status_t ret = SGX_SUCCESS;
  195. ret = sgx_rijndael128GCM_encrypt(pESK, zeroes, SGX_AESGCM_KEY_SIZE,
  196. (uint8_t*) (clients[i].key), iv, SGX_AESGCM_IV_SIZE, NULL, 0, &tag);
  197. if(ret!=SGX_SUCCESS) {
  198. printf("stg_generateClientKeys FAIL\n");
  199. return false;
  200. }
  201. /*
  202. if(c_simid % 10 == 0) {
  203. printf("Storage: c_simid = %d, Key:", c_simid);
  204. for (int k = 0; k<SGX_AESGCM_KEY_SIZE; k++) {
  205. printf("%x", (clients[i].key)[k]);
  206. }
  207. printf("\n");
  208. }
  209. */
  210. c_simid+=num_stg_nodes;
  211. }
  212. return true;
  213. }
  214. bool authenticateClient()
  215. {
  216. }