sort.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include <map>
  2. #include <deque>
  3. #include <pthread.h>
  4. #include "sort.hpp"
  5. // A set of precomputed WaksmanNetworks of a given size
  6. struct SizedWNs {
  7. pthread_mutex_t mutex;
  8. std::deque<WaksmanNetwork> wns;
  9. SizedWNs() { pthread_mutex_init(&mutex, NULL); }
  10. };
  11. // A (mutexed) map mapping sizes to SizedWNs
  12. struct PrecompWNs {
  13. pthread_mutex_t mutex;
  14. std::map<uint32_t,SizedWNs> sized_wns;
  15. PrecompWNs() { pthread_mutex_init(&mutex, NULL); }
  16. };
  17. static PrecompWNs precomp_wns;
  18. // A (mutexed) map mapping (N, nthreads) pairs to WNEvalPlans
  19. struct EvalPlans {
  20. pthread_mutex_t mutex;
  21. std::map<std::pair<uint32_t,threadid_t>,WNEvalPlan> eval_plans;
  22. EvalPlans() { pthread_mutex_init(&mutex, NULL); }
  23. };
  24. static EvalPlans precomp_eps;
  25. size_t sort_precompute(uint32_t N)
  26. {
  27. uint32_t *random_permutation = NULL;
  28. try {
  29. random_permutation = new uint32_t[N];
  30. } catch (std::bad_alloc&) {
  31. printf("Allocating memory failed in sort_precompute\n");
  32. assert(false);
  33. }
  34. for (uint32_t i=0;i<N;++i) {
  35. random_permutation[i] = i;
  36. }
  37. RecursiveShuffle_M2((unsigned char *) random_permutation, N, sizeof(uint32_t));
  38. WaksmanNetwork wnet(N);
  39. wnet.setPermutation(random_permutation);
  40. // Note that sized_wns[N] creates a map entry for N if it doesn't yet exist
  41. pthread_mutex_lock(&precomp_wns.mutex);
  42. SizedWNs& szwn = precomp_wns.sized_wns[N];
  43. pthread_mutex_unlock(&precomp_wns.mutex);
  44. pthread_mutex_lock(&szwn.mutex);
  45. szwn.wns.push_back(std::move(wnet));
  46. size_t ret = szwn.wns.size();
  47. pthread_mutex_unlock(&szwn.mutex);
  48. return ret;
  49. }
  50. void sort_precompute_evalplan(uint32_t N, threadid_t nthreads)
  51. {
  52. std::pair<uint32_t,threadid_t> idx = {N, nthreads};
  53. pthread_mutex_lock(&precomp_eps.mutex);
  54. if (!precomp_eps.eval_plans.count(idx)) {
  55. precomp_eps.eval_plans.try_emplace(idx, N, nthreads);
  56. }
  57. pthread_mutex_unlock(&precomp_eps.mutex);
  58. }
  59. // Shuffle Nr items at the beginning of an allocated array of Na items
  60. // using up to nthreads threads. The items to shuffle are byte arrays
  61. // of size msg_size. Return Nw, the size of the Waksman network we
  62. // used, which must satisfy Nr <= Nw <= Na.
  63. uint32_t shuffle_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  64. uint32_t Nr, uint32_t Na)
  65. {
  66. // Find the smallest Nw for which we have a precomputed
  67. // WaksmanNetwork with Nr <= Nw <= Na
  68. pthread_mutex_lock(&precomp_wns.mutex);
  69. std::optional<WaksmanNetwork> wn;
  70. uint32_t Nw = 0;
  71. for (auto& N : precomp_wns.sized_wns) {
  72. if (N.first > Na) {
  73. printf("No precomputed WaksmanNetworks of size at most %u\n", Na);
  74. assert(false);
  75. }
  76. if (N.first < Nr) {
  77. continue;
  78. }
  79. // We're in the right range, but see if we have an actual
  80. // precomputed WaksmanNetwork
  81. pthread_mutex_lock(&N.second.mutex);
  82. if (N.second.wns.size() == 0) {
  83. pthread_mutex_unlock(&N.second.mutex);
  84. continue;
  85. }
  86. wn = std::move(N.second.wns.front());
  87. N.second.wns.pop_front();
  88. Nw = N.first;
  89. pthread_mutex_unlock(&N.second.mutex);
  90. break;
  91. }
  92. pthread_mutex_unlock(&precomp_wns.mutex);
  93. if (!wn) {
  94. printf("No precomputed WaksmanNetwork of size range [%u,%u] found.\n",
  95. Nr, Na);
  96. assert(wn);
  97. }
  98. std::pair<uint32_t,threadid_t> epidx = {Nw, nthreads};
  99. pthread_mutex_lock(&precomp_eps.mutex);
  100. if (!precomp_eps.eval_plans.count(epidx)) {
  101. printf("No precomputed WNEvalPlan with N=%u, nthreads=%hu\n",
  102. Nw, nthreads);
  103. assert(false);
  104. }
  105. const WNEvalPlan &eval_plan = precomp_eps.eval_plans.at(epidx);
  106. pthread_mutex_unlock(&precomp_eps.mutex);
  107. // Mark Nw-Nr items as padding (Nr, Na, and Nw are _not_ private)
  108. for (uint32_t i=Nr; i<Nw; ++i) {
  109. (*(uint32_t*)(items+msg_size*i)) = uint32_t(-1);
  110. }
  111. // Shuffle Nw items
  112. wn.value().applyInversePermutation<OSWAP_16X>(
  113. items, msg_size, eval_plan);
  114. return Nw;
  115. }
  116. // Perform the sort using up to nthreads threads. The items to sort are
  117. // byte arrays of size msg_size. The key is the 10-bit storage server
  118. // id concatenated with the 22-bit uid at the storage server.
  119. void sort_mtobliv(threadid_t nthreads, uint8_t* items, uint16_t msg_size,
  120. uint32_t Nr, uint32_t Na,
  121. // the arguments to the callback are items, the sorted indices, and
  122. // the number of non-padding items
  123. std::function<void(const uint8_t*, const uint64_t*, uint32_t Nr)> cb)
  124. {
  125. // Shuffle the items
  126. uint32_t Nw = shuffle_mtobliv(nthreads, items, msg_size, Nr, Na);
  127. // Create the indices
  128. uint64_t *idx = new uint64_t[Nr];
  129. uint64_t *nextidx = idx;
  130. for (uint32_t i=0; i<Nw; ++i) {
  131. uint64_t key = (*(uint32_t*)(items+msg_size*i));
  132. if (key != uint32_t(-1)) {
  133. *nextidx = (key<<32) + i;
  134. ++nextidx;
  135. }
  136. }
  137. if (nextidx != idx + Nr) {
  138. printf("Found %u non-padding items, expected %u\n",
  139. nextidx-idx, Nr);
  140. assert(nextidx == idx + Nr);
  141. }
  142. // Sort the keys and indices
  143. uint64_t *backingidx = new uint64_t[Nr];
  144. bool whichbuf = mtmergesort<uint64_t>(idx, Nr, backingidx, nthreads);
  145. uint64_t *sortedidx = whichbuf ? backingidx : idx;
  146. for (uint32_t i=0; i<Nr; ++i) {
  147. sortedidx[i] &= uint64_t(0xffffffff);
  148. }
  149. cb(items, sortedidx, Nr);
  150. delete[] idx;
  151. delete[] backingidx;
  152. }