sort.tcc 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. // set_key for each kind of key we sort on
  2. template<>
  3. inline void set_key<UidKey>(UidKey *key, const uint8_t *item, uint32_t index)
  4. {
  5. key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index;
  6. }
  7. template<>
  8. inline void set_key<UidPriorityKey>(UidPriorityKey *key, const uint8_t *item, uint32_t index)
  9. {
  10. key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) +
  11. (*(const uint32_t *)(item+4));
  12. key->idx = index;
  13. }
  14. template<>
  15. inline void set_key<NidPriorityKey>(NidPriorityKey *key, const uint8_t *item, uint32_t index)
  16. {
  17. constexpr uint32_t nid_mask = (~((1<<DEST_UID_BITS)-1));
  18. key->nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) +
  19. (*(const uint32_t *)(item+4));
  20. key->idx = index;
  21. }
  22. // compare_keys for each kind of key we sort on. Note that it must not
  23. // be possible for any of these functions to return 0. These functions
  24. // must also be oblivious. Return a positive (32-bit signed) int if *a
  25. // is larger than *b, or a negative (32-bit signed) int otherwise.
  26. template<>
  27. int compare_keys<UidKey>(const void* a, const void* b)
  28. {
  29. bool alarge = (((const UidKey*)a)->uid_index >
  30. ((const UidKey *)b)->uid_index);
  31. return oselect_uint32_t(-1, 1, alarge);
  32. }
  33. template<>
  34. int compare_keys<UidPriorityKey>(const void* a, const void* b)
  35. {
  36. uint64_t aup = ((const UidPriorityKey*)a)->uid_priority;
  37. uint64_t bup = ((const UidPriorityKey*)b)->uid_priority;
  38. uint32_t aidx = ((const UidPriorityKey*)a)->idx;
  39. uint32_t bidx = ((const UidPriorityKey*)b)->idx;
  40. bool auplarge = (aup > bup);
  41. bool aupeq = (aup == bup);
  42. bool aidxlarge = (aidx > bidx);
  43. bool alarge = auplarge | (aupeq & aidxlarge);
  44. return oselect_uint32_t(-1, 1, alarge);
  45. }
  46. template<>
  47. int compare_keys<NidPriorityKey>(const void* a, const void* b)
  48. {
  49. uint64_t anp = ((const NidPriorityKey*)a)->nid_priority;
  50. uint64_t bnp = ((const NidPriorityKey*)b)->nid_priority;
  51. uint32_t aidx = ((const NidPriorityKey*)a)->idx;
  52. uint32_t bidx = ((const NidPriorityKey*)b)->idx;
  53. bool anplarge = (anp > bnp);
  54. bool anpeq = (anp == bnp);
  55. bool aidxlarge = (aidx > bidx);
  56. bool alarge = anplarge | (anpeq & aidxlarge);
  57. return oselect_uint32_t(-1, 1, alarge);
  58. }
  59. // Sort Nr items at the beginning of an allocated array of Na items
  60. // using up to nthreads threads. The items to sort are byte arrays of
  61. // size msg_size. The keys are of type T. T must have set_key<T> and
  62. // compare_keys<T> defined. The items will be shuffled in-place, and a
  63. // sorted array of keys will be passed to the provided callback
  64. // function.
  65. template<typename T>
  66. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  67. uint32_t Nr, uint32_t Na,
  68. // the arguments to the callback are items, the sorted indices, and
  69. // the number of non-padding items
  70. std::function<void(const uint8_t*, const T*, uint32_t Nr)> cb)
  71. {
  72. // Shuffle the items
  73. uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
  74. // Create the indices
  75. T *idx = new T[Nr];
  76. T *nextidx = idx;
  77. for (uint32_t i=0; i<Nw; ++i) {
  78. uint64_t padding = (*(uint32_t*)(items+msg_size*i));
  79. if (padding != uint32_t(-1)) {
  80. set_key<T>(nextidx, items+msg_size*i, i);
  81. ++nextidx;
  82. }
  83. }
  84. if (nextidx != idx + Nr) {
  85. printf("Found %u non-padding items, expected %u\n",
  86. nextidx-idx, Nr);
  87. assert(nextidx == idx + Nr);
  88. }
  89. // Sort the keys and indices
  90. T *backingidx = new T[Nr];
  91. bool whichbuf = mtmergesort<T>(idx, Nr, backingidx, nthreads);
  92. T *sortedidx = whichbuf ? backingidx : idx;
  93. cb(items, sortedidx, Nr);
  94. delete[] idx;
  95. delete[] backingidx;
  96. }
  97. template <typename T>
  98. struct move_sorted_args {
  99. const T* sorted_keys;
  100. const uint8_t *items;
  101. uint8_t *destbuf;
  102. uint32_t start, num;
  103. uint16_t msg_size;
  104. };
  105. template <typename T>
  106. static void *move_sorted(void *voidargs)
  107. {
  108. const move_sorted_args<T> *args =
  109. (move_sorted_args<T> *)voidargs;
  110. uint16_t msg_size = args->msg_size;
  111. uint32_t start = args->start;
  112. uint32_t end = start + args->num;
  113. const T *sorted_keys = args->sorted_keys;
  114. const uint8_t *items = args->items;
  115. uint8_t *destbuf = args->destbuf;
  116. for (uint32_t i=start; i<end; ++i) {
  117. memmove(destbuf + i * msg_size,
  118. items + (sorted_keys[i].index()) * msg_size,
  119. msg_size);
  120. }
  121. return NULL;
  122. }
  123. // As above, but also pass an Nr*msg_size-byte buffer outbuf to put
  124. // the sorted values into, instead of passing a callback. This calls
  125. // the above function, then copies the data in sorted order into outbuf.
  126. // Note: the outbuf buffer cannot overlap the items buffer.
  127. template<typename T>
  128. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  129. uint32_t Nr_, uint32_t Na, uint8_t *outbuf)
  130. {
  131. sort_mtobliv<T>(nthreads, items, msg_size, Nr_, Na,
  132. [nthreads, msg_size, outbuf]
  133. (const uint8_t* origitems, const T* keys, uint32_t Nr) {
  134. // Special-case nthreads=1 for efficiency
  135. if (nthreads <= 1) {
  136. move_sorted_args<T> args = {
  137. keys, origitems, outbuf, 0, Nr, msg_size
  138. };
  139. move_sorted<T>(&args);
  140. } else {
  141. move_sorted_args<T> args[nthreads];
  142. uint32_t inc = Nr / nthreads;
  143. uint32_t extra = Nr % nthreads;
  144. uint32_t last = 0;
  145. for (threadid_t i=0; i<nthreads; ++i) {
  146. uint32_t num = inc + (i < extra);
  147. args[i] = {
  148. keys, origitems, outbuf, last, num, msg_size
  149. };
  150. last += num;
  151. }
  152. // Launch all but the first section into other threads
  153. for (threadid_t i=1; i<nthreads; ++i) {
  154. threadpool_dispatch(g_thread_id+i,
  155. move_sorted<T>, args+i);
  156. }
  157. // Do the first section ourselves
  158. move_sorted<T>(args);
  159. // Join the threads
  160. for (threadid_t i=1; i<nthreads; ++i) {
  161. threadpool_join(g_thread_id+i, NULL);
  162. }
  163. }
  164. });
  165. }
  166. // As above, but the first Nr msg_size-byte entries in the items array
  167. // will end up with the sorted values. Note: if Nr < Na, entries beyond
  168. // Nr may also change, but you should not even look at those values!
  169. // This calls the above function with a temporary buffer, then copies
  170. // that buffer back into the items array, so it's less efficient, both
  171. // in memory and CPU, than the above functions.
  172. template<typename T>
  173. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  174. uint32_t Nr, uint32_t Na)
  175. {
  176. uint8_t *tempbuf = new uint8_t[Nr * msg_size];
  177. sort_mtobliv<T>(nthreads, items, msg_size, Nr, Na, tempbuf);
  178. memmove(items, tempbuf, Nr * msg_size);
  179. delete[] tempbuf;
  180. }