WaksmanNetwork.tcc 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // #define PROFILE_MTMERGESORT
  2. template<typename T> static int compare_keys(const void *a, const void *b);
  3. template<typename T>
  4. struct MergeArgs {
  5. T* dst;
  6. const T* leftsrc;
  7. size_t Nleft;
  8. const T* rightsrc;
  9. size_t Nright;
  10. };
  11. // Merge two sorted arrays into one. The (sorted) source arrays are
  12. // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
  13. // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
  14. // given number of threads.
  15. template<typename T>
  16. static void* merge(void *voidargs)
  17. {
  18. const MergeArgs<T>* args = (const MergeArgs<T>*)voidargs;
  19. #ifdef PROFILE_MTMERGESORT
  20. unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
  21. #endif
  22. T* dst = args->dst;
  23. const T* left = args->leftsrc;
  24. const T* right = args->rightsrc;
  25. const T* leftend = args->leftsrc + args->Nleft;
  26. const T* rightend = args->rightsrc + args->Nright;
  27. while (left != leftend && right != rightend) {
  28. if (compare_keys<T>(left, right) < 0) {
  29. *dst = *left;
  30. ++dst;
  31. ++left;
  32. } else {
  33. *dst = *right;
  34. ++dst;
  35. ++right;
  36. }
  37. }
  38. if (left != leftend) {
  39. memmove(dst, left, (leftend-left)*sizeof(T));
  40. }
  41. if (right != rightend) {
  42. memmove(dst, right, (rightend-right)*sizeof(T));
  43. }
  44. #ifdef PROFILE_MTMERGESORT
  45. printf_with_rtclock_diff(start, "end merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
  46. #endif
  47. return NULL;
  48. }
  49. // In the sorted subarray src[0 .. len-1], binary search for the first
  50. // element that's larger than the target. The return value is the index
  51. // into that subarray, so it's 0 if src[0] > target, and it's len if all
  52. // the elements are less than the target. Remember that all elements
  53. // have to be different, so no comparison will ever return that the
  54. // elements are equal.
  55. template<typename T>
  56. static size_t binsearch(const T* src, size_t len, const T* target)
  57. {
  58. size_t left = 0;
  59. size_t right = len;
  60. if (len == 0) {
  61. return 0;
  62. }
  63. if (compare_keys<T>(src + left, target) > 0) {
  64. return 0;
  65. }
  66. if (len > 0 && compare_keys<T>(src + right - 1, target) < 0) {
  67. return len;
  68. }
  69. // Invariant: src[left] < target and src[right] > target (where
  70. // src[len] is considered to be greater than all targets)
  71. while (right - left > 1) {
  72. size_t mid = left + (right - left)/2;
  73. if (compare_keys<T>(src + mid, target) > 0) {
  74. right = mid;
  75. } else {
  76. left = mid;
  77. }
  78. }
  79. return right;
  80. }
  81. // Merge two sorted arrays into one. The (sorted) source arrays are
  82. // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
  83. // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
  84. // given number of threads.
  85. template<typename T>
  86. static void mtmerge(T* dst, const T* leftsrc, size_t Nleft,
  87. const T* rightsrc, size_t Nright, threadid_t nthreads)
  88. {
  89. #ifdef PROFILE_MTMERGESORT
  90. unsigned long start = printf_with_rtclock("begin mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
  91. #endif
  92. threadid_t threads_to_use = nthreads;
  93. if (Nleft < 500 || Nright < 500) {
  94. threads_to_use = 1;
  95. }
  96. // Break the left array into threads_to_use approximately
  97. // equal-sized pieces
  98. MergeArgs<T> margs[threads_to_use];
  99. size_t leftinc = Nleft / threads_to_use;
  100. size_t leftextra = Nleft % threads_to_use;
  101. size_t leftlast = 0;
  102. size_t rightlast = 0;
  103. for (threadid_t t=0; t<threads_to_use; ++t) {
  104. size_t leftlen = leftinc + (t < leftextra);
  105. // Find the segment in the right array corresponding to this
  106. // segment in the lest array. If this is the last segment of
  107. // the left array, that's just the whole remaining right array.
  108. size_t rightlen;
  109. if (t == threads_to_use - 1) {
  110. rightlen = Nright - rightlast;
  111. } else {
  112. // The first element of the next left segment
  113. const T* target = leftsrc + leftlast + leftlen;
  114. // In the sorted subarray rightsrc[rightlast .. Nright-1],
  115. // binary search for the first element that's larger than
  116. // the target. The return value is the index into that
  117. // subarray, so it's 0 if rightsrc[rightlast] > target, and
  118. // it's Nright-rightlast if all the elements are less than
  119. // the target.
  120. rightlen = binsearch<T>(rightsrc + rightlast,
  121. Nright-rightlast, target);
  122. }
  123. margs[t] = { dst + leftlast + rightlast,
  124. leftsrc + leftlast, leftlen,
  125. rightsrc + rightlast, rightlen };
  126. leftlast += leftlen;
  127. rightlast += rightlen;
  128. if (t > 0) {
  129. threadpool_dispatch(g_thread_id+t, merge<T>, &margs[t]);
  130. }
  131. }
  132. // Do the first block ourselves
  133. merge<T>(&margs[0]);
  134. for (size_t t=1; t<threads_to_use; ++t) {
  135. threadpool_join(g_thread_id+t, NULL);
  136. }
  137. #ifdef PROFILE_MTMERGESORT
  138. printf_with_rtclock_diff(start, "end mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
  139. #endif
  140. }
  141. template<typename T>
  142. struct MTMergesortArgs {
  143. T* buf;
  144. size_t N;
  145. T* backing;
  146. threadid_t nthreads;
  147. bool ret;
  148. };
  149. template<typename T>
  150. static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads);
  151. template<typename T>
  152. static void *mtmergesort_launch(void *voidargs)
  153. {
  154. MTMergesortArgs<T>* args = (MTMergesortArgs<T>*)voidargs;
  155. args->ret = mtmergesort<T>(args->buf, args->N, args->backing,
  156. args->nthreads);
  157. return NULL;
  158. }
  159. // Multithreaded mergesort. Pass the data of type T to sort, as a
  160. // pointer and number of elements. Also pass a backing store of the
  161. // same size. The sorted data will end up in either the original data
  162. // array or the backing store; this function will return false if it's
  163. // in the original data and true if it's in the backing store. Use up
  164. // to the given number of threads.
  165. template<typename T>
  166. bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
  167. {
  168. if (nthreads == 1 || N < 1000) {
  169. // Just sort naively
  170. #ifdef PROFILE_MTMERGESORT
  171. unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
  172. #endif
  173. qsort(buf, N, sizeof(T), compare_keys<T>);
  174. #ifdef PROFILE_MTMERGESORT
  175. printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
  176. #endif
  177. return false;
  178. }
  179. #ifdef PROFILE_MTMERGESORT
  180. unsigned long start = printf_with_rtclock("begin mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
  181. #endif
  182. size_t Nleft = (N+1)/2;
  183. size_t Nright = N/2;
  184. threadid_t threads_left = (nthreads+1)/2;
  185. threadid_t threads_right = nthreads/2;
  186. MTMergesortArgs<T> ms_right_args =
  187. { buf + Nleft, Nright, backing + Nleft, threads_right, false };
  188. threadpool_dispatch(g_thread_id+threads_left, mtmergesort_launch<T>,
  189. &ms_right_args);
  190. bool leftret = mtmergesort<T>(buf, Nleft, backing, threads_left);
  191. threadpool_join(g_thread_id+threads_left, NULL);
  192. bool rightret = ms_right_args.ret;
  193. // If the left and right sorts put their answers in different
  194. // places, move the right answer to match the left
  195. if (leftret != rightret) {
  196. if (leftret) {
  197. // The left is in backing, and the right is in buf
  198. memmove(backing + Nleft, buf + Nleft, Nright * sizeof(T));
  199. } else {
  200. // The left is in buf, and the right is in backing
  201. memmove(buf + Nleft, backing + Nleft, Nright * sizeof(T));
  202. }
  203. }
  204. // Merge the two halves
  205. if (leftret) {
  206. // The recursive outputs are in backing; merge them into buf
  207. mtmerge<T>(buf, backing, Nleft, backing+Nleft, Nright, nthreads);
  208. } else {
  209. // The recursive outputs are in buf; merge them into backing
  210. mtmerge<T>(backing, buf, Nleft, buf+Nleft, Nright, nthreads);
  211. }
  212. #ifdef PROFILE_MTMERGESORT
  213. printf_with_rtclock_diff(start, "end mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
  214. #endif
  215. return !leftret;
  216. }