123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- // #define PROFILE_MTMERGESORT
- template<typename T> static int compare_keys(const void *a, const void *b);
- template<typename T>
- struct MergeArgs {
- T* dst;
- const T* leftsrc;
- size_t Nleft;
- const T* rightsrc;
- size_t Nright;
- };
- // Merge two sorted arrays into one. The (sorted) source arrays are
- // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
- // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
- // given number of threads.
- template<typename T>
- static void* merge(void *voidargs)
- {
- const MergeArgs<T>* args = (const MergeArgs<T>*)voidargs;
- #ifdef PROFILE_MTMERGESORT
- unsigned long start = printf_with_rtclock("begin merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
- #endif
- T* dst = args->dst;
- const T* left = args->leftsrc;
- const T* right = args->rightsrc;
- const T* leftend = args->leftsrc + args->Nleft;
- const T* rightend = args->rightsrc + args->Nright;
- while (left != leftend && right != rightend) {
- if (compare_keys<T>(left, right) < 0) {
- *dst = *left;
- ++dst;
- ++left;
- } else {
- *dst = *right;
- ++dst;
- ++right;
- }
- }
- if (left != leftend) {
- memmove(dst, left, (leftend-left)*sizeof(T));
- }
- if (right != rightend) {
- memmove(dst, right, (rightend-right)*sizeof(T));
- }
- #ifdef PROFILE_MTMERGESORT
- printf_with_rtclock_diff(start, "end merge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", args->dst, args->leftsrc, args->Nleft, args->rightsrc, args->Nright);
- #endif
- return NULL;
- }
- // In the sorted subarray src[0 .. len-1], binary search for the first
- // element that's larger than the target. The return value is the index
- // into that subarray, so it's 0 if src[0] > target, and it's len if all
- // the elements are less than the target. Remember that all elements
- // have to be different, so no comparison will ever return that the
- // elements are equal.
- template<typename T>
- static size_t binsearch(const T* src, size_t len, const T* target)
- {
- size_t left = 0;
- size_t right = len;
- if (len == 0) {
- return 0;
- }
- if (compare_keys<T>(src + left, target) > 0) {
- return 0;
- }
- if (len > 0 && compare_keys<T>(src + right - 1, target) < 0) {
- return len;
- }
- // Invariant: src[left] < target and src[right] > target (where
- // src[len] is considered to be greater than all targets)
- while (right - left > 1) {
- size_t mid = left + (right - left)/2;
- if (compare_keys<T>(src + mid, target) > 0) {
- right = mid;
- } else {
- left = mid;
- }
- }
- return right;
- }
- // Merge two sorted arrays into one. The (sorted) source arrays are
- // leftsrc and rightsrc of lengths Nleft and Nright respectively. Put
- // the merged sorted array into dst[0..Nleft+Nright-1]. Use up to the
- // given number of threads.
- template<typename T>
- static void mtmerge(T* dst, const T* leftsrc, size_t Nleft,
- const T* rightsrc, size_t Nright, threadid_t nthreads)
- {
- #ifdef PROFILE_MTMERGESORT
- unsigned long start = printf_with_rtclock("begin mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
- #endif
- threadid_t threads_to_use = nthreads;
- if (Nleft < 500 || Nright < 500) {
- threads_to_use = 1;
- }
- // Break the left array into threads_to_use approximately
- // equal-sized pieces
- MergeArgs<T> margs[threads_to_use];
- size_t leftinc = Nleft / threads_to_use;
- size_t leftextra = Nleft % threads_to_use;
- size_t leftlast = 0;
- size_t rightlast = 0;
- for (threadid_t t=0; t<threads_to_use; ++t) {
- size_t leftlen = leftinc + (t < leftextra);
- // Find the segment in the right array corresponding to this
- // segment in the lest array. If this is the last segment of
- // the left array, that's just the whole remaining right array.
- size_t rightlen;
- if (t == threads_to_use - 1) {
- rightlen = Nright - rightlast;
- } else {
- // The first element of the next left segment
- const T* target = leftsrc + leftlast + leftlen;
- // In the sorted subarray rightsrc[rightlast .. Nright-1],
- // binary search for the first element that's larger than
- // the target. The return value is the index into that
- // subarray, so it's 0 if rightsrc[rightlast] > target, and
- // it's Nright-rightlast if all the elements are less than
- // the target.
- rightlen = binsearch<T>(rightsrc + rightlast,
- Nright-rightlast, target);
- }
- margs[t] = { dst + leftlast + rightlast,
- leftsrc + leftlast, leftlen,
- rightsrc + rightlast, rightlen };
- leftlast += leftlen;
- rightlast += rightlen;
- if (t > 0) {
- threadpool_dispatch(g_thread_id+t, merge<T>, &margs[t]);
- }
- }
- // Do the first block ourselves
- merge<T>(&margs[0]);
- for (size_t t=1; t<threads_to_use; ++t) {
- threadpool_join(g_thread_id+t, NULL);
- }
- #ifdef PROFILE_MTMERGESORT
- printf_with_rtclock_diff(start, "end mtmerge(dst=%p, leftsrc=%p, Nleft=%lu, rightsrc=%p, Nright=%lu, nthreads=%lu)\n", dst, leftsrc, Nleft, rightsrc, Nright, nthreads);
- #endif
- }
- template<typename T>
- struct MTMergesortArgs {
- T* buf;
- size_t N;
- T* backing;
- threadid_t nthreads;
- bool ret;
- };
- template<typename T>
- static bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads);
- template<typename T>
- static void *mtmergesort_launch(void *voidargs)
- {
- MTMergesortArgs<T>* args = (MTMergesortArgs<T>*)voidargs;
- args->ret = mtmergesort<T>(args->buf, args->N, args->backing,
- args->nthreads);
- return NULL;
- }
- // Multithreaded mergesort. Pass the data of type T to sort, as a
- // pointer and number of elements. Also pass a backing store of the
- // same size. The sorted data will end up in either the original data
- // array or the backing store; this function will return false if it's
- // in the original data and true if it's in the backing store. Use up
- // to the given number of threads.
- template<typename T>
- bool mtmergesort(T* buf, size_t N, T* backing, threadid_t nthreads)
- {
- if (nthreads == 1 || N < 1000) {
- // Just sort naively
- #ifdef PROFILE_MTMERGESORT
- unsigned long start = printf_with_rtclock("begin qsort(buf=%p, N=%lu)\n", buf, N);
- #endif
- qsort(buf, N, sizeof(T), compare_keys<T>);
- #ifdef PROFILE_MTMERGESORT
- printf_with_rtclock_diff(start, "end qsort(buf=%p, N=%lu)\n", buf, N);
- #endif
- return false;
- }
- #ifdef PROFILE_MTMERGESORT
- unsigned long start = printf_with_rtclock("begin mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
- #endif
- size_t Nleft = (N+1)/2;
- size_t Nright = N/2;
- threadid_t threads_left = (nthreads+1)/2;
- threadid_t threads_right = nthreads/2;
- MTMergesortArgs<T> ms_right_args =
- { buf + Nleft, Nright, backing + Nleft, threads_right, false };
- threadpool_dispatch(g_thread_id+threads_left, mtmergesort_launch<T>,
- &ms_right_args);
- bool leftret = mtmergesort<T>(buf, Nleft, backing, threads_left);
- threadpool_join(g_thread_id+threads_left, NULL);
- bool rightret = ms_right_args.ret;
- // If the left and right sorts put their answers in different
- // places, move the right answer to match the left
- if (leftret != rightret) {
- if (leftret) {
- // The left is in backing, and the right is in buf
- memmove(backing + Nleft, buf + Nleft, Nright * sizeof(T));
- } else {
- // The left is in buf, and the right is in backing
- memmove(buf + Nleft, backing + Nleft, Nright * sizeof(T));
- }
- }
- // Merge the two halves
- if (leftret) {
- // The recursive outputs are in backing; merge them into buf
- mtmerge<T>(buf, backing, Nleft, backing+Nleft, Nright, nthreads);
- } else {
- // The recursive outputs are in buf; merge them into backing
- mtmerge<T>(backing, buf, Nleft, buf+Nleft, Nright, nthreads);
- }
- #ifdef PROFILE_MTMERGESORT
- printf_with_rtclock_diff(start, "end mtmergesort(buf=%p, N=%lu, backing=%p, nthreads=%lu)\n", buf, N, backing, nthreads);
- #endif
- return !leftret;
- }
|