sort.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #include <map>
  2. #include <deque>
  3. #include <pthread.h>
  4. #include "sort.hpp"
  5. // A set of precomputed WaksmanNetworks of a given size
  6. struct SizedWNs {
  7. pthread_mutex_t mutex;
  8. std::deque<WaksmanNetwork> wns;
  9. };
  10. // A (mutexed) map mapping sizes to SizedWNs
  11. struct PrecompWNs {
  12. pthread_mutex_t mutex;
  13. std::map<uint32_t,SizedWNs> sized_wns;
  14. };
  15. static PrecompWNs precomp_wns;
  16. // A (mutexed) map mapping (N, nthreads) pairs to WNEvalPlans
  17. struct EvalPlans {
  18. pthread_mutex_t mutex;
  19. std::map<std::pair<uint32_t,threadid_t>,WNEvalPlan> eval_plans;
  20. };
  21. static EvalPlans precomp_eps;
  22. void sort_precompute(uint32_t N)
  23. {
  24. uint32_t *random_permutation = new uint32_t[N];
  25. if (!random_permutation) {
  26. printf("Allocating memory failed in sort_precompute\n");
  27. }
  28. assert(random_permutation);
  29. for (uint32_t i=0;i<N;++i) {
  30. random_permutation[i] = i;
  31. }
  32. RecursiveShuffle_M2((unsigned char *) random_permutation, N, sizeof(uint32_t));
  33. WaksmanNetwork wnet(N);
  34. wnet.setPermutation(random_permutation);
  35. // Note that sized_wns[N] creates a map entry for N if it doesn't yet exist
  36. pthread_mutex_lock(&precomp_wns.mutex);
  37. SizedWNs& szwn = precomp_wns.sized_wns[N];
  38. pthread_mutex_unlock(&precomp_wns.mutex);
  39. pthread_mutex_lock(&szwn.mutex);
  40. szwn.wns.push_back(std::move(wnet));
  41. pthread_mutex_unlock(&szwn.mutex);
  42. }
  43. void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)
  44. {
  45. std::pair<uint32_t,threadid_t> idx = {N, nthreads};
  46. pthread_mutex_lock(&precomp_eps.mutex);
  47. if (!precomp_eps.eval_plans.count(idx)) {
  48. precomp_eps.eval_plans.try_emplace(idx, N, nthreads);
  49. }
  50. pthread_mutex_unlock(&precomp_eps.mutex);
  51. }
  52. // Perform the sort using up to nthreads threads. The items to sort are
  53. // byte arrays of size msg_size. The key is the first 4 bytes of each
  54. // item.
  55. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  56. uint32_t Nr, uint32_t Na,
  57. // the arguments to the callback are nthreads, items, the sorted
  58. // indices, and the number of non-padding items
  59. std::function<void(threadid_t, const uint8_t*, const uint64_t*,
  60. uint32_t Nr)> cb)
  61. {
  62. // Find the smallest Nw for which we have a precomputed
  63. // WaksmanNetwork with Nr <= Nw <= Na
  64. pthread_mutex_lock(&precomp_wns.mutex);
  65. std::optional<WaksmanNetwork> wn;
  66. uint32_t Nw;
  67. for (auto& N : precomp_wns.sized_wns) {
  68. if (N.first > Na) {
  69. printf("No precomputed WaksmanNetworks of size at most %u\n", Na);
  70. assert(false);
  71. }
  72. if (N.first < Nr) {
  73. continue;
  74. }
  75. // We're in the right range, but see if we have an actual
  76. // precomputed WaksmanNetwork
  77. pthread_mutex_lock(&N.second.mutex);
  78. if (N.second.wns.size() == 0) {
  79. pthread_mutex_unlock(&N.second.mutex);
  80. continue;
  81. }
  82. wn = std::move(N.second.wns.front());
  83. N.second.wns.pop_front();
  84. Nw = N.first;
  85. pthread_mutex_unlock(&N.second.mutex);
  86. break;
  87. }
  88. pthread_mutex_unlock(&precomp_wns.mutex);
  89. if (!wn) {
  90. printf("No precomputed WaksmanNetwork of size range [%u,%u] found.\n",
  91. Nr, Na);
  92. assert(wn);
  93. }
  94. std::pair<uint32_t,threadid_t> epidx = {Nw, nthreads};
  95. pthread_mutex_lock(&precomp_eps.mutex);
  96. if (!precomp_eps.eval_plans.count(epidx)) {
  97. printf("No precomputed WNEvalPlan with N=%u, nthreads=%hu\n",
  98. Nw, nthreads);
  99. assert(false);
  100. }
  101. const WNEvalPlan &eval_plan = precomp_eps.eval_plans.at(epidx);
  102. pthread_mutex_unlock(&precomp_eps.mutex);
  103. // Mark Nw-Nr items as padding (Nr, Na, and Nw are _not_ private)
  104. for (uint32_t i=Nr; i<Nw; ++i) {
  105. (*(uint32_t*)(items+msg_size*i)) = uint32_t(-1);
  106. }
  107. // Shuffle Nw items
  108. wn.value().applyInversePermutation<OSWAP_16X>(
  109. items, msg_size, eval_plan);
  110. // Create the indices
  111. uint64_t *idx = new uint64_t[Nr];
  112. uint64_t *nextidx = idx;
  113. for (uint32_t i=0; i<Nw; ++i) {
  114. uint64_t key = (*(uint32_t*)(items+msg_size*i));
  115. if (key != uint32_t(-1)) {
  116. *nextidx = (key<<32) + i;
  117. ++nextidx;
  118. }
  119. }
  120. if (nextidx != idx + Nr) {
  121. printf("Found %u non-padding items, expected %u\n",
  122. nextidx-idx, Nr);
  123. assert(nextidx == idx + Nr);
  124. }
  125. // Sort the keys and indices
  126. uint64_t *backingidx = new uint64_t[Nr];
  127. bool whichbuf = mtmergesort<uint64_t>(idx, Nr, backingidx, nthreads);
  128. uint64_t *sortedidx = whichbuf ? backingidx : idx;
  129. for (uint32_t i=0; i<Nr; ++i) {
  130. sortedidx[i] &= uint64_t(0xffffffff);
  131. }
  132. cb(nthreads, items, sortedidx, Nr);
  133. delete[] idx;
  134. delete[] backingidx;
  135. }