WaksmanNetwork.hpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. #ifndef __WAKSMANNETWORK_HPP__
  2. #define __WAKSMANNETWORK_HPP__
  3. #include <unordered_map>
  4. #include <vector>
  5. #include "oasm_lib.h"
  6. #include <sgx_tcrypto.h>
  7. #include "utils.hpp"
  8. #include "RecursiveShuffle.hpp"
  9. #include "aes.hpp"
  10. // #define PROFILE_MTAPPLYPERM
  11. typedef __uint128_t randkey_t;
  12. #define FPERM_OSWAP_STYLE OSWAP_8_16X // OSwap_Style for forward perm (consistent w/ randkey_t)
  13. // A struct to hold a multi-thread evaluation plan for a
  14. // WaksmanNetwork of a given size and number of threads
  15. // If you have a WaksmanNetwork with N>2 items, and you want to apply
  16. // it using nthreads>1 threads, this WaksmanNetwork will itself use
  17. // the first (N-1)/2 input switches and N/2 output switches (in
  18. // addition to those used by its subnetworks); it will recurse into
  19. // the left subnetwork containing (N+1)/2 items and (nthreads+1)/2
  20. // threads, and into the right subnetwork containing N/2 items and
  21. // nthreads/2 threads. If N=2, then there is just a single output
  22. // switch and no subnetworks. If N<2, there are no switches and no
  23. // subnetworks. If nthreads=1, we compute the total number of inputs
  24. // and output switches used by this WaksmanNetwork and its
  25. // subnetworks, but just store the total in this WNEvalPlan with no
  26. // WNEvalPlans for the subnetworks.
  27. struct WNEvalPlan {
  28. // The number of items for the represented WaksmanNetwork
  29. uint32_t N;
  30. // The number of threads to use to evaluate it
  31. uint32_t nthreads;
  32. // The total number of input and output switches used by this
  33. // WaksmanNetwork and its subnetworks
  34. size_t subtree_num_inswitches, subtree_num_outswitches;
  35. // If N>2 and nthreads>1, these are the evaluation plans for the
  36. // subnetworks. This vector will contain 0 (if N<=2 or nthreads=1)
  37. // or 2 (otherwise) items.
  38. std::vector<WNEvalPlan> subplans;
  39. WNEvalPlan(uint32_t N, uint32_t nthreads) : N(N), nthreads(nthreads) {
  40. if (N<2) {
  41. subtree_num_inswitches = 0;
  42. subtree_num_outswitches = 0;
  43. } else if (N == 2) {
  44. subtree_num_inswitches = 0;
  45. subtree_num_outswitches = 1;
  46. } else if (nthreads <= 1) {
  47. subtree_num_inswitches = 0;
  48. subtree_num_outswitches = 0;
  49. count_switches(N);
  50. } else {
  51. const uint32_t Nleft = (N+1)/2;
  52. const uint32_t Nright = N/2;
  53. const uint32_t numInSwitches = (N-1)/2;
  54. const uint32_t numOutSwitches = N/2;
  55. const uint32_t nthr_left = (nthreads+1)/2;
  56. const uint32_t nthr_right = nthreads/2;
  57. subplans.emplace_back(Nleft, nthr_left);
  58. subplans.emplace_back(Nright, nthr_right);
  59. subtree_num_inswitches = numInSwitches +
  60. subplans[0].subtree_num_inswitches +
  61. subplans[1].subtree_num_inswitches;
  62. subtree_num_outswitches = numOutSwitches +
  63. subplans[0].subtree_num_outswitches +
  64. subplans[1].subtree_num_outswitches;
  65. }
  66. }
  67. // Count the number of input and output switches used by a
  68. // WaksmanNetwork with N items. Add those values to
  69. // subtree_num_inswitches and subtree_num_outswitches.
  70. void count_switches(uint32_t N) {
  71. if (N<2) {
  72. return;
  73. }
  74. if (N == 2) {
  75. subtree_num_outswitches += 1;
  76. return;
  77. }
  78. const uint32_t Nleft = (N+1)/2;
  79. const uint32_t Nright = N/2;
  80. const uint32_t numInSwitches = (N-1)/2;
  81. const uint32_t numOutSwitches = N/2;
  82. subtree_num_inswitches += numInSwitches;
  83. subtree_num_outswitches += numOutSwitches;
  84. count_switches(Nleft);
  85. count_switches(Nright);
  86. }
  87. void dump(int indent = 0) {
  88. printf("%*sN = %lu, nthreads = %lu, inswitches = %lu, outswitches = %lu\n",
  89. indent, "", N, nthreads, subtree_num_inswitches,
  90. subtree_num_outswitches);
  91. if (subplans.size() > 0) {
  92. subplans[0].dump(indent+2);
  93. subplans[1].dump(indent+2);
  94. }
  95. }
  96. };
  97. /*
  98. WaksmanNetwork Class: Contains a Waksman permutation network and can apply it to an input array.
  99. setPermutation(uint32_t *permutation, unsigned char *forward_perm, [optional preallocated
  100. memory regions]): Takes permutation as an array of N index values (i.e. values in [N]) and
  101. optional pointers to allocated memory. It sets the Waksman network switches to that
  102. permutation.
  103. applyPermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of block_size
  104. bytes each and applies stored permutation in-place (i.e. modifying input buffer).
  105. applyInversePermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of
  106. block_size bytes each and applies inverse of stored permutation in-place.
  107. */
  108. class WaksmanNetwork {
  109. uint32_t Ntotal; // number of items to permute
  110. std::vector<uint32_t> inSwitchVec; // input layer of (numbered) switches
  111. std::vector<uint8_t> outSwitchVec; // output layer of switches
  112. // A struct to keep track of the current subnet number, and input and
  113. // output switches, for each subnet as we traverse the network.
  114. struct WNTraversal {
  115. uint64_t subnetNumber;
  116. uint32_t *inSwitches;
  117. uint8_t *outSwitches;
  118. WNTraversal(WaksmanNetwork &wn) : subnetNumber(0),
  119. inSwitches(wn.inSwitchVec.data()),
  120. outSwitches(wn.outSwitchVec.data()) {}
  121. };
  122. struct appInvPermArgs {
  123. WaksmanNetwork &wn;
  124. unsigned char *buf;
  125. size_t block_size;
  126. const WNEvalPlan &plan;
  127. WNTraversal &traversal;
  128. appInvPermArgs(WaksmanNetwork &wn, unsigned char *buf,
  129. size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal)
  130. : wn(wn), buf(buf), block_size(block_size), plan(plan),
  131. traversal(traversal) {}
  132. };
  133. template <OSwap_Style oswap_style>
  134. static void *applyInversePermutation_launch(void *voidarg)
  135. {
  136. appInvPermArgs *arg = (appInvPermArgs *)voidarg;
  137. arg->wn.applyInversePermutation<oswap_style>(arg->buf, arg->block_size,
  138. arg->plan, arg->traversal);
  139. return NULL;
  140. }
  141. // A struct to hold pre-allocated memory (and AES keys) so that we
  142. // only allocate memory once, before the recursive setPermutation is
  143. // called.
  144. struct WNMem {
  145. unsigned char *forward_perm;
  146. uint32_t *unselected_cnt;
  147. std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>> *reverse_perm;
  148. AESkey forward_key, reverse_key;
  149. WNMem(WaksmanNetwork &wn) {
  150. // Round Ntotal up to an even number
  151. uint32_t Neven = wn.Ntotal + (wn.Ntotal&1);
  152. forward_perm = new unsigned char[Neven * (sizeof(randkey_t) + 8)];
  153. unselected_cnt = new uint32_t[wn.Ntotal];
  154. reverse_perm = new std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>>;
  155. __m128i forward_rawkey, reverse_rawkey;
  156. getRandomBytes((unsigned char *) &forward_rawkey, sizeof(forward_rawkey));
  157. getRandomBytes((unsigned char *) &reverse_rawkey, sizeof(reverse_rawkey));
  158. AES_128_Key_Expansion(forward_key, forward_rawkey);
  159. AES_128_Key_Expansion(reverse_key, reverse_rawkey);
  160. }
  161. ~WNMem() {
  162. delete[] forward_perm;
  163. delete[] unselected_cnt;
  164. delete reverse_perm;
  165. }
  166. };
  167. void setPermutation(uint32_t *permutation, uint32_t N,
  168. uint32_t depth, WNTraversal &traversal, const WNMem &mem);
  169. template <OSwap_Style oswap_style>
  170. void applyPermutation(unsigned char *buf, uint32_t N,
  171. size_t block_size, WNTraversal &traversal);
  172. template <OSwap_Style oswap_style>
  173. void applyInversePermutation(unsigned char *buf, uint32_t N,
  174. size_t block_size, WNTraversal &traversal);
  175. template <OSwap_Style oswap_style>
  176. void applyInversePermutation(unsigned char *buf, size_t block_size,
  177. const WNEvalPlan &plan, WNTraversal &traversal);
  178. public:
  179. // Set up the WaksmanNetwork for N items. N need not be a power of 2.
  180. // N <= 2^31
  181. WaksmanNetwork(uint32_t N);
  182. void setPermutation(uint32_t *permutation);
  183. template <OSwap_Style oswap_style>
  184. void applyPermutation(unsigned char *buf, size_t block_size);
  185. template <OSwap_Style oswap_style>
  186. void applyInversePermutation(unsigned char *buf, size_t block_size);
  187. template <OSwap_Style oswap_style>
  188. void applyInversePermutation(unsigned char *buf, size_t block_size,
  189. const WNEvalPlan &plan);
  190. };
  191. // Define this to show the intermediate states of applyPermutation
  192. // #define SHOW_APPLYPERM
  193. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  194. template <OSwap_Style oswap_style>
  195. void WaksmanNetwork::applyPermutation(unsigned char *buf, size_t block_size) {
  196. FOAV_SAFE_CNTXT(AP, Ntotal)
  197. if (Ntotal > 1) {
  198. WNTraversal traversal(*this);
  199. applyPermutation<oswap_style>(buf, Ntotal, block_size, traversal);
  200. }
  201. }
  202. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  203. template <OSwap_Style oswap_style>
  204. void WaksmanNetwork::applyPermutation(unsigned char *buf, uint32_t N,
  205. size_t block_size, WNTraversal &traversal) {
  206. FOAV_SAFE_CNTXT(AP, Ntotal)
  207. FOAV_SAFE_CNTXT(AP, N)
  208. if (N < 2) return;
  209. const uint32_t Nleft = (N+1)/2;
  210. const uint32_t Nright = N/2;
  211. const uint32_t numInSwitches = (N-1)/2;
  212. const uint32_t numOutSwitches = N/2;
  213. const uint32_t *inSwitch = traversal.inSwitches;
  214. const uint8_t *outSwitch = traversal.outSwitches;
  215. traversal.subnetNumber += 1;
  216. traversal.inSwitches += numInSwitches;
  217. traversal.outSwitches += numOutSwitches;
  218. #ifdef SHOW_APPLYPERM
  219. printf("s");
  220. for(uint32_t i=0;i<N;++i) {
  221. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  222. }
  223. printf("\n");
  224. #endif
  225. if (N == 2) {
  226. #ifdef SHOW_APPLYPERM
  227. printf("o");
  228. for(uint32_t i=0;i<numOutSwitches;++i) {
  229. printf(" %s", outSwitch[i] ? " X" : "||");
  230. }
  231. printf("\n");
  232. #endif
  233. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]);
  234. #ifdef SHOW_APPLYPERM
  235. printf("e");
  236. for(uint32_t i=0;i<N;++i) {
  237. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  238. }
  239. printf("\n");
  240. #endif
  241. } else {
  242. #ifdef SHOW_APPLYPERM
  243. printf("i");
  244. for(uint32_t i=0;i<numInSwitches;++i) {
  245. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  246. }
  247. printf("\n");
  248. #endif
  249. // Apply input switches to permutation
  250. const uint32_t *curInSwitchVal = inSwitch;
  251. for (uint32_t i=0; i<numInSwitches; i++) {
  252. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  253. (*curInSwitchVal)&1);
  254. curInSwitchVal += 1;
  255. }
  256. #ifdef SHOW_APPLYPERM
  257. printf(" ");
  258. for(uint32_t i=0;i<N;++i) {
  259. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  260. }
  261. printf("\n");
  262. #endif
  263. // Apply subnetwork switches
  264. applyPermutation<oswap_style>(buf, Nleft, block_size, traversal);
  265. applyPermutation<oswap_style>(buf + block_size*Nleft, Nright,
  266. block_size, traversal);
  267. #ifdef SHOW_APPLYPERM
  268. printf("r");
  269. for(uint32_t i=0;i<N;++i) {
  270. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  271. }
  272. printf("\n");
  273. printf("o");
  274. for(uint32_t i=0;i<numOutSwitches;++i) {
  275. printf(" %s", outSwitch[i] ? " X" : "||");
  276. }
  277. printf("\n");
  278. #endif
  279. // Apply output switches to permutation
  280. for (uint32_t i=0; i<numOutSwitches; i++) {
  281. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  282. *outSwitch);
  283. ++outSwitch;
  284. }
  285. #ifdef SHOW_APPLYPERM
  286. printf("e");
  287. for(uint32_t i=0;i<N;++i) {
  288. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  289. }
  290. printf("\n");
  291. #endif
  292. }
  293. }
  294. // Apply permutation encoded by control bits to data elements in buffer
  295. // using a multithread evaluation plan. Permutes in place.
  296. template <OSwap_Style oswap_style>
  297. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  298. size_t block_size, const WNEvalPlan &plan) {
  299. FOAV_SAFE_CNTXT(AP, Ntotal)
  300. if (Ntotal > 1) {
  301. WNTraversal traversal(*this);
  302. applyInversePermutation<oswap_style>(buf, block_size, plan, traversal);
  303. }
  304. }
  305. template <typename CBT>
  306. struct ApplySwitchesArgs {
  307. unsigned char *buf;
  308. size_t block_size;
  309. const CBT* switches;
  310. uint32_t swStart, swEnd;
  311. uint32_t stride;
  312. };
  313. // Apply a consecutive sequence of input or output switches, using
  314. // arguments passed as an ApplySwitchesArgs*. CBT is the
  315. // control bit type (uint32_t for input switches or uint8_t for output
  316. // switches).
  317. template <OSwap_Style oswap_style, typename CBT>
  318. static void* applySwitchesRange(void *voidargs)
  319. {
  320. const ApplySwitchesArgs<CBT>* args =
  321. (const ApplySwitchesArgs<CBT> *)voidargs;
  322. unsigned char *buf = args->buf;
  323. const size_t block_size = args->block_size;
  324. const uint32_t swStart = args->swStart;
  325. const CBT* switches = args->switches + swStart;
  326. const uint32_t swEnd = args->swEnd;
  327. const uint32_t stride = args->stride;
  328. FOAV_SAFE_CNTXT(applySwitchesRange, swEnd)
  329. for (uint32_t i=swStart; i<swEnd; ++i) {
  330. FOAV_SAFE2_CNTXT(applySwitchesRange, i, swEnd)
  331. oswap_buffer<oswap_style>(buf + block_size*(i),
  332. buf + block_size*(stride+i), block_size,
  333. (*switches)&1);
  334. ++switches;
  335. }
  336. return NULL;
  337. }
  338. // Apply a consecutive sequence of input or output switches using
  339. // up to nthreads threads. CBT is the control bit type (uint32_t for
  340. // input switches or uint8_t for output switches), but it will be
  341. // deduced automatically from the type of the switches argument.
  342. template <OSwap_Style oswap_style, typename CBT>
  343. static void applySwitches(unsigned char *buf, size_t block_size,
  344. const CBT* switches, uint32_t numSwitches, uint32_t stride,
  345. uint32_t nthreads)
  346. {
  347. uint32_t threads_to_use = nthreads;
  348. ApplySwitchesArgs<CBT> asargs[threads_to_use];
  349. uint32_t inc = numSwitches / threads_to_use;
  350. uint32_t extra = numSwitches % threads_to_use;
  351. uint32_t last = 0;
  352. for (uint32_t t=0; t<threads_to_use; ++t) {
  353. uint32_t next = last + inc + (t < extra);
  354. asargs[t] = { buf, block_size, switches, last, next, stride };
  355. last = next;
  356. if (t > 0) {
  357. threadpool_dispatch(g_thread_id+t,
  358. applySwitchesRange<oswap_style,CBT>, &asargs[t]);
  359. }
  360. }
  361. // Do the first block ourselves
  362. applySwitchesRange<oswap_style,CBT>(&asargs[0]);
  363. for (size_t t=1; t<threads_to_use; ++t) {
  364. threadpool_join(g_thread_id+t, NULL);
  365. }
  366. }
  367. // Apply inverse of permutation encoded by control bits to data elements
  368. // in buffer using a multithread evaluation plan. Permutes in place.
  369. template <OSwap_Style oswap_style>
  370. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  371. size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal) {
  372. const uint32_t N = plan.N;
  373. const uint32_t nthreads = plan.nthreads;
  374. if (N < 2) return;
  375. if (nthreads <= 1) {
  376. #ifdef PROFILE_MTAPPLYPERM
  377. unsigned long start = printf_with_rtclock("Thread %u starting single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  378. #endif
  379. applyInversePermutation<oswap_style>(buf, N, block_size, traversal);
  380. #ifdef PROFILE_MTAPPLYPERM
  381. printf_with_rtclock_diff(start, "Thread %u ending single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  382. #endif
  383. return;
  384. }
  385. #ifdef PROFILE_MTAPPLYPERM
  386. unsigned long start = printf_with_rtclock("Thread %u starting applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  387. #endif
  388. const uint32_t Nleft = (N+1)/2;
  389. const uint32_t Nright = N/2;
  390. const uint32_t numInSwitches = (N-1)/2;
  391. const uint32_t numOutSwitches = N/2;
  392. const uint32_t *inSwitch = traversal.inSwitches;
  393. const uint8_t *outSwitch = traversal.outSwitches;
  394. const uint32_t nthr_left = (nthreads+1)/2;
  395. const uint32_t nthr_right = nthreads/2;
  396. WNTraversal lefttraversal = traversal;
  397. lefttraversal.inSwitches += numInSwitches;
  398. lefttraversal.outSwitches += numOutSwitches;
  399. traversal.inSwitches +=
  400. plan.subplans[0].subtree_num_inswitches + numInSwitches;
  401. traversal.outSwitches +=
  402. plan.subplans[0].subtree_num_outswitches + numOutSwitches;
  403. #ifdef SHOW_APPLYPERM
  404. printf("s");
  405. for(uint32_t i=0;i<N;++i) {
  406. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  407. }
  408. printf("\n");
  409. #endif
  410. FOAV_SAFE_CNTXT(AIP, N)
  411. if (N == 2) {
  412. #ifdef SHOW_APPLYPERM
  413. printf("o");
  414. for(uint32_t i=0;i<numOutSwitches;++i) {
  415. printf(" %s", outSwitch[i] ? " X" : "||");
  416. }
  417. printf("\n");
  418. #endif
  419. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  420. outSwitch[0]);
  421. #ifdef SHOW_APPLYPERM
  422. printf("e");
  423. for(uint32_t i=0;i<N;++i) {
  424. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  425. }
  426. printf("\n");
  427. #endif
  428. } else {
  429. // Apply output switches to permutation
  430. #ifdef SHOW_APPLYPERM
  431. printf("o");
  432. for(uint32_t i=0;i<numOutSwitches;++i) {
  433. printf(" %s", outSwitch[i] ? " X" : "||");
  434. }
  435. printf("\n");
  436. #endif
  437. #ifdef PROFILE_MTAPPLYPERM
  438. unsigned long outswstart = printf_with_rtclock("Thread %u starting output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  439. #endif
  440. applySwitches<oswap_style>(buf, block_size, outSwitch, numOutSwitches,
  441. Nleft, nthreads);
  442. #ifdef PROFILE_MTAPPLYPERM
  443. printf_with_rtclock_diff(outswstart, "Thread %u ending output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  444. #endif
  445. #ifdef SHOW_APPLYPERM
  446. printf(" ");
  447. for(uint32_t i=0;i<N;++i) {
  448. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  449. }
  450. printf("\n");
  451. #endif
  452. // Apply subnetwork switches
  453. threadid_t rightthreadid = g_thread_id + nthr_left;
  454. appInvPermArgs rightargs(*this, buf + block_size*Nleft,
  455. block_size, plan.subplans[1], traversal);
  456. threadpool_dispatch(rightthreadid,
  457. applyInversePermutation_launch<oswap_style>, &rightargs);
  458. applyInversePermutation<oswap_style>(buf, block_size,
  459. plan.subplans[0], lefttraversal);
  460. threadpool_join(rightthreadid, NULL);
  461. // Apply input switches to permutation
  462. #ifdef SHOW_APPLYPERM
  463. printf("r");
  464. for(uint32_t i=0;i<N;++i) {
  465. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  466. }
  467. printf("\n");
  468. printf("i");
  469. for(uint32_t i=0;i<numInSwitches;++i) {
  470. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  471. }
  472. printf("\n");
  473. #endif
  474. #ifdef PROFILE_MTAPPLYPERM
  475. unsigned long inswstart = printf_with_rtclock("Thread %u starting input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  476. #endif
  477. applySwitches<oswap_style>(buf, block_size, inSwitch, numInSwitches,
  478. Nleft, nthreads);
  479. #ifdef PROFILE_MTAPPLYPERM
  480. printf_with_rtclock_diff(inswstart, "Thread %u ending input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  481. #endif
  482. #ifdef SHOW_APPLYPERM
  483. printf("e");
  484. for(uint32_t i=0;i<N;++i) {
  485. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  486. }
  487. printf("\n");
  488. #endif
  489. }
  490. #ifdef PROFILE_MTAPPLYPERM
  491. printf_with_rtclock_diff(start, "Thread %u ending applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  492. #endif
  493. }
  494. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  495. template <OSwap_Style oswap_style>
  496. void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size) {
  497. FOAV_SAFE_CNTXT(AIP, Ntotal)
  498. if (Ntotal > 1) {
  499. WNTraversal traversal(*this);
  500. applyInversePermutation<oswap_style>(buf, Ntotal, block_size,
  501. traversal);
  502. }
  503. }
  504. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  505. template <OSwap_Style oswap_style>
  506. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  507. uint32_t N, size_t block_size, WNTraversal &traversal) {
  508. FOAV_SAFE_CNTXT(AIP, N)
  509. if (N < 2) return;
  510. const uint32_t Nleft = (N+1)/2;
  511. const uint32_t Nright = N/2;
  512. const uint32_t numInSwitches = (N-1)/2;
  513. const uint32_t numOutSwitches = N/2;
  514. const uint32_t *inSwitch = traversal.inSwitches;
  515. const uint8_t *outSwitch = traversal.outSwitches;
  516. traversal.subnetNumber += 1;
  517. traversal.inSwitches += numInSwitches;
  518. traversal.outSwitches += numOutSwitches;
  519. #ifdef SHOW_APPLYPERM
  520. printf("s");
  521. for(uint32_t i=0;i<N;++i) {
  522. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  523. }
  524. printf("\n");
  525. #endif
  526. FOAV_SAFE_CNTXT(AIP, N)
  527. if (N == 2) {
  528. #ifdef SHOW_APPLYPERM
  529. printf("o");
  530. for(uint32_t i=0;i<numOutSwitches;++i) {
  531. printf(" %s", outSwitch[i] ? " X" : "||");
  532. }
  533. printf("\n");
  534. #endif
  535. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  536. outSwitch[0]);
  537. #ifdef SHOW_APPLYPERM
  538. printf("e");
  539. for(uint32_t i=0;i<N;++i) {
  540. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  541. }
  542. printf("\n");
  543. #endif
  544. } else {
  545. // Apply output switches to permutation
  546. #ifdef SHOW_APPLYPERM
  547. printf("o");
  548. for(uint32_t i=0;i<numOutSwitches;++i) {
  549. printf(" %s", outSwitch[i] ? " X" : "||");
  550. }
  551. printf("\n");
  552. #endif
  553. FOAV_SAFE_CNTXT(AIP, numOutSwitches)
  554. for (uint32_t i=0; i<numOutSwitches; i++) {
  555. FOAV_SAFE2_CNTXT(AIP, i, numOutSwitches)
  556. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  557. *outSwitch);
  558. ++outSwitch;
  559. }
  560. #ifdef SHOW_APPLYPERM
  561. printf(" ");
  562. for(uint32_t i=0;i<N;++i) {
  563. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  564. }
  565. printf("\n");
  566. #endif
  567. // Apply subnetwork switches
  568. applyInversePermutation<oswap_style>(buf, Nleft,
  569. block_size, traversal);
  570. applyInversePermutation<oswap_style>(buf + block_size*Nleft, Nright,
  571. block_size, traversal);
  572. // Apply input switches to permutation
  573. #ifdef SHOW_APPLYPERM
  574. printf("r");
  575. for(uint32_t i=0;i<N;++i) {
  576. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  577. }
  578. printf("\n");
  579. printf("i");
  580. for(uint32_t i=0;i<numInSwitches;++i) {
  581. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  582. }
  583. printf("\n");
  584. #endif
  585. const uint32_t *curInSwitchVal = inSwitch;
  586. FOAV_SAFE_CNTXT(AIP, numInSwitches)
  587. for (uint32_t i=0; i<numInSwitches; i++) {
  588. FOAV_SAFE2_CNTXT(AIP, i, numInSwitches)
  589. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  590. (*curInSwitchVal&1));
  591. curInSwitchVal += 1;
  592. }
  593. #ifdef SHOW_APPLYPERM
  594. printf("e");
  595. for(uint32_t i=0;i<N;++i) {
  596. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  597. }
  598. printf("\n");
  599. #endif
  600. }
  601. }
  602. #if 0
  603. void OblivWaksmanShuffle(unsigned char *buffer, uint32_t N, size_t block_size, enc_ret *ret);
  604. void DecryptAndOblivWaksmanShuffle(unsigned char *encrypted_buffer, uint32_t N,
  605. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  606. void DecryptAndOWSS(unsigned char *encrypted_buffer, uint32_t N,
  607. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  608. void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
  609. size_t encrypted_block_size, size_t nthreads,
  610. unsigned char *result_buffer, enc_ret *ret);
  611. #endif
  612. #endif