storage.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #include "utils.hpp"
  2. #include "config.hpp"
  3. #include "ORExpand.hpp"
  4. #include "sort.hpp"
  5. #include "storage.hpp"
  6. #include "client.hpp"
  7. #define PROFILE_STORAGE
  8. StgClient *clients;
  9. uint8_t *epoch_tokens;
  10. uint8_t *epoch_msgbundles;
  11. static struct {
  12. uint32_t max_users;
  13. uint32_t my_storage_node_id;
  14. // A local storage buffer, used when we need to do non-in-place
  15. // sorts of the messages that have arrived
  16. MsgBuffer stg_buf;
  17. // The destination vector for ORExpand
  18. std::vector<uint32_t> dest;
  19. } storage_state;
  20. bool storage_generateClientKeys(uint32_t num_clients, uint32_t my_stg_no) {
  21. uint16_t num_priv_channels = g_teems_config.m_priv_in;
  22. uint16_t msg_size = g_teems_config.msg_size;
  23. uint32_t pt_msgbundle_size = num_priv_channels * msg_size;
  24. clients = new StgClient[num_clients];
  25. for(uint32_t i =0; i < num_clients; i++) {
  26. uint32_t mid = storage_state.my_storage_node_id + i;
  27. clients[i].my_id = mid;
  28. clients[i].priv_friends = new clientid_t[g_teems_config.m_priv_in];
  29. // Initialize this client's private channel friends as themself
  30. for(int j =0; j <g_teems_config.m_priv_in; j++) {
  31. (clients[i].priv_friends)[j] = mid;
  32. }
  33. }
  34. uint32_t num_stg_nodes = g_teems_config.num_storage_nodes;
  35. uint32_t c_simid = my_stg_no;
  36. //printf("In Ingestion::genCK, num_clients = %d\n", num_clients);
  37. for (uint32_t i=0; i<num_clients; i++) {
  38. const sgx_aes_gcm_128bit_key_t *pESK = &(g_teems_config.ESK);
  39. unsigned char zeroes[SGX_AESGCM_KEY_SIZE];
  40. unsigned char iv[SGX_AESGCM_IV_SIZE];
  41. sgx_aes_gcm_128bit_tag_t tag;
  42. memset(zeroes, 0, SGX_AESGCM_KEY_SIZE);
  43. memset(iv, 0, SGX_AESGCM_IV_SIZE);
  44. memcpy(iv, (uint8_t*) (&c_simid), sizeof(c_simid));
  45. memcpy(iv + sizeof(c_simid), "STG", sizeof("STG"));
  46. sgx_status_t ret = SGX_SUCCESS;
  47. ret = sgx_rijndael128GCM_encrypt(pESK, zeroes, SGX_AESGCM_KEY_SIZE,
  48. (uint8_t*) (clients[i].key), iv, SGX_AESGCM_IV_SIZE, NULL, 0, &tag);
  49. if(ret!=SGX_SUCCESS) {
  50. printf("stg_generateClientKeys FAIL\n");
  51. return false;
  52. }
  53. /*
  54. if(c_simid % 10 == 0) {
  55. printf("Storage: c_simid = %d, Key:", c_simid);
  56. for (int k = 0; k<SGX_AESGCM_KEY_SIZE; k++) {
  57. printf("%x", (clients[i].key)[k]);
  58. }
  59. printf("\n");
  60. }
  61. */
  62. c_simid+=num_stg_nodes;
  63. }
  64. return true;
  65. }
  66. // route_init will call this function; no one else should call it
  67. // explicitly. The parameter is the number of messages that can fit in
  68. // the storage-side MsgBuffer. Returns true on success, false on
  69. // failure.
  70. bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
  71. {
  72. storage_state.max_users = max_users;
  73. storage_state.stg_buf.alloc(msg_buf_size);
  74. storage_state.dest.resize(msg_buf_size);
  75. uint32_t my_storage_node_id = 0;
  76. uint32_t my_stg_pos = 0;
  77. for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
  78. if (g_teems_config.roles[i] & ROLE_STORAGE) {
  79. if (i == g_teems_config.my_node_num) {
  80. storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
  81. my_stg_pos = my_storage_node_id;
  82. } else {
  83. ++my_storage_node_id;
  84. }
  85. }
  86. }
  87. printf("my_stg_pos = %d\n", my_stg_pos);
  88. storage_generateClientKeys(max_users, my_stg_pos);
  89. // sendClientTokens();
  90. return true;
  91. }
  92. // Handle the messages received by a storage node. Pass a _locked_
  93. // MsgBuffer. This function will itself reset and unlock it when it's
  94. // done with it.
  95. void storage_received(MsgBuffer &storage_buf)
  96. {
  97. uint16_t msg_size = g_teems_config.msg_size;
  98. nodenum_t my_node_num = g_teems_config.my_node_num;
  99. const uint8_t *msgs = storage_buf.buf;
  100. uint32_t num_msgs = storage_buf.inserted;
  101. uint32_t real = 0, padding = 0;
  102. uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
  103. uint32_t nid_mask = ~uid_mask;
  104. #ifdef PROFILE_STORAGE
  105. unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
  106. #endif
  107. // It's OK to test for errors in a way that's non-oblivous if
  108. // there's an error (but it should be oblivious if there are no
  109. // errors)
  110. for (uint32_t i=0; i<num_msgs; ++i) {
  111. uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
  112. bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
  113. & ((uid & uid_mask) < storage_state.max_users))
  114. | ((uid & uid_mask) == uid_mask));
  115. if (!ok) {
  116. printf("Received bad uid: %08x\n", uid);
  117. assert(ok);
  118. }
  119. }
  120. // Testing: report how many real and dummy messages arrived
  121. printf("Storage server received %u messages:\n", num_msgs);
  122. for (uint32_t i=0; i<num_msgs; ++i) {
  123. uint32_t dest_addr = *(const uint32_t*)msgs;
  124. nodenum_t dest_node =
  125. g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
  126. if (dest_node != my_node_num) {
  127. char hexbuf[2*msg_size + 1];
  128. for (uint32_t j=0;j<msg_size;++j) {
  129. snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
  130. }
  131. printf("Misrouted message: %s\n", hexbuf);
  132. } else if ((dest_addr & uid_mask) == uid_mask) {
  133. ++padding;
  134. } else {
  135. ++real;
  136. }
  137. msgs += msg_size;
  138. }
  139. printf("%u real, %u padding\n", real, padding);
  140. /*
  141. for (uint32_t i=0;i<num_msgs; ++i) {
  142. printf("%3d: %08x %08x\n", i,
  143. *(uint32_t*)(storage_buf.buf+(i*msg_size)),
  144. *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
  145. }
  146. */
  147. // Sort the received messages by userid into the
  148. // storage_state.stg_buf MsgBuffer.
  149. #ifdef PROFILE_STORAGE
  150. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
  151. #endif
  152. sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
  153. msg_size, storage_buf.inserted, storage_buf.bufsize,
  154. storage_state.stg_buf.buf);
  155. #ifdef PROFILE_STORAGE
  156. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
  157. #endif
  158. /*
  159. for (uint32_t i=0;i<num_msgs; ++i) {
  160. printf("%3d: %08x %08x\n", i,
  161. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  162. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
  163. }
  164. */
  165. #ifdef PROFILE_STORAGE
  166. unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
  167. #endif
  168. // Obliviously set the dest array
  169. uint32_t *dests = storage_state.dest.data();
  170. uint32_t stg_size = storage_state.stg_buf.bufsize;
  171. const uint8_t *buf = storage_state.stg_buf.buf;
  172. uint32_t m_priv_in = g_teems_config.m_priv_in;
  173. uint32_t uid = *(uint32_t*)(buf);
  174. uid &= uid_mask;
  175. // num_msgs is not a private value
  176. if (num_msgs > 0) {
  177. dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
  178. uid == uid_mask);
  179. }
  180. uint32_t prev_uid = uid;
  181. for (uint32_t i=1; i<num_msgs; ++i) {
  182. uid = *(uint32_t*)(buf + i*msg_size);
  183. uid &= uid_mask;
  184. uint32_t next = oselect_uint32_t(uid * m_priv_in, dests[i-1]+1,
  185. uid == prev_uid);
  186. dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
  187. prev_uid = uid;
  188. }
  189. for (uint32_t i=num_msgs; i<stg_size; ++i) {
  190. dests[i] = 0xffffffff;
  191. *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
  192. }
  193. #ifdef PROFILE_STORAGE
  194. printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
  195. #endif
  196. /*
  197. for (uint32_t i=0;i<stg_size; ++i) {
  198. printf("%3d: %08x %08x %u\n", i,
  199. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  200. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  201. dests[i]);
  202. }
  203. */
  204. #ifdef PROFILE_STORAGE
  205. unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
  206. #endif
  207. ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
  208. msg_size, stg_size, g_teems_config.nthreads);
  209. #ifdef PROFILE_STORAGE
  210. printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
  211. #endif
  212. /*
  213. for (uint32_t i=0;i<stg_size; ++i) {
  214. printf("%3d: %08x %08x %u\n", i,
  215. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  216. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  217. dests[i]);
  218. }
  219. */
  220. // You can do more processing after these lines, as long as they
  221. // don't touch storage_buf. They _can_ touch the backing buffer
  222. // storage_state.stg_buf.
  223. storage_buf.reset();
  224. pthread_mutex_unlock(&storage_buf.mutex);
  225. // TODO: Send the storage_state.stg_buf to clients
  226. // Encrypt the storage_state.stg_buf into MSG_BUNDLES
  227. // encryptMsgBundles(num_clients, num_priv_channels, msg_size);
  228. // Uses: storage_state.stg_buf, clients, epoch_msgbundles
  229. // generateTokens();
  230. // Uses: clients, epoch_tokens,
  231. // ocall_process_msgbundles(epoch_msgbundles, epoch_tokens);
  232. // sendClientTokens();
  233. //
  234. // processStorage(msg_bundles) :
  235. // a) Send out msg_bundles to clients with open sockets
  236. // b) Store the other msg_bundles in "backend"
  237. storage_state.stg_buf.reset();
  238. #ifdef PROFILE_STORAGE
  239. printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
  240. #endif
  241. }
  242. /*
  243. Given a local client identifier, generate the tokens for this client
  244. for their priv_friends for the next round.
  245. Populates the supplied tokens buffer with the correct tokens.
  246. */
  247. bool generate_tokens(clientid_t lcid)
  248. {
  249. uint32_t pt_tokens_size = (g_teems_config.m_priv_in * SGX_AESGCM_KEY_SIZE);
  250. uint32_t enc_tokens_size = (g_teems_config.m_priv_in * SGX_AESGCM_KEY_SIZE) +
  251. SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  252. const sgx_aes_gcm_128bit_key_t *pTSK = &(g_teems_config.TSK);
  253. unsigned char *tkn_iv_ptr = epoch_tokens + enc_tokens_size * lcid;
  254. unsigned char *tkn_ptr = tkn_iv_ptr + SGX_AESGCM_IV_SIZE;
  255. unsigned char *tkn_tag = tkn_ptr + pt_tokens_size;
  256. // We construct the plaintext underlying the token in the correct location
  257. // for this client in epoch_tokens
  258. // The tokens get stored in token_body
  259. // Later we encrypt token_body with the client's storage key and overwrite
  260. // the correct location in epoch_tokens
  261. unsigned char token_body[pt_tokens_size];
  262. memset(token_body, 0, pt_tokens_size);
  263. memset(tkn_iv_ptr, 0, SGX_AESGCM_IV_SIZE);
  264. // IV = client_id | epoch_no
  265. memcpy(tkn_iv_ptr, (uint8_t*) (&(clients[lcid].my_id)), sizeof(clientid_t));
  266. // TODO: Add epoch to IV
  267. //memcpy(tkn_iv_ptr + sizeof(clientid_t), epoch_no, sizeof(epoch_no));
  268. unsigned char *ptr = tkn_ptr;
  269. for(int i = 0; i<g_teems_config.m_priv_in; i++)
  270. {
  271. memcpy(ptr, (uint8_t*) (&(clients[lcid].my_id)), sizeof(clientid_t));
  272. memcpy(ptr + sizeof(clientid_t), (uint8_t*) (&(clients[lcid].priv_friends[i])), sizeof(clientid_t));
  273. ptr+=SGX_AESGCM_KEY_SIZE;
  274. }
  275. sgx_status_t ret = SGX_SUCCESS;
  276. ret = sgx_rijndael128GCM_encrypt(pTSK, tkn_ptr, pt_tokens_size,
  277. (uint8_t*) token_body, tkn_iv_ptr, SGX_AESGCM_IV_SIZE, NULL, 0,
  278. (sgx_aes_gcm_128bit_tag_t*) tkn_tag);
  279. if(ret!=SGX_SUCCESS) {
  280. printf("generate_tokens: Creating token FAIL\n");
  281. return false;
  282. }
  283. /*
  284. if(lcid == 0) {
  285. printf("Checking generated token_body:");
  286. for(uint32_t i = 0; i < pt_tokens_size; i++) {
  287. printf("%x", token_body[i]);
  288. }
  289. printf("\n");
  290. }
  291. */
  292. unsigned char *cl_iv = clients[lcid].iv;
  293. ret = (sgx_rijndael128GCM_encrypt(&(clients[lcid].key), token_body, pt_tokens_size,
  294. (uint8_t*) tkn_ptr, cl_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  295. (sgx_aes_gcm_128bit_tag_t*) tkn_tag));
  296. if(ret!=SGX_SUCCESS) {
  297. printf("generate_tokens: Encrypting token FAIL\n");
  298. return false;
  299. }
  300. memcpy(tkn_iv_ptr, cl_iv, SGX_AESGCM_IV_SIZE);
  301. // Update IV
  302. uint64_t *iv_ctr = (uint64_t*) cl_iv;
  303. (*iv_ctr)+=1;
  304. /*
  305. if(lcid == 0) {
  306. printf("Encrypted client token:");
  307. for(uint32_t i = 0; i < pt_tokens_size; i++) {
  308. printf("%x", tkn_ptr[i]);
  309. }
  310. printf("\n");
  311. }
  312. */
  313. return true;
  314. }
  315. bool ecall_storage_authenticate(clientid_t cid, unsigned char *auth_message)
  316. {
  317. bool ret = false;
  318. uint32_t lcid = cid / g_teems_config.num_storage_nodes;
  319. const sgx_aes_gcm_128bit_key_t *ckey = &(clients[lcid].key);
  320. ret = authenticateClient(auth_message, ckey);
  321. if(!ret) {
  322. printf("Storage authentication FAIL\n");
  323. }
  324. // When clients authenticate:
  325. // 1) Send back mailbox archive
  326. // 2) Send tokens for current epoch
  327. ret &= generate_tokens(lcid);
  328. /*
  329. uint32_t enc_token_bundle_size = SGX_AESGCM_KEY_SIZE + SGX_AESGCM_IV_SIZE
  330. unsigned char *token_ptr = lcid *
  331. // sendTokens()
  332. */
  333. /*
  334. if(ret) {
  335. printf("ecall_storage_auth : ret = SUCCESS\n");
  336. }
  337. else {
  338. printf("ecall_storage_auth : ret = FAIL\n");
  339. }
  340. */
  341. return ret;
  342. }
  343. void ecall_supply_storage_buffers(unsigned char *msgbundles,
  344. uint32_t msgbundles_size, unsigned char *tokens, uint32_t tokens_size)
  345. {
  346. epoch_msgbundles = msgbundles;
  347. epoch_tokens = tokens;
  348. printf("Storage buffers were set");
  349. }