Browse Source

Implementation of ORExpand (single-threaded version)

Ian Goldberg 1 year ago
parent
commit
54e0392e60
5 changed files with 256 additions and 1 deletions
  1. 45 0
      Enclave/OblivAlgs/ORExpand.cpp
  2. 98 0
      Enclave/OblivAlgs/ORExpand.hpp
  3. 102 0
      Enclave/OblivAlgs/ORExpand.tcc
  4. 1 0
      Enclave/storage.cpp
  5. 10 1
      Makefile

+ 45 - 0
Enclave/OblivAlgs/ORExpand.cpp

@@ -0,0 +1,45 @@
+#include "ORExpand.hpp"
+
+#ifdef TEST_OREXPAND
+void test_ORExpand()
+{
+    size_t block_size = 48;
+    uint32_t N = 16;
+
+    uint32_t dest_specified[] = {0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14};
+    size_t dest_specified_len = sizeof(dest_specified)/sizeof(uint32_t);
+    unsigned char *buf = new unsigned char[N*block_size];
+    uint32_t *dest = new uint32_t[N];
+    for (size_t i=0;i<dest_specified_len;++i) {
+        dest[i] = dest_specified[i];
+    }
+    for (size_t i=dest_specified_len;i<N;++i) {
+        dest[i] = 0xffffffff;
+    }
+
+    for (size_t i=0;i<N;++i) {
+        for (size_t j=0;j<block_size;j+=2) {
+            buf[i*block_size+j] = (unsigned char)i;
+            buf[i*block_size+j+1] = (unsigned char)j/2;
+        }
+    }
+
+    ORExpand<OSWAP_16X>(buf, dest, block_size, N);
+
+    for(size_t i=0;i<N;++i) {
+        printf("%2d ", i);
+        if (dest[i] == 0xffffffff) {
+            printf("PD ");
+        } else {
+            printf("%2d ", dest[i]);
+        }
+        for (size_t j=0;j<block_size;++j) {
+            printf("%02x", buf[i*block_size+j]);
+        }
+        printf("\n");
+    }
+
+    delete[] buf;
+    delete[] dest;
+}
+#endif

+ 98 - 0
Enclave/OblivAlgs/ORExpand.hpp

@@ -0,0 +1,98 @@
+#ifndef __OREXPAND_HPP__
+#define __OREXPAND_HPP__
+
+#include <cstddef>
+#include <cstdint>
+
+#include "oasm_lib.h"
+
+// Notation: for an array A, let A[lo..hi] denote the subarray of
+// A that is A[lo], A[lo+1], ..., A[hi-1].  That is, it is inclusive of
+// lo and exclusive of hi.  The length of the subarray is hi-lo.  If
+// lo < mid < hi, then A[lo..mid] and A[mid..hi] are a partition of
+// A[lo..hi] into two non-empty non-overlapping pieces whose union is
+// exactly A[lo..hi].
+
+// This file implements oblivious recursive expansion (ORExpand), which
+// is basically ORCompact (the TightCompact function in
+// TightCompaction_v2.tcc) run backwards (doing the OSwaps in reverse
+// order).
+
+// Throughout this file, "buf" is the data buffer, composed of blocks
+// that are each block_size bytes long.  block_size must be a size
+// supported by the corresponding OSwap_Style.  "dest" is an array of
+// destinations.  The idea is that if the input to ORExpand has, say,
+// dest[7] = 19, then the block at index 7 in the input buf will end up
+// at index 19 in the output buf (treating buf as an array of
+// block_size-byte blocks).
+
+// These functions do _not_ implement arbitrary permutations, however
+// (see the WaksmanNetwork.* files for that).  They will only expand a
+// contiguous range of blocks into a possibly non-contiguous range,
+// inserting padding blocks in the intervening spaces, and preserving
+// the order of the non-padding blocks.  The contiguous range may (in
+// some cases; see below) start at some offset z, rather than at the
+// beginning of the (sub)array, however, in which case the contiguous
+// range may "wrap around" to the beginning of the (sub)array.
+
+// Note that A[lo+((z+k) mod (hi-lo))] is element k (starting from 0) of a
+// contiguous range in A[lo..hi], starting at offset z, and wrapping
+// around if necessary.
+
+// We denote block i of buf (buf[i]) as a padding block by setting
+// dest[i] = 0xffffffff.
+
+// Therefore, suppose we are working on the subarray buf[lo..hi] (that
+// is, we only perform operations on buf[lo..hi] and on the
+// corresponding dest[lo..hi], and all elements of dest[lo..hi] are
+// either themselves in the range lo..hi (inclusive of lo and exclusive
+// of hi) or are 0xffffffff). If there are r non-padding blocks (and so
+// hi-lo-r padding blocks), and the offset is z, then on input, we must
+// have:
+//
+// lo <= dest[lo+(z mod (hi-lo))]
+//     < dest[lo+((z+1) mod (hi-lo))]
+//     < dest[lo+((z+2) mod (hi-lo))]
+//     ...
+//     < dest[lo+((z+(r-1)) mod (hi-lo))]
+//     < hi
+//
+// and the other elements of dest[lo..hi] are 0xffffffff.
+
+// Note that while the offset may cause the _indices_ to wrap around,
+// the dest _values_ have no offset, and cannot wrap around.
+
+// The restriction on z is that if hi-lo is _not_ a power of 2, then z
+// must be 0, and indeed lo must be 0 (and so the contiguous range of
+// non-padding blocks must start at the beginning of the array).
+
+// Note that z does not need to be passed explicitly; it is implied by
+// the contents of dest.  In fact, the algorithm does not even need to
+// compute it explicitly.
+
+// The values passed are not private, but the _contents_ of the buf and
+// dest arrays are.
+
+template <OSwap_Style oswap_style>
+void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
+    uint32_t lo, uint32_t hi);
+
+// A convenience wrapper for the initial call.  Here, N is the length of
+// the input arrays (in block_size-byte blocks for buf, and in 32-bit
+// words for dest).  z must be 0, so the non-padding elements of dest
+// must be at the beginning.
+
+template <OSwap_Style oswap_style>
+void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
+        uint32_t N) {
+    ORExpand<oswap_style>(buf, dest, block_size, 0, N);
+}
+
+
+#include "ORExpand.tcc"
+
+#ifdef TEST_OREXPAND
+void test_ORExpand();
+#endif
+
+#endif

