WaksmanNetwork.cpp 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259
  1. #include "SortingNetwork.hpp"
  2. #include "WaksmanNetwork.hpp"
  3. #include "oasm_lib.h"
  4. // Count the number of input and output switches, and the number of
  5. // WaksmanSubnetworks, used to handle N items. Add the numbers to the
  6. // numInSwitches, numOutSwitches, and numSubnetworks parameters.
  7. static void countSwitches(uint32_t N, size_t &numInSwitches,
  8. size_t &numOutSwitches, size_t &numSubnetworks)
  9. {
  10. ++numSubnetworks;
  11. // Base cases
  12. FOAV_SAFE_CNTXT(countswitches, N)
  13. if (N < 2) {
  14. return;
  15. } else if (N == 2) {
  16. ++numOutSwitches;
  17. return;
  18. }
  19. // How many switches do we use ourselves?
  20. // If N is even, we use (N/2)-1 input and N/2 output switches
  21. // If N is odd, we use (N-1)/2 input and (N-1)/2 output switches
  22. // Note that with integer division, both cases can be handled by
  23. // computing (N-1)/2 input and N/2 output switches.
  24. numInSwitches += (N-1)/2;
  25. numOutSwitches += N/2;
  26. // Then recurse into the two children. If N is even, we divide in
  27. // half. If N is odd, the left child will have the extra entry.
  28. countSwitches((N+1)/2, numInSwitches, numOutSwitches, numSubnetworks);
  29. countSwitches(N/2, numInSwitches, numOutSwitches, numSubnetworks);
  30. }
  31. WaksmanNetwork::WaksmanNetwork(uint32_t N) : Ntotal(N) {
  32. size_t numInSwitches = 0, numOutSwitches = 0, numSubnetworks = 0;
  33. countSwitches(N, numInSwitches, numOutSwitches, numSubnetworks);
  34. inSwitchVec.resize(numInSwitches);
  35. outSwitchVec.resize(numOutSwitches);
  36. }
  37. /* Intialize data structure counting unselected permutation mappings for fast random selection.
  38. Call initially with empty=true - argument only needed for recursive calls.
  39. */
  40. static inline void initUnselectedCnt(uint32_t *unselected_cnt, uint32_t num_vals, bool empty = true) {
  41. FOAV_SAFE2_CNTXT(countswitches, num_vals, empty)
  42. if (num_vals == 0) { // Check just in case - this should never be called with num_vals == 0.
  43. return;
  44. }
  45. if (empty == true) {
  46. unselected_cnt[num_vals-1] = num_vals;
  47. }
  48. if (num_vals == 1) {
  49. return;
  50. }
  51. uint32_t num_left = (num_vals+1)/2;
  52. initUnselectedCnt(unselected_cnt, num_left, true);
  53. initUnselectedCnt(unselected_cnt+num_left, num_vals - num_left, false);
  54. }
  55. /* Modifies unselected_cnt to indicate item at index has been selected.
  56. Call initially with unadjusted=true - argument only needed for recursive calls.
  57. */
  58. static inline void updateUnselectedCnt(uint32_t *unselected_cnt, uint32_t num_vals, uint32_t index,
  59. bool unadjusted = true) {
  60. FOAV_SAFE2_CNTXT(updateUnselectedCnt, num_vals, unadjusted)
  61. if (num_vals == 0) { // Check just in case - this should never be called with num_vals == 0.
  62. return;
  63. }
  64. FOAV_SAFE2_CNTXT(updateUnselectedCnt, num_vals, unadjusted)
  65. if (unadjusted == true) {
  66. unselected_cnt[num_vals-1]--;
  67. }
  68. FOAV_SAFE2_CNTXT(updateUnselectedCnt, num_vals, unadjusted)
  69. if (num_vals == 1) {
  70. return;
  71. }
  72. uint32_t num_left = (num_vals+1)/2;
  73. FOAV_SAFE2_CNTXT(updateUnselectedCnt, index, num_left)
  74. if (index < num_left) {
  75. updateUnselectedCnt(unselected_cnt, num_left, index, true);
  76. } else {
  77. updateUnselectedCnt(unselected_cnt+num_left, num_vals - num_left, index-num_left, false);
  78. }
  79. }
  80. /* Computes pseudo-random permutation (PRP), __uint128_t -> __uint128_t.
  81. The input is the 128-bit integer with in_high in the top 64 bits and
  82. in_low in the lower 64 bits.
  83. */
  84. static inline __uint128_t prp128(const AESkey &aeskey,
  85. uint64_t in_high, uint64_t in_low) {
  86. __m128i ciphertext;
  87. AES_ECB_encrypt(ciphertext, _mm_set_epi64x(in_high,in_low), aeskey);
  88. return reinterpret_cast<__uint128_t>(ciphertext);
  89. }
  90. void print_u128(__uint128_t x) {
  91. unsigned char *c = ((unsigned char *) &x) + sizeof(__uint128_t) - 1;
  92. for (int i=0; i<sizeof(__uint128_t); i++) {
  93. printf("%.2hhx", *c);
  94. c--;
  95. }
  96. }
  97. /* Look up either (1) permutation mapping corresponding to hash, or (2) random unselected mapping.
  98. Returns index into forward_perm pointing to the mapping looked up.
  99. */
  100. static inline uint32_t permOrRand(uint32_t N, unsigned char *forward_perm, randkey_t hashval, uint32_t *unselected_cnt,
  101. uint8_t rand_flag) {
  102. uint32_t rand_bytes;
  103. uint32_t start = 0;
  104. uint32_t end = N-1;
  105. uint32_t mid;
  106. uint8_t hash_dir;
  107. uint8_t rand_dir;
  108. uint32_t tot_unselected_cnt = unselected_cnt[end];
  109. uint32_t left_unselected_cnt;
  110. uint64_t rand_val;
  111. getRandomBytes((unsigned char *) &rand_bytes, sizeof(uint32_t));
  112. rand_val = tot_unselected_cnt * rand_bytes;
  113. rand_val >>= 32;
  114. while (true) {
  115. FOAV_SAFE_CNTXT(permOrRand, start)
  116. FOAV_SAFE_CNTXT(permOrRand, end)
  117. if (start == end) {
  118. return start;
  119. }
  120. mid = (start+end)/2;
  121. // Compare desired hash value to hash value just after the current midpoint
  122. hash_dir = ogt<randkey_t>((randkey_t *) (forward_perm + ((mid+1)*(sizeof(randkey_t) + 8))),
  123. &hashval);
  124. // Compare random unselected value to number unselected in left half
  125. left_unselected_cnt = unselected_cnt[mid];
  126. rand_dir = ogt_set_flag(left_unselected_cnt, rand_val);
  127. // Pick between hash_dir and rand_dir based on rand_flag
  128. bool f1 = ((1-rand_flag) & hash_dir);
  129. bool f2 = (rand_flag & rand_dir);
  130. FOAV_SAFE_CNTXT(permOrRand, f1)
  131. FOAV_SAFE_CNTXT(permOrRand, f2)
  132. if ((f1 | f2) == 1) {
  133. end = mid;
  134. tot_unselected_cnt = left_unselected_cnt;
  135. } else {
  136. start = mid+1;
  137. tot_unselected_cnt -= left_unselected_cnt;
  138. rand_val -= left_unselected_cnt;
  139. }
  140. }
  141. }
  142. // If this is defined, set it to the smallest N you want to see
  143. // profiling data for
  144. // #define PROFILE_SETPERM_N 32768
  145. // Define this to show the intermediate states of setPermutation
  146. // #define SHOW_SETPERM
  147. // Produce the partner of x; that is, x+Nleft if x < Nleft, or x-Nleft
  148. // if x >= Nleft
  149. static inline uint32_t PARTNER(uint32_t x, uint32_t Nleft)
  150. {
  151. uint32_t side = (x >= Nleft) * (Nleft<<1);
  152. return x + Nleft - side;
  153. }
  154. // The elements of the permutation array start off as just 32-bit
  155. // integers, where if j = permutation[i], then the item in position i
  156. // will move to position j. This will sort the permutation; that is, it
  157. // will apply the inverse of the given permutation. So if we want to
  158. // apply the given permutation, we will first use the permutation to set
  159. // the control bits of the Waksman network in a way that will sort it,
  160. // and then apply the inverse permutation by applying the Waksman
  161. // switches in reverse order.
  162. // The strategy of setPermutation is as follows. The invariant is that
  163. // we are given as input a permutation of 0..2k-1, and we will set the
  164. // Waksman network control bits to output the sorted list 0..2k-1. (If
  165. // we are given an input of odd length, so a permutation of 0..2k-2, we
  166. // implicitly append an entry permutation[2k-1] = 2k-1 to it.) We then
  167. // find a setting of the k-1 input switches (switch i OSWAPs
  168. // permutation[i] with permutation[i+k] for i=0..k-2; permutation[k-1]
  169. // never gets swapped, whether or not there was a permutation[2k-1] in
  170. // the original input) such that permutation[0..k-1] mod k ends up being
  171. // a permutation of 0..k-1, and permutation[k..2k-1] mod k ends up being
  172. // a permultation of 0..k-1. (If we were given an odd input initially,
  173. // then it will necessarily be the case that permutation[2k-1] = 2k-1
  174. // and so permutation[2k-1] mod k = k-1, so permutation[k..2k-2] mod k
  175. // will be a permutation of 0..k-2.) We recurse on the left and the
  176. // right, which will set the input and output switches of the
  177. // subnetworks such that, after applying the switches, on the left,
  178. // permutation[0..k-1] mod k will be 0..k-1 (in order), and similarly on
  179. // the right either permutation[k..2k-2] mod k or permutation[k..2k-1]
  180. // mod k, depending on the input size, will be 0..k-2 or 0..k-1
  181. // respectively.
  182. // Then we set the k-1 or k output switches (depending on the length of
  183. // the right side), where switch i again OSWAPs permutation[i] with
  184. // permutation[i+k]. Note that both of these values will necessarily be
  185. // i mod k at this point, so the switch just has to be set to the "high
  186. // bit" of permutation[i]; that is, the bit that is 1 iff
  187. // (permutation[i] >= k). This will yield the desired sorted list.
  188. // Note that when recursing, we only consider the permutation values mod
  189. // k, but we need to remember whether the value v represented the
  190. // original v or the original v+k, so that we can use that bit to set
  191. // output switch v correctly. To keep track of this, when we recurse,
  192. // for each v = permutation[i] in the array, we attach to it a stack of
  193. // the "high bits" it's gone through so far (initially empty). At
  194. // recursive depth d (the initial call is d=0), we have the values in
  195. // the permutation array being (v, [b_0, ..., b_{d-1}]) where v is an
  196. // integer 0 <= v <= 2k-1, and each b_i is a bit. When we recurse to
  197. // depth d+1, we push the new high bit onto the stack (the top of the
  198. // stack is on the right in this notation), to yield (v mod k, [b_0,
  199. // ..., b_{d-1}, b_d]). The recursive call then uses the v mod k values,
  200. // which, as above, will be a permutation of 0..k-1. When the
  201. // recursions finish, the topmost high bit on the stack will be popped
  202. // off to yield (v mod k + b_d*k (= v), [b_0, ..., b_{d-1}]).
  203. // The way we actually internally represent the value at depth d
  204. // (v, [b_0, ..., b_{d-1}]) is by packing that into a single integer,
  205. // with v followed by the d bits: x = v<<d | b_0<<(d-1) | ... | b_{d-1}.
  206. // For example, suppose initially we have N=14, and v = permutation[i] =
  207. // 12. Then the initial representation of v (with depth d=0) is just
  208. // x = v = 12.
  209. // At the first level, k=7, so when we recurse, (12, []) will become
  210. // x = (12 mod 7, [(12 >= 7)]) = (5, [1]) for d=1, which we represent as
  211. // [101][1] in binary (brackets for clarity only) = 11. At the next
  212. // level, k=4 and d=2, so x = (5 mod 4, [1, (5 >= 4)]) = (1, [1,1]),
  213. // which we represent as [1][11] = 7. At the next level (suppose this
  214. // entry ends up in the left recursion), k=2 and d=3, so x = (1 mod 2,
  215. // [1, 1, (1 >= 2)]) = (1, [1,1,0]), which we represent as [1][110] =
  216. // 14. When k<4, there are no more recursive calls. As each layer of
  217. // recursion ends, at k=2 and d=3, 14 = [1][110] becomes [(1+0*2)][11] =
  218. // [1][11] = 7. At k=4 and d=2, [1][11] becomes [(1+1*4)][1] = [101][1]
  219. // = 11. At k=7 and d=1, [101][1] becomes [(5+1*7)][] = 12.
  220. // The following functions manipulate this representation. Note that
  221. // they must all be oblivious to x, but need not be to depth or k.
  222. // Return the value v encoded in the representation x at depth d
  223. static inline uint32_t GET(uint32_t x, uint32_t depth)
  224. {
  225. return x>>depth;
  226. }
  227. // Turn a representation x of a value v between 0 and 2k-1 at depth d
  228. // (so with d extra bits) into one at depth d+1 (with v between 0 and
  229. // k-1). Pass kd = k<<d. k will be Nleft.
  230. static inline uint32_t PUSH(uint32_t x, uint32_t kd)
  231. {
  232. // If the effective value is v and the d extra bits are s,
  233. // then x = v<<d | s. We want to turn that into
  234. // ((v%k) << (d+1)) | (s<<1) | (floor(v/k))
  235. // Recall v < 2*k, so floor(v/k) is the bit b that indicates
  236. // whether v >= k, or equivalently, that x >= (k<<d)
  237. uint32_t b = (x >= kd);
  238. // Now (v%k) = (v - b*k), which avoids taking a potentially
  239. // non-oblivious mod. So ((v%k) << (d+1)) | (s<<1) | b
  240. // = (((v%k) << d) | s) << 1 | b
  241. // = (((v - b*k)<<d) | s) << 1 | b
  242. // = (((v<<d)|s) - ((b*k)<<d)) << 1 | b
  243. // = (x - ((b*k)<<d)) << 1 | b
  244. // = ((x<<1) - ((b*k)<<(d+1))) | b
  245. // = ((x<<1) - b*(k<<(d+1)) | b
  246. x = ((x<<1) - b*(kd<<1)) | b;
  247. return x;
  248. }
  249. // Turn a representation x of a value v between 0 and k-1 at depth d+1
  250. // (so with d+1 extra bits) into one at depth d (with x between 0 and
  251. // 2*k-1). It should always be that POP(PUSH(x, d, k), d, k) = x
  252. // whenever 0 <= x < (k<<(d+1)). Pass kd = k<<d. k weill be Nleft.
  253. static inline uint32_t POP(uint32_t x, uint32_t kd)
  254. {
  255. uint32_t b = x&1;
  256. x = (x>>1) + b*kd;
  257. return x;
  258. }
  259. /* Input:
  260. permutation: points to array of integers 0, ..., N-1 in some order, indicating i->permutation[i]
  261. Note: This function modifies the input permutation (and actually sorts it).
  262. */
  263. void WaksmanNetwork::setPermutation(uint32_t *permutation) {
  264. FOAV_SAFE_CNTXT(WN_SetPerm, Ntotal)
  265. if (Ntotal > 1) {
  266. WNTraversal traversal(*this);
  267. WNMem mem(*this);
  268. setPermutation(permutation, Ntotal, 0, traversal, mem);
  269. }
  270. }
  271. /* Input:
  272. permutation: points to array of integers 0, ..., N-1 in some order, indicating i->permutation[i]
  273. Note: This function modifies the input permutation (and actually sorts it).
  274. */
  275. void WaksmanNetwork::setPermutation(uint32_t *permutation, uint32_t N,
  276. uint32_t depth, WNTraversal &traversal, const WNMem &mem) {
  277. //printf("Start setPermutation(): N=%d\n", N);
  278. #ifdef SHOW_SETPERM
  279. printf("S");
  280. for(uint32_t i=0;i<N;++i) {
  281. printf(" %2d", permutation[i]);
  282. }
  283. printf("\n ");
  284. for(uint32_t i=0;i<N;++i) {
  285. printf(" %2d", GET(permutation[i], depth));
  286. }
  287. printf("\n");
  288. #endif
  289. // Handle N<=2 as special cases
  290. FOAV_SAFE_CNTXT(setPermutation, N)
  291. if (N < 2) return;
  292. traversal.subnetNumber += 1;
  293. FOAV_SAFE_CNTXT(setPermutation, N)
  294. if (N == 2) {
  295. // Store output switch value
  296. traversal.outSwitches[0] = GET(permutation[0], depth);
  297. //printf("Set outSwitches[0] to %d\n", outSwitches[0]);
  298. // Apply output switch
  299. oswap_buffer<OSWAP_4>((unsigned char *) permutation,
  300. (unsigned char *) (permutation + 1), 4, traversal.outSwitches[0]);
  301. #ifdef SHOW_SETPERM
  302. printf("O");
  303. for(uint32_t i=0;i<N/2;++i) {
  304. printf(" %s", traversal.outSwitches[i] ? " X" : "||");
  305. }
  306. printf("\n");
  307. printf("E");
  308. for(uint32_t i=0;i<N;++i) {
  309. printf(" %2d", permutation[i]);
  310. }
  311. printf("\n ");
  312. for(uint32_t i=0;i<N;++i) {
  313. printf(" %2d", GET(permutation[i],depth));
  314. }
  315. printf("\n");
  316. #endif
  317. traversal.outSwitches += 1;
  318. return;
  319. }
  320. #ifdef PROFILE_SETPERM_N
  321. unsigned long prof_all, prof_before, prof_flt, prof_sflt, prof_unsel, prof_rlt,
  322. prof_setsw, prof_srtsw, prof_appsw, prof_rec1, prof_rec2, prof_outsw;
  323. if (N >= PROFILE_SETPERM_N) {
  324. prof_all = printf_with_rtclock("begin setPermutation N=%u\n", N);
  325. prof_before = printf_with_rtclock("begin before recursion N=%u\n", N);
  326. }
  327. #endif
  328. // The size of the left recursive half. If N is odd, this is the
  329. // larger half
  330. const uint32_t Nleft = (N+1)/2;
  331. // The size of the right recursive half. This is also the number of
  332. // output switches.
  333. const uint32_t Nright = N/2;
  334. // N, rounded up to an even number
  335. const uint32_t Neven = (Nleft<<1);
  336. if (N > 4) {
  337. #ifdef PROFILE_SETPERM_N
  338. if (N >= PROFILE_SETPERM_N) {
  339. prof_flt = printf_with_rtclock("begin forward lookup table N=%u\n", N);
  340. }
  341. #endif
  342. const uint64_t snNum = traversal.subnetNumber;
  343. // Create forward lookup using pseudorandom permutation (PRP)
  344. // Produced as PRP(i)->(i, GET(permutation[i])) sorted by PRP(i)
  345. // Note: i and permutation[i] are represented as uint32_t values to pack into one uint64_t
  346. unsigned char *cur_forward_hash = mem.forward_perm;
  347. uint32_t *cur_forward_map = (uint32_t *) (mem.forward_perm + sizeof(randkey_t));
  348. // Generate key for forward-lookup PRP
  349. __uint128_t forward_perm_hash;
  350. //printf("Creating forward lookup table\n");
  351. for (uint32_t i=0; i<Neven; i++) {
  352. forward_perm_hash = prp128(mem.forward_key, snNum, (uint64_t) i);
  353. FOAV_SAFE_CNTXT(setPermutation, snNum)
  354. FOAV_SAFE_CNTXT(setPermutation, forward_perm_hash)
  355. memcpy(cur_forward_hash, &forward_perm_hash, sizeof(__uint128_t));
  356. cur_forward_hash += sizeof(randkey_t) + 8;
  357. *cur_forward_map = i;
  358. cur_forward_map += 1;
  359. *cur_forward_map = i < N ? GET(permutation[i], depth) : N;
  360. cur_forward_map = (uint32_t *) (cur_forward_hash + sizeof(randkey_t));
  361. }
  362. #ifdef PROFILE_SETPERM_N
  363. if (N >= PROFILE_SETPERM_N) {
  364. printf_with_rtclock_diff(prof_flt, "end forward lookup table N=%u\n", N);
  365. prof_sflt = printf_with_rtclock("begin sort forward lookup table N=%u\n", N);
  366. }
  367. #endif
  368. BitonicSort<FPERM_OSWAP_STYLE, randkey_t>(mem.forward_perm, (size_t) N, sizeof(randkey_t) + 8, true);
  369. // Print forward lookup table
  370. /*
  371. unsigned char *tmp_cur_forward_hash = forward_perm;
  372. uint32_t *tmp_cur_forward_map = (uint32_t *) (forward_perm + sizeof(randkey_t));
  373. __uint128_t tmp_forward_perm_hash;
  374. for (uint32_t i=0; i<N; i++) {
  375. memcpy(&tmp_forward_perm_hash, tmp_cur_forward_hash, sizeof(__uint128_t));
  376. printf("\t (");
  377. print_u128(tmp_forward_perm_hash);
  378. printf(") %d -> %d\n", *tmp_cur_forward_map, *(tmp_cur_forward_map+1));
  379. tmp_cur_forward_hash += sizeof(randkey_t) + 8;
  380. tmp_cur_forward_map = (uint32_t *) (tmp_cur_forward_hash + sizeof(randkey_t));
  381. }
  382. */
  383. #ifdef PROFILE_SETPERM_N
  384. if (N >= PROFILE_SETPERM_N) {
  385. printf_with_rtclock_diff(prof_sflt, "end sort forward lookup table N=%u\n", N);
  386. prof_unsel = printf_with_rtclock("begin unselected count N=%u\n", N);
  387. }
  388. #endif
  389. // Create cumulative count of unselected items
  390. initUnselectedCnt(mem.unselected_cnt, N);
  391. #ifdef PROFILE_SETPERM_N
  392. if (N >= PROFILE_SETPERM_N) {
  393. printf_with_rtclock_diff(prof_unsel, "end unselected count N=%u\n", N);
  394. prof_rlt = printf_with_rtclock("begin reverse lookup table N=%u\n", N);
  395. }
  396. #endif
  397. // Create reverse lookup using hash table
  398. // Maps \pi(i) to i and index of i->\pi(i) in forward_perm
  399. mem.reverse_perm->reserve(N);
  400. // Lookup done on keyed hash of \pi(i) with reverse key
  401. cur_forward_hash = mem.forward_perm;
  402. cur_forward_map = (uint32_t *) (mem.forward_perm + sizeof(randkey_t));
  403. randkey_t reverse_perm_hash;
  404. //printf("Creating reverse-permutation hash table\n");
  405. FOAV_SAFE_CNTXT(setPermutation, Neven)
  406. for (uint32_t i=0; i<Neven; i++) {
  407. FOAV_SAFE_CNTXT(setPermutation, i)
  408. reverse_perm_hash = prp128(mem.reverse_key, snNum, (uint64_t) *(cur_forward_map+1));
  409. FOAV_SAFE_CNTXT(setPermutation, snNum)
  410. FOAV_SAFE_CNTXT(setPermutation, reverse_perm_hash)
  411. std::pair<uint32_t, uint32_t> reverse_val(*cur_forward_map, i);
  412. //printf("Inserting prp128(%d) = ", *(cur_forward_map+1));
  413. //print_u128(reverse_perm_hash);
  414. //printf(" -> (%d, %d)\n", reverse_val.first, reverse_val.second);
  415. mem.reverse_perm->insert(std::make_pair(reverse_perm_hash, reverse_val));
  416. cur_forward_hash += sizeof(randkey_t) + 8;
  417. cur_forward_map = (uint32_t *) (cur_forward_hash + sizeof(randkey_t));
  418. }
  419. #ifdef PROFILE_SETPERM_N
  420. if (N >= PROFILE_SETPERM_N) {
  421. printf_with_rtclock_diff(prof_rlt, "end reverse lookup table N=%u\n", N);
  422. prof_setsw = printf_with_rtclock("begin set switches N=%u\n", N);
  423. }
  424. #endif
  425. // Set input switch values
  426. uint32_t cycle_start = Neven-1; // start of current permutation cycle
  427. uint32_t forward = 0; // item defining switch to set
  428. uint32_t forward_partner; // forward permutation "partner" (i.e. same input switch)
  429. randkey_t forward_partner_hash;
  430. uint32_t perm_idx;
  431. uint32_t forward_partner_map; // permutation map applied to forward partner
  432. uint32_t switch_num;
  433. uint32_t switch_val;
  434. uint32_t forward_partner_map_partner; // "partner" (i.e. same residue class) of forward_partner_map
  435. randkey_t forward_partner_map_partner_hash;
  436. //const uint32_t input_switch_bit = N >> 1; // bit pattern determining input switch partners
  437. //const uint32_t switch_mask = (N-1) >> 1; // mask to compute input switch number via AND
  438. //const uint32_t crp_xor = N >> 1; // bit pattern to compute composite residue partner via XOR
  439. uint32_t *cur_switch = traversal.inSwitches;
  440. uint8_t rand_flag = 0; // Indicate if next forward lookup should be random (due to cycle end)
  441. // Perform first back-and-forth lookups on items Neven-1 and Nleft-1, which have no input switch
  442. //printf("forward = %d\n", Neven-1);
  443. //printf("forward partner = %d\n", Nleft-1);
  444. forward_partner_hash = (randkey_t) prp128(mem.forward_key, snNum, Nleft-1);
  445. FOAV_SAFE_CNTXT(setPermutation, snNum)
  446. FOAV_SAFE_CNTXT(setPermutation, forward_partner_hash)
  447. perm_idx = permOrRand(N, mem.forward_perm, forward_partner_hash, mem.unselected_cnt, 0);
  448. cur_forward_map = (uint32_t *) (mem.forward_perm + perm_idx*(sizeof(randkey_t) + 8) +
  449. sizeof(randkey_t));
  450. forward_partner_map = *(cur_forward_map+1);
  451. //printf("forward_partner_map = %d\n", forward_partner_map);
  452. updateUnselectedCnt(mem.unselected_cnt, N, perm_idx);
  453. forward_partner_map_partner = PARTNER(forward_partner_map, Nleft);
  454. //printf("forward_partner_map_partner = %d\n", forward_partner_map_partner);
  455. forward_partner_map_partner_hash = prp128(mem.reverse_key, snNum, (uint64_t) forward_partner_map_partner);
  456. FOAV_SAFE_CNTXT(setPermutation, snNum)
  457. FOAV_SAFE_CNTXT(setPermutation, forward_partner_map_partner_hash)
  458. //printf("looking up ");
  459. //print_u128(forward_partner_map_partner_hash);
  460. //printf("\n");
  461. std::pair<uint32_t, uint32_t>& reverse_perm_ret = mem.reverse_perm->at(forward_partner_map_partner_hash);
  462. forward = reverse_perm_ret.first;
  463. perm_idx = reverse_perm_ret.second;
  464. updateUnselectedCnt(mem.unselected_cnt, N, perm_idx);
  465. rand_flag = oe_set_flag(forward, cycle_start);
  466. // Perform remaining back-and-forth lookups and input switch settings
  467. for (uint32_t i=0; i<Nleft-1; i++) {
  468. //printf("forward = %d\n", forward);
  469. // Forward map partner (ignored if random lookup)
  470. forward_partner = PARTNER(forward, Nleft);
  471. //printf("forward partner = %d\n", forward_partner);
  472. // Either map forward_partner under permutation or perform random lookup
  473. forward_partner_hash = (randkey_t) prp128(mem.forward_key, snNum, (uint64_t) forward_partner);
  474. FOAV_SAFE_CNTXT(setPermutation, snNum)
  475. FOAV_SAFE_CNTXT(setPermutation, forward_partner_hash)
  476. perm_idx = permOrRand(N, mem.forward_perm, forward_partner_hash, mem.unselected_cnt, rand_flag);
  477. cur_forward_map = (uint32_t *) (mem.forward_perm + perm_idx*(sizeof(randkey_t) + 8) +
  478. sizeof(randkey_t));
  479. forward_partner_map = *(cur_forward_map+1);
  480. //printf("forward_partner_map = %d\n", forward_partner_map);
  481. // update unselected_cnt with forward lookup
  482. updateUnselectedCnt(mem.unselected_cnt, N, perm_idx);
  483. // Write out current switch setting (need to do after potentially random permOrRand lookup)
  484. switch_val = ((*cur_forward_map) >= Nleft); // value of current input switch
  485. //printf("switch_val = %d\n", switch_val);
  486. switch_num = (*cur_forward_map) - switch_val * Nleft; // number of current input switch
  487. //printf("switch_num = %d\n", switch_num);
  488. *cur_switch = (switch_num<<1) | switch_val;
  489. cur_switch++;
  490. // If random, update cycle_start
  491. oset_value_uint32_t(&cycle_start, PARTNER((*cur_forward_map),Nleft), rand_flag);
  492. // Reverse map the residue-class partner
  493. forward_partner_map_partner = PARTNER(forward_partner_map, Nleft);
  494. //printf("forward_partner_map_partner = %d\n", forward_partner_map_partner);
  495. forward_partner_map_partner_hash = prp128(mem.reverse_key, snNum, (uint64_t) forward_partner_map_partner);
  496. FOAV_SAFE_CNTXT(setPermutation, snNum)
  497. FOAV_SAFE_CNTXT(setPermutation, forward_partner_map_partner_hash)
  498. std::pair<uint32_t, uint32_t>& reverse_perm_ret = mem.reverse_perm->at(forward_partner_map_partner_hash);
  499. forward = reverse_perm_ret.first;
  500. perm_idx = reverse_perm_ret.second;
  501. //printf("forward = %d, perm_idx = %d\n", forward, perm_idx);
  502. // Update unselected_cnt with reverse lookup
  503. updateUnselectedCnt(mem.unselected_cnt, N, perm_idx);
  504. // Indicate random lookup needed if cycle start has been reached
  505. rand_flag = 0; // Needed because oe_set_flag() only sets (i.e. doesn't unset)
  506. rand_flag = oe_set_flag(forward, cycle_start);
  507. //printf("rand_flag = %d\n", rand_flag);
  508. }
  509. // Clear reverse lookup for use by any recursive call
  510. mem.reverse_perm->clear();
  511. #ifdef PROFILE_SETPERM_N
  512. if (N >= PROFILE_SETPERM_N) {
  513. printf_with_rtclock_diff(prof_setsw, "end set switches N=%u\n", N);
  514. prof_srtsw = printf_with_rtclock("begin sort switches N=%u\n", N);
  515. }
  516. #endif
  517. // Put switches in order
  518. BitonicSort<OSWAP_4, uint32_t>((unsigned char *) traversal.inSwitches,
  519. (size_t) Nleft-1, 4, true);
  520. #ifdef PROFILE_SETPERM_N
  521. if (N >= PROFILE_SETPERM_N) {
  522. printf_with_rtclock_diff(prof_srtsw, "end sort switches N=%u\n", N);
  523. }
  524. #endif
  525. // Print switches
  526. /*
  527. printf("Switch\tVal\n");
  528. cur_switch = (uint32_t *) inSwitches.data();
  529. for (uint64_t i : inSwitches) {
  530. printf("%d\t%d\n", (*cur_switch)>>1, *(cur_switch)&1);
  531. cur_switch += 1;
  532. }
  533. */
  534. } else {
  535. // N == 3 or N == 4
  536. // If (GET(permutation[0]) & 1) == (GET(permutation[1]) & 1), set
  537. // the switch to 1 (so that permutation[0] and permutation[2] get
  538. // swapped, otherwise 0. The switch setting is actually stored in
  539. // the low bit of inSwitches[0].
  540. traversal.inSwitches[0] = uint64_t((GET(permutation[0],depth) ^
  541. GET(permutation[1],depth) ^ 1) & 1);
  542. }
  543. #ifdef PROFILE_SETPERM_N
  544. if (N >= PROFILE_SETPERM_N) {
  545. prof_appsw = printf_with_rtclock("begin apply switches N=%u\n", N);
  546. }
  547. #endif
  548. #ifdef SHOW_SETPERM
  549. printf("I");
  550. for(uint32_t i=0;i<Nleft-1;++i) {
  551. printf(" %s", (traversal.inSwitches[i]&1) ? " X" : "||");
  552. }
  553. printf("\n");
  554. #endif
  555. // Apply input switches to permutation
  556. uint32_t *cur_switch_val = traversal.inSwitches;
  557. uint32_t kd = Nleft << depth;
  558. FOAV_SAFE_CNTXT(setPermutation, Nleft)
  559. for (uint32_t i=0; i<Nleft-1; i++) {
  560. FOAV_SAFE2_CNTXT(setPermutation, i, Nleft)
  561. permutation[i] = PUSH(permutation[i], kd);
  562. permutation[i+Nleft] = PUSH(permutation[i+Nleft], kd);
  563. oswap_buffer<OSWAP_4>((unsigned char *) (permutation+i),
  564. (unsigned char *) (permutation+Nleft+i), 4, (*cur_switch_val)&1);
  565. cur_switch_val += 1;
  566. }
  567. permutation[Nleft-1] = PUSH(permutation[Nleft-1], kd);
  568. if (N == Neven) {
  569. permutation[2*Nleft-1] = PUSH(permutation[2*Nleft-1], kd);
  570. }
  571. #ifdef PROFILE_SETPERM_N
  572. if (N >= PROFILE_SETPERM_N) {
  573. printf_with_rtclock_diff(prof_appsw, "end apply switches N=%u\n", N);
  574. printf_with_rtclock_diff(prof_before, "end before recursion N=%u\n", N);
  575. prof_rec1 = printf_with_rtclock("begin recursion1 N=%u\n", N);
  576. }
  577. #endif
  578. #ifdef SHOW_SETPERM
  579. printf(" ");
  580. for(uint32_t i=0;i<N;++i) {
  581. printf(" %2d", permutation[i]);
  582. }
  583. printf("\n ");
  584. for(uint32_t i=0;i<N;++i) {
  585. printf(" %2d", GET(permutation[i], depth+1));
  586. }
  587. printf("\n");
  588. #endif
  589. traversal.inSwitches += (Nleft-1);
  590. uint8_t *outSwitch = traversal.outSwitches;
  591. traversal.outSwitches += Nright;
  592. // Recursively set switches of subnetworks and propagate permutation through network
  593. setPermutation(permutation, Nleft, depth+1, traversal, mem);
  594. #ifdef PROFILE_SETPERM_N
  595. if (N >= PROFILE_SETPERM_N) {
  596. printf_with_rtclock_diff(prof_rec1, "end recursion1 N=%u\n", N);
  597. prof_rec2 = printf_with_rtclock("begin recursion2 N=%u\n", N);
  598. }
  599. #endif
  600. setPermutation(permutation + Nleft, Nright, depth+1, traversal, mem);
  601. #ifdef SHOW_SETPERM
  602. printf("R");
  603. for(uint32_t i=0;i<N;++i) {
  604. printf(" %2d", permutation[i]);
  605. }
  606. printf("\n ");
  607. for(uint32_t i=0;i<N;++i) {
  608. printf(" %2d", GET(permutation[i],depth+1));
  609. }
  610. printf("\n");
  611. #endif
  612. #ifdef PROFILE_SETPERM_N
  613. if (N >= PROFILE_SETPERM_N) {
  614. printf_with_rtclock_diff(prof_rec2, "end recursion2 N=%u\n", N);
  615. prof_outsw = printf_with_rtclock("begin output switches N=%u\n", N);
  616. }
  617. #endif
  618. // Store output switch values and apply to permutation values
  619. //printf("Setting output switches\n");
  620. for (uint32_t i=0; i<Nright; i++) {
  621. outSwitch[i] = permutation[i] & 1;
  622. permutation[i] = POP(permutation[i], kd);
  623. permutation[i+Nleft] = POP(permutation[i+Nleft], kd);
  624. //printf("\toutSwitch[%d] = %d\n", i, outSwitch[i]);
  625. oswap_buffer<OSWAP_4>((unsigned char *) (permutation + i),
  626. (unsigned char *) (permutation + Nleft + i), 4, outSwitch[i]);
  627. }
  628. if (N != Neven) {
  629. permutation[Nright] = POP(permutation[Nright], kd);
  630. }
  631. #ifdef PROFILE_SETPERM_N
  632. if (N >= PROFILE_SETPERM_N) {
  633. printf_with_rtclock_diff(prof_outsw, "end output switches N=%u\n", N);
  634. printf_with_rtclock_diff(prof_all, "end setPermutation N=%u\n", N);
  635. }
  636. #endif
  637. #ifdef SHOW_SETPERM
  638. printf("O");
  639. for(uint32_t i=0;i<Nright;++i) {
  640. printf(" %s", outSwitch[i] ? " X" : "||");
  641. }
  642. printf("\n");
  643. printf("E");
  644. for(uint32_t i=0;i<N;++i) {
  645. printf(" %2d", permutation[i]);
  646. }
  647. printf("\n ");
  648. for(uint32_t i=0;i<N;++i) {
  649. printf(" %2d", GET(permutation[i],depth));
  650. }
  651. printf("\n");
  652. #endif
  653. }
  654. /*
  655. void generateRandomPermutation(uint32_t N, uint32_t *random_permutation){
  656. //Initialize random permutation as 1,...,N
  657. for(uint32_t i=0; i<N; i++) {
  658. random_permutation[i]=i;
  659. }
  660. //Convert it to a random permutation of [1,N]
  661. RecursiveShuffle_M2((unsigned char *) random_permutation, (uint32_t) N, sizeof(uint32_t));
  662. // To parallelize: RecursiveShuffle_M2_parallel(buf, N, block_size, 1);
  663. }
  664. */
  665. #if 0
  666. void OblivWaksmanShuffle(unsigned char *buffer, uint32_t N, size_t block_size, enc_ret *ret) {
  667. uint32_t *random_permutation;
  668. try {
  669. random_permutation = new uint32_t[N];
  670. } catch (std::bad_alloc&) {
  671. printf("Allocating memory failed in OblivWaksmanShuffle\n");
  672. }
  673. // Generate random permutation
  674. double wt1, wt2;
  675. ocall_wallclock(&wt1, 1);
  676. generateRandomPermutation(N, random_permutation);
  677. ocall_wallclock(&wt2, 1);
  678. ret->gen_perm_time = wt2-wt1;
  679. #ifdef COUNT_OSWAPS
  680. ret->OSWAP_gp = OSWAP_COUNTER;
  681. OSWAP_COUNTER=0;
  682. #endif
  683. #ifdef TEST_WN_OA
  684. uint32_t *correct_permuted_keys = new uint32_t[N];
  685. printf("perm =");
  686. for(size_t i=0; i<N; i++) {
  687. printf(" %2d", random_permutation[i]);
  688. }
  689. printf("\norig =");
  690. for(size_t i=0; i<N; i++) {
  691. printf(" %2d", *((uint32_t*)(buffer + (block_size * i))));
  692. }
  693. printf("\ncorrect =");
  694. for(size_t i=0; i<N; i++) {
  695. uint32_t buffer_key = *((uint32_t*)(buffer + (block_size * random_permutation[i])));
  696. correct_permuted_keys[i] = buffer_key;
  697. printf(" %2d", buffer_key);
  698. }
  699. printf("\n");
  700. #endif
  701. // Set control bits to implement randomly generated permutation
  702. ocall_wallclock(&wt1, 1);
  703. FOAV_SAFE_CNTXT(OWShuffle, N)
  704. WaksmanNetwork wnet((uint32_t) N);
  705. //printf("\nSetting control bits\n");
  706. wnet.setPermutation(random_permutation);
  707. ocall_wallclock(&wt2, 1);
  708. ret->control_bits_time = wt2-wt1;
  709. #ifdef COUNT_OSWAPS
  710. ret->OSWAP_cb=OSWAP_COUNTER;
  711. OSWAP_COUNTER=0;
  712. #endif
  713. // Apply the permutation
  714. //printf("\n Applying permutation\n");
  715. ocall_wallclock(&wt1, 1);
  716. if (block_size == 4) {
  717. wnet.applyInversePermutation<OSWAP_4>(buffer, block_size);
  718. } else if (block_size == 8) {
  719. wnet.applyInversePermutation<OSWAP_8>(buffer, block_size);
  720. } else if (block_size == 12) {
  721. wnet.applyInversePermutation<OSWAP_12>(buffer, block_size);
  722. } else if (block_size%16 == 0) {
  723. wnet.applyInversePermutation<OSWAP_16X>(buffer, block_size);
  724. } else {
  725. wnet.applyInversePermutation<OSWAP_8_16X>(buffer, block_size);
  726. }
  727. ocall_wallclock(&wt2, 1);
  728. ret->apply_perm_time = wt2-wt1;
  729. #ifdef COUNT_OSWAPS
  730. ret->OSWAP_ap = OSWAP_COUNTER;
  731. #endif
  732. #ifdef TEST_WN_OA
  733. printf("output =");
  734. for(size_t i=0; i<N; i++) {
  735. printf(" %2d", *((uint32_t*)(buffer + (block_size * i))));
  736. }
  737. printf("\n");
  738. unsigned char *buffer_ptr = buffer;
  739. for(size_t i=0; i<N; i++) {
  740. uint32_t buffer_key = *((uint32_t*)(buffer_ptr));
  741. if(correct_permuted_keys[i]!=buffer_key) {
  742. printf("TEST_WN_OA: Shuffle Correctness Failed\n");
  743. break;
  744. }
  745. buffer_ptr+=block_size;
  746. }
  747. delete []correct_permuted_keys;
  748. #endif
  749. delete[] random_permutation;
  750. }
  751. void OblivWaksmanShuffle(unsigned char *buffer, uint32_t N,
  752. size_t block_size, uint32_t nthreads, enc_ret *ret) {
  753. uint32_t *random_permutation;
  754. try {
  755. random_permutation = new uint32_t[N];
  756. } catch (std::bad_alloc&) {
  757. printf("Allocating memory failed in OblivWaksmanShuffle\n");
  758. }
  759. // Generate random permutation
  760. double wt1, wt2;
  761. ocall_wallclock(&wt1, 1);
  762. generateRandomPermutation(N, random_permutation);
  763. ocall_wallclock(&wt2, 1);
  764. ret->gen_perm_time = wt2-wt1;
  765. #ifdef COUNT_OSWAPS
  766. ret->OSWAP_gp = OSWAP_COUNTER;
  767. OSWAP_COUNTER=0;
  768. #endif
  769. #ifdef TEST_WN_OA
  770. uint32_t *correct_permuted_keys = new uint32_t[N];
  771. printf("perm =");
  772. for(size_t i=0; i<N; i++) {
  773. printf(" %2d", random_permutation[i]);
  774. }
  775. printf("\norig =");
  776. for(size_t i=0; i<N; i++) {
  777. printf(" %2d", *((uint32_t*)(buffer + (block_size * i))));
  778. }
  779. printf("\ncorrect =");
  780. for(size_t i=0; i<N; i++) {
  781. uint32_t buffer_key = *((uint32_t*)(buffer + (block_size * random_permutation[i])));
  782. correct_permuted_keys[i] = buffer_key;
  783. printf(" %2d", buffer_key);
  784. }
  785. printf("\n");
  786. #endif
  787. // Set control bits to implement randomly generated permutation
  788. ocall_wallclock(&wt1, 1);
  789. FOAV_SAFE_CNTXT(OWShuffle, N)
  790. WaksmanNetwork wnet((uint32_t) N);
  791. //printf("\nSetting control bits\n");
  792. wnet.setPermutation(random_permutation);
  793. WNEvalPlan evalplan(N, nthreads);
  794. ocall_wallclock(&wt2, 1);
  795. ret->control_bits_time = wt2-wt1;
  796. #ifdef COUNT_OSWAPS
  797. ret->OSWAP_cb=OSWAP_COUNTER;
  798. OSWAP_COUNTER=0;
  799. #endif
  800. // Apply the permutation
  801. //printf("\n Applying permutation\n");
  802. ocall_wallclock(&wt1, 1);
  803. if (block_size == 4) {
  804. wnet.applyInversePermutation<OSWAP_4>(buffer, block_size, evalplan);
  805. } else if (block_size == 8) {
  806. wnet.applyInversePermutation<OSWAP_8>(buffer, block_size, evalplan);
  807. } else if (block_size == 12) {
  808. wnet.applyInversePermutation<OSWAP_12>(buffer, block_size, evalplan);
  809. } else if (block_size%16 == 0) {
  810. wnet.applyInversePermutation<OSWAP_16X>(buffer, block_size, evalplan);
  811. } else {
  812. wnet.applyInversePermutation<OSWAP_8_16X>(buffer, block_size, evalplan);
  813. }
  814. ocall_wallclock(&wt2, 1);
  815. ret->apply_perm_time = wt2-wt1;
  816. #ifdef COUNT_OSWAPS
  817. ret->OSWAP_ap = OSWAP_COUNTER;
  818. #endif
  819. #ifdef TEST_WN_OA
  820. printf("output =");
  821. for(size_t i=0; i<N; i++) {
  822. printf(" %2d", *((uint32_t*)(buffer + (block_size * i))));
  823. }
  824. printf("\n");
  825. unsigned char *buffer_ptr = buffer;
  826. for(size_t i=0; i<N; i++) {
  827. uint32_t buffer_key = *((uint32_t*)(buffer_ptr));
  828. if(correct_permuted_keys[i]!=buffer_key) {
  829. printf("TEST_WN_OA: Shuffle Correctness Failed\n");
  830. break;
  831. }
  832. buffer_ptr+=block_size;
  833. }
  834. delete []correct_permuted_keys;
  835. #endif
  836. delete[] random_permutation;
  837. }
  838. void DecryptAndOblivWaksmanShuffle(unsigned char *encrypted_buffer, uint32_t N,
  839. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret) {
  840. double wt1, wt2;
  841. // Decrypt buffer to decrypted_buffer
  842. unsigned char *decrypted_buffer = NULL;
  843. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, (uint64_t) N, encrypted_block_size,
  844. &decrypted_buffer);
  845. // Set the Waksman control bits to implement the permutation
  846. ocall_wallclock(&wt1, 0);
  847. ocall_wallclock(&wt1, 1);
  848. PRB_pool_init(1);
  849. OblivWaksmanShuffle(decrypted_buffer, N, decrypted_block_size, ret);
  850. ocall_wallclock(&wt2, 1);
  851. ret->ptime = wt2-wt1;
  852. #ifdef COUNT_OSWAPS
  853. ret->OSWAP_count = OSWAP_COUNTER;
  854. #endif
  855. // Encrypt buffer to result_buffer
  856. encryptBuffer(decrypted_buffer, (uint64_t) N, decrypted_block_size, result_buffer);
  857. PRB_pool_shutdown();
  858. free(decrypted_buffer);
  859. return;
  860. }
  861. void OblivWaksmanSort(unsigned char *buffer, uint32_t N, size_t block_size, enc_ret *ret) {
  862. uint32_t *sort_permutation;
  863. try {
  864. FOAV_SAFE_CNTXT(OWSort, N)
  865. sort_permutation = new uint32_t[N];
  866. } catch (std::bad_alloc&) {
  867. printf("Allocating memory failed in OblivWaksmanSort\n");
  868. }
  869. // Generate sort permutation
  870. double wt1, wt2;
  871. ocall_wallclock(&wt1, 1);
  872. generateSortPermutation_OA(N, buffer, block_size, sort_permutation);
  873. ocall_wallclock(&wt2, 1);
  874. ret->gen_perm_time = wt2-wt1;
  875. #ifdef COUNT_OSWAPS
  876. ret->OSWAP_gp = OSWAP_COUNTER;
  877. OSWAP_COUNTER=0;
  878. #endif
  879. // Set control bits to implement randomly generated permutation
  880. ocall_wallclock(&wt1, 1);
  881. #ifdef PROFILE_SETPERM_N
  882. unsigned long x = printf_with_rtclock("Creating network\n");
  883. #endif
  884. FOAV_SAFE_CNTXT(OblivWaksmanSort, N)
  885. WaksmanNetwork wnet = WaksmanNetwork((uint32_t) N);
  886. FOAV_SAFE_CNTXT(OblivWaksmanSort, wnet)
  887. #ifdef PROFILE_SETPERM_N
  888. printf_with_rtclock_diff(x, "Created network\n");
  889. #endif
  890. //printf("\nSetting control bits\n");
  891. wnet.setPermutation(sort_permutation);
  892. ocall_wallclock(&wt2, 1);
  893. ret->control_bits_time = wt2-wt1;
  894. #ifdef COUNT_OSWAPS
  895. ret->OSWAP_cb=OSWAP_COUNTER;
  896. OSWAP_COUNTER=0;
  897. #endif
  898. // Apply the permutation
  899. //printf("\nApplying permutation\n");
  900. ocall_wallclock(&wt1, 1);
  901. FOAV_SAFE_CNTXT(AP, block_size)
  902. if (block_size == 4) {
  903. wnet.applyInversePermutation<OSWAP_4>(buffer, block_size);
  904. } else if (block_size == 8) {
  905. wnet.applyInversePermutation<OSWAP_8>(buffer, block_size);
  906. } else if (block_size == 12) {
  907. wnet.applyInversePermutation<OSWAP_12>(buffer, block_size);
  908. } else if (block_size%16 == 0) {
  909. wnet.applyInversePermutation<OSWAP_16X>(buffer, block_size);
  910. } else {
  911. wnet.applyInversePermutation<OSWAP_8_16X>(buffer, block_size);
  912. }
  913. ocall_wallclock(&wt2, 1);
  914. ret->apply_perm_time = wt2-wt1;
  915. #ifdef COUNT_OSWAPS
  916. ret->OSWAP_ap = OSWAP_COUNTER;
  917. #endif
  918. delete[] sort_permutation;
  919. }
  920. void OblivWaksmanSort(unsigned char *buffer, uint32_t N, size_t block_size, uint32_t nthreads, enc_ret *ret) {
  921. uint32_t *sort_permutation;
  922. try {
  923. FOAV_SAFE_CNTXT(OWSort, N)
  924. sort_permutation = new uint32_t[N];
  925. } catch (std::bad_alloc&) {
  926. printf("Allocating memory failed in OblivWaksmanSort\n");
  927. }
  928. // Generate sort permutation
  929. double wt1, wt2;
  930. ocall_wallclock(&wt1, 1);
  931. generateSortPermutation_OA(N, buffer, block_size, sort_permutation);
  932. ocall_wallclock(&wt2, 1);
  933. ret->gen_perm_time = wt2-wt1;
  934. #ifdef COUNT_OSWAPS
  935. ret->OSWAP_gp = OSWAP_COUNTER;
  936. OSWAP_COUNTER=0;
  937. #endif
  938. // Set control bits to implement randomly generated permutation
  939. ocall_wallclock(&wt1, 1);
  940. #ifdef PROFILE_SETPERM_N
  941. unsigned long x = printf_with_rtclock("Creating network\n");
  942. #endif
  943. FOAV_SAFE_CNTXT(OblivWaksmanSort, N)
  944. WaksmanNetwork wnet = WaksmanNetwork((uint32_t) N);
  945. FOAV_SAFE_CNTXT(OblivWaksmanSort, wnet)
  946. #ifdef PROFILE_SETPERM_N
  947. printf_with_rtclock_diff(x, "Created network\n");
  948. #endif
  949. //printf("\nSetting control bits\n");
  950. wnet.setPermutation(sort_permutation);
  951. WNEvalPlan evalplan(N, nthreads);
  952. ocall_wallclock(&wt2, 1);
  953. ret->control_bits_time = wt2-wt1;
  954. #ifdef COUNT_OSWAPS
  955. ret->OSWAP_cb=OSWAP_COUNTER;
  956. OSWAP_COUNTER=0;
  957. #endif
  958. // Apply the permutation
  959. //printf("\nApplying permutation\n");
  960. ocall_wallclock(&wt1, 1);
  961. FOAV_SAFE_CNTXT(AP, block_size)
  962. if (block_size == 4) {
  963. wnet.applyInversePermutation<OSWAP_4>(buffer, block_size, evalplan);
  964. } else if (block_size == 8) {
  965. wnet.applyInversePermutation<OSWAP_8>(buffer, block_size, evalplan);
  966. } else if (block_size == 12) {
  967. wnet.applyInversePermutation<OSWAP_12>(buffer, block_size, evalplan);
  968. } else if (block_size%16 == 0) {
  969. wnet.applyInversePermutation<OSWAP_16X>(buffer, block_size, evalplan);
  970. } else {
  971. wnet.applyInversePermutation<OSWAP_8_16X>(buffer, block_size, evalplan);
  972. }
  973. ocall_wallclock(&wt2, 1);
  974. ret->apply_perm_time = wt2-wt1;
  975. #ifdef COUNT_OSWAPS
  976. ret->OSWAP_ap = OSWAP_COUNTER;
  977. #endif
  978. delete[] sort_permutation;
  979. }
  980. void DecryptAndOblivWaksmanSort(unsigned char *encrypted_buffer, uint32_t N,
  981. size_t encrypted_block_size, uint32_t nthreads, unsigned char *result_buffer, enc_ret *ret) {
  982. double wt1, wt2;
  983. // Decrypt buffer to decrypted_buffer
  984. unsigned char *decrypted_buffer = NULL;
  985. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, (uint64_t) N, encrypted_block_size,
  986. &decrypted_buffer);
  987. // Set the Waksman control bits to implement the permutation
  988. threadpool_init(nthreads);
  989. ocall_wallclock(&wt1, 0);
  990. ocall_wallclock(&wt1, 1);
  991. PRB_pool_init(nthreads);
  992. OblivWaksmanSort(decrypted_buffer, N, decrypted_block_size, nthreads, ret);
  993. ocall_wallclock(&wt2, 1);
  994. ret->ptime = wt2-wt1;
  995. #ifdef COUNT_OSWAPS
  996. ret->OSWAP_count = OSWAP_COUNTER;
  997. #endif
  998. // Encrypt buffer to result_buffer
  999. encryptBuffer(decrypted_buffer, (uint64_t) N, decrypted_block_size, result_buffer);
  1000. PRB_pool_shutdown();
  1001. threadpool_shutdown();
  1002. free(decrypted_buffer);
  1003. return;
  1004. }
  1005. void DecryptAndOblivWaksmanSort(unsigned char *encrypted_buffer, uint32_t N,
  1006. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret) {
  1007. double wt1, wt2;
  1008. // Decrypt buffer to decrypted_buffer
  1009. unsigned char *decrypted_buffer = NULL;
  1010. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, (uint64_t) N, encrypted_block_size,
  1011. &decrypted_buffer);
  1012. // Set the Waksman control bits to implement the permutation
  1013. ocall_wallclock(&wt1, 0);
  1014. ocall_wallclock(&wt1, 1);
  1015. PRB_pool_init(1);
  1016. OblivWaksmanSort(decrypted_buffer, N, decrypted_block_size, ret);
  1017. ocall_wallclock(&wt2, 1);
  1018. ret->ptime = wt2-wt1;
  1019. #ifdef COUNT_OSWAPS
  1020. ret->OSWAP_count = OSWAP_COUNTER;
  1021. #endif
  1022. // Encrypt buffer to result_buffer
  1023. encryptBuffer(decrypted_buffer, (uint64_t) N, decrypted_block_size, result_buffer);
  1024. PRB_pool_shutdown();
  1025. free(decrypted_buffer);
  1026. return;
  1027. }
  1028. void DecryptAndOWSS(unsigned char *encrypted_buffer, uint32_t N,
  1029. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret) {
  1030. double wt1, wt2, wt3;
  1031. // Decrypt buffer to decrypted_buffer
  1032. unsigned char *decrypted_buffer = NULL;
  1033. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, (uint64_t) N, encrypted_block_size,
  1034. &decrypted_buffer);
  1035. // Set the Waksman control bits to implement the permutation
  1036. ocall_wallclock(&wt1, 0);
  1037. ocall_wallclock(&wt1, 1);
  1038. PRB_pool_init(1);
  1039. OblivWaksmanShuffle(decrypted_buffer, N, decrypted_block_size, ret);
  1040. #ifdef COUNT_OSWAPS
  1041. ret->OSWAP_count = OSWAP_COUNTER;
  1042. #endif
  1043. ocall_wallclock(&wt2, 1);
  1044. qsort(decrypted_buffer, N, decrypted_block_size, compare);
  1045. ocall_wallclock(&wt3, 1);
  1046. ret->qsort_time = wt3-wt2;
  1047. ret->ptime = wt3-wt1;
  1048. // Encrypt buffer to result_buffer
  1049. encryptBuffer(decrypted_buffer, (uint64_t) N, decrypted_block_size, result_buffer);
  1050. PRB_pool_shutdown();
  1051. free(decrypted_buffer);
  1052. return;
  1053. }
  1054. #endif
  1055. struct datacopy_args {
  1056. const unsigned char *inbuf;
  1057. const uint64_t *idx;
  1058. unsigned char *outbuf;
  1059. size_t start, end, sz;
  1060. };
  1061. static void* datacopy_range(void *voidargs)
  1062. {
  1063. const datacopy_args *args = (datacopy_args*)voidargs;
  1064. for (size_t i=args->start; i<args->end; ++i) {
  1065. memmove(args->outbuf+i*args->sz,
  1066. args->inbuf+(args->idx[i]&0xffffffff)*args->sz,
  1067. args->sz);
  1068. }
  1069. return NULL;
  1070. }
  1071. #if 0
  1072. // Sort the given array of N elements, each of size sz, using up to
  1073. // nthreads threads. The output is put into the same memory as the input
  1074. // array. The first 4 bytes of each element is its key.
  1075. static void mtsort(void *buffer, size_t N, size_t sz, threadid_t nthreads)
  1076. {
  1077. uint64_t *idx = new uint64_t[N];
  1078. unsigned char *inbuf = (unsigned char *)buffer;
  1079. unsigned char *outbuf = new unsigned char[N*sz];
  1080. for (size_t i=0; i<N; ++i) {
  1081. uint64_t key = (*(uint32_t*)(inbuf+sz*i));
  1082. idx[i] = (key<<32) + i;
  1083. }
  1084. // Sort the keys and indices
  1085. uint64_t *backingidx = new uint64_t[N];
  1086. bool whichbuf = mtmergesort<uint64_t>(idx, N, backingidx, nthreads);
  1087. uint64_t *sortedidx = whichbuf ? backingidx : idx;
  1088. // Copy the data using the sorted indices, potentially using
  1089. // multiple threads
  1090. threadid_t threads_to_use = nthreads;
  1091. datacopy_args dcargs[threads_to_use];
  1092. size_t inc = N / threads_to_use;
  1093. size_t extra = N % threads_to_use;
  1094. size_t last = 0;
  1095. for (size_t t=0; t<threads_to_use; ++t) {
  1096. size_t next = last + inc + (t < extra);
  1097. dcargs[t] = { inbuf, sortedidx, outbuf, last, next, sz };
  1098. last = next;
  1099. if (t > 0) {
  1100. threadpool_dispatch(g_thread_id+t, datacopy_range,
  1101. &dcargs[t]);
  1102. }
  1103. }
  1104. // Do the first block ourselves
  1105. datacopy_range(&dcargs[0]);
  1106. for (size_t t=1; t<threads_to_use; ++t) {
  1107. threadpool_join(g_thread_id+t, NULL);
  1108. }
  1109. delete[] idx;
  1110. delete[] backingidx;
  1111. memmove(inbuf, outbuf, N*sz);
  1112. delete[] outbuf;
  1113. }
  1114. void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
  1115. size_t encrypted_block_size, threadid_t nthreads,
  1116. unsigned char *result_buffer, enc_ret *ret) {
  1117. double wt1, wt2, wt3;
  1118. // Decrypt buffer to decrypted_buffer
  1119. unsigned char *decrypted_buffer = NULL;
  1120. size_t decrypted_block_size = decryptBuffer(encrypted_buffer, (uint64_t) N, encrypted_block_size,
  1121. &decrypted_buffer);
  1122. // Set the Waksman control bits to implement the permutation
  1123. threadpool_init(nthreads);
  1124. ocall_wallclock(&wt1, 0);
  1125. ocall_wallclock(&wt1, 1);
  1126. PRB_pool_init(nthreads);
  1127. OblivWaksmanShuffle(decrypted_buffer, N, decrypted_block_size, nthreads, ret);
  1128. #ifdef COUNT_OSWAPS
  1129. ret->OSWAP_count = OSWAP_COUNTER;
  1130. #endif
  1131. ocall_wallclock(&wt2, 1);
  1132. mtsort(decrypted_buffer, N, decrypted_block_size, nthreads);
  1133. ocall_wallclock(&wt3, 1);
  1134. ret->qsort_time = wt3-wt2;
  1135. ret->ptime = wt3-wt1;
  1136. // Encrypt buffer to result_buffer
  1137. encryptBuffer(decrypted_buffer, (uint64_t) N, decrypted_block_size, result_buffer);
  1138. PRB_pool_shutdown();
  1139. threadpool_shutdown();
  1140. free(decrypted_buffer);
  1141. return;
  1142. }
  1143. #endif