BitonicSortLib.java 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>
  2. package com.oblivm.backend.circuits;
  3. import com.oblivm.backend.circuits.arithmetic.IntegerLib;
  4. import com.oblivm.backend.flexsc.CompEnv;
  5. public class BitonicSortLib<T> extends IntegerLib<T> {
  6. public BitonicSortLib(CompEnv<T> e) {
  7. super(e);
  8. }
  9. public void sortWithPayload(T[][] a, T[][] data, T isAscending) {
  10. bitonicSortWithPayload(a, data, 0, a.length, isAscending);
  11. }
  12. private void bitonicSortWithPayload(T[][] key, T[][] data, int lo, int n, T dir) {
  13. if (n > 1) {
  14. int m = n / 2;
  15. bitonicSortWithPayload(key, data, lo, m, not(dir));
  16. bitonicSortWithPayload(key, data, lo + m, n - m, dir);
  17. bitonicMergeWithPayload(key, data, lo, n, dir);
  18. }
  19. }
  20. protected void bitonicMergeWithPayload(T[][] key, T[][] data, int lo, int n, T dir) {
  21. if (n > 1) {
  22. int m = greatestPowerOfTwoLessThan(n);
  23. for (int i = lo; i < lo + n - m; i++)
  24. compareWithPayload(key, data, i, i + m, dir);
  25. bitonicMergeWithPayload(key, data, lo, m, dir);
  26. bitonicMergeWithPayload(key, data, lo + m, n - m, dir);
  27. }
  28. }
  29. private void compareWithPayload(T[][] key, T[][] data, int i, int j, T dir) {
  30. T greater = not(leq(key[i], key[j]));
  31. T swap = eq(greater, dir);
  32. T[] s = mux(key[j], key[i], swap);
  33. s = xor(s, key[i]);
  34. T[] ki = xor(key[j], s);
  35. T[] kj = xor(key[i], s);
  36. key[i] = ki;
  37. key[j] = kj;
  38. T[] s2 = mux(data[j], data[i], swap);
  39. s2 = xor(s2, data[i]);
  40. T[] di = xor(data[j], s2);
  41. T[] dj = xor(data[i], s2);
  42. data[i] = di;
  43. data[j] = dj;
  44. }
  45. public void sort(T[][] a, T isAscending) {
  46. bitonicSort(a, 0, a.length, isAscending);
  47. }
  48. private void bitonicSort(T[][] key, int lo, int n, T dir) {
  49. if (n > 1) {
  50. int m = n / 2;
  51. bitonicSort(key, lo, m, not(dir));
  52. bitonicSort(key, lo + m, n - m, dir);
  53. bitonicMerge(key, lo, n, dir);
  54. }
  55. }
  56. protected void bitonicMerge(T[][] key, int lo, int n, T dir) {
  57. if (n > 1) {
  58. int m = greatestPowerOfTwoLessThan(n);
  59. for (int i = lo; i < lo + n - m; i++)
  60. compare(key, i, i + m, dir);
  61. bitonicMerge(key, lo, m, dir);
  62. bitonicMerge(key, lo + m, n - m, dir);
  63. }
  64. }
  65. private void compare(T[][] key, int i, int j, T dir) {
  66. T swap = eq(not(leq(key[i], key[j])), dir);
  67. T[] s = mux(key[j], key[i], swap);
  68. s = xor(s, key[i]);
  69. T[] ki = xor(key[j], s);
  70. T[] kj = xor(key[i], s);
  71. key[i] = ki;
  72. key[j] = kj;
  73. }
  74. private int greatestPowerOfTwoLessThan(int n) {
  75. int k = 1;
  76. while (k < n)
  77. k = k << 1;
  78. return k >> 1;
  79. }
  80. public void sortWithPayload(T[] a, T[][] data, T isAscending) {
  81. bitonicSortWithPayload(a, data, 0, a.length, isAscending);
  82. }
  83. private void bitonicSortWithPayload(T[] key, T[][] data, int lo, int n, T dir) {
  84. if (n > 1) {
  85. int m = n / 2;
  86. bitonicSortWithPayload(key, data, lo, m, not(dir));
  87. bitonicSortWithPayload(key, data, lo + m, n - m, dir);
  88. bitonicMergeWithPayload(key, data, lo, n, dir);
  89. }
  90. }
  91. private void bitonicMergeWithPayload(T[] key, T[][] data, int lo, int n, T dir) {
  92. if (n > 1) {
  93. int m = greatestPowerOfTwoLessThan(n);
  94. for (int i = lo; i < lo + n - m; i++)
  95. compareWithPayload(key, data, i, i + m, dir);
  96. bitonicMergeWithPayload(key, data, lo, m, dir);
  97. bitonicMergeWithPayload(key, data, lo + m, n - m, dir);
  98. }
  99. }
  100. private void compareWithPayload(T[] key, T[][] data, int i, int j, T dir) {
  101. T greater = and(key[i], not(key[j]));
  102. T swap = eq(greater, dir);
  103. T s = mux(key[j], key[i], swap);
  104. s = xor(s, key[i]);
  105. T ki = xor(key[j], s);
  106. T kj = xor(key[i], s);
  107. key[i] = ki;
  108. key[j] = kj;
  109. T[] s2 = mux(data[j], data[i], swap);
  110. s2 = xor(s2, data[i]);
  111. T[] di = xor(data[j], s2);
  112. T[] dj = xor(data[i], s2);
  113. data[i] = di;
  114. data[j] = dj;
  115. }
  116. public void sortWithPayload(T[][] a, T[] data, T isAscending) {
  117. bitonicSortWithPayload(a, data, 0, a.length, isAscending);
  118. }
  119. private void bitonicSortWithPayload(T[][] key, T[] data, int lo, int n, T dir) {
  120. if (n > 1) {
  121. int m = n / 2;
  122. bitonicSortWithPayload(key, data, lo, m, not(dir));
  123. bitonicSortWithPayload(key, data, lo + m, n - m, dir);
  124. bitonicMergeWithPayload(key, data, lo, n, dir);
  125. }
  126. }
  127. private void bitonicMergeWithPayload(T[][] key, T[] data, int lo, int n, T dir) {
  128. if (n > 1) {
  129. int m = greatestPowerOfTwoLessThan(n);
  130. for (int i = lo; i < lo + n - m; i++)
  131. compareWithPayload(key, data, i, i + m, dir);
  132. bitonicMergeWithPayload(key, data, lo, m, dir);
  133. bitonicMergeWithPayload(key, data, lo + m, n - m, dir);
  134. }
  135. }
  136. private void compareWithPayload(T[][] key, T[] data, int i, int j, T dir) {
  137. T greater = not(leq(key[i], key[j]));
  138. T swap = eq(greater, dir);
  139. T[] s = mux(key[j], key[i], swap);
  140. s = xor(s, key[i]);
  141. T[] ki = xor(key[j], s);
  142. T[] kj = xor(key[i], s);
  143. key[i] = ki;
  144. key[j] = kj;
  145. T s2 = mux(data[j], data[i], swap);
  146. s2 = xor(s2, data[i]);
  147. T di = xor(data[j], s2);
  148. T dj = xor(data[i], s2);
  149. data[i] = di;
  150. data[j] = dj;
  151. }
  152. }