Browse Source

Be able to use a CDPF to compare RegAS or RegXS for equality

Ian Goldberg 1 year ago
parent
commit
7ea4769df8
4 changed files with 111 additions and 7 deletions
  1. 46 0
      cdpf.cpp
  2. 20 0
      cdpf.hpp
  3. 34 0
      cdpf.tcc
  4. 11 7
      online.cpp

+ 46 - 0
cdpf.cpp

@@ -296,3 +296,49 @@ std::tuple<RegBS,RegBS,RegBS> CDPF::compare(value_t S, size_t &aes_ops)
 
     return std::make_tuple(lt, eq, gt);
 }
+
+// You can call this version directly if you already have S = target-x
+// reconstructed.  This routine is entirely local; no communication
+// is needed.  This function is identical to compare, above, except that
+// it only computes what's needed for the eq output.
+//
+// Cost:
+// VALUE_BITS - 7 = 57 local AES operations
+RegBS CDPF::is_zero(value_t S, size_t &aes_ops)
+{
+    RegBS eq;
+
+    // We' descend the DPF tree for the values S.
+
+    // Invariant: Snode is the node on level curlevel on the path to
+    // S.
+    nbits_t curlevel = 0;
+    const nbits_t depth = VALUE_BITS - 7;
+    DPFnode Snode = seed;
+
+    bool Sdir = !!(S & (value_t(1)<<63));
+    Snode = descend(Snode, curlevel, Sdir, aes_ops);
+    curlevel = 1;
+
+    // The last level is special
+    while(curlevel < depth-1) {
+        Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
+        Snode = descend(Snode, curlevel, Sdir, aes_ops);
+        ++curlevel;
+    }
+    // Now we're at the level just above the leaves.  If we go left,
+    // include *all* the bits (not just the low bit) of the right
+    // child of Snode, and if we go right, include all the bits of
+    // the left child of Tnode.
+    Sdir = !!(S & (value_t(1)<<((depth+7)-curlevel-1)));
+    Snode = descend_to_leaf(Snode, Sdir, aes_ops);
+    ++curlevel;
+
+    // Now Snode is the leaf containing S.  Pull out the bit in Snode
+    // for S itself into eq.
+
+    nbits_t Spos = S & 0x7f;
+    eq.bshare = bit_at(Snode, Spos);
+
+    return eq;
+}

+ 20 - 0
cdpf.hpp

@@ -138,6 +138,26 @@ struct CDPF : public DPF {
     // 3*VALUE_BITS - 22 = 170 local AES operations
     std::tuple<RegBS,RegBS,RegBS> compare(value_t S, size_t &aes_ops);
 
+    // Determine whether the given additively or XOR shared element is 0.
+    // The output is a bit share, which is a share of 1 iff the passed
+    // element is a share of 0.  Note also that you can compare two RegAS or
+    // RegXS values A and B for equality by passing A-B here.
+    //
+    // Cost:
+    // 1 word sent in 1 message
+    // VALUE_BITS - 7 = 57 local AES operations
+    template <typename T>
+    RegBS is_zero(MPCTIO &tio, yield_t &yield,
+        const T &x, size_t &aes_ops);
+
+    // You can call this version directly if you already have S = target-x
+    // reconstructed.  This routine is entirely local; no communication
+    // is needed.  This function is identical to compare, above, except that
+    // it only computes what's needed for the eq output.
+    //
+    // Cost:
+    // VALUE_BITS - 7 = 57 local AES operations
+    RegBS is_zero(value_t S, size_t &aes_ops);
 };
 
 // Descend from the parent of a leaf node to the leaf node

+ 34 - 0
cdpf.tcc

@@ -43,3 +43,37 @@ T& operator<<(T &os, const CDPF &cdpf)
 
     return os;
 }
+
+// Determine whether the given additively or XOR shared element is 0.
+// The output is a bit share, which is a share of 1 iff the passed
+// element is a share of 0.  Note also that you can compare two RegAS or
+// RegXS values A and B for equality by passing A-B here.
+//
+// Cost:
+// 1 word sent in 1 message
+// VALUE_BITS - 7 = 57 local AES operations
+template <typename T>
+RegBS CDPF::is_zero(MPCTIO &tio, yield_t &yield,
+    const T &x, size_t &aes_ops)
+{
+    // Reconstruct S = target-x
+    // The server does nothing in this protocol
+    if (tio.player() < 2) {
+        T S_share = as_target - x;
+        tio.iostream_peer() << S_share;
+        yield();
+        T peer_S_share;
+        tio.iostream_peer() >> peer_S_share;
+        S_share += peer_S_share;
+        value_t S = S_share.share();
+
+        // After that one single-word exchange, the rest of this
+        // algorithm is entirely a local computation.
+        return is_zero(S, aes_ops);
+    } else {
+        yield();
+    }
+    // The server gets a share of 0
+    RegBS eq;
+    return eq;
+}

+ 11 - 7
online.cpp

@@ -866,26 +866,30 @@ static int compare_test_one(MPCTIO &tio, yield_t &yield,
         RegAS xsh;
         tio.iostream_server() >> dpf >> xsh;
         auto [lt, eq, gt] = dpf.compare(tio, yield, xsh, aes_ops);
-        printf("%016lx %016lx %d %d %d ", target, x, lt.bshare,
-            eq.bshare, gt.bshare);
+        RegBS eeq = dpf.is_zero(tio, yield, xsh, aes_ops);
+        printf("%016lx %016lx %d %d %d %d ", target, x, lt.bshare,
+            eq.bshare, gt.bshare, eeq.bshare);
         // Check the answer
         if (player == 1) {
-            tio.iostream_peer() << xsh << lt << eq << gt;
+            tio.iostream_peer() << xsh << lt << eq << gt << eeq;
         } else {
             RegAS peer_xsh;
-            RegBS peer_lt, peer_eq, peer_gt;
-            tio.iostream_peer() >> peer_xsh >> peer_lt >> peer_eq >> peer_gt;
+            RegBS peer_lt, peer_eq, peer_gt, peer_eeq;
+            tio.iostream_peer() >> peer_xsh >> peer_lt >> peer_eq >>
+                peer_gt >> peer_eeq;
             lt ^= peer_lt;
             eq ^= peer_eq;
             gt ^= peer_gt;
+            eeq ^= peer_eeq;
             xsh += peer_xsh;
             int lti = int(lt.bshare);
             int eqi = int(eq.bshare);
             int gti = int(gt.bshare);
+            int eeqi = int(eeq.bshare);
             x = xsh.share();
-            printf(": %d %d %d ", lti, eqi, gti);
+            printf(": %d %d %d %d ", lti, eqi, gti, eeqi);
             bool signbit = (x >> 63);
-            if (lti + eqi + gti != 1) {
+            if (lti + eqi + gti != 1 || eqi != eeqi) {
                 printf("INCONSISTENT");
                 res = 0;
             } else if (x == 0 && eqi) {