storage.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. #include "utils.hpp"
  2. #include "config.hpp"
  3. #include "ORExpand.hpp"
  4. #include "sort.hpp"
  5. #include "storage.hpp"
  6. #include "client.hpp"
  7. #define PROFILE_STORAGE
  8. StgClient *clients;
  9. uint8_t *epoch_tokens;
  10. uint8_t *epoch_mailboxes;
  11. static struct {
  12. uint32_t max_users;
  13. uint32_t my_storage_node_id;
  14. // A local storage buffer, used when we need to do non-in-place
  15. // sorts of the messages that have arrived
  16. MsgBuffer stg_buf;
  17. // The destination vector for ORExpand
  18. std::vector<uint32_t> dest;
  19. // The selected array for compaction during public routing
  20. // Need a bool array for compaction, and std:vector<bool> lacks .data()
  21. bool *pub_selected;
  22. } storage_state;
  23. bool storage_generateClientKeys(uint32_t num_clients, uint32_t my_stg_no) {
  24. uint16_t num_priv_channels = g_teems_config.m_priv_in;
  25. uint16_t msg_size = g_teems_config.msg_size;
  26. uint32_t pt_msgbundle_size = num_priv_channels * msg_size;
  27. clients = new StgClient[num_clients];
  28. for(uint32_t i =0; i < num_clients; i++) {
  29. uint32_t mid = storage_state.my_storage_node_id + i;
  30. clients[i].my_id = mid;
  31. clients[i].priv_friends = new clientid_t[g_teems_config.m_priv_out];
  32. // Initialize this client's private channel friends as themself
  33. for(int j =0; j <g_teems_config.m_priv_out; j++) {
  34. (clients[i].priv_friends)[j] = mid;
  35. }
  36. }
  37. uint32_t num_stg_nodes = g_teems_config.num_storage_nodes;
  38. uint32_t c_simid = my_stg_no;
  39. for (uint32_t i=0; i<num_clients; i++) {
  40. const sgx_aes_gcm_128bit_key_t *pESK = &(g_teems_config.ESK);
  41. unsigned char zeroes[SGX_AESGCM_KEY_SIZE];
  42. unsigned char iv[SGX_AESGCM_IV_SIZE];
  43. sgx_aes_gcm_128bit_tag_t tag;
  44. memset(zeroes, 0, SGX_AESGCM_KEY_SIZE);
  45. memset(iv, 0, SGX_AESGCM_IV_SIZE);
  46. memcpy(iv, (uint8_t*) (&c_simid), sizeof(c_simid));
  47. memcpy(iv + sizeof(c_simid), "STG", sizeof("STG"));
  48. sgx_status_t ret = SGX_SUCCESS;
  49. ret = sgx_rijndael128GCM_encrypt(pESK, zeroes, SGX_AESGCM_KEY_SIZE,
  50. (uint8_t*) (clients[i].key), iv, SGX_AESGCM_IV_SIZE, NULL, 0, &tag);
  51. if(ret!=SGX_SUCCESS) {
  52. printf("stg_generateClientKeys FAIL\n");
  53. return false;
  54. }
  55. /*
  56. if(c_simid % 10 == 0) {
  57. printf("Storage: c_simid = %d, Key:", c_simid);
  58. for (int k = 0; k<SGX_AESGCM_KEY_SIZE; k++) {
  59. printf("%x", (clients[i].key)[k]);
  60. }
  61. printf("\n");
  62. }
  63. */
  64. c_simid+=num_stg_nodes;
  65. }
  66. return true;
  67. }
  68. struct UserRange {
  69. uint32_t start, num;
  70. bool ret;
  71. };
  72. static void* generate_all_tokens_launch(void *voidargs)
  73. {
  74. UserRange *args = (UserRange *)voidargs;
  75. uint32_t pt_tokens_size = (g_teems_config.m_priv_out * SGX_CMAC_MAC_SIZE);
  76. uint32_t enc_tokens_size = pt_tokens_size +
  77. SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  78. unsigned char token_body[pt_tokens_size];
  79. uint32_t user_start = args->start;
  80. uint32_t user_end = args->start + args->num;
  81. const sgx_aes_gcm_128bit_key_t *pTSK = &(g_teems_config.TSK);
  82. for(uint32_t lcid=user_start; lcid < user_end; lcid++) {
  83. unsigned char *tkn_iv_ptr = epoch_tokens + enc_tokens_size * lcid;
  84. unsigned char *tkn_ptr = tkn_iv_ptr + SGX_AESGCM_IV_SIZE;
  85. unsigned char *tkn_tag = tkn_ptr + pt_tokens_size;
  86. // We construct the plaintext [S|R|Epoch] underlying the token in
  87. // the correct location for this client in the epoch_tokens buffer
  88. // The tokens ( i.e., CMAC over S|R|Epoch) is stored in token_body
  89. // Then encrypt token_body with the client's storage key and overwrite
  90. // the correct location in epoch_tokens
  91. memset(token_body, 0, pt_tokens_size);
  92. memset(tkn_iv_ptr, 0, SGX_AESGCM_IV_SIZE);
  93. // IV = epoch_no, for encrypting the token bundle
  94. // epoch_no used is for the next epoch
  95. unsigned long epoch_val = storage_epoch + 1;
  96. memcpy(tkn_iv_ptr, &epoch_val, sizeof(epoch_val));
  97. sgx_status_t ret = SGX_SUCCESS;
  98. unsigned char *ptr = tkn_ptr;
  99. unsigned char *tkn_body_ptr = token_body;
  100. for(int i = 0; i<g_teems_config.m_priv_out; i++)
  101. {
  102. memcpy(ptr, (&(clients[lcid].my_id)), sizeof(clientid_t));
  103. memcpy(ptr + sizeof(clientid_t), (&(clients[lcid].priv_friends[i])), sizeof(clientid_t));
  104. memcpy(ptr + 2 * sizeof(clientid_t), &epoch_val, sizeof(epoch_val));
  105. ret = sgx_rijndael128_cmac_msg(pTSK, ptr, pt_tokens_size,
  106. (sgx_cmac_128bit_tag_t*) tkn_body_ptr);
  107. if(ret!=SGX_SUCCESS) {
  108. printf("generate_tokens: Creating token FAIL\n");
  109. args->ret = false;
  110. return NULL;
  111. }
  112. ptr+=SGX_CMAC_MAC_SIZE;
  113. tkn_body_ptr+=SGX_CMAC_MAC_SIZE;
  114. }
  115. /*
  116. if(lcid == 0) {
  117. printf("Checking generated token_body:");
  118. for(uint32_t i = 0; i < pt_tokens_size; i++) {
  119. printf("%x", token_body[i]);
  120. }
  121. printf("\n");
  122. }
  123. */
  124. unsigned char *cl_iv = clients[lcid].iv;
  125. ret = (sgx_rijndael128GCM_encrypt(&(clients[lcid].key), token_body, pt_tokens_size,
  126. (uint8_t*) tkn_ptr, cl_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  127. (sgx_aes_gcm_128bit_tag_t*) tkn_tag));
  128. if(ret!=SGX_SUCCESS) {
  129. printf("generate_tokens: Encrypting token FAIL\n");
  130. args->ret = false;
  131. return NULL;
  132. }
  133. memcpy(tkn_iv_ptr, cl_iv, SGX_AESGCM_IV_SIZE);
  134. // Update IV
  135. uint64_t *iv_ctr = (uint64_t*) cl_iv;
  136. (*iv_ctr)+=1;
  137. /*
  138. if(lcid == 0) {
  139. printf("Encrypted client token bundle:");
  140. for(uint32_t i = 0; i < enc_tokens_size; i++) {
  141. printf("%x", tkn_iv_ptr[i]);
  142. }
  143. printf("\n");
  144. }
  145. */
  146. }
  147. args->ret = true;
  148. return NULL;
  149. }
  150. static bool launch_all_users(void *(*launch)(void *)) {
  151. threadid_t nthreads = g_teems_config.nthreads;
  152. // Special-case nthread=1 for efficiency
  153. if (nthreads <= 1) {
  154. UserRange args = { 0, storage_state.max_users };
  155. return launch(&args);
  156. }
  157. UserRange args[nthreads];
  158. uint32_t inc = storage_state.max_users / nthreads;
  159. uint32_t extra = storage_state.max_users % nthreads;
  160. uint32_t last = 0;
  161. for (threadid_t i=0; i<nthreads; ++i) {
  162. uint32_t num = inc + (i < extra);
  163. args[i] = { last, num };
  164. last += num;
  165. }
  166. // Launch all but the first section into other threads
  167. for (threadid_t i=1; i<nthreads; ++i) {
  168. threadpool_dispatch(g_thread_id+i, launch, args+i);
  169. }
  170. // Do the first section ourselves
  171. launch(args);
  172. // Join the threads
  173. for (threadid_t i=1; i<nthreads; ++i) {
  174. threadpool_join(g_thread_id+i, NULL);
  175. }
  176. bool ret = true;
  177. for (threadid_t i=0; i<nthreads; ++i) {
  178. ret &= args[i].ret;
  179. }
  180. return ret;
  181. }
  182. bool generate_all_tokens() {
  183. return launch_all_users(generate_all_tokens_launch);
  184. }
  185. /* processMsgs
  186. - Take all the messages in storage_state.stg_buf
  187. - Encrypt them all with their corresponding client key and IV and store into
  188. epoch_mailboxes
  189. */
  190. static void *processMsgs_launch(void *voidargs) {
  191. UserRange *args = (UserRange *)voidargs;
  192. uint32_t user_start = args->start;
  193. uint32_t user_end = args->start + args->num;
  194. uint32_t mailbox_size, num_expected_msgs;
  195. if (g_teems_config.private_routing) {
  196. mailbox_size = g_teems_config.m_priv_in * g_teems_config.msg_size;
  197. num_expected_msgs = g_teems_config.m_priv_in * storage_state.max_users;
  198. } else {
  199. mailbox_size = g_teems_config.m_pub_in * g_teems_config.msg_size;
  200. num_expected_msgs = g_teems_config.m_pub_in * storage_state.max_users;
  201. }
  202. uint32_t enc_mailbox_size = mailbox_size + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  203. unsigned char *epoch_buf_ptr = epoch_mailboxes +
  204. enc_mailbox_size * user_start;
  205. unsigned char *stg_buf_ptr = storage_state.stg_buf.buf +
  206. mailbox_size * user_start;
  207. sgx_status_t ret = SGX_SUCCESS;
  208. unsigned char *epoch_buf_ct_ptr = epoch_buf_ptr + SGX_AESGCM_IV_SIZE;
  209. unsigned char *epoch_buf_tag_ptr = epoch_buf_ct_ptr + mailbox_size;
  210. for(uint32_t lcid = user_start; lcid < user_end; lcid++) {
  211. memcpy(epoch_buf_ptr, clients[lcid].iv, SGX_AESGCM_IV_SIZE);
  212. ret = sgx_rijndael128GCM_encrypt(&(clients[lcid].key), stg_buf_ptr, mailbox_size,
  213. (uint8_t*) epoch_buf_ct_ptr, epoch_buf_ptr, SGX_AESGCM_IV_SIZE, NULL, 0,
  214. (sgx_aes_gcm_128bit_tag_t*) epoch_buf_tag_ptr);
  215. if(ret!=SGX_SUCCESS) {
  216. printf("processMsgs: Encrypting msgs FAIL\n");
  217. args->ret = false;
  218. return NULL;
  219. }
  220. // Update IV
  221. uint64_t *iv_ctr = (uint64_t*) clients[lcid].iv;
  222. (*iv_ctr)+=1;
  223. /*
  224. if(lcid==0) {
  225. printf("\n\nMessage for lcid 0, S, R = %d, %d\n\n\n", *((uint32_t*) stg_buf_ptr),
  226. *((uint32_t*) (stg_buf_ptr + 4)));
  227. }
  228. */
  229. stg_buf_ptr+=mailbox_size;
  230. epoch_buf_ptr+=enc_mailbox_size;
  231. epoch_buf_ct_ptr+=enc_mailbox_size;
  232. epoch_buf_tag_ptr+=enc_mailbox_size;
  233. }
  234. args->ret = true;
  235. return NULL;
  236. }
  237. bool processMsgs() {
  238. return launch_all_users(processMsgs_launch);
  239. }
  240. // route_init will call this function; no one else should call it
  241. // explicitly. The parameter is the number of messages that can fit in
  242. // the storage-side MsgBuffer. Returns true on success, false on
  243. // failure.
  244. bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
  245. {
  246. storage_state.max_users = max_users;
  247. storage_state.stg_buf.alloc(msg_buf_size);
  248. storage_state.dest.resize(msg_buf_size);
  249. storage_state.pub_selected = new bool[msg_buf_size];
  250. uint32_t my_storage_node_id = 0;
  251. uint32_t my_stg_pos = 0;
  252. for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
  253. if (g_teems_config.roles[i] & ROLE_STORAGE) {
  254. if (i == g_teems_config.my_node_num) {
  255. storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
  256. my_stg_pos = my_storage_node_id;
  257. } else {
  258. ++my_storage_node_id;
  259. }
  260. }
  261. }
  262. storage_generateClientKeys(max_users, my_stg_pos);
  263. return true;
  264. }
  265. void storage_close() {
  266. delete[] storage_state.pub_selected;
  267. }
  268. // Handle the messages received by a storage node. Pass a _locked_
  269. // MsgBuffer. This function will itself reset and unlock it when it's
  270. // done with it.
  271. void storage_received(MsgBuffer &storage_buf)
  272. {
  273. uint16_t msg_size = g_teems_config.msg_size;
  274. nodenum_t my_node_num = g_teems_config.my_node_num;
  275. const uint8_t *msgs = storage_buf.buf;
  276. uint32_t num_msgs = storage_buf.inserted;
  277. uint32_t real = 0, padding = 0;
  278. uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
  279. uint32_t nid_mask = ~uid_mask;
  280. #ifdef PROFILE_STORAGE
  281. unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
  282. #endif
  283. // It's OK to test for errors in a way that's non-oblivous if
  284. // there's an error (but it should be oblivious if there are no
  285. // errors)
  286. for (uint32_t i=0; i<num_msgs; ++i) {
  287. uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
  288. bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
  289. & ((uid & uid_mask) < storage_state.max_users))
  290. | ((uid & uid_mask) == uid_mask));
  291. if (!ok) {
  292. printf("Received bad uid: %08x\n", uid);
  293. assert(ok);
  294. }
  295. }
  296. // Testing: report how many real and dummy messages arrived
  297. printf("Storage server received %u messages:\n", num_msgs);
  298. for (uint32_t i=0; i<num_msgs; ++i) {
  299. uint32_t dest_addr = *(const uint32_t*)msgs;
  300. nodenum_t dest_node =
  301. g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
  302. if (dest_node != my_node_num) {
  303. char hexbuf[2*msg_size + 1];
  304. for (uint32_t j=0;j<msg_size;++j) {
  305. snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
  306. }
  307. printf("Misrouted message: %s\n", hexbuf);
  308. } else if ((dest_addr & uid_mask) == uid_mask) {
  309. ++padding;
  310. } else {
  311. ++real;
  312. }
  313. msgs += msg_size;
  314. }
  315. printf("%u real, %u padding\n", real, padding);
  316. /*
  317. for (uint32_t i=0;i<num_msgs; ++i) {
  318. printf("%3d: %08x %08x\n", i,
  319. *(uint32_t*)(storage_buf.buf+(i*msg_size)),
  320. *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
  321. }
  322. */
  323. // Sort the received messages by userid into the
  324. // storage_state.stg_buf MsgBuffer.
  325. #ifdef PROFILE_STORAGE
  326. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
  327. #endif
  328. sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
  329. msg_size, storage_buf.inserted, storage_buf.bufsize,
  330. storage_state.stg_buf.buf);
  331. #ifdef PROFILE_STORAGE
  332. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
  333. #endif
  334. // For public routing, remove excess per-user messages by making them
  335. // padding, and then compact non-padding messages.
  336. if (!g_teems_config.private_routing) {
  337. uint8_t *msg = storage_state.stg_buf.buf;
  338. uint32_t uid;
  339. uint32_t prev_uid = uid_mask; // initialization technically unnecessary
  340. uint32_t num_user_msgs = 0; // number of messages to the user
  341. uint8_t sel;
  342. for (uint32_t i=0; i<num_msgs; ++i) {
  343. uid = (*(uint32_t*) msg) &= uid_mask;
  344. num_user_msgs = oselect_uint32_t(1, num_user_msgs+1,
  345. uid == prev_uid);
  346. // Select if messages per user not exceeded and msg is not padding
  347. sel = ((uint8_t) ((num_user_msgs <= g_teems_config.m_pub_in))) &
  348. ((uint8_t) uid != uid_mask);
  349. storage_state.pub_selected[i] = (bool) sel;
  350. // Make padding if not selected
  351. *(uint32_t *) msg = (*(uint32_t *) msg) & nid_mask;
  352. *(uint32_t *) msg += oselect_uint32_t(uid_mask, uid, sel);
  353. msg += msg_size;
  354. prev_uid = uid;
  355. }
  356. #ifdef PROFILE_STORAGE
  357. unsigned long start_compaction = printf_with_rtclock("begin public-channel compaction (%u)\n", num_msgs);
  358. #endif
  359. TightCompact_parallel<OSWAP_16X>(
  360. (unsigned char *) storage_state.stg_buf.buf,
  361. num_msgs, msg_size, storage_state.pub_selected,
  362. g_teems_config.nthreads);
  363. #ifdef PROFILE_STORAGE
  364. printf_with_rtclock_diff(start_compaction, "end public-channel compaction (%u)\n", num_msgs);
  365. #endif
  366. }
  367. /*
  368. for (uint32_t i=0;i<num_msgs; ++i) {
  369. printf("%3d: %08x %08x\n", i,
  370. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  371. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
  372. }
  373. */
  374. #ifdef PROFILE_STORAGE
  375. unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
  376. #endif
  377. // Obliviously set the dest array
  378. uint32_t *dests = storage_state.dest.data();
  379. uint32_t stg_size = storage_state.stg_buf.bufsize;
  380. const uint8_t *buf = storage_state.stg_buf.buf;
  381. uint32_t m_in = g_teems_config.private_routing ? g_teems_config.m_priv_in : g_teems_config.m_pub_in;
  382. uint32_t uid = *(uint32_t*)(buf);
  383. uid &= uid_mask;
  384. // num_msgs is not a private value
  385. if (num_msgs > 0) {
  386. dests[0] = oselect_uint32_t(uid * m_in, 0xffffffff,
  387. uid == uid_mask);
  388. }
  389. uint32_t prev_uid = uid;
  390. for (uint32_t i=1; i<num_msgs; ++i) {
  391. uid = *(uint32_t*)(buf + i*msg_size);
  392. uid &= uid_mask;
  393. uint32_t next = oselect_uint32_t(uid * m_in, dests[i-1]+1,
  394. uid == prev_uid);
  395. dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
  396. prev_uid = uid;
  397. }
  398. for (uint32_t i=num_msgs; i<stg_size; ++i) {
  399. dests[i] = 0xffffffff;
  400. *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
  401. }
  402. #ifdef PROFILE_STORAGE
  403. printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
  404. #endif
  405. /*
  406. for (uint32_t i=0;i<stg_size; ++i) {
  407. printf("%3d: %08x %08x %u\n", i,
  408. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  409. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  410. dests[i]);
  411. }
  412. */
  413. #ifdef PROFILE_STORAGE
  414. unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
  415. #endif
  416. ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
  417. msg_size, stg_size, g_teems_config.nthreads);
  418. #ifdef PROFILE_STORAGE
  419. printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
  420. #endif
  421. /*
  422. for (uint32_t i=0;i<stg_size; ++i) {
  423. printf("%3d: %08x %08x %u\n", i,
  424. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  425. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  426. dests[i]);
  427. }
  428. */
  429. // You can do more processing after these lines, as long as they
  430. // don't touch storage_buf. They _can_ touch the backing buffer
  431. // storage_state.stg_buf.
  432. storage_buf.reset();
  433. pthread_mutex_unlock(&storage_buf.mutex);
  434. if (g_teems_config.private_routing) {
  435. bool ret = generate_all_tokens();
  436. }
  437. processMsgs();
  438. storage_epoch++;
  439. storage_state.stg_buf.reset();
  440. #ifdef PROFILE_STORAGE
  441. printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
  442. #endif
  443. }
  444. bool ecall_storage_authenticate(clientid_t cid, unsigned char *auth_message)
  445. {
  446. bool ret = false;
  447. uint32_t lcid = cid / g_teems_config.num_storage_nodes;
  448. const sgx_aes_gcm_128bit_key_t *ckey = &(clients[lcid].key);
  449. ret = authenticateClient(auth_message, ckey);
  450. if(!ret) {
  451. printf("Storage authentication FAIL\n");
  452. }
  453. return ret;
  454. }
  455. void ecall_supply_storage_buffers(unsigned char *mailboxes,
  456. uint32_t mailboxes_size, unsigned char *tokens, uint32_t tokens_size)
  457. {
  458. epoch_mailboxes = mailboxes;
  459. epoch_tokens = tokens;
  460. }