#ifndef __WAKSMANNETWORK_HPP__ #define __WAKSMANNETWORK_HPP__ #include #include #include "oasm_lib.h" #include #include "utils.hpp" #include "RecursiveShuffle.hpp" #include "aes.hpp" // #define PROFILE_MTAPPLYPERM typedef __uint128_t randkey_t; #define FPERM_OSWAP_STYLE OSWAP_8_16X // OSwap_Style for forward perm (consistent w/ randkey_t) // A struct to hold a multi-thread evaluation plan for a // WaksmanNetwork of a given size and number of threads // If you have a WaksmanNetwork with N>2 items, and you want to apply // it using nthreads>1 threads, this WaksmanNetwork will itself use // the first (N-1)/2 input switches and N/2 output switches (in // addition to those used by its subnetworks); it will recurse into // the left subnetwork containing (N+1)/2 items and (nthreads+1)/2 // threads, and into the right subnetwork containing N/2 items and // nthreads/2 threads. If N=2, then there is just a single output // switch and no subnetworks. If N<2, there are no switches and no // subnetworks. If nthreads=1, we compute the total number of inputs // and output switches used by this WaksmanNetwork and its // subnetworks, but just store the total in this WNEvalPlan with no // WNEvalPlans for the subnetworks. struct WNEvalPlan { // The number of items for the represented WaksmanNetwork uint32_t N; // The number of threads to use to evaluate it uint32_t nthreads; // The total number of input and output switches used by this // WaksmanNetwork and its subnetworks size_t subtree_num_inswitches, subtree_num_outswitches; // If N>2 and nthreads>1, these are the evaluation plans for the // subnetworks. This vector will contain 0 (if N<=2 or nthreads=1) // or 2 (otherwise) items. std::vector subplans; // Make WNEvalPlan objects non-copyable for efficiency WNEvalPlan(const WNEvalPlan&) = delete; WNEvalPlan& operator=(const WNEvalPlan&) = delete; // But moves are OK WNEvalPlan(WNEvalPlan &&wn) = default; WNEvalPlan& operator=(WNEvalPlan&&) = default; WNEvalPlan(uint32_t N, uint32_t nthreads) : N(N), nthreads(nthreads) { if (N<2) { subtree_num_inswitches = 0; subtree_num_outswitches = 0; } else if (N == 2) { subtree_num_inswitches = 0; subtree_num_outswitches = 1; } else if (nthreads <= 1) { subtree_num_inswitches = 0; subtree_num_outswitches = 0; count_switches(N); } else { const uint32_t Nleft = (N+1)/2; const uint32_t Nright = N/2; const uint32_t numInSwitches = (N-1)/2; const uint32_t numOutSwitches = N/2; const uint32_t nthr_left = (nthreads+1)/2; const uint32_t nthr_right = nthreads/2; subplans.emplace_back(Nleft, nthr_left); subplans.emplace_back(Nright, nthr_right); subtree_num_inswitches = numInSwitches + subplans[0].subtree_num_inswitches + subplans[1].subtree_num_inswitches; subtree_num_outswitches = numOutSwitches + subplans[0].subtree_num_outswitches + subplans[1].subtree_num_outswitches; } } // Count the number of input and output switches used by a // WaksmanNetwork with N items. Add those values to // subtree_num_inswitches and subtree_num_outswitches. void count_switches(uint32_t N) { if (N<2) { return; } if (N == 2) { subtree_num_outswitches += 1; return; } const uint32_t Nleft = (N+1)/2; const uint32_t Nright = N/2; const uint32_t numInSwitches = (N-1)/2; const uint32_t numOutSwitches = N/2; subtree_num_inswitches += numInSwitches; subtree_num_outswitches += numOutSwitches; count_switches(Nleft); count_switches(Nright); } void dump(int indent = 0) const { printf("%*sN = %lu, nthreads = %lu, inswitches = %lu, outswitches = %lu\n", indent, "", N, nthreads, subtree_num_inswitches, subtree_num_outswitches); if (subplans.size() > 0) { subplans[0].dump(indent+2); subplans[1].dump(indent+2); } } }; /* WaksmanNetwork Class: Contains a Waksman permutation network and can apply it to an input array. setPermutation(uint32_t *permutation, unsigned char *forward_perm, [optional preallocated memory regions]): Takes permutation as an array of N index values (i.e. values in [N]) and optional pointers to allocated memory. It sets the Waksman network switches to that permutation. applyPermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of block_size bytes each and applies stored permutation in-place (i.e. modifying input buffer). applyInversePermutation(unsigned char *buf, size_t block_size): Takes buffer of N items of block_size bytes each and applies inverse of stored permutation in-place. */ class WaksmanNetwork { uint32_t Ntotal; // number of items to permute std::vector inSwitchVec; // input layer of (numbered) switches std::vector outSwitchVec; // output layer of switches // A struct to keep track of the current subnet number, and input and // output switches, for each subnet as we traverse the network. struct WNTraversal { uint64_t subnetNumber; uint32_t *inSwitches; uint8_t *outSwitches; WNTraversal(WaksmanNetwork &wn) : subnetNumber(0), inSwitches(wn.inSwitchVec.data()), outSwitches(wn.outSwitchVec.data()) {} }; struct appInvPermArgs { WaksmanNetwork &wn; unsigned char *buf; size_t block_size; const WNEvalPlan &plan; WNTraversal &traversal; appInvPermArgs(WaksmanNetwork &wn, unsigned char *buf, size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal) : wn(wn), buf(buf), block_size(block_size), plan(plan), traversal(traversal) {} }; template static void *applyInversePermutation_launch(void *voidarg) { appInvPermArgs *arg = (appInvPermArgs *)voidarg; arg->wn.applyInversePermutation(arg->buf, arg->block_size, arg->plan, arg->traversal); return NULL; } // A struct to hold pre-allocated memory (and AES keys) so that we // only allocate memory once, before the recursive setPermutation is // called. struct WNMem { unsigned char *forward_perm; uint32_t *unselected_cnt; std::unordered_map> *reverse_perm; AESkey forward_key, reverse_key; WNMem(WaksmanNetwork &wn) { // Round Ntotal up to an even number uint32_t Neven = wn.Ntotal + (wn.Ntotal&1); forward_perm = new unsigned char[Neven * (sizeof(randkey_t) + 8)]; unselected_cnt = new uint32_t[wn.Ntotal]; reverse_perm = new std::unordered_map>; __m128i forward_rawkey, reverse_rawkey; getRandomBytes((unsigned char *) &forward_rawkey, sizeof(forward_rawkey)); getRandomBytes((unsigned char *) &reverse_rawkey, sizeof(reverse_rawkey)); AES_128_Key_Expansion(forward_key, forward_rawkey); AES_128_Key_Expansion(reverse_key, reverse_rawkey); } ~WNMem() { delete[] forward_perm; delete[] unselected_cnt; delete reverse_perm; } }; void setPermutation(uint32_t *permutation, uint32_t N, uint32_t depth, WNTraversal &traversal, const WNMem &mem); template void applyPermutation(unsigned char *buf, uint32_t N, size_t block_size, WNTraversal &traversal); template void applyInversePermutation(unsigned char *buf, uint32_t N, size_t block_size, WNTraversal &traversal); template void applyInversePermutation(unsigned char *buf, size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal); public: // Make WaksmanNetwork objects non-copyable for efficiency WaksmanNetwork(const WaksmanNetwork&) = delete; WaksmanNetwork& operator=(const WaksmanNetwork&) = delete; // But moves are OK WaksmanNetwork(WaksmanNetwork &&wn) = default; WaksmanNetwork& operator=(WaksmanNetwork&&) = default; // Set up the WaksmanNetwork for N items. N need not be a power of 2. // N <= 2^31 WaksmanNetwork(uint32_t N); void setPermutation(uint32_t *permutation); template void applyPermutation(unsigned char *buf, size_t block_size); template void applyInversePermutation(unsigned char *buf, size_t block_size); template void applyInversePermutation(unsigned char *buf, size_t block_size, const WNEvalPlan &plan); }; // Define this to show the intermediate states of applyPermutation // #define SHOW_APPLYPERM // Apply permutation encoded by control bits to data elements in buffer. Permutes in place. template void WaksmanNetwork::applyPermutation(unsigned char *buf, size_t block_size) { FOAV_SAFE_CNTXT(AP, Ntotal) if (Ntotal > 1) { WNTraversal traversal(*this); applyPermutation(buf, Ntotal, block_size, traversal); } } // Apply permutation encoded by control bits to data elements in buffer. Permutes in place. template void WaksmanNetwork::applyPermutation(unsigned char *buf, uint32_t N, size_t block_size, WNTraversal &traversal) { FOAV_SAFE_CNTXT(AP, Ntotal) FOAV_SAFE_CNTXT(AP, N) if (N < 2) return; const uint32_t Nleft = (N+1)/2; const uint32_t Nright = N/2; const uint32_t numInSwitches = (N-1)/2; const uint32_t numOutSwitches = N/2; const uint32_t *inSwitch = traversal.inSwitches; const uint8_t *outSwitch = traversal.outSwitches; traversal.subnetNumber += 1; traversal.inSwitches += numInSwitches; traversal.outSwitches += numOutSwitches; #ifdef SHOW_APPLYPERM printf("s"); for(uint32_t i=0;i(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]); #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i(buf + block_size*(i), buf + block_size*(Nleft+i), block_size, (*curInSwitchVal)&1); curInSwitchVal += 1; } #ifdef SHOW_APPLYPERM printf(" "); for(uint32_t i=0;i(buf, Nleft, block_size, traversal); applyPermutation(buf + block_size*Nleft, Nright, block_size, traversal); #ifdef SHOW_APPLYPERM printf("r"); for(uint32_t i=0;i(buf + block_size*i, buf + block_size*(Nleft+i), block_size, *outSwitch); ++outSwitch; } #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size, const WNEvalPlan &plan) { FOAV_SAFE_CNTXT(AP, Ntotal) if (Ntotal > 1) { WNTraversal traversal(*this); applyInversePermutation(buf, block_size, plan, traversal); } } template struct ApplySwitchesArgs { unsigned char *buf; size_t block_size; const CBT* switches; uint32_t swStart, swEnd; uint32_t stride; }; // Apply a consecutive sequence of input or output switches, using // arguments passed as an ApplySwitchesArgs*. CBT is the // control bit type (uint32_t for input switches or uint8_t for output // switches). template static void* applySwitchesRange(void *voidargs) { const ApplySwitchesArgs* args = (const ApplySwitchesArgs *)voidargs; unsigned char *buf = args->buf; const size_t block_size = args->block_size; const uint32_t swStart = args->swStart; const CBT* switches = args->switches + swStart; const uint32_t swEnd = args->swEnd; const uint32_t stride = args->stride; FOAV_SAFE_CNTXT(applySwitchesRange, swEnd) for (uint32_t i=swStart; i(buf + block_size*(i), buf + block_size*(stride+i), block_size, (*switches)&1); ++switches; } return NULL; } // Apply a consecutive sequence of input or output switches using // up to nthreads threads. CBT is the control bit type (uint32_t for // input switches or uint8_t for output switches), but it will be // deduced automatically from the type of the switches argument. template static void applySwitches(unsigned char *buf, size_t block_size, const CBT* switches, uint32_t numSwitches, uint32_t stride, uint32_t nthreads) { uint32_t threads_to_use = nthreads; ApplySwitchesArgs asargs[threads_to_use]; uint32_t inc = numSwitches / threads_to_use; uint32_t extra = numSwitches % threads_to_use; uint32_t last = 0; for (uint32_t t=0; t 0) { threadpool_dispatch(g_thread_id+t, applySwitchesRange, &asargs[t]); } } // Do the first block ourselves applySwitchesRange(&asargs[0]); for (size_t t=1; t void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size, const WNEvalPlan &plan, WNTraversal &traversal) { const uint32_t N = plan.N; const uint32_t nthreads = plan.nthreads; if (N < 2) return; if (nthreads <= 1) { #ifdef PROFILE_MTAPPLYPERM unsigned long start = printf_with_rtclock("Thread %u starting single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N); #endif applyInversePermutation(buf, N, block_size, traversal); #ifdef PROFILE_MTAPPLYPERM printf_with_rtclock_diff(start, "Thread %u ending single-threaded applyInversePermutation(N=%lu)\n", g_thread_id, N); #endif return; } #ifdef PROFILE_MTAPPLYPERM unsigned long start = printf_with_rtclock("Thread %u starting applyInversePermutation(N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads); #endif const uint32_t Nleft = (N+1)/2; const uint32_t Nright = N/2; const uint32_t numInSwitches = (N-1)/2; const uint32_t numOutSwitches = N/2; const uint32_t *inSwitch = traversal.inSwitches; const uint8_t *outSwitch = traversal.outSwitches; const uint32_t nthr_left = (nthreads+1)/2; const uint32_t nthr_right = nthreads/2; WNTraversal lefttraversal = traversal; lefttraversal.inSwitches += numInSwitches; lefttraversal.outSwitches += numOutSwitches; traversal.inSwitches += numInSwitches; traversal.outSwitches += numOutSwitches; if (plan.subplans.size() > 0) { traversal.inSwitches += plan.subplans[0].subtree_num_inswitches; traversal.outSwitches += plan.subplans[0].subtree_num_outswitches; } #ifdef SHOW_APPLYPERM printf("s"); for(uint32_t i=0;i(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]); #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i(buf, block_size, outSwitch, numOutSwitches, Nleft, nthreads); #ifdef PROFILE_MTAPPLYPERM printf_with_rtclock_diff(outswstart, "Thread %u ending output switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads); #endif #ifdef SHOW_APPLYPERM printf(" "); for(uint32_t i=0;i, &rightargs); applyInversePermutation(buf, block_size, plan.subplans[0], lefttraversal); threadpool_join(rightthreadid, NULL); // Apply input switches to permutation #ifdef SHOW_APPLYPERM printf("r"); for(uint32_t i=0;i(buf, block_size, inSwitch, numInSwitches, Nleft, nthreads); #ifdef PROFILE_MTAPPLYPERM printf_with_rtclock_diff(inswstart, "Thread %u ending input switches (N=%lu, nthreads=%lu)\n", g_thread_id, N, nthreads); #endif #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i void WaksmanNetwork::applyInversePermutation(unsigned char *buf, size_t block_size) { FOAV_SAFE_CNTXT(AIP, Ntotal) if (Ntotal > 1) { WNTraversal traversal(*this); applyInversePermutation(buf, Ntotal, block_size, traversal); } } // Apply inverse of permutation in control bits to data elements in buffer. Permutes in place. template void WaksmanNetwork::applyInversePermutation(unsigned char *buf, uint32_t N, size_t block_size, WNTraversal &traversal) { FOAV_SAFE_CNTXT(AIP, N) if (N < 2) return; const uint32_t Nleft = (N+1)/2; const uint32_t Nright = N/2; const uint32_t numInSwitches = (N-1)/2; const uint32_t numOutSwitches = N/2; const uint32_t *inSwitch = traversal.inSwitches; const uint8_t *outSwitch = traversal.outSwitches; traversal.subnetNumber += 1; traversal.inSwitches += numInSwitches; traversal.outSwitches += numOutSwitches; #ifdef SHOW_APPLYPERM printf("s"); for(uint32_t i=0;i(buf, buf + block_size, (uint32_t) block_size, outSwitch[0]); #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i(buf + block_size*i, buf + block_size*(Nleft+i), block_size, *outSwitch); ++outSwitch; } #ifdef SHOW_APPLYPERM printf(" "); for(uint32_t i=0;i(buf, Nleft, block_size, traversal); applyInversePermutation(buf + block_size*Nleft, Nright, block_size, traversal); // Apply input switches to permutation #ifdef SHOW_APPLYPERM printf("r"); for(uint32_t i=0;i(buf + block_size*(i), buf + block_size*(Nleft+i), block_size, (*curInSwitchVal&1)); curInSwitchVal += 1; } #ifdef SHOW_APPLYPERM printf("e"); for(uint32_t i=0;i