WaksmanNetwork.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. #ifndef __WAKSMANNETWORK_HPP__
  2. #define __WAKSMANNETWORK_HPP__
  3. #include <unordered_map>
  4. #include <vector>
  5. #include "oasm_lib.h"
  6. #include <sgx_tcrypto.h>
  7. #include "utils.hpp"
  8. #include "RecursiveShuffle.hpp"
  9. #include "aes.hpp"
  10. // #define PROFILE_MTAPPLYPERM
  11. typedef __uint128_t randkey_t;
  12. #define FPERM_OSWAP_STYLE OSWAP_8_16X // OSwap_Style for forward perm (consistent w/ randkey_t)
  13. // A struct to hold a multi-thread evaluation plan for a
  14. // WaksmanNetwork of a given size and number of threads
  15. // If you have a WaksmanNetwork with N>2 items, and you want to apply
  16. // it using nthreads>1 threads, this WaksmanNetwork will itself use
  17. // the first (N-1)/2 input switches and N/2 output switches (in
  18. // addition to those used by its subnetworks); it will recurse into
  19. // the left subnetwork containing (N+1)/2 items and (nthreads+1)/2
  20. // threads, and into the right subnetwork containing N/2 items and
  21. // nthreads/2 threads. If N=2, then there is just a single output
  22. // switch and no subnetworks. If N<2, there are no switches and no
  23. // subnetworks. If nthreads=1, we compute the total number of inputs
  24. // and output switches used by this WaksmanNetwork and its
  25. // subnetworks, but just store the total in this WNEvalPlan with no
  26. // WNEvalPlans for the subnetworks.
  27. struct WNEvalPlan {
  28. // The number of items for the represented WaksmanNetwork
  29. uint32_t N;
  30. // The number of threads to use to evaluate it
  31. uint32_t nthreads;
  32. // The total number of input and output switches used by this
  33. // WaksmanNetwork and its subnetworks
  34. size_t subtree_num_inswitches, subtree_num_outswitches;
  35. // If N>2 and nthreads>1, these are the evaluation plans for the
  36. // subnetworks. This vector will contain 0 (if N<=2 or nthreads=1)
  37. // or 2 (otherwise) items.
  38. std::vector<WNEvalPlan> subplans;
  39. // Make WNEvalPlan objects non-copyable for efficiency
  40. WNEvalPlan(const WNEvalPlan&) = delete;
  41. WNEvalPlan& operator=(const WNEvalPlan&) = delete;
  42. // But moves are OK
  43. WNEvalPlan(WNEvalPlan &&wn) = default;
  44. WNEvalPlan& operator=(WNEvalPlan&&) = default;
  45. WNEvalPlan(uint32_t N_, uint32_t nthreads_) : N(N_),
  46. nthreads(nthreads_) {
  47. if (N<2) {
  48. subtree_num_inswitches = 0;
  49. subtree_num_outswitches = 0;
  50. } else if (N == 2) {
  51. subtree_num_inswitches = 0;
  52. subtree_num_outswitches = 1;
  53. } else if (nthreads <= 1) {
  54. subtree_num_inswitches = 0;
  55. subtree_num_outswitches = 0;
  56. count_switches(N);
  57. } else {
  58. const uint32_t Nleft = (N+1)/2;
  59. const uint32_t Nright = N/2;
  60. const uint32_t numInSwitches = (N-1)/2;
  61. const uint32_t numOutSwitches = N/2;
  62. const uint32_t nthr_left = (nthreads+1)/2;
  63. const uint32_t nthr_right = nthreads/2;
  64. subplans.emplace_back(Nleft, nthr_left);
  65. subplans.emplace_back(Nright, nthr_right);
  66. subtree_num_inswitches = numInSwitches +
  67. subplans[0].subtree_num_inswitches +
  68. subplans[1].subtree_num_inswitches;
  69. subtree_num_outswitches = numOutSwitches +
  70. subplans[0].subtree_num_outswitches +
  71. subplans[1].subtree_num_outswitches;
  72. }
  73. }
  74. // Count the number of input and output switches used by a
  75. // WaksmanNetwork with num items. Add those values to
  76. // subtree_num_inswitches and subtree_num_outswitches.
  77. void count_switches(uint32_t num) {
  78. if (num<2) {
  79. return;
  80. }
  81. if (num == 2) {
  82. subtree_num_outswitches += 1;
  83. return;
  84. }
  85. const uint32_t Nleft = (num+1)/2;
  86. const uint32_t Nright = num/2;
  87. const uint32_t numInSwitches = (num-1)/2;
  88. const uint32_t numOutSwitches = num/2;
  89. subtree_num_inswitches += numInSwitches;
  90. subtree_num_outswitches += numOutSwitches;
  91. count_switches(Nleft);
  92. count_switches(Nright);
  93. }
  94. void dump(int indent = 0) const {
  95. printf("%*sN = %lu, nthreads = %lu, inswitches = %lu, outswitches = %lu\n",
  96. indent, "", N, nthreads, subtree_num_inswitches,
  97. subtree_num_outswitches);
  98. if (subplans.size() > 0) {
  99. subplans[0].dump(indent+2);
  100. subplans[1].dump(indent+2);
  101. }
  102. }
  103. };
  104. /*
  105. WaksmanNetwork Class: Contains a Waksman permutation network and can apply it to an input array.
  106. setPermutation(uint32_t *permutation, unsigned char *forward_perm, [optional preallocated
  107. memory regions]): Takes permutation as an array of N index values (i.e. values in [N]) and
  108. optional pointers to allocated memory. It sets the Waksman network switches to that
  109. permutation.
  110. applyPermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of block_size
  111. bytes each and applies stored permutation in-place (i.e. modifying input buffer).
  112. applyInversePermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of
  113. block_size bytes each and applies inverse of stored permutation in-place.
  114. */
  115. class WaksmanNetwork {
  116. uint32_t Ntotal; // number of items to permute
  117. std::vector<uint32_t> inSwitchVec; // input layer of (numbered) switches
  118. std::vector<uint8_t> outSwitchVec; // output layer of switches
  119. // A struct to keep track of the current subnet number, and input and
  120. // output switches, for each subnet as we traverse the network.
  121. struct WNTraversal {
  122. uint64_t subnetNumber;
  123. uint32_t *inSwitches;
  124. uint8_t *outSwitches;
  125. WNTraversal(WaksmanNetwork &wn) : subnetNumber(0),
  126. inSwitches(wn.inSwitchVec.data()),
  127. outSwitches(wn.outSwitchVec.data()) {}
  128. };
  129. struct appInvPermArgs {
  130. WaksmanNetwork &wn;
  131. unsigned char *buf;
  132. size_t block_size;
  133. const WNEvalPlan &plan;
  134. WNTraversal &traversal;
  135. appInvPermArgs(WaksmanNetwork &wn_, unsigned char *buf_,
  136. size_t block_size_, const WNEvalPlan &plan_, WNTraversal
  137. &traversal_) : wn(wn_), buf(buf_), block_size(block_size_),
  138. plan(plan_), traversal(traversal_) {}
  139. };
  140. template <OSwap_Style oswap_style>
  141. static void *applyInversePermutation_launch(void *voidarg)
  142. {
  143. appInvPermArgs *arg = (appInvPermArgs *)voidarg;
  144. arg->wn.applyInversePermutation<oswap_style>(arg->buf, arg->block_size,
  145. arg->plan, arg->traversal);
  146. return NULL;
  147. }
  148. // A struct to hold pre-allocated memory (and AES keys) so that we
  149. // only allocate memory once, before the recursive setPermutation is
  150. // called.
  151. struct WNMem {
  152. unsigned char *forward_perm;
  153. uint32_t *unselected_cnt;
  154. std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>> *reverse_perm;
  155. AESkey forward_key, reverse_key;
  156. WNMem(WaksmanNetwork &wn) {
  157. // Round Ntotal up to an even number
  158. uint32_t Neven = wn.Ntotal + (wn.Ntotal&1);
  159. forward_perm = new unsigned char[Neven * (sizeof(randkey_t) + 8)];
  160. unselected_cnt = new uint32_t[wn.Ntotal];
  161. reverse_perm = new std::unordered_map<randkey_t, std::pair<uint32_t, uint32_t>>;
  162. __m128i forward_rawkey, reverse_rawkey;
  163. getRandomBytes((unsigned char *) &forward_rawkey, sizeof(forward_rawkey));
  164. getRandomBytes((unsigned char *) &reverse_rawkey, sizeof(reverse_rawkey));
  165. AES_128_Key_Expansion(forward_key, forward_rawkey);
  166. AES_128_Key_Expansion(reverse_key, reverse_rawkey);
  167. }
  168. ~WNMem() {
  169. delete[] forward_perm;
  170. delete[] unselected_cnt;
  171. delete reverse_perm;
  172. }
  173. };
  174. void setPermutation(uint32_t *permutation, uint32_t N,
  175. uint32_t depth, WNTraversal &traversal, const WNMem &mem);
  176. template <OSwap_Style oswap_style>
  177. void applyPermutation(unsigned char *buf, uint32_t N,
  178. size_t block_size, WNTraversal &traversal);
  179. template <OSwap_Style oswap_style>
  180. void applyInversePermutation(unsigned char *buf, uint32_t N,
  181. size_t block_size, WNTraversal &traversal);
  182. template <OSwap_Style oswap_style>
  183. void applyInversePermutation(unsigned char *buf, size_t block_size,
  184. const WNEvalPlan &plan, WNTraversal &traversal);
  185. public:
  186. // Make WaksmanNetwork objects non-copyable for efficiency
  187. WaksmanNetwork(const WaksmanNetwork&) = delete;
  188. WaksmanNetwork& operator=(const WaksmanNetwork&) = delete;
  189. // But moves are OK
  190. WaksmanNetwork(WaksmanNetwork &&wn) = default;
  191. WaksmanNetwork& operator=(WaksmanNetwork&&) = default;
  192. // Set up the WaksmanNetwork for N items. N need not be a power of 2.
  193. // N <= 2^31
  194. WaksmanNetwork(uint32_t N);
  195. void setPermutation(uint32_t *permutation);
  196. template <OSwap_Style oswap_style>
  197. void applyPermutation(unsigned char *buf, size_t block_size);
  198. template <OSwap_Style oswap_style>
  199. void applyInversePermutation(unsigned char *buf, size_t block_size);
  200. template <OSwap_Style oswap_style>
  201. void applyInversePermutation(unsigned char *buf, size_t block_size,
  202. const WNEvalPlan &plan);
  203. };
  204. // Define this to show the intermediate states of applyPermutation
  205. // #define SHOW_APPLYPERM
  206. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  207. template <OSwap_Style oswap_style>
  208. void WaksmanNetwork::applyPermutation(unsigned char *buf, size_t block_size) {
  209. FOAV_SAFE_CNTXT(AP, Ntotal)
  210. if (Ntotal > 1) {
  211. WNTraversal traversal(*this);
  212. applyPermutation<oswap_style>(buf, Ntotal, block_size, traversal);
  213. }
  214. }
  215. // Apply permutation encoded by control bits to data elements in buffer. Permutes in place.
  216. template <OSwap_Style oswap_style>
  217. void WaksmanNetwork::applyPermutation(unsigned char *buf, uint32_t N,
  218. size_t block_size, WNTraversal &traversal) {
  219. FOAV_SAFE_CNTXT(AP, Ntotal)
  220. FOAV_SAFE_CNTXT(AP, N)
  221. if (N < 2) return;
  222. const uint32_t Nleft = (N+1)/2;
  223. const uint32_t Nright = N/2;
  224. const uint32_t numInSwitches = (N-1)/2;
  225. const uint32_t numOutSwitches = N/2;
  226. const uint32_t *inSwitch = traversal.inSwitches;
  227. const uint8_t *outSwitch = traversal.outSwitches;
  228. traversal.subnetNumber += 1;
  229. traversal.inSwitches += numInSwitches;
  230. traversal.outSwitches += numOutSwitches;
  231. #ifdef SHOW_APPLYPERM
  232. printf("s");
  233. for(uint32_t i=0;i<N;++i) {
  234. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  235. }
  236. printf("\n");
  237. #endif
  238. if (N == 2) {
  239. #ifdef SHOW_APPLYPERM
  240. printf("o");
  241. for(uint32_t i=0;i<numOutSwitches;++i) {
  242. printf(" %s", outSwitch[i] ? " X" : "||");
  243. }
  244. printf("\n");
  245. #endif
  246. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]);
  247. #ifdef SHOW_APPLYPERM
  248. printf("e");
  249. for(uint32_t i=0;i<N;++i) {
  250. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  251. }
  252. printf("\n");
  253. #endif
  254. } else {
  255. #ifdef SHOW_APPLYPERM
  256. printf("i");
  257. for(uint32_t i=0;i<numInSwitches;++i) {
  258. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  259. }
  260. printf("\n");
  261. #endif
  262. // Apply input switches to permutation
  263. const uint32_t *curInSwitchVal = inSwitch;
  264. for (uint32_t i=0; i<numInSwitches; i++) {
  265. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  266. (*curInSwitchVal)&1);
  267. curInSwitchVal += 1;
  268. }
  269. #ifdef SHOW_APPLYPERM
  270. printf(" ");
  271. for(uint32_t i=0;i<N;++i) {
  272. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  273. }
  274. printf("\n");
  275. #endif
  276. // Apply subnetwork switches
  277. applyPermutation<oswap_style>(buf, Nleft, block_size, traversal);
  278. applyPermutation<oswap_style>(buf + block_size*Nleft, Nright,
  279. block_size, traversal);
  280. #ifdef SHOW_APPLYPERM
  281. printf("r");
  282. for(uint32_t i=0;i<N;++i) {
  283. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  284. }
  285. printf("\n");
  286. printf("o");
  287. for(uint32_t i=0;i<numOutSwitches;++i) {
  288. printf(" %s", outSwitch[i] ? " X" : "||");
  289. }
  290. printf("\n");
  291. #endif
  292. // Apply output switches to permutation
  293. for (uint32_t i=0; i<numOutSwitches; i++) {
  294. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  295. *outSwitch);
  296. ++outSwitch;
  297. }
  298. #ifdef SHOW_APPLYPERM
  299. printf("e");
  300. for(uint32_t i=0;i<N;++i) {
  301. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  302. }
  303. printf("\n");
  304. #endif
  305. }
  306. }
  307. // Apply permutation encoded by control bits to data elements in buffer
  308. // using a multithread evaluation plan. Permutes in place.
  309. template <OSwap_Style oswap_style>
  310. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  311. size_t block_size, const WNEvalPlan &plan) {
  312. FOAV_SAFE_CNTXT(AP, Ntotal)
  313. if (Ntotal > 1) {
  314. WNTraversal traversal(*this);
  315. applyInversePermutation<oswap_style>(buf, block_size, plan, traversal);
  316. }
  317. }
  318. template <typename CBT>
  319. struct ApplySwitchesArgs {
  320. unsigned char *buf;
  321. size_t block_size;
  322. const CBT* switches;
  323. uint32_t swStart, swEnd;
  324. uint32_t stride;
  325. };
  326. // Apply a consecutive sequence of input or output switches, using
  327. // arguments passed as an ApplySwitchesArgs*. CBT is the
  328. // control bit type (uint32_t for input switches or uint8_t for output
  329. // switches).
  330. template <OSwap_Style oswap_style, typename CBT>
  331. static void* applySwitchesRange(void *voidargs)
  332. {
  333. const ApplySwitchesArgs<CBT>* args =
  334. (const ApplySwitchesArgs<CBT> *)voidargs;
  335. unsigned char *buf = args->buf;
  336. const size_t block_size = args->block_size;
  337. const uint32_t swStart = args->swStart;
  338. const CBT* switches = args->switches + swStart;
  339. const uint32_t swEnd = args->swEnd;
  340. const uint32_t stride = args->stride;
  341. FOAV_SAFE_CNTXT(applySwitchesRange, swEnd)
  342. for (uint32_t i=swStart; i<swEnd; ++i) {
  343. FOAV_SAFE2_CNTXT(applySwitchesRange, i, swEnd)
  344. oswap_buffer<oswap_style>(buf + block_size*(i),
  345. buf + block_size*(stride+i), block_size,
  346. (*switches)&1);
  347. ++switches;
  348. }
  349. return NULL;
  350. }
  351. // Apply a consecutive sequence of input or output switches using
  352. // up to nthreads threads. CBT is the control bit type (uint32_t for
  353. // input switches or uint8_t for output switches), but it will be
  354. // deduced automatically from the type of the switches argument.
  355. template <OSwap_Style oswap_style, typename CBT>
  356. static void applySwitches(unsigned char *buf, size_t block_size,
  357. const CBT* switches, uint32_t numSwitches, uint32_t stride,
  358. uint32_t nthreads)
  359. {
  360. uint32_t threads_to_use = nthreads;
  361. ApplySwitchesArgs<CBT> asargs[threads_to_use];
  362. uint32_t inc = numSwitches / threads_to_use;
  363. uint32_t extra = numSwitches % threads_to_use;
  364. uint32_t last = 0;
  365. for (uint32_t t=0; t<threads_to_use; ++t) {
  366. uint32_t next = last + inc + (t < extra);
  367. asargs[t] = { buf, block_size, switches, last, next, stride };
  368. last = next;
  369. if (t > 0) {
  370. threadpool_dispatch(g_thread_id+t,
  371. applySwitchesRange<oswap_style,CBT>, &asargs[t]);
  372. }
  373. }
  374. // Do the first block ourselves
  375. applySwitchesRange<oswap_style,CBT>(&asargs[0]);
  376. for (size_t t=1; t<threads_to_use; ++t) {
  377. threadpool_join(g_thread_id+t, NULL);
  378. }
  379. }
  380. // Apply inverse of permutation encoded by control bits to data elements
  381. // in buffer using a multithread evaluation plan. Permutes in place.
  382. template <OSwap_Style oswap_style>
  383. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  384. size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal) {
  385. const uint32_t N = plan.N;
  386. const uint32_t nthreads = plan.nthreads;
  387. if (N < 2) return;
  388. if (nthreads <= 1) {
  389. #ifdef PROFILE_MTAPPLYPERM
  390. unsigned long start = printf_with_rtclock("Thread %u starting single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  391. #endif
  392. applyInversePermutation<oswap_style>(buf, N, block_size, traversal);
  393. #ifdef PROFILE_MTAPPLYPERM
  394. printf_with_rtclock_diff(start, "Thread %u ending single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N);
  395. #endif
  396. return;
  397. }
  398. #ifdef PROFILE_MTAPPLYPERM
  399. unsigned long start = printf_with_rtclock("Thread %u starting applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  400. #endif
  401. const uint32_t Nleft = (N+1)/2;
  402. const uint32_t Nright = N/2;
  403. const uint32_t numInSwitches = (N-1)/2;
  404. const uint32_t numOutSwitches = N/2;
  405. const uint32_t *inSwitch = traversal.inSwitches;
  406. const uint8_t *outSwitch = traversal.outSwitches;
  407. const uint32_t nthr_left = (nthreads+1)/2;
  408. const uint32_t nthr_right = nthreads/2;
  409. WNTraversal lefttraversal = traversal;
  410. lefttraversal.inSwitches += numInSwitches;
  411. lefttraversal.outSwitches += numOutSwitches;
  412. traversal.inSwitches += numInSwitches;
  413. traversal.outSwitches += numOutSwitches;
  414. if (plan.subplans.size() > 0) {
  415. traversal.inSwitches += plan.subplans[0].subtree_num_inswitches;
  416. traversal.outSwitches += plan.subplans[0].subtree_num_outswitches;
  417. }
  418. #ifdef SHOW_APPLYPERM
  419. printf("s");
  420. for(uint32_t i=0;i<N;++i) {
  421. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  422. }
  423. printf("\n");
  424. #endif
  425. FOAV_SAFE_CNTXT(AIP, N)
  426. if (N == 2) {
  427. #ifdef SHOW_APPLYPERM
  428. printf("o");
  429. for(uint32_t i=0;i<numOutSwitches;++i) {
  430. printf(" %s", outSwitch[i] ? " X" : "||");
  431. }
  432. printf("\n");
  433. #endif
  434. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  435. outSwitch[0]);
  436. #ifdef SHOW_APPLYPERM
  437. printf("e");
  438. for(uint32_t i=0;i<N;++i) {
  439. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  440. }
  441. printf("\n");
  442. #endif
  443. } else {
  444. // Apply output switches to permutation
  445. #ifdef SHOW_APPLYPERM
  446. printf("o");
  447. for(uint32_t i=0;i<numOutSwitches;++i) {
  448. printf(" %s", outSwitch[i] ? " X" : "||");
  449. }
  450. printf("\n");
  451. #endif
  452. #ifdef PROFILE_MTAPPLYPERM
  453. unsigned long outswstart = printf_with_rtclock("Thread %u starting output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  454. #endif
  455. applySwitches<oswap_style>(buf, block_size, outSwitch, numOutSwitches,
  456. Nleft, nthreads);
  457. #ifdef PROFILE_MTAPPLYPERM
  458. printf_with_rtclock_diff(outswstart, "Thread %u ending output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  459. #endif
  460. #ifdef SHOW_APPLYPERM
  461. printf(" ");
  462. for(uint32_t i=0;i<N;++i) {
  463. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  464. }
  465. printf("\n");
  466. #endif
  467. // Apply subnetwork switches
  468. threadid_t rightthreadid = g_thread_id + nthr_left;
  469. appInvPermArgs rightargs(*this, buf + block_size*Nleft,
  470. block_size, plan.subplans[1], traversal);
  471. threadpool_dispatch(rightthreadid,
  472. applyInversePermutation_launch<oswap_style>, &rightargs);
  473. applyInversePermutation<oswap_style>(buf, block_size,
  474. plan.subplans[0], lefttraversal);
  475. threadpool_join(rightthreadid, NULL);
  476. // Apply input switches to permutation
  477. #ifdef SHOW_APPLYPERM
  478. printf("r");
  479. for(uint32_t i=0;i<N;++i) {
  480. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  481. }
  482. printf("\n");
  483. printf("i");
  484. for(uint32_t i=0;i<numInSwitches;++i) {
  485. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  486. }
  487. printf("\n");
  488. #endif
  489. #ifdef PROFILE_MTAPPLYPERM
  490. unsigned long inswstart = printf_with_rtclock("Thread %u starting input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  491. #endif
  492. applySwitches<oswap_style>(buf, block_size, inSwitch, numInSwitches,
  493. Nleft, nthreads);
  494. #ifdef PROFILE_MTAPPLYPERM
  495. printf_with_rtclock_diff(inswstart, "Thread %u ending input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  496. #endif
  497. #ifdef SHOW_APPLYPERM
  498. printf("e");
  499. for(uint32_t i=0;i<N;++i) {
  500. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  501. }
  502. printf("\n");
  503. #endif
  504. }
  505. #ifdef PROFILE_MTAPPLYPERM
  506. printf_with_rtclock_diff(start, "Thread %u ending applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads);
  507. #endif
  508. }
  509. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  510. template <OSwap_Style oswap_style>
  511. void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size) {
  512. FOAV_SAFE_CNTXT(AIP, Ntotal)
  513. if (Ntotal > 1) {
  514. WNTraversal traversal(*this);
  515. applyInversePermutation<oswap_style>(buf, Ntotal, block_size,
  516. traversal);
  517. }
  518. }
  519. // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place.
  520. template <OSwap_Style oswap_style>
  521. void WaksmanNetwork::applyInversePermutation(unsigned char *buf,
  522. uint32_t N, size_t block_size, WNTraversal &traversal) {
  523. FOAV_SAFE_CNTXT(AIP, N)
  524. if (N < 2) return;
  525. const uint32_t Nleft = (N+1)/2;
  526. const uint32_t Nright = N/2;
  527. const uint32_t numInSwitches = (N-1)/2;
  528. const uint32_t numOutSwitches = N/2;
  529. const uint32_t *inSwitch = traversal.inSwitches;
  530. const uint8_t *outSwitch = traversal.outSwitches;
  531. traversal.subnetNumber += 1;
  532. traversal.inSwitches += numInSwitches;
  533. traversal.outSwitches += numOutSwitches;
  534. #ifdef SHOW_APPLYPERM
  535. printf("s");
  536. for(uint32_t i=0;i<N;++i) {
  537. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  538. }
  539. printf("\n");
  540. #endif
  541. FOAV_SAFE_CNTXT(AIP, N)
  542. if (N == 2) {
  543. #ifdef SHOW_APPLYPERM
  544. printf("o");
  545. for(uint32_t i=0;i<numOutSwitches;++i) {
  546. printf(" %s", outSwitch[i] ? " X" : "||");
  547. }
  548. printf("\n");
  549. #endif
  550. oswap_buffer<oswap_style>(buf, buf + block_size, (uint32_t) block_size,
  551. outSwitch[0]);
  552. #ifdef SHOW_APPLYPERM
  553. printf("e");
  554. for(uint32_t i=0;i<N;++i) {
  555. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  556. }
  557. printf("\n");
  558. #endif
  559. } else {
  560. // Apply output switches to permutation
  561. #ifdef SHOW_APPLYPERM
  562. printf("o");
  563. for(uint32_t i=0;i<numOutSwitches;++i) {
  564. printf(" %s", outSwitch[i] ? " X" : "||");
  565. }
  566. printf("\n");
  567. #endif
  568. FOAV_SAFE_CNTXT(AIP, numOutSwitches)
  569. for (uint32_t i=0; i<numOutSwitches; i++) {
  570. FOAV_SAFE2_CNTXT(AIP, i, numOutSwitches)
  571. oswap_buffer<oswap_style>(buf + block_size*i, buf + block_size*(Nleft+i), block_size,
  572. *outSwitch);
  573. ++outSwitch;
  574. }
  575. #ifdef SHOW_APPLYPERM
  576. printf(" ");
  577. for(uint32_t i=0;i<N;++i) {
  578. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  579. }
  580. printf("\n");
  581. #endif
  582. // Apply subnetwork switches
  583. applyInversePermutation<oswap_style>(buf, Nleft,
  584. block_size, traversal);
  585. applyInversePermutation<oswap_style>(buf + block_size*Nleft, Nright,
  586. block_size, traversal);
  587. // Apply input switches to permutation
  588. #ifdef SHOW_APPLYPERM
  589. printf("r");
  590. for(uint32_t i=0;i<N;++i) {
  591. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  592. }
  593. printf("\n");
  594. printf("i");
  595. for(uint32_t i=0;i<numInSwitches;++i) {
  596. printf(" %s", (inSwitch[i]&1) ? " X" : "||");
  597. }
  598. printf("\n");
  599. #endif
  600. const uint32_t *curInSwitchVal = inSwitch;
  601. FOAV_SAFE_CNTXT(AIP, numInSwitches)
  602. for (uint32_t i=0; i<numInSwitches; i++) {
  603. FOAV_SAFE2_CNTXT(AIP, i, numInSwitches)
  604. oswap_buffer<oswap_style>(buf + block_size*(i), buf + block_size*(Nleft+i), block_size,
  605. (*curInSwitchVal&1));
  606. curInSwitchVal += 1;
  607. }
  608. #ifdef SHOW_APPLYPERM
  609. printf("e");
  610. for(uint32_t i=0;i<N;++i) {
  611. printf(" %2d", *(uint32_t*)(buf+block_size*i));
  612. }
  613. printf("\n");
  614. #endif
  615. }
  616. }
  617. #if 0
  618. void OblivWaksmanShuffle(unsigned char *buffer, uint32_t N, size_t block_size, enc_ret *ret);
  619. void DecryptAndOblivWaksmanShuffle(unsigned char *encrypted_buffer, uint32_t N,
  620. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  621. void DecryptAndOWSS(unsigned char *encrypted_buffer, uint32_t N,
  622. size_t encrypted_block_size, unsigned char *result_buffer, enc_ret *ret);
  623. void DecryptAndMTSS(unsigned char *encrypted_buffer, uint32_t N,
  624. size_t encrypted_block_size, size_t nthreads,
  625. unsigned char *result_buffer, enc_ret *ret);
  626. #endif
  627. #include "WaksmanNetwork.tcc"
  628. #endif