Browse Source

Extend StreamEval to allow for an XOR input offset in addition to an additive one

Ian Goldberg 1 year ago
parent
commit
ac0020d18d
4 changed files with 46 additions and 21 deletions
  1. 6 6
      duoram.tcc
  2. 4 4
      online.cpp
  3. 10 2
      rdpf.hpp
  4. 26 9
      rdpf.tcc

+ 6 - 6
duoram.tcc

@@ -146,7 +146,7 @@ Duoram<T>::Shape::MemRefAS::operator T()
         auto indshift = combine(indoffset, peerindoffset, shape.addr_size);
 
         // Evaluate the DPFs and compute the dotproducts
-        StreamEval ev(dp, indshift, shape.tio.aes_ops());
+        StreamEval ev(dp, indshift, 0, shape.tio.aes_ops());
         for (size_t i=0; i<shape.shape_size; ++i) {
             auto L = ev.next();
             // The values from the two DPFs
@@ -175,7 +175,7 @@ Duoram<T>::Shape::MemRefAS::operator T()
 
         // Evaluate the DPFs to compute the cancellation terms
         T gamma0, gamma1;
-        StreamEval ev(dp, indshift, shape.tio.aes_ops());
+        StreamEval ev(dp, indshift, 0, shape.tio.aes_ops());
         for (size_t i=0; i<shape.shape_size; ++i) {
             auto L = ev.next();
 
@@ -217,8 +217,8 @@ typename Duoram<T>::Shape::MemRefAS
         RDPFTriple dt = shape.tio.rdpftriple(shape.addr_size);
 
         // Compute the index and message offsets
-        RegAS indoffset = idx;
-        indoffset -= dt.as_target;
+        RegAS indoffset = dt.as_target;
+        indoffset -= idx;
         auto Moffset = std::make_tuple(M, M, M);
         Moffset -= dt.scaled_value<T>();
 
@@ -243,7 +243,7 @@ typename Duoram<T>::Shape::MemRefAS
         auto Mshift = combine(Moffset, peerMoffset);
 
         // Evaluate the DPFs and add them to the database
-        StreamEval ev(dt, -indshift, shape.tio.aes_ops());
+        StreamEval ev(dt, indshift, 0, shape.tio.aes_ops());
         for (size_t i=0; i<shape.shape_size; ++i) {
             auto L = ev.next();
             // The values from the three DPFs
@@ -277,7 +277,7 @@ typename Duoram<T>::Shape::MemRefAS
         auto Mshift = combine(p0Moffset, p1Moffset);
 
         // Evaluate the DPFs and subtract them from the blinds
-        StreamEval ev(dp, -indshift, shape.tio.aes_ops());
+        StreamEval ev(dp, indshift, 0, shape.tio.aes_ops());
         for (size_t i=0; i<shape.shape_size; ++i) {
             auto L = ev.next();
             // The values from the two DPFs

+ 4 - 4
online.cpp

@@ -315,7 +315,7 @@ static void rdpfeval_timing(MPCIO &mpcio, yield_t &yield,
                 for (int i=0;i<2;++i) {
                     RDPF &dpf = dp.dpf[i];
                     RegXS scaled_xor;
-                    auto ev = StreamEval(dpf, start, op_counter, false);
+                    auto ev = StreamEval(dpf, start, 0, op_counter, false);
                     for (address_t x=0;x<(address_t(1)<<depth);++x) {
                         DPFnode leaf = ev.next();
                         RegXS sx = dpf.scaled_xs(leaf);
@@ -330,7 +330,7 @@ static void rdpfeval_timing(MPCIO &mpcio, yield_t &yield,
                 for (int i=0;i<3;++i) {
                     RDPF &dpf = dt.dpf[i];
                     RegXS scaled_xor;
-                    auto ev = StreamEval(dpf, start, op_counter, false);
+                    auto ev = StreamEval(dpf, start, 0, op_counter, false);
                     for (address_t x=0;x<(address_t(1)<<depth);++x) {
                         DPFnode leaf = ev.next();
                         RegXS sx = dpf.scaled_xs(leaf);
@@ -371,7 +371,7 @@ static void tupleeval_timing(MPCIO &mpcio, yield_t &yield,
             if (mpcio.player == 2) {
                 RDPFPair dp = tio.rdpfpair(depth);
                 RegXS scaled_xor0, scaled_xor1;
-                auto ev = StreamEval(dp, start, op_counter, false);
+                auto ev = StreamEval(dp, start, 0, op_counter, false);
                 for (address_t x=0;x<(address_t(1)<<depth);++x) {
                     auto [L0, L1] = ev.next();
                     RegXS sx0 = dp.dpf[0].scaled_xs(L0);
@@ -388,7 +388,7 @@ static void tupleeval_timing(MPCIO &mpcio, yield_t &yield,
             } else {
                 RDPFTriple dt = tio.rdpftriple(depth);
                 RegXS scaled_xor0, scaled_xor1, scaled_xor2;
-                auto ev = StreamEval(dt, start, op_counter, false);
+                auto ev = StreamEval(dt, start, 0, op_counter, false);
                 for (address_t x=0;x<(address_t(1)<<depth);++x) {
                     auto [L0, L1, L2] = ev.next();
                     RegXS sx0 = dt.dpf[0].scaled_xs(L0);

+ 10 - 2
rdpf.hpp

@@ -17,6 +17,7 @@ class StreamEval {
     size_t &op_counter;
     bool use_expansion;
     nbits_t depth;
+    address_t counter_xor_offset;
     address_t indexmask;
     address_t pathindex;
     address_t nextindex;
@@ -26,9 +27,16 @@ public:
     // It will wrap around to 0 when it hits 2^depth.  If use_expansion
     // is true, then if the DPF has been expanded, just output values
     // from that.  If use_expansion=false or if the DPF has not been
-    // expanded, compute the values on the fly.
-    StreamEval(const T &rdpf, address_t start, size_t &op_counter,
+    // expanded, compute the values on the fly.  If xor_offset is
+    // non-zero, then the outputs are actually
+    // DPF(start XOR xor_offset)
+    // DPF((start+1) XOR xor_offset)
+    // DPF((start+2) XOR xor_offset)
+    // etc.
+    StreamEval(const T &rdpf, address_t start,
+        address_t xor_offset, size_t &op_counter,
         bool use_expansion = true);
+
     // Get the next value (or tuple of values) from the evaluator
     typename T::node next();
 };

+ 26 - 9
rdpf.tcc

@@ -4,11 +4,17 @@
 // It will wrap around to 0 when it hits 2^depth.  If use_expansion
 // is true, then if the DPF has been expanded, just output values
 // from that.  If use_expansion=false or if the DPF has not been
-// expanded, compute the values on the fly.
+// expanded, compute the values on the fly.  If xor_offset is non-zero,
+// then the outputs are actually
+// DPF(start XOR xor_offset)
+// DPF((start+1) XOR xor_offset)
+// DPF((start+2) XOR xor_offset)
+// etc.
 template <typename T>
 StreamEval<T>::StreamEval(const T &rdpf, address_t start,
-    size_t &op_counter, bool use_expansion) : rdpf(rdpf),
-    op_counter(op_counter), use_expansion(use_expansion)
+    address_t xor_offset, size_t &op_counter,
+    bool use_expansion) : rdpf(rdpf), op_counter(op_counter),
+    use_expansion(use_expansion), counter_xor_offset(xor_offset)
 {
     depth = rdpf.depth();
     // Prevent overflow of 1<<depth
@@ -18,6 +24,7 @@ StreamEval<T>::StreamEval(const T &rdpf, address_t start,
         indexmask = ~0;
     }
     start &= indexmask;
+    counter_xor_offset &= indexmask;
     // Record that we haven't actually output the leaf for index start
     // itself yet
     nextindex = start;
@@ -30,7 +37,10 @@ StreamEval<T>::StreamEval(const T &rdpf, address_t start,
     path[0] = rdpf.get_seed();
     for (nbits_t i=1;i<depth;++i) {
         bool dir = !!(pathindex & (address_t(1)<<(depth-i)));
-        path[i] = rdpf.descend(path[i-1], i-1, dir, op_counter);
+        bool xor_offset_bit =
+            !!(counter_xor_offset & (address_t(1)<<(depth-i)));
+        path[i] = rdpf.descend(path[i-1], i-1,
+            dir ^ xor_offset_bit, op_counter);
     }
 }
 
@@ -39,7 +49,8 @@ typename T::node StreamEval<T>::next()
 {
     if (use_expansion && rdpf.has_expansion()) {
         // Just use the precomputed values
-        typename T::node leaf = rdpf.get_expansion(nextindex);
+        typename T::node leaf =
+            rdpf.get_expansion(nextindex ^ counter_xor_offset);
         nextindex = (nextindex + 1) & indexmask;
         return leaf;
     }
@@ -70,16 +81,22 @@ typename T::node StreamEval<T>::next()
         // around from the right subtree back to the left, in which case
         // it will be 0.
         bool top_changed_bit =
-            nextindex & (address_t(1) << how_many_1_bits);
+            !!(nextindex & (address_t(1) << how_many_1_bits));
+        bool xor_offset_bit =
+            !!(counter_xor_offset & (address_t(1) << how_many_1_bits));
         path[depth-how_many_1_bits] =
             rdpf.descend(path[depth-how_many_1_bits-1],
-                depth-how_many_1_bits-1, top_changed_bit, op_counter);
+                depth-how_many_1_bits-1,
+                top_changed_bit ^ xor_offset_bit, op_counter);
         for (nbits_t i = depth-how_many_1_bits; i < depth-1; ++i) {
-            path[i+1] = rdpf.descend(path[i], i, 0, op_counter);
+            bool xor_offset_bit =
+                !!(counter_xor_offset & (address_t(1) << (depth-i-1)));
+            path[i+1] = rdpf.descend(path[i], i, xor_offset_bit, op_counter);
         }
     }
+    bool xor_offset_bit = counter_xor_offset & 1;
     typename T::node leaf = rdpf.descend(path[depth-1], depth-1,
-        nextindex & 1, op_counter);
+        (nextindex & 1) ^ xor_offset_bit, op_counter);
     pathindex = nextindex;
     nextindex = (nextindex + 1) & indexmask;
     return leaf;