123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- // set_key for each kind of key we sort on
- template<>
- inline void set_key<UidKey>(UidKey *key, const uint8_t *item, uint32_t index)
- {
- key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index;
- }
- template<>
- inline void set_key<UidPriorityKey>(UidPriorityKey *key, const uint8_t *item, uint32_t index)
- {
- key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) +
- (*(const uint32_t *)(item+4));
- key->idx = index;
- }
- template<>
- inline void set_key<NidPriorityKey>(NidPriorityKey *key, const uint8_t *item, uint32_t index)
- {
- constexpr uint32_t nid_mask = (~((1<<DEST_UID_BITS)-1));
- key->nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) +
- (*(const uint32_t *)(item+4));
- key->idx = index;
- }
- // compare_keys for each kind of key we sort on. Note that it must not
- // be possible for any of these functions to return 0. These functions
- // must also be oblivious. Return a positive (32-bit signed) int if *a
- // is larger than *b, or a negative (32-bit signed) int otherwise.
- template<>
- int compare_keys<UidKey>(const void* a, const void* b)
- {
- bool alarge = (((const UidKey*)a)->uid_index >
- ((const UidKey *)b)->uid_index);
- return oselect_uint32_t(-1, 1, alarge);
- }
- template<>
- int compare_keys<UidPriorityKey>(const void* a, const void* b)
- {
- uint64_t aup = ((const UidPriorityKey*)a)->uid_priority;
- uint64_t bup = ((const UidPriorityKey*)b)->uid_priority;
- uint32_t aidx = ((const UidPriorityKey*)a)->idx;
- uint32_t bidx = ((const UidPriorityKey*)b)->idx;
- bool auplarge = (aup > bup);
- bool aupeq = (aup == bup);
- bool aidxlarge = (aidx > bidx);
- bool alarge = auplarge | (aupeq & aidxlarge);
- return oselect_uint32_t(-1, 1, alarge);
- }
- template<>
- int compare_keys<NidPriorityKey>(const void* a, const void* b)
- {
- uint64_t anp = ((const NidPriorityKey*)a)->nid_priority;
- uint64_t bnp = ((const NidPriorityKey*)b)->nid_priority;
- uint32_t aidx = ((const NidPriorityKey*)a)->idx;
- uint32_t bidx = ((const NidPriorityKey*)b)->idx;
- bool anplarge = (anp > bnp);
- bool anpeq = (anp == bnp);
- bool aidxlarge = (aidx > bidx);
- bool alarge = anplarge | (anpeq & aidxlarge);
- return oselect_uint32_t(-1, 1, alarge);
- }
- // Sort Nr items at the beginning of an allocated array of Na items
- // using up to nthreads threads. The items to sort are byte arrays of
- // size msg_size. The keys are of type T. T must have set_key<T> and
- // compare_keys<T> defined. The items will be shuffled in-place, and a
- // sorted array of keys will be passed to the provided callback
- // function.
- template<typename T>
- void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
- uint32_t Nr, uint32_t Na,
- // the arguments to the callback are items, the sorted indices, and
- // the number of non-padding items
- std::function<void(const uint8_t*, const T*, uint32_t Nr)> cb)
- {
- // Shuffle the items
- uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
- // Create the indices
- T *idx = new T[Nr];
- T *nextidx = idx;
- for (uint32_t i=0; i<Nw; ++i) {
- uint64_t padding = (*(uint32_t*)(items+msg_size*i));
- if (padding != uint32_t(-1)) {
- set_key<T>(nextidx, items+msg_size*i, i);
- ++nextidx;
- }
- }
- if (nextidx != idx + Nr) {
- printf("Found %u non-padding items, expected %u\n",
- nextidx-idx, Nr);
- assert(nextidx == idx + Nr);
- }
- // Sort the keys and indices
- T *backingidx = new T[Nr];
- bool whichbuf = mtmergesort<T>(idx, Nr, backingidx, nthreads);
- T *sortedidx = whichbuf ? backingidx : idx;
- cb(items, sortedidx, Nr);
- delete[] idx;
- delete[] backingidx;
- }
- template <typename T>
- struct move_sorted_args {
- const T* sorted_keys;
- const uint8_t *items;
- uint8_t *destbuf;
- uint32_t start, num;
- uint16_t msg_size;
- };
- template <typename T>
- static void *move_sorted(void *voidargs)
- {
- const move_sorted_args<T> *args =
- (move_sorted_args<T> *)voidargs;
- uint16_t msg_size = args->msg_size;
- uint32_t start = args->start;
- uint32_t end = start + args->num;
- const T *sorted_keys = args->sorted_keys;
- const uint8_t *items = args->items;
- uint8_t *destbuf = args->destbuf;
- for (uint32_t i=start; i<end; ++i) {
- memmove(destbuf + i * msg_size,
- items + (sorted_keys[i].index()) * msg_size,
- msg_size);
- }
- return NULL;
- }
- // As above, but also pass an Nr*msg_size-byte buffer outbuf to put
- // the sorted values into, instead of passing a callback. This calls
- // the above function, then copies the data in sorted order into outbuf.
- // Note: the outbuf buffer cannot overlap the items buffer.
- template<typename T>
- void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
- uint32_t Nr_, uint32_t Na, uint8_t *outbuf)
- {
- sort_mtobliv<T>(nthreads, items, msg_size, Nr_, Na,
- [nthreads, msg_size, outbuf]
- (const uint8_t* origitems, const T* keys, uint32_t Nr) {
- // Special-case nthreads=1 for efficiency
- if (nthreads <= 1) {
- move_sorted_args<T> args = {
- keys, origitems, outbuf, 0, Nr, msg_size
- };
- move_sorted<T>(&args);
- } else {
- move_sorted_args<T> args[nthreads];
- uint32_t inc = Nr / nthreads;
- uint32_t extra = Nr % nthreads;
- uint32_t last = 0;
- for (threadid_t i=0; i<nthreads; ++i) {
- uint32_t num = inc + (i < extra);
- args[i] = {
- keys, origitems, outbuf, last, num, msg_size
- };
- last += num;
- }
- // Launch all but the first section into other threads
- for (threadid_t i=1; i<nthreads; ++i) {
- threadpool_dispatch(g_thread_id+i,
- move_sorted<T>, args+i);
- }
- // Do the first section ourselves
- move_sorted<T>(args);
- // Join the threads
- for (threadid_t i=1; i<nthreads; ++i) {
- threadpool_join(g_thread_id+i, NULL);
- }
- }
- });
- }
- // As above, but the first Nr msg_size-byte entries in the items array
- // will end up with the sorted values. Note: if Nr < Na, entries beyond
- // Nr may also change, but you should not even look at those values!
- // This calls the above function with a temporary buffer, then copies
- // that buffer back into the items array, so it's less efficient, both
- // in memory and CPU, than the above functions.
- template<typename T>
- void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
- uint32_t Nr, uint32_t Na)
- {
- uint8_t *tempbuf = new uint8_t[Nr * msg_size];
- sort_mtobliv<T>(nthreads, items, msg_size, Nr, Na, tempbuf);
- memmove(items, tempbuf, Nr * msg_size);
- delete[] tempbuf;
- }
|