#include #include "utils.hpp" #ifdef COUNT_OSWAPS thread_local uint64_t OSWAP_COUNTER=0; #endif thread_local PRB_buffer PRB_buf; thread_local uint64_t PRB_rand_bits = 0; thread_local uint32_t PRB_rand_bits_remaining = 0; bool bulk_initialized = false; sgx_aes_ctr_128bit_key_t bulk_random_seed[SGX_AESCTR_KEY_SIZE]; unsigned char bulk_counter[SGX_AESCTR_KEY_SIZE]; int compare(const void *buf1, const void *buf2) { uint64_t label1, label2; memcpy(&label1, (const unsigned char*) buf1, 8); memcpy(&label2, (const unsigned char*) buf2, 8); return((int)(label1 - label2)); } int compare_32(const void *buf1, const void *buf2) { uint32_t label1, label2; memcpy(&label1, (const unsigned char*) buf1, 4); memcpy(&label2, (const unsigned char*) buf2, 4); return((int)(label1 - label2)); } #if 0 void generateSortPermutation_DJB(size_t N, unsigned char *buffer, size_t block_size, size_t *permutation) { size_t *keys; try { keys = new size_t[N]; } catch (std::bad_alloc&) { printf("Allocating memory failed in generateSortPermutation_DJB\n"); } unsigned char *buffer_ptr = buffer; for(size_t i=0; i ((unsigned char*) keys, N, (unsigned char*) permutation, NULL, 4, true); /* printf("\nSort Permutation:\n"); for(size_t i=0; i>= 1; return N1; } #ifndef BEFTS_MODE /* * printf: * Invokes OCALL to display the enclave buffer to the terminal. */ void printf(const char *fmt, ...) { char buf[BUFSIZ] = {'\0'}; va_list ap; va_start(ap, fmt); vsnprintf(buf, BUFSIZ, fmt, ap); va_end(ap); ocall_print_string(buf); } /* * printf_with_rtclock: * Invokes OCALL to display the enclave buffer to the terminal with a * timestamp and returns the timestamp. */ unsigned long printf_with_rtclock(const char *fmt, ...) { unsigned long ret; char buf[BUFSIZ] = {'\0'}; va_list ap; va_start(ap, fmt); vsnprintf(buf, BUFSIZ, fmt, ap); va_end(ap); ocall_print_string_with_rtclock(&ret, buf); return ret; } /* * printf_with_rtclock_diff: * Invokes OCALL to display the enclave buffer to the terminal with a * timestamp and returns the timestamp. Also prints the difference from * the before timestamp. */ unsigned long printf_with_rtclock_diff(unsigned long before, const char *fmt, ...) { unsigned long ret; char buf[BUFSIZ] = {'\0'}; va_list ap; va_start(ap, fmt); vsnprintf(buf, BUFSIZ, fmt, ap); va_end(ap); ocall_print_string_with_rtclock_diff(&ret, buf, before); return ret; } #endif #if 0 void displayORPPacket(unsigned char* packet_in, size_t block_size) { unsigned char *packet_ptr = packet_in; uint64_t evict_stream, ORP_label, key; unsigned char data[block_size]; memcpy(&evict_stream, packet_ptr, sizeof(uint64_t)); packet_ptr+=sizeof(uint64_t); memcpy(&ORP_label, packet_ptr, sizeof(uint64_t)); packet_ptr+=sizeof(uint64_t); memcpy(&key, packet_ptr, sizeof(uint64_t)); packet_ptr+=sizeof(uint64_t); memcpy(data, packet_ptr, block_size); data[block_size]='\0'; printf("(evict_stream = %ld, ORP_label = %ld, Key = %ld)\n", evict_stream, ORP_label, key); //printf("Hex of data is :"); //for(int i=0;i bool isDummy(unsigned char *ptr_to_serialized_packet){ return(((uint64_t*) ptr_to_serialized_packet)[0] == UINT64_MAX); } void setDummy(unsigned char *ptr_to_serialized_packet){ ((uint64_t*) ptr_to_serialized_packet)[0] = UINT64_MAX; } // isORPDummy and setORPDummy works on ORP packets : bool isORPDummy(unsigned char *ptr_to_serialized_packet){ return(((uint64_t*) ptr_to_serialized_packet)[1] == UINT64_MAX); } void setORPDummy(unsigned char *ptr_to_packet){ ((uint64_t*) ptr_to_packet)[0] = UINT64_MAX; ((uint64_t*) ptr_to_packet)[1] = UINT64_MAX; ((uint64_t*) ptr_to_packet)[2] = UINT64_MAX; } size_t packetsConsumedUptoMSN(signed long msn_no, size_t msns_with_extra_packets, size_t packets_per_entry_msn) { if(msn_no<0) return 0; if(msn_no<=msns_with_extra_packets){ return (msn_no * (packets_per_entry_msn+1)); } else{ size_t reg_msn = msn_no - msns_with_extra_packets; return ((reg_msn * packets_per_entry_msn) + (msns_with_extra_packets * packets_per_entry_msn)); } } #endif #ifdef USE_PRB void PRB_pool_init(int nthreads) { // Nothing needs to be done any more } void PRB_pool_shutdown() { // Nothing needs to be done any more } PRB_buffer::PRB_buffer() { } sgx_status_t PRB_buffer::init_PRB_buffer(uint32_t buffer_size = PRB_BUFFER_SIZE) { sgx_status_t rt = SGX_SUCCESS; if(initialized==false) { rt = sgx_read_rand((unsigned char*) random_seed, SGX_AESCTR_KEY_SIZE); if(rt!=SGX_SUCCESS){ printf("Failed sgx_read_rand (%x)", rt); return rt; } rt = sgx_read_rand((unsigned char*) counter, SGX_AESCTR_KEY_SIZE); if(rt!=SGX_SUCCESS){ printf("Failed sgx_read_rand (%x)", rt); return rt; } initialized=true; } char zeroes[buffer_size]; // We don't bother initializing to zeroes since AES_CTR just adds the PRB_stream to the buffer // Use AES CTR to populate random_bytes rt = sgx_aes_ctr_encrypt(random_seed, (const uint8_t*) zeroes, buffer_size, (uint8_t*) counter, CTR_INC_BITS, random_bytes); *(uint64_t*)counter += 1; if(rt!=SGX_SUCCESS){ printf("Failed sgx_aes_ctr_encrypt (%x) in init_getRandomBytes\n", rt); return rt; } random_bytes_left = PRB_BUFFER_SIZE; random_bytes_ptr = random_bytes; return rt; } sgx_status_t PRB_buffer::getRandomBytes(unsigned char *buffer, size_t size) { sgx_status_t rt = SGX_SUCCESS; if(initialized==false) init_PRB_buffer(); if(size < random_bytes_left) { // Supply buffer with random bytes from random_bytes memcpy(buffer, random_bytes_ptr, size); random_bytes_ptr+=size; random_bytes_left-= size; return rt; } else { // Consume all the random bytes we have left unsigned char *ptr = buffer; size_t size_left_for_req = size - random_bytes_left; memcpy(ptr, random_bytes_ptr, random_bytes_left); ptr+= random_bytes_left; // Use AES CTR to populate random_bytes rt = sgx_aes_ctr_encrypt(random_seed, (const uint8_t*) random_bytes, PRB_BUFFER_SIZE, (uint8_t*) counter, CTR_INC_BITS, random_bytes); *(uint64_t*)counter += 1; if(rt!=SGX_SUCCESS){ printf("Failed sgx_aes_ctr_encrypt (%x)", rt); return rt; } random_bytes_left = PRB_BUFFER_SIZE; random_bytes_ptr = random_bytes; // Add size_left_for_req random bytes to the buffer memcpy(ptr, random_bytes_ptr, size_left_for_req); random_bytes_ptr+=size_left_for_req; random_bytes_left-=size_left_for_req; return rt; } } /* sgx_status_t PRB_buffer::getBulkRandomBytes(unsigned char *buffer, size_t size) { sgx_status_t rt = SGX_SUCCESS; rt = sgx_aes_ctr_encrypt(random_seed, (const uint8_t*) buffer, size, (uint8_t*) counter, CTR_INC_BITS, buffer); *(uint64_t*)counter += 1; if(rt!=SGX_SUCCESS){ printf("Failed sgx_aes_ctr_encrypt (%x) in getBulkRandomBytes [%p %p %lu %p %d %p]\n", rt, random_seed, (const uint8_t*) buffer, size, (uint8_t*) counter, CTR_INC_BITS, buffer); return rt; } return rt; } sgx_status_t initialize_BRB() { sgx_status_t rt = SGX_SUCCESS; rt = sgx_read_rand((unsigned char*) bulk_random_seed, SGX_AESCTR_KEY_SIZE); if(rt!=SGX_SUCCESS){ printf("initialize_BRB(): Failed sgx_read_rand (%x)", rt); return rt; } rt = sgx_read_rand((unsigned char*) bulk_counter, SGX_AESCTR_KEY_SIZE); if(rt!=SGX_SUCCESS){ printf("initialize_BRB(): Failed sgx_read_rand (%x)", rt); return rt; } bulk_initialized = true; return rt; } sgx_status_t getBulkRandomBytes(unsigned char *buffer, size_t size) { if(bulk_initialized == false){ initialize_BRB(); } sgx_status_t rt = SGX_SUCCESS; rt = sgx_aes_ctr_encrypt(bulk_random_seed, (const uint8_t*) buffer, size, (uint8_t*) bulk_counter, CTR_INC_BITS, buffer); if(rt!=SGX_SUCCESS){ printf("getBulkRandomBytes: Failed sgx_aes_ctr_encrypt (%x) in getBulkRandomBytes [%p %p %lu %p %d %p]\n", rt, bulk_random_seed, (const uint8_t*) buffer, size, (uint8_t*) bulk_counter, CTR_INC_BITS, buffer); return rt; } return rt; } */ #else sgx_status_t getRandomBytes(unsigned char *random_bytes, size_t size) { sgx_status_t rt = SGX_SUCCESS; rt = sgx_read_rand((unsigned char*) random_bytes, size); return rt; } #endif unsigned char* compare_keys(unsigned char *packet_1, unsigned char *packet_2){ if( *((uint64_t*)(packet_1)) < *((uint64_t*)(packet_2))){ return packet_1; } else { return packet_2; } } void merge(unsigned char *data, size_t data_size, size_t l, size_t m, size_t r, unsigned char* (*comparator)(unsigned char*, unsigned char*)){ uint64_t i=0, j=0, k=0; size_t s1, s2; s1 = l+(m-l+1); s2 = (m+1)+(r-m); //unsigned char merged_array[(r-l+1)*data_size]; unsigned char *merged_array = (unsigned char*) malloc((r-l+1)*data_size); i = l; j = m+1; k = 0; while (i < s1 && j < s2) { unsigned char *smaller_pkt = comparator(data+(i*data_size), data+(j*data_size)); if(smaller_pkt == data+(i*data_size)){ memcpy(merged_array+(k*data_size), smaller_pkt, data_size); i++; } else{ memcpy(merged_array+(k*data_size), smaller_pkt, data_size); j++; } k++; } while (i < s1) { memcpy(merged_array + (k*data_size), data+(i*data_size), data_size); i++; k++; } while (j < s2) { memcpy(merged_array + (k*data_size), data+(j*data_size), data_size); j++; k++; } memcpy(data+(l*data_size), merged_array, data_size * ((r-l)+1)); free(merged_array); } void mergeSort(unsigned char *data, size_t data_size, size_t start_index, size_t end_index, unsigned char* (*comparator)(unsigned char*, unsigned char*)){ if(start_index < end_index){ size_t m = start_index + (end_index-start_index)/2; mergeSort(data, data_size, start_index, m, comparator); mergeSort(data, data_size, m+1, end_index, comparator); merge(data, data_size, start_index, m , end_index, comparator); } } void mergeSort_OPRM(unsigned char *data, size_t data_size, size_t start_index, size_t end_index, unsigned char* (*comparator)(unsigned char*, unsigned char*)){ if(start_index < end_index){ size_t m = start_index + (end_index-start_index)/2; mergeSort(data, data_size, start_index, m, comparator); mergeSort(data, data_size, m+1, end_index, comparator); merge(data, data_size, start_index, m , end_index, comparator); } } #if 0 //Tight Compaction and Expansion utility functions for testing if a Block is real/dummy uint8_t isBlockReal_16(unsigned char *block_ptr) { uint16_t label = *((uint16_t *)(block_ptr)); return (label==UINT16_MAX); } uint8_t isBlockReal_32(unsigned char *block_ptr) { uint32_t label = *((uint32_t *)(block_ptr)); return (label==UINT32_MAX); } uint8_t isBlockReal_64(unsigned char *block_ptr) { uint64_t label = *((uint64_t *)(block_ptr)); return (label==UINT64_MAX); } void oswap_buffer(unsigned char *dest, unsigned char *source, uint32_t buffer_size, uint8_t flag){ #ifdef COUNT_OSWAPS uint64_t *ltvp = &OSWAP_COUNTER; FOAV_SAFE2_CNTXT(oswap_buffer, buffer_size, *ltvp) OSWAP_COUNTER++; #endif if(buffer_size%16==0){ oswap_buffer_16x(dest, source, buffer_size, flag); } else if(buffer_size==8){ oswap_buffer_byte(dest, source, buffer_size, flag); } else{ oswap_buffer_byte(dest, source, 8, flag); oswap_buffer_16x(dest+8, source+8, buffer_size-8, flag); } } uint8_t isCorrect16x(uint32_t block_size){ printf("Entered Correctness Tester!!!\n"); bool is_correct = true; unsigned char *b1 = new unsigned char[block_size]; unsigned char *b2 = new unsigned char[block_size]; unsigned char *b3 = new unsigned char[block_size]; unsigned char *b4 = new unsigned char[block_size]; getBulkRandomBytes(b1, block_size); getBulkRandomBytes(b2, block_size); memcpy(b3, b1, block_size); memcpy(b4, b2, block_size); bool swap_flag = false; oswap_buffer(b1, b2, block_size, swap_flag); if(memcmp(b1, b3, block_size)){ is_correct=false; printf("Failed Test 1\n"); } if(memcmp(b2, b4, block_size)){ is_correct=false; printf("Failed Test 2\n"); } memcpy(b1, b3, block_size); memcpy(b2, b4, block_size); swap_flag = true; oswap_buffer(b1, b2, block_size, swap_flag); if(memcmp(b1, b4, block_size)){ is_correct=false; printf("Failed Test 3\n"); } if(memcmp(b2, b3, block_size)){ is_correct=false; printf("Failed Test 4\n"); } delete []b1; delete []b2; delete []b3; delete []b4; if(is_correct){ printf("Correctness test SUCCESS! \n"); return true; } return false; } uint8_t isCorrect8_16x(uint32_t block_size){ printf("Entered Correctness Tester!!!\n"); bool is_correct = true; unsigned char *b1 = new unsigned char[block_size]; unsigned char *b2 = new unsigned char[block_size]; unsigned char *b3 = new unsigned char[block_size]; unsigned char *b4 = new unsigned char[block_size]; getBulkRandomBytes(b1, block_size); getBulkRandomBytes(b2, block_size); memcpy(b3, b1, block_size); memcpy(b4, b2, block_size); bool swap_flag = false; oswap_buffer(b1, b2, block_size, swap_flag); if(memcmp(b1, b3, block_size)){ is_correct=false; printf("Failed Test 1\n"); } if(memcmp(b2, b4, block_size)){ is_correct=false; printf("Failed Test 2\n"); } memcpy(b1, b3, block_size); memcpy(b2, b4, block_size); swap_flag = true; oswap_buffer(b1, b2, block_size, swap_flag); if(memcmp(b1, b4, block_size)){ is_correct=false; printf("Failed Test 3\n"); } if(memcmp(b2, b3, block_size)){ is_correct=false; printf("Failed Test 4\n"); } delete []b1; delete []b2; delete []b3; delete []b4; if(is_correct){ printf("Correctness test SUCCESS! \n"); return true; } return false; } void swapBuckets(unsigned char *bkt1, unsigned char *bkt2, unsigned char *temp_bucket, size_t bucket_size) { memcpy(temp_bucket, bkt2, bucket_size); memcpy(bkt2, bkt1, bucket_size); memcpy(bkt1, temp_bucket, bucket_size); } #endif /*** Thread pool implementation ***/ /* Implements a restricted-model thread pool. The restriction is that * every thread is the "parent" of a number of other threads (and no * thread has more than one parent). Each thread can be dispatched and * joined only by its parent, so there's no contention on the dispatch * and join inter-thread communication. A parent thread has to specify * the exact thread id of the child thread it dispatches work to. */ thread_local threadid_t g_thread_id = 0; enum threadstate_t { THREADSTATE_NONE, THREADSTATE_WAITING, THREADSTATE_DISPATCHING, THREADSTATE_WORKING, THREADSTATE_TERMINATE }; struct threadblock_t { threadid_t threadid; threadstate_t state; pthread_t thread_handle; pthread_mutex_t mutex; pthread_cond_t dispatch_cond; void *(*dispatch_func)(void *data); void *dispatch_data; pthread_cond_t join_cond; void *ret_data; #ifdef COUNT_OSWAPS size_t num_oswaps; #endif }; static threadblock_t *threadpool_control_blocks = NULL; static threadid_t threadpool_numthreads = 0; /* The main thread loop */ static void* threadloop(void *vdata) { threadblock_t *block = (threadblock_t *)vdata; /* Initialize any per-thread state */ g_thread_id = block->threadid; PRB_rand_bits = 0; PRB_rand_bits_remaining = 0; pthread_mutex_lock(&block->mutex); while(1) { /* Wait for work */ block->state = THREADSTATE_WAITING; pthread_cond_wait(&block->dispatch_cond, &block->mutex); if (block->state == THREADSTATE_TERMINATE) { break; } /* Do the work */ block->state = THREADSTATE_WORKING; pthread_mutex_unlock(&block->mutex); block->ret_data = (block->dispatch_func)(block->dispatch_data); #ifdef COUNT_OSWAPS /* Account for the oswaps done in this thread */ block->num_oswaps = OSWAP_COUNTER; OSWAP_COUNTER = 0; #endif /* Signal the parent thread that we're done, and loop back to * wait for more work. */ pthread_mutex_lock(&block->mutex); pthread_cond_signal(&block->join_cond); } block->state = THREADSTATE_NONE; pthread_mutex_unlock(&block->mutex); return NULL; } /* Create the threadpool, with numthreads-1 additional threads (numbered * 1 through numthreads-1) in addition to the current "main" thread * (numbered 0). Returns 0 on success, -1 on failure. It is allowed, but * not very useful, to pass 1 here. */ int threadpool_init(threadid_t numthreads) { g_thread_id = 0; PRB_rand_bits = 0; PRB_rand_bits_remaining = 0; if (numthreads < 1) { return -1; } else if (numthreads == 1) { threadpool_numthreads = 1; return 0; } /* We don't actually create a thread control block for the main * thread 0, so the internal indexing into this array will be that * thread i's control block lives at index i-1 in this array. */ threadpool_control_blocks = new threadblock_t[numthreads-1]; if (threadpool_control_blocks == NULL) { return -1; } threadpool_numthreads = numthreads; /* Init each thread control block */ bool thread_create_failure = false; for (threadid_t i = 0; i < numthreads-1; ++i) { threadblock_t *block = threadpool_control_blocks + i; block->threadid = i+1; block->state = THREADSTATE_NONE; pthread_mutex_init(&block->mutex, NULL); pthread_cond_init(&block->dispatch_cond, NULL); pthread_cond_init(&block->join_cond, NULL); block->thread_handle = NULL; int create_ret = pthread_create(&block->thread_handle, NULL, threadloop, block); if (create_ret) { thread_create_failure = true; printf("Failed to launch thread %lu; ret=%d\n", i+1, create_ret); } } if (thread_create_failure) { threadpool_shutdown(); return -1; } return 0; } /* Ask all the threads to terminate, wait for that to happen, and clean * up. */ void threadpool_shutdown() { /* Note that this function may be called when some threads failed to * launch at all in threadpool_init. In that case, the thread field * in the thread's control block will be NULL. The mutex/cond * variables will still have been initialized, however, and need * cleaning. */ if (threadpool_numthreads == 0) { /* Nothing to do */ return; } if (threadpool_numthreads == 1) { /* Almost nothing to do */ threadpool_numthreads = 0; return; } for (threadid_t i=0;imutex); if (block->state == THREADSTATE_WORKING) { /* There's a thread actively running? Wait for it to * finish. */ pthread_mutex_unlock(&block->mutex); threadpool_join(i+1, NULL); pthread_mutex_lock(&block->mutex); } if (block->state == THREADSTATE_WAITING) { /* Tell the thread to exit */ block->state = THREADSTATE_TERMINATE; pthread_mutex_unlock(&block->mutex); pthread_cond_signal(&block->dispatch_cond); pthread_join(block->thread_handle, NULL); block->thread_handle = NULL; } if (block->state != THREADSTATE_NONE) { printf("Unexpected state on thread %lu during shutdown: %u\n", i+1, block->state); pthread_cond_destroy(&block->dispatch_cond); pthread_cond_destroy(&block->join_cond); pthread_mutex_destroy(&block->mutex); } } delete[] threadpool_control_blocks; threadpool_control_blocks = NULL; threadpool_numthreads = 0; } /* Dispatch some work to a particular thread in the thread pool. */ void threadpool_dispatch(threadid_t threadid, void *(*func)(void*), void *data) { threadblock_t *block = threadpool_control_blocks + (threadid-1); pthread_mutex_lock(&block->mutex); if (block->state != THREADSTATE_WAITING) { printf("Thread %lu not in expected WAITING state: %u\n", threadid, block->state); pthread_mutex_unlock(&block->mutex); return; } block->dispatch_func = func; block->dispatch_data = data; block->state = THREADSTATE_DISPATCHING; pthread_mutex_unlock(&block->mutex); /* Tell the thread there's work to do */ pthread_cond_signal(&block->dispatch_cond); } /* Join a thread */ void threadpool_join(threadid_t threadid, void **resp) { threadblock_t *block = threadpool_control_blocks + (threadid-1); pthread_mutex_lock(&block->mutex); /* Did the thread finish already? */ if (block->state == THREADSTATE_DISPATCHING || block->state == THREADSTATE_WORKING) { /* Wait until the thread completes */ pthread_cond_wait(&block->join_cond, &block->mutex); } else if (block->state != THREADSTATE_WAITING) { printf("Thread %lu in unexpected state (not WORKING or WAITING) on join: %u\n", threadid, block->state); } if (resp) { *resp = block->ret_data; } #ifdef COUNT_OSWAPS uint64_t *ltvp = &OSWAP_COUNTER; FOAV_SAFE_CNTXT(oswap_buffer, *ltvp) OSWAP_COUNTER += block->num_oswaps; block->num_oswaps = 0; #endif pthread_mutex_unlock(&block->mutex); }