Browse Source

Bitonic sort for arbitrary lengths

Not just powers of 2
Ian Goldberg 1 year ago
parent
commit
36c4daa621
3 changed files with 70 additions and 35 deletions
  1. 3 3
      duoram.hpp
  2. 36 23
      duoram.tcc
  3. 31 9
      online.cpp

+ 3 - 3
duoram.hpp

@@ -266,7 +266,7 @@ class Duoram<T>::Flat : public Duoram<T>::Shape {
     }
 
     // Internal function to aid bitonic_sort
-    void butterfly(address_t start, nbits_t depth, bool dir);
+    void butterfly(address_t start, address_t len, bool dir);
 
 public:
     // Constructor.  len=0 means the maximum size (the parent's size
@@ -338,11 +338,11 @@ public:
     template<typename U,typename V>
     void osort(const U &idx1, const V &idx2, bool dir=0);
 
-    // Bitonic sort the elements from start to start+(1<<depth)-1, in
+    // Bitonic sort the elements from start to start+len-1, in
     // increasing order if dir=0 or decreasing order if dir=1. Note that
     // the elements must be at most 63 bits long each for the notion of
     // ">" to make consistent sense.
-    void bitonic_sort(address_t start, nbits_t depth, bool dir=0);
+    void bitonic_sort(address_t start, address_t len, bool dir=0);
 
     // Assuming the memory is already sorted, do an oblivious binary
     // search for the largest index containing the value at most the

+ 36 - 23
duoram.tcc

@@ -177,62 +177,75 @@ Duoram<T>::Flat::Flat(Duoram &duoram, MPCTIO &tio, yield_t &yield,
     this->set_shape_size(len);
 }
 
-// Bitonic sort the elements from start to start+(1<<depth)-1, in
+// Bitonic sort the elements from start to start+len-1, in
 // increasing order if dir=0 or decreasing order if dir=1. Note that
 // the elements must be at most 63 bits long each for the notion of
 // ">" to make consistent sense.
 template <typename T>
