WaksmanNetwork.tcc 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. // #define PROFILE_MTMERGESORT
  2. template<typename T> static int compare(const void *a, const void *b);
  3. template<>
  4. int compare<uint64_t>(const void *a, const void *b)
  5. {
  6. uint32_t *a32 = (uint32_t*)a;
  7. uint32_t *b32 = (uint32_t*)b;
  8. int hi = a32[1]-b32[1];
  9. int lo = a32[0]-b32[0];
  10. return oselect_uint32_t(hi, lo, !hi);
  11. }
  12. template<typename T>
  13. struct MergeArgs {
  14. T* dst;
  15. const T* leftsrc;
  16. size_t Nleft;
  17. const T* rightsrc;
  18. size_t Nright;
  19. };
  20. // Merge two sorted arrays into one. The (sorted) source arrays are
  21. // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
  22. // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
  23. // given number of threads.
  24. template<typename T>
  25. static void* merge(void *voidargs)
  26. {
  27. const MergeArgs<T>* args = (const MergeArgs<T>*)voidargs;
  28. #ifdef PROFILE_MTMERGESORT
  29. unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
  30. #endif
  31. T* dst = args->dst;
  32. const T* left = args->leftsrc;
  33. const T* right = args->rightsrc;
  34. const T* leftend = args->leftsrc + args->Nleft;
  35. const T* rightend = args->rightsrc + args->Nright;
  36. while (left != leftend && right != rightend) {
  37. if (compare<T>(left, right) < 0) {
  38. *dst = *left;
  39. ++dst;
  40. ++left;
  41. } else {
  42. *dst = *right;
  43. ++dst;
  44. ++right;
  45. }
  46. }
  47. if (left != leftend) {
  48. memmove(dst, left, (leftend-left)*sizeof(T));
  49. }
  50. if (right != rightend) {
  51. memmove(dst, right, (rightend-right)*sizeof(T));
  52. }
  53. #ifdef PROFILE_MTMERGESORT
  54. printf_with_rtclock_diff(start, "end merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
  55. #endif
  56. return NULL;
  57. }
  58. // In the sorted subarray src[0 .. len-1], binary search for the first
  59. // element that's larger than the target. The return value is the index
  60. // into that subarray, so it's 0 if src[0] > target, and it's len if all
  61. // the elements are less than the target. Remember that all elements
  62. // have to be different, so no comparison will ever return that the
  63. // elements are equal.
  64. template<typename T>
  65. static size_t binsearch(const T* src, size_t len, const T* target)
  66. {
  67. size_t left = 0;
  68. size_t right = len;
  69. if (len == 0) {
  70. return 0;
  71. }
  72. if (compare<T>(src + left, target) > 0) {
  73. return 0;
  74. }
  75. if (len > 0 && compare<T>(src + right - 1, target) < 0) {
  76. return len;
  77. }
  78. // Invariant: src[left] < target and src[right] > target (where
  79. // src[len] is considered to be greater than all targets)
  80. while (right - left > 1) {
  81. size_t mid = left + (right - left)/2;
  82. if (compare<T>(src + mid, target) > 0) {
  83. right = mid;
  84. } else {
  85. left = mid;
  86. }
  87. }
  88. return right;
  89. }
  90. // Merge two sorted arrays into one. The (sorted) source arrays are
  91. // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
  92. // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
  93. // given number of threads.
  94. template<typename T>
  95. static void mtmerge(T* dst, const T* leftsrc, size_t Nleft,
  96. const T* rightsrc, size_t Nright, threadid_t nthreads)
  97. {
  98. #ifdef PROFILE_MTMERGESORT
  99. unsigned long start = printf_with_rtclock("begin mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
  100. #endif
  101. threadid_t threads_to_use = nthreads;
  102. if (Nleft < 500 || Nright < 500) {
  103. threads_to_use = 1;
  104. }
  105. // Break the left array into threads_to_use approximately
  106. // equal-sized pieces
  107. MergeArgs<T> margs[threads_to_use];
  108. size_t leftinc = Nleft / threads_to_use;
  109. size_t leftextra = Nleft % threads_to_use;
  110. size_t leftlast = 0;
  111. size_t rightlast = 0;
  112. for (threadid_t t=0; t<threads_to_use; ++t) {
  113. size_t leftlen = leftinc + (t < leftextra);
  114. // Find the segment in the right array corresponding to this
  115. // segment in the lest array. If this is the last segment of
  116. // the left array, that's just the whole remaining right array.
  117. size_t rightlen;
  118. if (t == threads_to_use - 1) {
  119. rightlen = Nright - rightlast;
  120. } else {
  121. // The first element of the next left segment
  122. const T* target = leftsrc + leftlast + leftlen;
  123. // In the sorted subarray rightsrc[rightlast .. Nright-1],
  124. // binary search for the first element that's larger than
  125. // the target. The return value is the index into that
  126. // subarray, so it's 0 if rightsrc[rightlast] > target, and
  127. // it's Nright-rightlast if all the elements are less than
  128. // the target.
  129. rightlen = binsearch<T>(rightsrc + rightlast,
  130. Nright-rightlast, target);
  131. }
  132. margs[t] = { dst + leftlast + rightlast,
  133. leftsrc + leftlast, leftlen,
  134. rightsrc + rightlast, rightlen };
  135. leftlast += leftlen;
  136. rightlast += rightlen;
  137. if (t > 0) {
  138. threadpool_dispatch(g_thread_id+t, merge<T>, &margs[t]);
  139. }
  140. }
  141. // Do the first block ourselves
  142. merge<T>(&margs[0]);
  143. for (size_t t=1; t<threads_to_use; ++t) {
  144. threadpool_join(g_thread_id+t, NULL);
  145. }
  146. #ifdef PROFILE_MTMERGESORT
  147. printf_with_rtclock_diff(start, "end mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
  148. #endif
  149. }
  150. template<typename T>
  151. struct MTMergesortArgs {
  152. T* buf;
  153. size_t N;
  154. T* backing;
  155. threadid_t nthreads;
  156. bool ret;
  157. };
  158. template<typename T>
  159. static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads);
  160. template<typename T>
  161. static void *mtmergesort_launch(void *voidargs)
  162. {
  163. MTMergesortArgs<T>* args = (MTMergesortArgs<T>*)voidargs;
  164. args->ret = mtmergesort<T>(args->buf, args->N, args->backing,
  165. args->nthreads);
  166. return NULL;
  167. }
  168. // Multithreaded mergesort. Pass the data of type T to sort, as a
  169. // pointer and number of elements. Also pass a backing store of the
  170. // same size. The sorted data will end up in either the original data
  171. // array or the backing store; this function will return false if it's
  172. // in the original data and true if it's in the backing store. Use up
  173. // to the given number of threads.
  174. template<typename T>
  175. bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
  176. {
  177. if (nthreads == 1 || N < 1000) {
  178. // Just sort naively
  179. #ifdef PROFILE_MTMERGESORT
  180. unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
  181. #endif
  182. qsort(buf, N, sizeof(T), compare<T>);
  183. #ifdef PROFILE_MTMERGESORT
  184. printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
  185. #endif
  186. return false;
  187. }
  188. #ifdef PROFILE_MTMERGESORT
  189. unsigned long start = printf_with_rtclock("begin mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
  190. #endif
  191. size_t Nleft = (N+1)/2;
  192. size_t Nright = N/2;
  193. threadid_t threads_left = (nthreads+1)/2;
  194. threadid_t threads_right = nthreads/2;
  195. MTMergesortArgs<T> ms_right_args =
  196. { buf + Nleft, Nright, backing + Nleft, threads_right, false };
  197. threadpool_dispatch(g_thread_id+threads_left, mtmergesort_launch<T>,
  198. &ms_right_args);
  199. bool leftret = mtmergesort<T>(buf, Nleft, backing, threads_left);
  200. threadpool_join(g_thread_id+threads_left, NULL);
  201. bool rightret = ms_right_args.ret;
  202. // If the left and right sorts put their answers in different
  203. // places, move the right answer to match the left
  204. if (leftret != rightret) {
  205. if (leftret) {
  206. // The left is in backing, and the right is in buf
  207. memmove(backing + Nleft, buf + Nleft, Nright * sizeof(T));
  208. } else {
  209. // The left is in buf, and the right is in backing
  210. memmove(buf + Nleft, backing + Nleft, Nright * sizeof(T));
  211. }
  212. }
  213. // Merge the two halves
  214. if (leftret) {
  215. // The recursive outputs are in backing; merge them into buf
  216. mtmerge<T>(buf, backing, Nleft, backing+Nleft, Nright, nthreads);
  217. } else {
  218. // The recursive outputs are in buf; merge them into backing
  219. mtmerge<T>(backing, buf, Nleft, buf+Nleft, Nright, nthreads);
  220. }
  221. #ifdef PROFILE_MTMERGESORT
  222. printf_with_rtclock_diff(start, "end mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
  223. #endif
  224. return !leftret;
  225. }