// set_key for each kind of key we sort on template<> inline void set_key(UidKey *key, const uint8_t *item, uint32_t index) { key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index; } template<> inline void set_key(UidPriorityKey *key, const uint8_t *item, uint32_t index) { key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) + (*(const uint32_t *)(item+4)); key->idx = index; } template<> inline void set_key(NidPriorityKey *key, const uint8_t *item, uint32_t index) { constexpr uint32_t nid_mask = (~((1<nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) + (*(const uint32_t *)(item+4)); key->idx = index; } // compare_keys for each kind of key we sort on. Note that it must not // be possible for any of these functions to return 0. These functions // must also be oblivious. Return a positive (32-bit signed) int if *a // is larger than *b, or a negative (32-bit signed) int otherwise. template<> int compare_keys(const void* a, const void* b) { bool alarge = (((const UidKey*)a)->uid_index > ((const UidKey *)b)->uid_index); return oselect_uint32_t(-1, 1, alarge); } template<> int compare_keys(const void* a, const void* b) { uint64_t aup = ((const UidPriorityKey*)a)->uid_priority; uint64_t bup = ((const UidPriorityKey*)b)->uid_priority; uint32_t aidx = ((const UidPriorityKey*)a)->idx; uint32_t bidx = ((const UidPriorityKey*)b)->idx; bool auplarge = (aup > bup); bool aupeq = (aup == bup); bool aidxlarge = (aidx > bidx); bool alarge = auplarge | (aupeq & aidxlarge); return oselect_uint32_t(-1, 1, alarge); } template<> int compare_keys(const void* a, const void* b) { uint64_t anp = ((const NidPriorityKey*)a)->nid_priority; uint64_t bnp = ((const NidPriorityKey*)b)->nid_priority; uint32_t aidx = ((const NidPriorityKey*)a)->idx; uint32_t bidx = ((const NidPriorityKey*)b)->idx; bool anplarge = (anp > bnp); bool anpeq = (anp == bnp); bool aidxlarge = (aidx > bidx); bool alarge = anplarge | (anpeq & aidxlarge); return oselect_uint32_t(-1, 1, alarge); } // Sort Nr items at the beginning of an allocated array of Na items // using up to nthreads threads. The items to sort are byte arrays of // size msg_size. The keys are of type T. T must have set_key and // compare_keys defined. The items will be shuffled in-place, and a // sorted array of keys will be passed to the provided callback // function. template void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size, uint32_t Nr, uint32_t Na, // the arguments to the callback are items, the sorted indices, and // the number of non-padding items std::function cb) { // Shuffle the items uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na); // Create the indices T *idx = new T[Nr]; T *nextidx = idx; for (uint32_t i=0; i(nextidx, items+msg_size*i, i); ++nextidx; } } if (nextidx != idx + Nr) { printf("Found %u non-padding items, expected %u\n", nextidx-idx, Nr); assert(nextidx == idx + Nr); } // Sort the keys and indices T *backingidx = new T[Nr]; bool whichbuf = mtmergesort(idx, Nr, backingidx, nthreads); T *sortedidx = whichbuf ? backingidx : idx; cb(items, sortedidx, Nr); delete[] idx; delete[] backingidx; } template struct move_sorted_args { const T* sorted_keys; const uint8_t *items; uint8_t *destbuf; uint32_t start, num; uint16_t msg_size; }; template static void *move_sorted(void *voidargs) { const move_sorted_args *args = (move_sorted_args *)voidargs; uint16_t msg_size = args->msg_size; uint32_t start = args->start; uint32_t end = start + args->num; const T *sorted_keys = args->sorted_keys; const uint8_t *items = args->items; uint8_t *destbuf = args->destbuf; for (uint32_t i=start; i void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size, uint32_t Nr, uint32_t Na, uint8_t *outbuf) { sort_mtobliv(nthreads, items, msg_size, Nr, Na, [nthreads, msg_size, outbuf] (const uint8_t* origitems, const T* keys, uint32_t Nr) { // Special-case nthreads=1 for efficiency if (nthreads <= 1) { move_sorted_args args = { keys, origitems, outbuf, 0, Nr, msg_size }; move_sorted(&args); } else { move_sorted_args args[nthreads]; uint32_t inc = Nr / nthreads; uint32_t extra = Nr % nthreads; uint32_t last = 0; for (threadid_t i=0; i, args+i); } // Do the first section ourselves move_sorted(args); // Join the threads for (threadid_t i=1; i void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size, uint32_t Nr, uint32_t Na) { uint8_t *tempbuf = new uint8_t[Nr * msg_size]; sort_mtobliv(nthreads, items, msg_size, Nr, Na, tempbuf); memmove(items, tempbuf, Nr * msg_size); delete[] tempbuf; }