-void Duoram<T>::Flat::bitonic_sort(address_t start, nbits_t depth, bool dir)
+void Duoram<T>::Flat::bitonic_sort(address_t start, address_t len, bool dir)
 {
-    if (depth == 0) return;
-    if (depth == 1) {
+    if (len < 2) return;
+    if (len == 2) {
         osort(start, start+1, dir);
         return;
     }
-    // Recurse on the first half (increasing order) and the second half
-    // (decreasing order) in parallel
+    address_t leftlen, rightlen;
+    leftlen = (len+1) >> 1;
+    rightlen = len >> 1;
+
+    // Recurse on the first half (opposite to the desired order)
+    // and the second half (desired order) in parallel
     run_coroutines(this->yield,
-        [this, start, depth](yield_t &yield) {
+        [this, start, leftlen, dir](yield_t &yield) {
             Flat Acoro = context(yield);
-            Acoro.bitonic_sort(start, depth-1, 0);
+            Acoro.bitonic_sort(start, leftlen, !dir);
         },
-        [this, start, depth](yield_t &yield) {
+        [this, start, leftlen, rightlen, dir](yield_t &yield) {
             Flat Acoro = context(yield);
-            Acoro.bitonic_sort(start+(1<<(depth-1)), depth-1, 1);
+            Acoro.bitonic_sort(start+leftlen, rightlen, dir);
         });
     // Merge the two into the desired order
-    butterfly(start, depth, dir);
+    butterfly(start, len, dir);
 }
 
 // Internal function to aid bitonic_sort
 template <typename T>
-void Duoram<T>::Flat::butterfly(address_t start, nbits_t depth, bool dir)
+void Duoram<T>::Flat::butterfly(address_t start, address_t len, bool dir)
 {
-    if (depth == 0) return;
-    if (depth == 1) {
+    if (len < 2) return;
+    if (len == 2) {
         osort(start, start+1, dir);
         return;
     }
-    // Sort pairs of elements half the width apart in parallel
-    address_t halfwidth = address_t(1)<<(depth-1);
+    address_t leftlen, rightlen, offset, num_swaps;
+    // leftlen = (len+1) >> 1;
+    leftlen = 1;
+    while(2*leftlen < len) {
+        leftlen *= 2;
+    }
+    rightlen = len - leftlen;
+    offset = leftlen;
+    num_swaps = rightlen;
+
+    // Sort pairs of elements offset apart in parallel
     std::vector<coro_t> coroutines;
-    for (address_t i=0; i<halfwidth;++i) {
+    for (address_t i=0; i<num_swaps;++i) {
         coroutines.emplace_back(
-            [this, start, halfwidth, dir, i](yield_t &yield) {
+            [this, start, offset, dir, i](yield_t &yield) {
                 Flat Acoro = context(yield);
-                Acoro.osort(start+i, start+i+halfwidth, dir);
+                Acoro.osort(start+i, start+i+offset, dir);
             });
     }
     run_coroutines(this->yield, coroutines);
     // Recurse on each half in parallel
     run_coroutines(this->yield,
-        [this, start, depth, dir](yield_t &yield) {
+        [this, start, leftlen, dir](yield_t &yield) {
             Flat Acoro = context(yield);
-            Acoro.butterfly(start, depth-1, dir);
+            Acoro.butterfly(start, leftlen, dir);
         },
-        [this, start, halfwidth, depth, dir](yield_t &yield) {
+        [this, start, leftlen, rightlen, dir](yield_t &yield) {
             Flat Acoro = context(yield);
-            Acoro.butterfly(start+halfwidth, depth-1, dir);
+            Acoro.butterfly(start+leftlen, rightlen, dir);
         });
 }
 

+ 31 - 9
online.cpp

@@ -1066,13 +1066,18 @@ static void sort_test(MPCIO &mpcio,
         depth = atoi(*args);
         ++args;
     }
+    address_t len = (1<<depth);
+    if (*args) {
+        len = atoi(*args);
+        ++args;
+    }
 
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth] {
+        boost::asio::post(pool, [&mpcio, thread_num, depth, len] {
             MPCTIO tio(mpcio, thread_num);
-            run_coroutines(tio, [&tio, depth] (yield_t &yield) {
+            run_coroutines(tio, [&tio, depth, len] (yield_t &yield) {
                 address_t size = address_t(1)<<depth;
                 // size_t &aes_ops = tio.aes_ops();
                 Duoram<RegAS> oram(tio.player(), size);
@@ -1090,14 +1095,26 @@ static void sort_test(MPCIO &mpcio,
                         });
                 }
                 run_coroutines(yield, coroutines);
-                A.bitonic_sort(0, depth);
+                A.bitonic_sort(0, len);
                 if (depth <= 10) {
                     oram.dump();
-                    auto check = A.reconstruct();
-                    if (tio.player() == 0) {
-                        for (address_t i=0;i<size;++i) {
+                }
+                auto check = A.reconstruct();
+                bool fail = false;
+                if (tio.player() == 0) {
+                    for (address_t i=0;i<size;++i) {
+                        if (depth <= 10) {
                             printf("%04x %016lx\n", i, check[i].share());
                         }
+                        if (i>0 && i<len &&
+                            check[i].share() < check[i-1].share()) {
+                            fail = true;
+                        }
+                    }
+                    if (fail) {
+                        printf("FAIL\n");
+                    } else {
+                        printf("PASS\n");
                     }
                 }
             });
@@ -1118,6 +1135,11 @@ static void bsearch_test(MPCIO &mpcio,
         depth = atoi(*args);
         ++args;
     }
+    address_t len = (1<<depth);
+    if (*args) {
+        len = atoi(*args);
+        ++args;
+    }
     if (*args) {
         target = strtoull(*args, NULL, 16);
         ++args;
@@ -1126,9 +1148,9 @@ static void bsearch_test(MPCIO &mpcio,
     int num_threads = opts.num_threads;
     boost::asio::thread_pool pool(num_threads);
     for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
-        boost::asio::post(pool, [&mpcio, thread_num, depth, target] {
+        boost::asio::post(pool, [&mpcio, thread_num, depth, len, target] {
             MPCTIO tio(mpcio, thread_num);
-            run_coroutines(tio, [&tio, depth, target] (yield_t &yield) {
+            run_coroutines(tio, [&tio, depth, len, target] (yield_t &yield) {
                 address_t size = address_t(1)<<depth;
                 RegAS tshare;
                 if (tio.player() == 2) {
@@ -1163,7 +1185,7 @@ static void bsearch_test(MPCIO &mpcio,
                         });
                 }
                 run_coroutines(yield, coroutines);
-                A.bitonic_sort(0, depth);
+                A.bitonic_sort(0, len);
 
                 // Binary search for the target
                 RegAS tindex = A.obliv_binary_search(tshare);