storage.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. #include "utils.hpp"
  2. #include "config.hpp"
  3. #include "ORExpand.hpp"
  4. #include "sort.hpp"
  5. #include "storage.hpp"
  6. #include "client.hpp"
  7. #define PROFILE_STORAGE
  8. StgClient *clients;
  9. uint8_t *epoch_tokens;
  10. uint8_t *epoch_msgbundles;
  11. static struct {
  12. uint32_t max_users;
  13. uint32_t my_storage_node_id;
  14. // A local storage buffer, used when we need to do non-in-place
  15. // sorts of the messages that have arrived
  16. MsgBuffer stg_buf;
  17. // The destination vector for ORExpand
  18. std::vector<uint32_t> dest;
  19. } storage_state;
  20. bool storage_generateClientKeys(uint32_t num_clients, uint32_t my_stg_no) {
  21. uint16_t num_priv_channels = g_teems_config.m_priv_in;
  22. uint16_t msg_size = g_teems_config.msg_size;
  23. uint32_t pt_msgbundle_size = num_priv_channels * msg_size;
  24. clients = new StgClient[num_clients];
  25. for(uint32_t i =0; i < num_clients; i++) {
  26. uint32_t mid = storage_state.my_storage_node_id + i;
  27. clients[i].my_id = mid;
  28. clients[i].priv_friends = new clientid_t[g_teems_config.m_priv_out];
  29. // Initialize this client's private channel friends as themself
  30. for(int j =0; j <g_teems_config.m_priv_out; j++) {
  31. (clients[i].priv_friends)[j] = mid;
  32. }
  33. }
  34. uint32_t num_stg_nodes = g_teems_config.num_storage_nodes;
  35. uint32_t c_simid = my_stg_no;
  36. for (uint32_t i=0; i<num_clients; i++) {
  37. const sgx_aes_gcm_128bit_key_t *pESK = &(g_teems_config.ESK);
  38. unsigned char zeroes[SGX_AESGCM_KEY_SIZE];
  39. unsigned char iv[SGX_AESGCM_IV_SIZE];
  40. sgx_aes_gcm_128bit_tag_t tag;
  41. memset(zeroes, 0, SGX_AESGCM_KEY_SIZE);
  42. memset(iv, 0, SGX_AESGCM_IV_SIZE);
  43. memcpy(iv, (uint8_t*) (&c_simid), sizeof(c_simid));
  44. memcpy(iv + sizeof(c_simid), "STG", sizeof("STG"));
  45. sgx_status_t ret = SGX_SUCCESS;
  46. ret = sgx_rijndael128GCM_encrypt(pESK, zeroes, SGX_AESGCM_KEY_SIZE,
  47. (uint8_t*) (clients[i].key), iv, SGX_AESGCM_IV_SIZE, NULL, 0, &tag);
  48. if(ret!=SGX_SUCCESS) {
  49. printf("stg_generateClientKeys FAIL\n");
  50. return false;
  51. }
  52. /*
  53. if(c_simid % 10 == 0) {
  54. printf("Storage: c_simid = %d, Key:", c_simid);
  55. for (int k = 0; k<SGX_AESGCM_KEY_SIZE; k++) {
  56. printf("%x", (clients[i].key)[k]);
  57. }
  58. printf("\n");
  59. }
  60. */
  61. c_simid+=num_stg_nodes;
  62. }
  63. return true;
  64. }
  65. bool generate_all_tokens()
  66. {
  67. uint32_t pt_tokens_size = (g_teems_config.m_priv_out * SGX_AESGCM_KEY_SIZE);
  68. uint32_t enc_tokens_size = (g_teems_config.m_priv_out * SGX_AESGCM_KEY_SIZE) +
  69. SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  70. unsigned char token_body[pt_tokens_size];
  71. const sgx_aes_gcm_128bit_key_t *pTSK = &(g_teems_config.TSK);
  72. for(uint32_t lcid=0; lcid<storage_state.max_users; lcid++) {
  73. unsigned char *tkn_iv_ptr = epoch_tokens + enc_tokens_size * lcid;
  74. unsigned char *tkn_ptr = tkn_iv_ptr + SGX_AESGCM_IV_SIZE;
  75. unsigned char *tkn_tag = tkn_ptr + pt_tokens_size;
  76. // We construct the plaintext [S|R|Epoch] underlying the token in
  77. // the correct location for this client in the epoch_tokens buffer
  78. // The tokens ( i.e., CMAC over S|R|Epoch) is stored in token_body
  79. // Then encrypt token_body with the client's storage key and overwrite
  80. // the correct location in epoch_tokens
  81. memset(token_body, 0, pt_tokens_size);
  82. memset(tkn_iv_ptr, 0, SGX_AESGCM_IV_SIZE);
  83. // IV = epoch_no, for encrypting the token bundle
  84. memcpy(tkn_iv_ptr, &storage_epoch, sizeof(storage_epoch));
  85. unsigned char *ptr = tkn_ptr;
  86. for(int i = 0; i<g_teems_config.m_priv_out; i++)
  87. {
  88. memcpy(ptr, (&(clients[lcid].my_id)), sizeof(clientid_t));
  89. memcpy(ptr + sizeof(clientid_t), (&(clients[lcid].priv_friends[i])), sizeof(clientid_t));
  90. memcpy(ptr + 2 * sizeof(clientid_t), &storage_epoch, sizeof(storage_epoch));
  91. ptr+=SGX_AESGCM_KEY_SIZE;
  92. }
  93. sgx_status_t ret = SGX_SUCCESS;
  94. ret = sgx_rijndael128_cmac_msg(pTSK, tkn_ptr, pt_tokens_size,
  95. (sgx_cmac_128bit_tag_t*) &token_body);
  96. if(ret!=SGX_SUCCESS) {
  97. printf("generate_tokens: Creating token FAIL\n");
  98. return false;
  99. }
  100. /*
  101. if(lcid == 0) {
  102. printf("Checking generated token_body:");
  103. for(uint32_t i = 0; i < pt_tokens_size; i++) {
  104. printf("%x", token_body[i]);
  105. }
  106. printf("\n");
  107. }
  108. */
  109. unsigned char *cl_iv = clients[lcid].iv;
  110. ret = (sgx_rijndael128GCM_encrypt(&(clients[lcid].key), token_body, pt_tokens_size,
  111. (uint8_t*) tkn_ptr, cl_iv, SGX_AESGCM_IV_SIZE, NULL, 0,
  112. (sgx_aes_gcm_128bit_tag_t*) tkn_tag));
  113. if(ret!=SGX_SUCCESS) {
  114. printf("generate_tokens: Encrypting token FAIL\n");
  115. return false;
  116. }
  117. memcpy(tkn_iv_ptr, cl_iv, SGX_AESGCM_IV_SIZE);
  118. // Update IV
  119. uint64_t *iv_ctr = (uint64_t*) cl_iv;
  120. (*iv_ctr)+=1;
  121. /*
  122. if(lcid == 0) {
  123. printf("Encrypted client token bundle:");
  124. for(uint32_t i = 0; i < enc_tokens_size; i++) {
  125. printf("%x", tkn_iv_ptr[i]);
  126. }
  127. printf("\n");
  128. }
  129. */
  130. }
  131. return true;
  132. }
  133. /* processMsgs
  134. - Take all the messages in storage_state.stg_buf
  135. - Encrypt them all with their corresponding client key and IV and store into
  136. epoch_msgbundles
  137. - ecall_send_msgbundle();
  138. */
  139. bool processMsgs() {
  140. unsigned char *epoch_buf_ptr = epoch_msgbundles;
  141. unsigned char *stg_buf_ptr = storage_state.stg_buf.buf;
  142. uint32_t msg_bundle_size = g_teems_config.m_priv_in * g_teems_config.msg_size;
  143. uint32_t enc_msg_bundle_size = msg_bundle_size + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  144. sgx_status_t ret = SGX_SUCCESS;
  145. unsigned char *epoch_buf_ct_ptr = epoch_buf_ptr + SGX_AESGCM_IV_SIZE;
  146. unsigned char *epoch_buf_tag_ptr = epoch_buf_ct_ptr + msg_bundle_size;
  147. for(uint32_t lcid = 0; lcid <storage_state.max_users; lcid++) {
  148. memcpy(epoch_buf_ptr, clients[lcid].iv, SGX_AESGCM_IV_SIZE);
  149. ret = sgx_rijndael128GCM_encrypt(&(clients[lcid].key), stg_buf_ptr, msg_bundle_size,
  150. (uint8_t*) epoch_buf_ct_ptr, epoch_buf_ptr, SGX_AESGCM_IV_SIZE, NULL, 0,
  151. (sgx_aes_gcm_128bit_tag_t*) epoch_buf_tag_ptr);
  152. if(ret!=SGX_SUCCESS) {
  153. printf("processMsgs: Encrypting msgs FAIL\n");
  154. return false;
  155. }
  156. // Update IV
  157. uint64_t *iv_ctr = (uint64_t*) clients[lcid].iv;
  158. (*iv_ctr)+=1;
  159. if(lcid==0) {
  160. printf("\n\nMessage for lcid 0, S, R = %d, %d\n\n\n", *((uint32_t*) stg_buf_ptr),
  161. *((uint32_t*) (stg_buf_ptr + 4)));
  162. }
  163. stg_buf_ptr+=msg_bundle_size;
  164. epoch_buf_ptr+=enc_msg_bundle_size;
  165. epoch_buf_ct_ptr+=enc_msg_bundle_size;
  166. epoch_buf_tag_ptr+=enc_msg_bundle_size;
  167. }
  168. return true;
  169. }
  170. bool generateDummies() {
  171. unsigned char *epoch_buf_ptr = epoch_msgbundles;
  172. uint32_t msg_bundle_size = g_teems_config.m_priv_in * g_teems_config.msg_size;
  173. uint32_t enc_msg_bundle_size = msg_bundle_size + SGX_AESGCM_IV_SIZE + SGX_AESGCM_MAC_SIZE;
  174. sgx_status_t ret = SGX_SUCCESS;
  175. unsigned char *epoch_buf_ct_ptr = epoch_buf_ptr + SGX_AESGCM_IV_SIZE;
  176. unsigned char *epoch_buf_tag_ptr = epoch_buf_ct_ptr + msg_bundle_size;
  177. unsigned char *zeroes = new unsigned char[msg_bundle_size];
  178. memset(zeroes, 0, msg_bundle_size);
  179. for(uint32_t lcid = 0; lcid <storage_state.max_users; lcid++) {
  180. memcpy(epoch_buf_ptr, clients[lcid].iv, SGX_AESGCM_IV_SIZE);
  181. ret = sgx_rijndael128GCM_encrypt(&(clients[lcid].key), zeroes, msg_bundle_size,
  182. (uint8_t*) epoch_buf_ct_ptr, epoch_buf_ptr, SGX_AESGCM_IV_SIZE, NULL, 0,
  183. (sgx_aes_gcm_128bit_tag_t*) epoch_buf_tag_ptr);
  184. if(ret!=SGX_SUCCESS) {
  185. printf("processMsgs: Encrypting msgs FAIL\n");
  186. return false;
  187. }
  188. // Update IV
  189. uint64_t *iv_ctr = (uint64_t*) clients[lcid].iv;
  190. (*iv_ctr)+=1;
  191. epoch_buf_ptr+=enc_msg_bundle_size;
  192. epoch_buf_ct_ptr+=enc_msg_bundle_size;
  193. epoch_buf_tag_ptr+=enc_msg_bundle_size;
  194. }
  195. delete zeroes;
  196. return true;
  197. }
  198. // route_init will call this function; no one else should call it
  199. // explicitly. The parameter is the number of messages that can fit in
  200. // the storage-side MsgBuffer. Returns true on success, false on
  201. // failure.
  202. bool storage_init(uint32_t max_users, uint32_t msg_buf_size)
  203. {
  204. storage_state.max_users = max_users;
  205. storage_state.stg_buf.alloc(msg_buf_size);
  206. storage_state.dest.resize(msg_buf_size);
  207. uint32_t my_storage_node_id = 0;
  208. uint32_t my_stg_pos = 0;
  209. for (nodenum_t i=0; i<g_teems_config.num_nodes; ++i) {
  210. if (g_teems_config.roles[i] & ROLE_STORAGE) {
  211. if (i == g_teems_config.my_node_num) {
  212. storage_state.my_storage_node_id = my_storage_node_id << DEST_UID_BITS;
  213. my_stg_pos = my_storage_node_id;
  214. } else {
  215. ++my_storage_node_id;
  216. }
  217. }
  218. }
  219. printf("my_stg_pos = %d\n", my_stg_pos);
  220. storage_generateClientKeys(max_users, my_stg_pos);
  221. // sendClientTokens();
  222. return true;
  223. }
  224. // Handle the messages received by a storage node. Pass a _locked_
  225. // MsgBuffer. This function will itself reset and unlock it when it's
  226. // done with it.
  227. void storage_received(MsgBuffer &storage_buf)
  228. {
  229. uint16_t msg_size = g_teems_config.msg_size;
  230. nodenum_t my_node_num = g_teems_config.my_node_num;
  231. const uint8_t *msgs = storage_buf.buf;
  232. uint32_t num_msgs = storage_buf.inserted;
  233. uint32_t real = 0, padding = 0;
  234. uint32_t uid_mask = (1 << DEST_UID_BITS) - 1;
  235. uint32_t nid_mask = ~uid_mask;
  236. #ifdef PROFILE_STORAGE
  237. unsigned long start_received = printf_with_rtclock("begin storage_received (%u)\n", storage_buf.inserted);
  238. #endif
  239. // It's OK to test for errors in a way that's non-oblivous if
  240. // there's an error (but it should be oblivious if there are no
  241. // errors)
  242. for (uint32_t i=0; i<num_msgs; ++i) {
  243. uint32_t uid = *(const uint32_t*)(storage_buf.buf+(i*msg_size));
  244. bool ok = ((((uid & nid_mask) == storage_state.my_storage_node_id)
  245. & ((uid & uid_mask) < storage_state.max_users))
  246. | ((uid & uid_mask) == uid_mask));
  247. if (!ok) {
  248. printf("Received bad uid: %08x\n", uid);
  249. assert(ok);
  250. }
  251. }
  252. // Testing: report how many real and dummy messages arrived
  253. printf("Storage server received %u messages:\n", num_msgs);
  254. for (uint32_t i=0; i<num_msgs; ++i) {
  255. uint32_t dest_addr = *(const uint32_t*)msgs;
  256. nodenum_t dest_node =
  257. g_teems_config.storage_map[dest_addr >> DEST_UID_BITS];
  258. if (dest_node != my_node_num) {
  259. char hexbuf[2*msg_size + 1];
  260. for (uint32_t j=0;j<msg_size;++j) {
  261. snprintf(hexbuf+2*j, 3, "%02x", msgs[j]);
  262. }
  263. printf("Misrouted message: %s\n", hexbuf);
  264. } else if ((dest_addr & uid_mask) == uid_mask) {
  265. ++padding;
  266. } else {
  267. ++real;
  268. }
  269. msgs += msg_size;
  270. }
  271. printf("%u real, %u padding\n", real, padding);
  272. /*
  273. for (uint32_t i=0;i<num_msgs; ++i) {
  274. printf("%3d: %08x %08x\n", i,
  275. *(uint32_t*)(storage_buf.buf+(i*msg_size)),
  276. *(uint32_t*)(storage_buf.buf+(i*msg_size+4)));
  277. }
  278. */
  279. // Sort the received messages by userid into the
  280. // storage_state.stg_buf MsgBuffer.
  281. #ifdef PROFILE_STORAGE
  282. unsigned long start_sort = printf_with_rtclock("begin oblivious sort (%u)\n", storage_buf.inserted);
  283. #endif
  284. sort_mtobliv<UidKey>(g_teems_config.nthreads, storage_buf.buf,
  285. msg_size, storage_buf.inserted, storage_buf.bufsize,
  286. storage_state.stg_buf.buf);
  287. #ifdef PROFILE_STORAGE
  288. printf_with_rtclock_diff(start_sort, "end oblivious sort (%u)\n", storage_buf.inserted);
  289. #endif
  290. /*
  291. for (uint32_t i=0;i<num_msgs; ++i) {
  292. printf("%3d: %08x %08x\n", i,
  293. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  294. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)));
  295. }
  296. */
  297. #ifdef PROFILE_STORAGE
  298. unsigned long start_dest = printf_with_rtclock("begin setting dests (%u)\n", storage_state.stg_buf.bufsize);
  299. #endif
  300. // Obliviously set the dest array
  301. uint32_t *dests = storage_state.dest.data();
  302. uint32_t stg_size = storage_state.stg_buf.bufsize;
  303. const uint8_t *buf = storage_state.stg_buf.buf;
  304. uint32_t m_priv_in = g_teems_config.m_priv_in;
  305. uint32_t uid = *(uint32_t*)(buf);
  306. uid &= uid_mask;
  307. // num_msgs is not a private value
  308. if (num_msgs > 0) {
  309. dests[0] = oselect_uint32_t(uid * m_priv_in, 0xffffffff,
  310. uid == uid_mask);
  311. }
  312. uint32_t prev_uid = uid;
  313. for (uint32_t i=1; i<num_msgs; ++i) {
  314. uid = *(uint32_t*)(buf + i*msg_size);
  315. uid &= uid_mask;
  316. uint32_t next = oselect_uint32_t(uid * m_priv_in, dests[i-1]+1,
  317. uid == prev_uid);
  318. dests[i] = oselect_uint32_t(next, 0xffffffff, uid == uid_mask);
  319. prev_uid = uid;
  320. }
  321. for (uint32_t i=num_msgs; i<stg_size; ++i) {
  322. dests[i] = 0xffffffff;
  323. *(uint32_t*)(buf + i*msg_size) = 0xffffffff;
  324. }
  325. #ifdef PROFILE_STORAGE
  326. printf_with_rtclock_diff(start_dest, "end setting dests (%u)\n", stg_size);
  327. #endif
  328. /*
  329. for (uint32_t i=0;i<stg_size; ++i) {
  330. printf("%3d: %08x %08x %u\n", i,
  331. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  332. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  333. dests[i]);
  334. }
  335. */
  336. #ifdef PROFILE_STORAGE
  337. unsigned long start_expand = printf_with_rtclock("begin ORExpand (%u)\n", stg_size);
  338. #endif
  339. ORExpand_parallel<OSWAP_16X>(storage_state.stg_buf.buf, dests,
  340. msg_size, stg_size, g_teems_config.nthreads);
  341. #ifdef PROFILE_STORAGE
  342. printf_with_rtclock_diff(start_expand, "end ORExpand (%u)\n", stg_size);
  343. #endif
  344. /*
  345. for (uint32_t i=0;i<stg_size; ++i) {
  346. printf("%3d: %08x %08x %u\n", i,
  347. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size)),
  348. *(uint32_t*)(storage_state.stg_buf.buf+(i*msg_size+4)),
  349. dests[i]);
  350. }
  351. */
  352. // You can do more processing after these lines, as long as they
  353. // don't touch storage_buf. They _can_ touch the backing buffer
  354. // storage_state.stg_buf.
  355. storage_buf.reset();
  356. pthread_mutex_unlock(&storage_buf.mutex);
  357. generate_all_tokens();
  358. uint32_t num_expected_msgs = g_teems_config.m_priv_in * storage_state.max_users;
  359. if(num_msgs!= 0) {
  360. processMsgs();
  361. } else {
  362. generateDummies();
  363. }
  364. storage_epoch++;
  365. storage_state.stg_buf.reset();
  366. #ifdef PROFILE_STORAGE
  367. printf_with_rtclock_diff(start_received, "end storage_received (%u)\n", storage_buf.inserted);
  368. #endif
  369. }
  370. bool ecall_storage_authenticate(clientid_t cid, unsigned char *auth_message)
  371. {
  372. bool ret = false;
  373. uint32_t lcid = cid / g_teems_config.num_storage_nodes;
  374. const sgx_aes_gcm_128bit_key_t *ckey = &(clients[lcid].key);
  375. ret = authenticateClient(auth_message, ckey);
  376. if(!ret) {
  377. printf("Storage authentication FAIL\n");
  378. }
  379. return ret;
  380. }
  381. void ecall_supply_storage_buffers(unsigned char *msgbundles,
  382. uint32_t msgbundles_size, unsigned char *tokens, uint32_t tokens_size)
  383. {
  384. epoch_msgbundles = msgbundles;
  385. epoch_tokens = tokens;
  386. }