WaksmanNetwork.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. #ifndef __WAKSMANNETWORK_HPP__
  2. #define __WAKSMANNETWORK_HPP__
  3. #include <unordered_map>
  4. #include <vector>
  5. #include "oasm_lib.h"
  6. #include <sgx_tcrypto.h>
  7. #include "utils.hpp"
  8. #include "RecursiveShuffle.hpp"
  9. #include "aes.hpp"
  10. // #define PROFILE_MTAPPLYPERM
  11. typedef __uint128_t randkey_t;
  12. #define FPERM_OSWAP_STYLE OSWAP_8_16X // OSwap_Style for forward perm (consistent w/ randkey_t)
  13. // A struct to hold a multi-thread evaluation plan for a
  14. // WaksmanNetwork of a given size and number of threads
  15. // If you have a WaksmanNetwork with N>2 items, and you want to apply
  16. // it using nthreads>1 threads, this WaksmanNetwork will itself use
  17. // the first (N-1)/2 input switches and N/2 output switches (in
  18. // addition to those used by its subnetworks); it will recurse into
  19. // the left subnetwork containing (N+1)/2 items and (nthreads+1)/2
  20. // threads, and into the right subnetwork containing N/2 items and
  21. // nthreads/2 threads. If N=2, then there is just a single output
  22. // switch and no subnetworks. If N<2, there are no switches and no
  23. // subnetworks. If nthreads=1, we compute the total number of inputs
  24. // and output switches used by this WaksmanNetwork and its
  25. // subnetworks, but just store the total in this WNEvalPlan with no
  26. // WNEvalPlans for the subnetworks.
  27. struct WNEvalPlan {
  28. // The number of items for the represented WaksmanNetwork
  29. uint32_t N;
  30. // The number of threads to use to evaluate it
  31. uint32_t nthreads;
  32. // The total number of input and output switches used by this
  33. // WaksmanNetwork and its subnetworks
  34. size_t subtree_num_inswitches, subtree_num_outswitches;
  35. // If N>2 and nthreads>1, these are the evaluation plans for the
  36. // subnetworks. This vector will contain 0 (if N<=2 or nthreads=1)
  37. // or 2 (otherwise) items.
  38. std::vector<WNEvalPlan> subplans;
  39. // Make WNEvalPlan objects non-copyable for efficiency
  40. WNEvalPlan(const WNEvalPlan&) = delete;
  41. WNEvalPlan& operator=(const WNEvalPlan&) = delete;
  42. // But moves are OK
  43. WNEvalPlan(WNEvalPlan &&wn) = default;
  44. WNEvalPlan& operator=(WNEvalPlan&&) = default;
  45. WNEvalPlan(uint32_t N, uint32_t nthreads) : N(N), nthreads(nthreads) {
  46. if (N<2) {
  47. subtree_num_inswitches = 0;
  48. subtree_num_outswitches = 0;
  49. } else if (N == 2) {
  50. subtree_num_inswitches = 0;
  51. subtree_num_outswitches = 1;
  52. } else if (nthreads <= 1) {
  53. subtree_num_inswitches = 0;
  54. subtree_num_outswitches = 0;
  55. count_switches(N);
  56. } else {
  57. const uint32_t Nleft = (N+1)/2;
  58. const uint32_t Nright = N/2;
  59. const uint32_t numInSwitches = (N-1)/2;
  60. const uint32_t numOutSwitches = N/2;
  61. const uint32_t nthr_left = (nthreads+1)/2;
  62. const uint32_t nthr_right = nthreads/2;
  63. subplans.emplace_back(Nleft, nthr_left);
  64. subplans.emplace_back(Nright, nthr_right);
  65. subtree_num_inswitches = numInSwitches +
  66. subplans[0].subtree_num_inswitches +
  67. subplans[1].subtree_num_inswitches;
  68. subtree_num_outswitches = numOutSwitches +
  69. subplans[0].subtree_num_outswitches +
  70. subplans[1].subtree_num_outswitches;
  71. }
  72. }
  73. // Count the number of input and output switches used by a
  74. // WaksmanNetwork with N items. Add those values to
  75. // subtree_num_inswitches and subtree_num_outswitches.
  76. void count_switches(uint32_t N) {
  77. if (N<2) {
  78. return;
  79. }
  80. if (N == 2) {
  81. subtree_num_outswitches += 1;
  82. return;
  83. }
  84. const uint32_t Nleft = (N+1)/2;
  85. const uint32_t Nright = N/2;
  86. const uint32_t numInSwitches = (N-1)/2;
  87. const uint32_t numOutSwitches = N/2;
  88. subtree_num_inswitches += numInSwitches;
  89. subtree_num_outswitches += numOutSwitches;
  90. count_switches(Nleft);
  91. count_switches(Nright);
  92. }
  93. void dump(int indent = 0) const {
  94. printf("%*sN = %lu, nthreads = %lu, inswitches = %lu, outswitches = %lu\n",
  95. indent, "", N, nthreads, subtree_num_inswitches,
  96. subtree_num_outswitches);
  97. if (subplans.size() > 0) {
  98. subplans[0].dump(indent+2);
  99. subplans[1].dump(indent+2);
  100. }
  101. }
  102. };
  103. /*
  104. WaksmanNetwork Class: Contains a Waksman permutation network and can apply it to an input array.
  105. setPermutation(uint32_t *permutation, unsigned char *forward_perm, [optional preallocated
  106. memory regions]): Takes permutation as an array of N index values (i.e. values in [N]) and
  107. optional pointers to allocated memory. It sets the Waksman network switches to that
  108. permutation.
  109. applyPermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of block_size
  110. bytes each and applies stored permutation in-place (i.e. modifying input buffer).
  111. applyInversePermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of
  112. block_size bytes each and applies inverse of stored permutation in-place.
  113. */
  114. class WaksmanNetwork {
  115. uint32_t Ntotal; // number of items to permute
  116. std::vector<uint32_t> inSwitchVec; // input layer of (numbered) switches
  117. std::vector<uint8_t> outSwitchVec; // output layer of switches
  118. // A struct to keep track of the current subnet number, and input and
  119. // output switches, for each subnet as we traverse the network.
  120. struct WNTraversal {
  121. uint64_t subnetNumber;
  122. uint32_t *inSwitches;
  123. uint8_t *outSwitches;
  124. WNTraversal(WaksmanNetwork &wn) : subnetNumber(0),
  125. inSwitches(wn.inSwitchVec.data()),
  126. outSwitches(wn.outSwitchVec.data()) {}
  127. };
  128. struct appInvPermArgs {
  129. WaksmanNetwork &wn;
  130. unsigned char *buf;
  131. size_t block_size;
  132. const WNEvalPlan &plan;
  133. WNTraversal &traversal;
  134. appInvPermArgs(WaksmanNetwork &wn, unsigned char *buf,
  135. size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal)
  136. : wn(wn), buf(buf), block_size(block_size), plan(plan),
  137. traversal(traversal) {}
  138. };
  139. template <OSwap_Style oswap_style>
  140. static void *applyInversePermutation_launch(void *voidarg)
  141. {
  142. appInvPermArgs *arg = (appInvPermArgs *)voidarg;
  143. arg->wn.applyInversePermutation<oswap_style>(arg->buf, arg->block_size,
  144. arg->plan, arg->traversal);
  145. return NULL;
  146. }
  147. // A struct to hold pre-allocated memory (and AES keys) so that we
  148. // only allocate memory once, before the recursive setPermutation is
  149. // called.
  150. struct WNMem {
  151. unsigned char *forward_perm;
  152. uint32_t *unselected_cnt;
  153. std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>> *reverse_perm;
  154. AESkey forward_key, reverse_key;
  155. WNMem(WaksmanNetwork &wn) {
  156. // Round Ntotal up to an even number
  157. uint32_t Neven = wn.Ntotal + (wn.Ntotal&1);
  158. forward_perm = new unsigned char[Neven * (sizeof(randkey_t) + 8)];
  159. unselected_cnt = new uint32_t[wn.Ntotal];
  160. reverse_perm = new std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>>;
  161. __m128i forward_rawkey, reverse_rawkey;
  162. getRandomBytes((unsigned char *) &forward_rawkey, sizeof(forward_rawkey));
  163. getRandomBytes((unsigned char *) &reverse_rawkey, sizeof(reverse_rawkey));
  164. AES_128_Key_Expansion(forward_key, forward_rawkey);
  165. AES_128_Key_Expansion(reverse_key, reverse_rawkey);
  166. }
  167. ~WNMem() {
  168. delete[] forward_perm;
  169. delete[] unselected_cnt;
  170. delete reverse_perm;
  171. }
  172. };
  173. void setPermutation(uint32_t *permutation, uint32_t N,
  174. uint32_t depth, WNTraversal &traversal, const WNMem &mem);
  175. template <OSwap_Style oswap_style>
  176. void applyPermutation(unsigned char *buf, uint32_t N,
  177. size_t block_size, WNTraversal &traversal);
  178. template <OSwap_Style oswap_style>
  179. void applyInversePermutation(unsigned char *buf, uint32_t N,
  180. size_t block_size, WNTraversal &traversal);
  181. template <OSwap_Style oswap_style>
  182. void applyInversePermutation(unsigned char *buf, size_t block_size,
  183. const WNEvalPlan &plan, WNTraversal &traversal);
  184. public:
  185. // Make WaksmanNetwork objects non-copyable for efficiency
  186. WaksmanNetwork(const WaksmanNetwork&) = delete;
  187. WaksmanNetwork& operator=(const WaksmanNetwork&) = delete;
  188. // But moves are OK
  189. WaksmanNetwork(WaksmanNetwork &&wn) = default;
  190. WaksmanNetwork& operator=(WaksmanNetwork&&) = default;
  191. // Set up the WaksmanNetwork for N items. N need not be a power of 2.
  192. // N <= 2^31
  193. WaksmanNetwork(uint32_t N);
  194. void setPermutation(uint32_t *permutation);
  195. template <OSwap_Style oswap_style>
  196. void applyPermutation(unsigned char *buf, size_t block_size);
  197. template <OSwap_Style oswap_style>
  198. void applyInversePermutation(unsigned char *buf, size_t block_size);
  199. template <OSwap_Style oswap_style>
  200. void applyInversePermutation(unsigned char *buf, size_t block_size,
  201. const WNEvalPlan &plan);
  202. };
  203. // Define this to show the intermediate states of applyPermutation
  204. // #define SHOW_APPLYPERM
  205. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  206. template <OSwap_Style oswap_style>
  207. void WaksmanNetwork::applyPermutation(unsigned char *buf, size_t block_size) {
  208. FOAV_SAFE_CNTXT(AP, Ntotal)
  209. if (Ntotal > 1) {
  210. WNTraversal traversal(*this);
  211. applyPermutation<oswap_style>(buf, Ntotal, block_size, traversal);
  212. }
  213. }
  214. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  215. template <OSwap_Style oswap_style>
  216. void WaksmanNetwork::applyPermutation(unsigned char *buf, uint32_t N,
  217. size_t block_size, WNTraversal &traversal) {
  218. FOAV_SAFE_CNTXT(AP, Ntotal)
  219. FOAV_SAFE_CNTXT(AP, N)
  220. if (N < 2) return;
  221. const uint32_t Nleft = (N+1)/2;
  222. const uint32_t Nright = N/2;
  223. const uint32_t numInSwitches = (N-1)/2;
  224. const uint32_t numOutSwitches = N/2;
  225. const uint32_t *inSwitch = traversal.inSwitches;
  226. const uint8_t *outSwitch = traversal.outSwitches;
  227. traversal.subnetNumber += 1;
  228. traversal.inSwitches += numInSwitches;
  229. traversal.outSwitches += numOutSwitches;
  230. #ifdef SHOW_APPLYPERM
  231. printf("s");
  232. for(uint32_t i=0;i<N;++i) {
  233. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  234. }
  235. printf("\n");
  236. #endif
  237. if (N == 2) {
  238. #ifdef SHOW_APPLYPERM
  239. printf("o");
  240. for(uint32_t i=0;i<numOutSwitches;++i) {
  241. printf(" %s", outSwitch[i] ? " X" : "||");
  242. }
  243. printf("\n");
  244. #endif
  245. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]);
  246. #ifdef SHOW_APPLYPERM
  247. printf("e");
  248. for(uint32_t i=0;i<N;++i) {
  249. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  250. }
  251. printf("\n");
  252. #endif
  253. } else {
  254. #ifdef SHOW_APPLYPERM
  255. printf("i");
  256. for(uint32_t i=0;i<numInSwitches;++i) {
  257. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  258. }
  259. printf("\n");
  260. #endif
  261. // Apply input switches to permutation
  262. const uint32_t *curInSwitchVal = inSwitch;
  263. for (uint32_t i=0; i<numInSwitches; i++) {
  264. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  265. (*curInSwitchVal)&1);
  266. curInSwitchVal += 1;
  267. }
  268. #ifdef SHOW_APPLYPERM
  269. printf(" ");
  270. for(uint32_t i=0;i<N;++i) {
  271. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  272. }
  273. printf("\n");
  274. #endif
  275. // Apply subnetwork switches
  276. applyPermutation<oswap_style>(buf, Nleft, block_size, traversal);
  277. applyPermutation<oswap_style>(buf + block_size*Nleft, Nright,
  278. block_size, traversal);
  279. #ifdef SHOW_APPLYPERM
  280. printf("r");
  281. for(uint32_t i=0;i<N;++i) {
  282. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  283. }
  284. printf("\n");
  285. printf("o");
  286. for(uint32_t i=0;i<numOutSwitches;++i) {
  287. printf(" %s", outSwitch[i] ? " X" : "||");
  288. }
  289. printf("\n");
  290. #endif
  291. // Apply output switches to permutation
  292. for (uint32_t i=0; i<numOutSwitches; i++) {
  293. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  294. *outSwitch);
  295. ++outSwitch;
  296. }
  297. #ifdef SHOW_APPLYPERM
  298. printf("e");
  299. for(uint32_t i=0;i<N;++i) {
  300. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  301. }
  302. printf("\n");
  303. #endif
  304. }
  305. }
  306. // Apply permutation encoded by control bits to data elements in buffer
  307. // using a multithread evaluation plan. Permutes in place.
  308. template <OSwap_Style oswap_style>
  309. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  310. size_t block_size, const WNEvalPlan &plan) {
  311. FOAV_SAFE_CNTXT(AP, Ntotal)
  312. if (Ntotal > 1) {
  313. WNTraversal traversal(*this);
  314. applyInversePermutation<oswap_style>(buf, block_size, plan, traversal);
  315. }
  316. }
  317. template <typename CBT>
  318. struct ApplySwitchesArgs {
  319. unsigned char *buf;
  320. size_t block_size;
  321. const CBT* switches;
  322. uint32_t swStart, swEnd;
  323. uint32_t stride;
  324. };
  325. // Apply a consecutive sequence of input or output switches, using
  326. // arguments passed as an ApplySwitchesArgs*. CBT is the
  327. // control bit type (uint32_t for input switches or uint8_t for output
  328. // switches).
  329. template <OSwap_Style oswap_style, typename CBT>
  330. static void* applySwitchesRange(void *voidargs)
  331. {
  332. const ApplySwitchesArgs<CBT>* args =
  333. (const ApplySwitchesArgs<CBT> *)voidargs;
  334. unsigned char *buf = args->buf;
  335. const size_t block_size = args->block_size;
  336. const uint32_t swStart = args->swStart;
  337. const CBT* switches = args->switches + swStart;
  338. const uint32_t swEnd = args->swEnd;
  339. const uint32_t stride = args->stride;
  340. FOAV_SAFE_CNTXT(applySwitchesRange, swEnd)
  341. for (uint32_t i=swStart; i<swEnd; ++i) {
  342. FOAV_SAFE2_CNTXT(applySwitchesRange, i, swEnd)
  343. oswap_buffer<oswap_style>(buf + block_size*(i),
  344. buf + block_size*(stride+i), block_size,
  345. (*switches)&1);
  346. ++switches;
  347. }
  348. return NULL;
  349. }
  350. // Apply a consecutive sequence of input or output switches using
  351. // up to nthreads threads. CBT is the control bit type (uint32_t for
  352. // input switches or uint8_t for output switches), but it will be
  353. // deduced automatically from the type of the switches argument.
  354. template <OSwap_Style oswap_style, typename CBT>
  355. static void applySwitches(unsigned char *buf, size_t block_size,
  356. const CBT* switches, uint32_t numSwitches, uint32_t stride,
  357. uint32_t nthreads)
  358. {
  359. uint32_t threads_to_use = nthreads;
  360. ApplySwitchesArgs<CBT> asargs[threads_to_use];
  361. uint32_t inc = numSwitches / threads_to_use;
  362. uint32_t extra = numSwitches % threads_to_use;
  363. uint32_t last = 0;
  364. for (uint32_t t=0; t<threads_to_use; ++t) {
  365. uint32_t next = last + inc + (t < extra);
  366. asargs[t] = { buf, block_size, switches, last, next, stride };
  367. last = next;
  368. if (t > 0) {
  369. threadpool_dispatch(g_thread_id+t,
  370. applySwitchesRange<oswap_style,CBT>, &asargs[t]);
  371. }
  372. }
  373. // Do the first block ourselves
  374. applySwitchesRange<oswap_style,CBT>(&asargs[0]);
  375. for (size_t t=1; t<threads_to_use; ++t) {
  376. threadpool_join(g_thread_id+t, NULL);
  377. }
  378. }
  379. // Apply inverse of permutation encoded by control bits to data elements
  380. // in buffer using a multithread evaluation plan. Permutes in place.
  381. template <OSwap_Style oswap_style>
  382. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  383. size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal) {
  384. const uint32_t N = plan.N;
  385. const uint32_t nthreads = plan.nthreads;
  386. if (N < 2) return;
  387. if (nthreads <= 1) {
  388. #ifdef PROFILE_MTAPPLYPERM
  389. unsigned long start = printf_with_rtclock("Thread %u starting single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  390. #endif
  391. applyInversePermutation<oswap_style>(buf, N, block_size, traversal);
  392. #ifdef PROFILE_MTAPPLYPERM
  393. printf_with_rtclock_diff(start, "Thread %u ending single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  394. #endif
  395. return;
  396. }
  397. #ifdef PROFILE_MTAPPLYPERM
  398. unsigned long start = printf_with_rtclock("Thread %u starting applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  399. #endif
  400. const uint32_t Nleft = (N+1)/2;
  401. const uint32_t Nright = N/2;
  402. const uint32_t numInSwitches = (N-1)/2;
  403. const uint32_t numOutSwitches = N/2;
  404. const uint32_t *inSwitch = traversal.inSwitches;
  405. const uint8_t *outSwitch = traversal.outSwitches;
  406. const uint32_t nthr_left = (nthreads+1)/2;
  407. const uint32_t nthr_right = nthreads/2;
  408. WNTraversal lefttraversal = traversal;
  409. lefttraversal.inSwitches += numInSwitches;
  410. lefttraversal.outSwitches += numOutSwitches;
  411. traversal.inSwitches += numInSwitches;
  412. traversal.outSwitches += numOutSwitches;
  413. if (plan.subplans.size() > 0) {
  414. traversal.inSwitches += plan.subplans[0].subtree_num_inswitches;
  415. traversal.outSwitches += plan.subplans[0].subtree_num_outswitches;
  416. }
  417. #ifdef SHOW_APPLYPERM
  418. printf("s");
  419. for(uint32_t i=0;i<N;++i) {
  420. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  421. }
  422. printf("\n");
  423. #endif
  424. FOAV_SAFE_CNTXT(AIP, N)
  425. if (N == 2) {
  426. #ifdef SHOW_APPLYPERM
  427. printf("o");
  428. for(uint32_t i=0;i<numOutSwitches;++i) {
  429. printf(" %s", outSwitch[i] ? " X" : "||");
  430. }
  431. printf("\n");
  432. #endif
  433. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  434. outSwitch[0]);
  435. #ifdef SHOW_APPLYPERM
  436. printf("e");
  437. for(uint32_t i=0;i<N;++i) {
  438. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  439. }
  440. printf("\n");
  441. #endif
  442. } else {
  443. // Apply output switches to permutation
  444. #ifdef SHOW_APPLYPERM
  445. printf("o");
  446. for(uint32_t i=0;i<numOutSwitches;++i) {
  447. printf(" %s", outSwitch[i] ? " X" : "||");
  448. }
  449. printf("\n");
  450. #endif
  451. #ifdef PROFILE_MTAPPLYPERM
  452. unsigned long outswstart = printf_with_rtclock("Thread %u starting output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  453. #endif
  454. applySwitches<oswap_style>(buf, block_size, outSwitch, numOutSwitches,
  455. Nleft, nthreads);
  456. #ifdef PROFILE_MTAPPLYPERM
  457. printf_with_rtclock_diff(outswstart, "Thread %u ending output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  458. #endif
  459. #ifdef SHOW_APPLYPERM
  460. printf(" ");
  461. for(uint32_t i=0;i<N;++i) {
  462. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  463. }
  464. printf("\n");
  465. #endif
  466. // Apply subnetwork switches
  467. threadid_t rightthreadid = g_thread_id + nthr_left;
  468. appInvPermArgs rightargs(*this, buf + block_size*Nleft,
  469. block_size, plan.subplans[1], traversal);
  470. threadpool_dispatch(rightthreadid,
  471. applyInversePermutation_launch<oswap_style>, &rightargs);
  472. applyInversePermutation<oswap_style>(buf, block_size,
  473. plan.subplans[0], lefttraversal);
  474. threadpool_join(rightthreadid, NULL);
  475. // Apply input switches to permutation
  476. #ifdef SHOW_APPLYPERM
  477. printf("r");
  478. for(uint32_t i=0;i<N;++i) {
  479. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  480. }
  481. printf("\n");
  482. printf("i");
  483. for(uint32_t i=0;i<numInSwitches;++i) {
  484. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  485. }
  486. printf("\n");
  487. #endif
  488. #ifdef PROFILE_MTAPPLYPERM
  489. unsigned long inswstart = printf_with_rtclock("Thread %u starting input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  490. #endif
  491. applySwitches<oswap_style>(buf, block_size, inSwitch, numInSwitches,
  492. Nleft, nthreads);
  493. #ifdef PROFILE_MTAPPLYPERM
  494. printf_with_rtclock_diff(inswstart, "Thread %u ending input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  495. #endif
  496. #ifdef SHOW_APPLYPERM
  497. printf("e");
  498. for(uint32_t i=0;i<N;++i) {
  499. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  500. }
  501. printf("\n");
  502. #endif
  503. }
  504. #ifdef PROFILE_MTAPPLYPERM
  505. printf_with_rtclock_diff(start, "Thread %u ending applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  506. #endif
  507. }
  508. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  509. template <OSwap_Style oswap_style>
  510. void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size) {
  511. FOAV_SAFE_CNTXT(AIP, Ntotal)
  512. if (Ntotal > 1) {
  513. WNTraversal traversal(*this);
  514. applyInversePermutation<oswap_style>(buf, Ntotal, block_size,
  515. traversal);
  516. }
  517. }
  518. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  519. template <OSwap_Style oswap_style>
  520. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  521. uint32_t N, size_t block_size, WNTraversal &traversal) {
  522. FOAV_SAFE_CNTXT(AIP, N)
  523. if (N < 2) return;
  524. const uint32_t Nleft = (N+1)/2;
  525. const uint32_t Nright = N/2;
  526. const uint32_t numInSwitches = (N-1)/2;
  527. const uint32_t numOutSwitches = N/2;
  528. const uint32_t *inSwitch = traversal.inSwitches;
  529. const uint8_t *outSwitch = traversal.outSwitches;
  530. traversal.subnetNumber += 1;
  531. traversal.inSwitches += numInSwitches;
  532. traversal.outSwitches += numOutSwitches;
  533. #ifdef SHOW_APPLYPERM
  534. printf("s");
  535. for(uint32_t i=0;i<N;++i) {
  536. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  537. }
  538. printf("\n");
  539. #endif
  540. FOAV_SAFE_CNTXT(AIP, N)
  541. if (N == 2) {
  542. #ifdef SHOW_APPLYPERM
  543. printf("o");
  544. for(uint32_t i=0;i<numOutSwitches;++i) {
  545. printf(" %s", outSwitch[i] ? " X" : "||");
  546. }
  547. printf("\n");
  548. #endif
  549. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  550. outSwitch[0]);
  551. #ifdef SHOW_APPLYPERM
  552. printf("e");
  553. for(uint32_t i=0;i<N;++i) {
  554. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  555. }
  556. printf("\n");
  557. #endif
  558. } else {
  559. // Apply output switches to permutation
  560. #ifdef SHOW_APPLYPERM
  561. printf("o");
  562. for(uint32_t i=0;i<numOutSwitches;++i) {
  563. printf(" %s", outSwitch[i] ? " X" : "||");
  564. }
  565. printf("\n");
  566. #endif
  567. FOAV_SAFE_CNTXT(AIP, numOutSwitches)
  568. for (uint32_t i=0; i<numOutSwitches; i++) {
  569. FOAV_SAFE2_CNTXT(AIP, i, numOutSwitches)
  570. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  571. *outSwitch);
  572. ++outSwitch;
  573. }
  574. #ifdef SHOW_APPLYPERM
  575. printf(" ");
  576. for(uint32_t i=0;i<N;++i) {
  577. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  578. }
  579. printf("\n");
  580. #endif
  581. // Apply subnetwork switches
  582. applyInversePermutation<oswap_style>(buf, Nleft,
  583. block_size, traversal);
  584. applyInversePermutation<oswap_style>(buf + block_size*Nleft, Nright,
  585. block_size, traversal);
  586. // Apply input switches to permutation
  587. #ifdef SHOW_APPLYPERM
  588. printf("r");
  589. for(uint32_t i=0;i<N;++i) {
  590. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  591. }
  592. printf("\n");
  593. printf("i");
  594. for(uint32_t i=0;i<numInSwitches;++i) {
  595. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  596. }
  597. printf("\n");
  598. #endif
  599. const uint32_t *curInSwitchVal = inSwitch;
  600. FOAV_SAFE_CNTXT(AIP, numInSwitches)
  601. for (uint32_t i=0; i<numInSwitches; i++) {
  602. FOAV_SAFE2_CNTXT(AIP, i, numInSwitches)
  603. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  604. (*curInSwitchVal&1));
  605. curInSwitchVal += 1;
  606. }
  607. #ifdef SHOW_APPLYPERM
  608. printf("e");
  609. for(uint32_t i=0;i<N;++i) {
  610. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  611. }
  612. printf("\n");
  613. #endif
  614. }
  615. }
  616. #if 0
  617. void OblivWaksmanShuffle(unsigned char *buffer, uint32_t N, size_t block_size, enc_ret *ret);
  618. void DecryptAndOblivWaksmanShuffle(unsigned char *encrypted_buffer, uint32_t N,
  619. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  620. void DecryptAndOWSS(unsigned char *encrypted_buffer, uint32_t N,
  621. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  622. void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
  623. size_t encrypted_block_size, size_t nthreads,
  624. unsigned char *result_buffer, enc_ret *ret);
  625. #endif
  626. #include "WaksmanNetwork.tcc"
  627. #endif