RecursiveShuffle.cpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #ifndef BEFTS_MODE
  2. #include <array>
  3. #include <sgx_tcrypto.h>
  4. #include "oasm_lib.h"
  5. #include "utils.hpp"
  6. #include "RecursiveShuffle.hpp"
  7. #endif
  8. size_t RS_RB_BUFFER_SIZE;
  9. unsigned char *random_bytes_buffer = NULL;
  10. uint32_t *random_bytes_buffer_ptr;
  11. uint32_t *random_bytes_buffer_ptr_end;
  12. /*
  13. MarkHalf: Marks half of the elements of an N sized array randomly.
  14. Pass in a bool array of size N, which will be populated with 1's at indexes which
  15. r get marked by MarkHalf
  16. NOTE: MarkHalf assumes selected_list is initialized to all 0's before passed to MarkHalf
  17. */
  18. void MarkHalf(uint64_t N, bool *selected_list) {
  19. uint64_t left_to_mark = N/2;
  20. uint64_t total_left = N;
  21. uint32_t coins[RS_MARKHALF_MAX_COINS];
  22. size_t coinsleft=0;
  23. FOAV_SAFE_CNTXT(MarkHalf_marking_half, N)
  24. for(uint64_t i=0; i<N; i++){
  25. FOAV_SAFE2_CNTXT(MarkHalf_marking_half, i, coinsleft)
  26. if (coinsleft == 0) {
  27. size_t numcoins = (N-i);
  28. FOAV_SAFE_CNTXT(MarkHalf_marking_half, numcoins)
  29. if (numcoins > RS_MARKHALF_MAX_COINS) {
  30. numcoins = RS_MARKHALF_MAX_COINS;
  31. }
  32. PRB_buf.getRandomBytes((unsigned char *) coins,
  33. sizeof(coins[0])*numcoins);
  34. coinsleft = numcoins;
  35. }
  36. //Mark with probability left_to_mark/total_left;
  37. uint32_t random_coin;
  38. random_coin = (total_left * coins[--coinsleft]) >> 32;
  39. uint32_t mark_threshold = total_left - left_to_mark;
  40. uint8_t mark_element = oge_set_flag(random_coin, mark_threshold);
  41. //If mark_element, obliviously set selected_list[i] to 1
  42. FOAV_SAFE_CNTXT(MarkHalf_marking_half, i)
  43. selected_list[i] = mark_element;
  44. left_to_mark-= mark_element;
  45. total_left--;
  46. FOAV_SAFE2_CNTXT(MarkHalf_marking_half, i, N)
  47. }
  48. }
  49. #if 0
  50. #ifndef BEFTS_MODE
  51. void RecursiveShuffle_M1(unsigned char *buf, uint64_t N, size_t block_size) {
  52. FOAV_SAFE2_CNTXT(RS_M1, N, block_size)
  53. size_t num_random_bytes = calculatelog2(N) * N * sizeof(uint32_t);
  54. #ifdef RS_M2_MEM_OPT1
  55. FOAV_SAFE2_CNTXT(RS_M1, num_random_bytes, RS_RB_BUFFER_LIMIT)
  56. if(num_random_bytes > RS_RB_BUFFER_LIMIT) {
  57. RS_RB_BUFFER_SIZE = RS_RB_BUFFER_LIMIT;
  58. }
  59. else{
  60. RS_RB_BUFFER_SIZE = num_random_bytes;
  61. }
  62. try {
  63. random_bytes_buffer = new unsigned char[RS_RB_BUFFER_SIZE];
  64. //FOAV_SAFE_CNTXT(RS_M1_initializing_selected_list, N)
  65. selected_list = new bool[N]{};
  66. } catch (std::bad_alloc&){
  67. printf("Allocating memory failed in RS_M2\n");
  68. }
  69. getBulkRandomBytes((unsigned char*)random_bytes_buffer, RS_RB_BUFFER_SIZE);
  70. random_bytes_buffer_ptr_end = (uint32_t*)(random_bytes_buffer + RS_RB_BUFFER_SIZE);
  71. #else
  72. try {
  73. random_bytes_buffer = new unsigned char[num_random_bytes];
  74. selected_list = new bool[N]{};
  75. } catch (std::bad_alloc&){
  76. printf("Allocating memory failed in RS_M2\n");
  77. }
  78. getBulkRandomBytes((unsigned char*)random_bytes_buffer, num_random_bytes);
  79. #endif
  80. random_bytes_buffer_ptr = (uint32_t*) random_bytes_buffer;
  81. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  82. if(block_size==4){
  83. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  84. RecursiveShuffle_M1_inner<OSWAP_4>(buf, N, block_size, selected_list);
  85. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  86. } else if(block_size==8){
  87. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  88. RecursiveShuffle_M1_inner<OSWAP_8>(buf, N, block_size, selected_list);
  89. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  90. } else if(block_size%16==0) {
  91. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  92. RecursiveShuffle_M1_inner<OSWAP_16X>(buf, N, block_size, selected_list);
  93. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  94. } else {
  95. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  96. RecursiveShuffle_M1_inner<OSWAP_8_16X>(buf, N, block_size, selected_list);
  97. FOAV_SAFE_CNTXT(RS_M1_branching_on_block_size_for_OSwap_Style_templates, block_size)
  98. }
  99. FOAV_SAFE_CNTXT(RecursiveShuffle_M1_delete, random_bytes_buffer)
  100. delete []random_bytes_buffer;
  101. FOAV_SAFE_CNTXT(RecursiveShuffle_M1_delete, selected_list)
  102. delete []selected_list;
  103. }
  104. #endif
  105. #endif
  106. void RecursiveShuffle_M2(unsigned char *buf, uint64_t N, size_t block_size){
  107. RecursiveShuffle_M2_parallel(buf, N, block_size, 1);
  108. }
  109. void RecursiveShuffle_M2_parallel(unsigned char *buf, uint64_t N, size_t block_size, size_t nthreads){
  110. FOAV_SAFE2_CNTXT(RS_M2, N, block_size)
  111. bool *selected_list;
  112. try {
  113. selected_list = new bool[N]{};
  114. } catch (std::bad_alloc&){
  115. printf("Allocating memory failed in RS_M2\n");
  116. }
  117. threadpool_init(nthreads);
  118. FOAV_SAFE_CNTXT(RS_M2_branching_on_block_size_for_OSwap_Style_templates, block_size)
  119. if(block_size==4){
  120. RecursiveShuffle_M2_inner_parallel<OSWAP_4>(buf, N, block_size, selected_list, nthreads);
  121. } else if(block_size==8){
  122. RecursiveShuffle_M2_inner_parallel<OSWAP_8>(buf, N, block_size, selected_list, nthreads);
  123. } else if(block_size%16==0) {
  124. RecursiveShuffle_M2_inner_parallel<OSWAP_16X>(buf, N, block_size, selected_list, nthreads);
  125. } else {
  126. RecursiveShuffle_M2_inner_parallel<OSWAP_8_16X>(buf, N, block_size, selected_list, nthreads);
  127. }
  128. threadpool_shutdown();
  129. FOAV_SAFE_CNTXT(RecursiveShuffle_M2_delete, selected_list)
  130. delete []selected_list;
  131. }
  132. #if 0
  133. // We maintain a double type return version of RecusiveShuffle_M2,
  134. // to time strictly the RS_M2 component when using it without any encryption or decryption
  135. // We need this only for the BOS optimizer!!
  136. double RecursiveShuffle_M2_opt(unsigned char *buf, uint64_t N, size_t block_size){
  137. FOAV_SAFE2_CNTXT(RS_M2_opt, N, block_size)
  138. //In a single call allocate all the randomness we need here!
  139. size_t num_random_bytes = calculatelog2(N) * N * sizeof(uint32_t);
  140. long t0, t1;
  141. ocall_clock(&t0);
  142. #ifdef RS_M2_MEM_OPT1
  143. if(num_random_bytes > RS_RB_BUFFER_LIMIT) {
  144. RS_RB_BUFFER_SIZE = RS_RB_BUFFER_LIMIT;
  145. }
  146. else{
  147. RS_RB_BUFFER_SIZE = num_random_bytes;
  148. }
  149. try {
  150. random_bytes_buffer = new unsigned char[RS_RB_BUFFER_SIZE];
  151. selected_list = new bool[N]{};
  152. } catch (std::bad_alloc&){
  153. printf("Allocating memory failed in RS_M2\n");
  154. }
  155. getBulkRandomBytes((unsigned char*)random_bytes_buffer, RS_RB_BUFFER_SIZE);
  156. random_bytes_buffer_ptr_end = (uint32_t*)(random_bytes_buffer + RS_RB_BUFFER_SIZE);
  157. #else
  158. try {
  159. random_bytes_buffer = new unsigned char[num_random_bytes];
  160. selected_list = new bool[N]{};
  161. } catch (std::bad_alloc&){
  162. printf("Allocating memory failed in RS_M2\n");
  163. }
  164. getBulkRandomBytes((unsigned char*)random_bytes_buffer, num_random_bytes);
  165. #endif
  166. random_bytes_buffer_ptr = (uint32_t*) random_bytes_buffer;
  167. FOAV_SAFE_CNTXT(RS_M2_opt, num_random_bytes)
  168. FOAV_SAFE2_CNTXT(RS_M2_opt, N, block_size)
  169. FOAV_SAFE_CNTXT(RS_M2_opt, block_size)
  170. if(block_size==4){
  171. RecursiveShuffle_M2_inner<OSWAP_4>(buf, N, block_size, selected_list);
  172. } else if(block_size==8){
  173. RecursiveShuffle_M2_inner<OSWAP_8>(buf, N, block_size, selected_list);
  174. } else if(block_size%16==0) {
  175. RecursiveShuffle_M2_inner<OSWAP_16X>(buf, N, block_size, selected_list);
  176. } else {
  177. RecursiveShuffle_M2_inner<OSWAP_8_16X>(buf, N, block_size, selected_list);
  178. }
  179. delete []random_bytes_buffer;
  180. delete []selected_list;
  181. ocall_clock(&t1);
  182. double ptime = ((double)(t1-t0))/1000.0;
  183. return ptime;
  184. }
  185. #ifndef BEFTS_MODE
  186. double DecryptAndShuffleM1(unsigned char *encrypted_buffer, size_t N, size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret) {
  187. // Decrypt buffer to decrypted_buffer
  188. unsigned char *decrypted_buffer = NULL;
  189. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, N, encrypted_block_size, &decrypted_buffer);
  190. long t0, t1;
  191. ocall_clock(&t0);
  192. // ShuffleM1 on decrypted_buffer
  193. PRB_pool_init(1);
  194. RecursiveShuffle_M1(decrypted_buffer, N, decrypted_block_size);
  195. ocall_clock(&t1);
  196. // Encrypt buffer to result_buffer
  197. encryptBuffer(decrypted_buffer, N, decrypted_block_size, result_buffer);
  198. PRB_pool_shutdown();
  199. free(decrypted_buffer);
  200. double ptime = ((double)(t1-t0))/1000.0;
  201. ret->OSWAP_count = OSWAP_COUNTER;
  202. ret->ptime = ptime;
  203. return(ptime);
  204. }
  205. #endif
  206. double DecryptAndShuffleM2(unsigned char *encrypted_buffer, size_t N, size_t encrypted_block_size, size_t nthreads, unsigned char *result_buffer, enc_ret *ret) {
  207. // Decrypt buffer to decrypted_buffer
  208. unsigned char *decrypted_buffer = NULL;
  209. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, N, encrypted_block_size, &decrypted_buffer);
  210. long t0, t1;
  211. ocall_clock(&t0);
  212. // ShuffleM2 on decrypted_buffer
  213. PRB_pool_init(nthreads);
  214. RecursiveShuffle_M2_parallel(decrypted_buffer, N, decrypted_block_size, nthreads);
  215. ocall_clock(&t1);
  216. // Encrypt buffer to result_buffer
  217. encryptBuffer(decrypted_buffer, N, decrypted_block_size, result_buffer);
  218. PRB_pool_shutdown();
  219. #ifdef TIME_MARKHALF
  220. printf("Time taken in MarkHalf calls = %f\n", MARKHALF_TIME);
  221. #endif
  222. free(decrypted_buffer);
  223. double ptime = ((double)(t1-t0))/1000.0;
  224. ret->OSWAP_count = OSWAP_COUNTER;
  225. ret->ptime = ptime;
  226. return(ptime);
  227. }
  228. #endif