+ 102 - 0
Enclave/OblivAlgs/ORExpand.tcc

@@ -0,0 +1,102 @@
+#include "utils.hpp"
+
+// See ORExpand.hpp for explanations of notation and inputs.
+// Particularly note that all subarrays [lo..hi] are _inclusive_ of lo
+// but _exclusive_ of hi.
+
+// buf is an array of block_size-byte blocks. dest is an array of 32-bit
+// words.  We are given two (contiguous) subarrays [lo..mid] and
+// [mid..hi], and indices a and b with a in [lo..mid] and b in
+// [mid..hi].  If (mid <= dest[a] < hi) or (lo <= dest[b] <= mid), then
+// swap dest[a] with dest[b] and buf[a] with buf[b]; otherwise, do not.
+// However, all tests and swaps must be done obliviously to the values
+// of dest[a] and dest[b] (and the contents of buf).  It's OK to not be
+// oblivious to the values of lo, mid, hi, a, and b themselves, however.
+template <OSwap_Style oswap_style>
+static inline void mid_oswap(unsigned char *buf, uint32_t *dest,
+    size_t block_size, uint32_t lo, uint32_t mid, uint32_t hi,
+    uint32_t a, uint32_t b)
+{
+    uint32_t desta = dest[a];
+    uint32_t destb = dest[b];
+    uint8_t swap_flag = ((mid <= desta) & (desta < hi))
+        | ((lo <= destb) & (destb < mid));
+    oswap_buffer<OSWAP_4>((unsigned char *)(dest+a),
+        (unsigned char *)(dest+b), 4, swap_flag);
+    oswap_buffer<oswap_style>(buf+a*block_size, buf+b*block_size,
+        (uint32_t)block_size, swap_flag);
+}
+
+template <OSwap_Style oswap_style>
+void ORExpand(unsigned char *buf, uint32_t *dest, size_t block_size,
+    uint32_t lo, uint32_t hi)
+{
+    // Passing hi < lo is an illegal input
+    assert(hi >= lo);
+
+    // The length of the subarray
+    const uint32_t N = hi-lo;
+
+    // Nothing to do on inputs where [lo..hi] has length 0 or 1
+    if (N < 2) {
+        return;
+    }
+
+    // The largest power of 2 strictly less than N
+    const uint32_t N2 = uint32_t(pow2_lt(N));
+
+    // We divide [lo..hi] (of length N) into two pieces:
+    // [lo..mid] and [mid..hi], where [mid..hi] has length N2 (the
+    // largest power of 2 strictly less than N).  Note that mid is just
+    // _somewhere_ in the middle of [lo..hi]; it will not be the exact
+    // midpoint if N is not itself a power of 2.  (It will be the exact
+    // midpoint if N is a power of 2, however.)
+    const uint32_t mid = hi-N2;
+
+    // N1 is the length of [lo..mid].  Note that N1 <= N2, with equality
+    // if and only if N is a power of 2.
+    const uint32_t N1 = N-N2;
+
+    // We're going to do N1 oblivious swaps on the buf and dest arrays,
+    // between items lo+i and hi-N1+i for 0 <= i < N1.  If dest[lo+i]
+    // lies in [mid..hi] (and is not 0xffffffff to indicate padding)
+    // _or_ if dest[hi-N1+i] lies in [lo..mid] (again and is not
+    // 0xffffffff), then we swap them and their corresponding buf
+    // blocks.  The cool part is that it cannot be the case that both
+    // dest[lo+i] and dest[hi-N1+i] are not padding and they both have
+    // values on the same side of mid.  Why is that?
+
+    // Case 1: If dest[lo+i] < dest[hi-N1+i], then all of the blocks
+    // from lo+i to hi-N1+1 inclusive must be non-padding blocks, and
+    // since this contiguous block has strictly increasing values, it
+    // must be that dest[hi-N1+i] - dest[lo+i] >= (hi-N1+i)-(lo+i) =
+    // N-N1 = N2.  Since the lengths of [lo..mid] and [mid..hi] are each
+    // at most N2, dest[lo+i] and dest[hi-N1+i] cannot be both in the
+    // same one of those subarrays.
+
+    // Case 2: If dest[hi-N1+i] < dest[lo+i], then the contiguous range
+    // of non-padding blocks wraps around hi back to lo, so we must have
+    // that dest[lo+i] - dest[hi-N1+i] >= (hi+i) - (hi-N1+i) = N1, and
+    // also since the range wraps around, it must start at a non-zero
+    // offset z, which means that N had to be a power of 2, and so
+    // N1=N2.  Therefore dest[lo+i] - dest[hi-N1+i] >= N2, and as above,
+    // dest[lo+i] and dest[hi-N1+i] cannot both be in [lo..mid] or both
+    // be in [mid..hi], each of which has length N1=N2.
+
+    // So these oblivious swaps will ensure that all the blocks with
+    // dest in [lo..mid] end up in [lo..mid] and all the blocks with
+    // dest in [mid..hi] end up in [mid..hi].  In addition, the property
+    // that all the non-padding blocks are contiguous (possibly wrapping
+    // around for the [mid..hi] subarray which has length a power of 2)
+    // and monotonicly increasing are preserved for both the [lo..mid]
+    // and [mid..hi] subarrays.
+
+    for (uint32_t i=0; i<N1; ++i) {
+        mid_oswap<oswap_style>(buf, dest, block_size, lo, mid, hi,
+            lo+i, hi-N1+i);
+    }
+
+    // And now we just recurse on the two subarrays.
+    ORExpand<oswap_style>(buf, dest, block_size, lo, mid);
+    ORExpand<oswap_style>(buf, dest, block_size, mid, hi);
+}

