sort.tcc 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // set_key for each kind of key we sort on
  2. template<>
  3. inline void set_key<UidKey>(UidKey *key, const uint8_t *item, uint32_t index)
  4. {
  5. key->uid_index = (uint64_t(*(const uint32_t *)item) << 32) + index;
  6. }
  7. template<>
  8. inline void set_key<UidPriorityKey>(UidPriorityKey *key, const uint8_t *item, uint32_t index)
  9. {
  10. key->uid_priority = (uint64_t(*(const uint32_t *)item) << 32) +
  11. (*(const uint32_t *)(item+4));
  12. key->idx = index;
  13. }
  14. template<>
  15. inline void set_key<NidPriorityKey>(NidPriorityKey *key, const uint8_t *item, uint32_t index)
  16. {
  17. constexpr uint32_t nid_mask = (~((1<<DEST_UID_BITS)-1));
  18. key->nid_priority = (uint64_t((*(const uint32_t *)item)&nid_mask) << 32) +
  19. (*(const uint32_t *)(item+4));
  20. key->idx = index;
  21. }
  22. // compare_keys for each kind of key we sort on. Note that it must not
  23. // be possible for any of these functions to return 0. These functions
  24. // must also be oblivious. Return a positive (32-bit signed) int if *a
  25. // is larger than *b, or a negative (32-bit signed) int otherwise.
  26. template<>
  27. int compare_keys<UidKey>(const void* a, const void* b)
  28. {
  29. bool alarge = (((const UidKey*)a)->uid_index >
  30. ((const UidKey *)b)->uid_index);
  31. return oselect_uint32_t(-1, 1, alarge);
  32. }
  33. template<>
  34. int compare_keys<UidPriorityKey>(const void* a, const void* b)
  35. {
  36. uint64_t aup = ((const UidPriorityKey*)a)->uid_priority;
  37. uint64_t bup = ((const UidPriorityKey*)b)->uid_priority;
  38. uint32_t aidx = ((const UidPriorityKey*)a)->idx;
  39. uint32_t bidx = ((const UidPriorityKey*)b)->idx;
  40. bool auplarge = (aup > bup);
  41. bool aupeq = (aup == bup);
  42. bool aidxlarge = (aidx > bidx);
  43. bool alarge = auplarge | (aupeq & aidxlarge);
  44. return oselect_uint32_t(-1, 1, alarge);
  45. }
  46. template<>
  47. int compare_keys<NidPriorityKey>(const void* a, const void* b)
  48. {
  49. uint64_t anp = ((const NidPriorityKey*)a)->nid_priority;
  50. uint64_t bnp = ((const NidPriorityKey*)b)->nid_priority;
  51. uint32_t aidx = ((const NidPriorityKey*)a)->idx;
  52. uint32_t bidx = ((const NidPriorityKey*)b)->idx;
  53. bool anplarge = (anp > bnp);
  54. bool anpeq = (anp == bnp);
  55. bool aidxlarge = (aidx > bidx);
  56. bool alarge = anplarge | (anpeq & aidxlarge);
  57. return oselect_uint32_t(-1, 1, alarge);
  58. }
  59. // Sort Nr items at the beginning of an allocated array of Na items
  60. // using up to nthreads threads. The items to sort are byte arrays of
  61. // size msg_size. The keys are of type T. T must have set_key<T> and
  62. // compare_keys<T> defined. The items will be shuffled in-place, and a
  63. // sorted array of keys will be passed to the provided callback
  64. // function.
  65. template<typename T>
  66. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  67. uint32_t Nr, uint32_t Na,
  68. // the arguments to the callback are items, the sorted indices, and
  69. // the number of non-padding items
  70. std::function<void(const uint8_t*, const T*, uint32_t Nr)> cb)
  71. {
  72. // Shuffle the items
  73. uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
  74. // Create the indices
  75. T *idx = new T[Nr];
  76. T *nextidx = idx;
  77. for (uint32_t i=0; i<Nw; ++i) {
  78. uint64_t padding = (*(uint32_t*)(items+msg_size*i));
  79. if (padding != uint32_t(-1)) {
  80. set_key<T>(nextidx, items+msg_size*i, i);
  81. ++nextidx;
  82. }
  83. }
  84. if (nextidx != idx + Nr) {
  85. printf("Found %u non-padding items, expected %u\n",
  86. nextidx-idx, Nr);
  87. assert(nextidx == idx + Nr);
  88. }
  89. // Sort the keys and indices
  90. T *backingidx = new T[Nr];
  91. bool whichbuf = mtmergesort<T>(idx, Nr, backingidx, nthreads);
  92. T *sortedidx = whichbuf ? backingidx : idx;
  93. cb(items, sortedidx, Nr);
  94. delete[] idx;
  95. delete[] backingidx;
  96. }
  97. template <typename T>
  98. struct move_sorted_args {
  99. const T* sorted_keys;
  100. const uint8_t *items;
  101. uint8_t *destbuf;
  102. uint32_t start, num;
  103. uint16_t msg_size;
  104. };
  105. template <typename T>
  106. static void *move_sorted(void *voidargs)
  107. {
  108. const move_sorted_args<T> *args =
  109. (move_sorted_args<T> *)voidargs;
  110. uint16_t msg_size = args->msg_size;
  111. uint32_t start = args->start;
  112. uint32_t end = start + args->num;
  113. const T *sorted_keys = args->sorted_keys;
  114. const uint8_t *items = args->items;
  115. uint8_t *destbuf = args->destbuf;
  116. for (uint32_t i=start; i<end; ++i) {
  117. memmove(destbuf + i * msg_size,
  118. items + (sorted_keys[i].index()) * msg_size,
  119. msg_size);
  120. }
  121. }
  122. // As above, but the first Nr msg_size-byte entries in the items array
  123. // will end up with the sorted values. Note: if Nr < Na, entries beyond
  124. // Nr may also change, but you should not even look at those values!
  125. // This calls the above function, and then copies the data in sorted
  126. // order into a temporary buffer, then copies that buffer back into the
  127. // items array, so it's less efficient, both in memory and CPU, than the
  128. // above function.
  129. template<typename T>
  130. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  131. uint32_t Nr, uint32_t Na)
  132. {
  133. sort_mtobliv<T>(nthreads, items, msg_size, Nr, Na, [nthreads, msg_size]
  134. (const uint8_t* origitems, const T* keys, uint32_t Nr) {
  135. // A temporary buffer into which to copy the items in sorted
  136. // order
  137. uint8_t *tempbuf = new uint8_t[Nr * msg_size];
  138. // Special-case nthreads=1 for efficiency
  139. if (nthreads <= 1) {
  140. move_sorted_args<T> args = {
  141. keys, origitems, tempbuf, 0, Nr, msg_size
  142. };
  143. move_sorted<T>(&args);
  144. } else {
  145. move_sorted_args<T> args[nthreads];
  146. uint32_t inc = Nr / nthreads;
  147. uint32_t extra = Nr % nthreads;
  148. uint32_t last = 0;
  149. for (threadid_t i=0; i<nthreads; ++i) {
  150. uint32_t num = inc + (i < extra);
  151. args[i] = {
  152. keys, origitems, tempbuf, last, num, msg_size
  153. };
  154. last += num;
  155. }
  156. // Launch all but the first section into other threads
  157. for (threadid_t i=1; i<nthreads; ++i) {
  158. threadpool_dispatch(g_thread_id+i,
  159. move_sorted<T>, args+i);
  160. }
  161. // Do the first section ourselves
  162. move_sorted<T>(args);
  163. // Join the threads
  164. for (threadid_t i=1; i<nthreads; ++i) {
  165. threadpool_join(g_thread_id+i, NULL);
  166. }
  167. }
  168. // Copy the temporary buffer back to items
  169. memmove(origitems, tempbuf, Nr * msg_size);
  170. delete[] tempbuf;
  171. });
  172. }