+ 1 - 0
Enclave/storage.cpp

@@ -1,6 +1,7 @@
 #include "utils.hpp"
 #include "config.hpp"
 #include "storage.hpp"
+#include "ORExpand.hpp"
 
 // Handle the messages received by a storage node
 void storage_received(const uint8_t *msgs, uint32_t num_msgs)

+ 10 - 1
Makefile

@@ -338,7 +338,16 @@ Enclave/storage.o: Enclave/enclave_api.h Enclave/OblivAlgs/CONFIG.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/storage.o: Enclave/OblivAlgs/oasm_lib.tcc Enclave/OblivAlgs/foav.h
 Enclave/storage.o: Enclave/config.hpp Enclave/enclave_api.h
-Enclave/storage.o: Enclave/storage.hpp
+Enclave/storage.o: Enclave/storage.hpp Enclave/OblivAlgs/ORExpand.hpp
+Enclave/storage.o: Enclave/OblivAlgs/ORExpand.tcc
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.hpp
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/CONFIG.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/oasm_lib.tcc
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/foav.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/ORExpand.tcc
+Enclave/OblivAlgs/ORExpand.o: Enclave/OblivAlgs/utils.hpp Enclave/Enclave_t.h
+Enclave/OblivAlgs/ORExpand.o: Enclave/enclave_api.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/CONFIG.h
 Enclave/OblivAlgs/RecursiveShuffle.o: Enclave/OblivAlgs/oasm_lib.tcc