Browse Source

Comparison code complete

Ian Goldberg 1 year ago
parent
commit
a954959e22
5 changed files with 176 additions and 14 deletions
  1. 3 2
      Makefile
  2. 3 0
      bitutils.hpp
  3. 19 3
      cdpf.cpp
  4. 22 9
      cdpf.hpp
  5. 129 0
      online.cpp

+ 3 - 2
Makefile

@@ -35,8 +35,9 @@ preproc.o: types.hpp coroutine.hpp mpcio.hpp preproc.hpp options.hpp rdpf.hpp
 preproc.o: bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc cdpf.hpp cdpf.tcc
 online.o: online.hpp mpcio.hpp types.hpp options.hpp mpcops.hpp coroutine.hpp
 online.o: rdpf.hpp bitutils.hpp dpf.hpp prg.hpp aes.hpp rdpf.tcc duoram.hpp
-online.o: duoram.tcc
+online.o: duoram.tcc cdpf.hpp cdpf.tcc
 mpcops.o: mpcops.hpp types.hpp mpcio.hpp coroutine.hpp bitutils.hpp
 rdpf.o: rdpf.hpp mpcio.hpp types.hpp coroutine.hpp bitutils.hpp dpf.hpp
 rdpf.o: prg.hpp aes.hpp rdpf.tcc mpcops.hpp
-cdpf.o: bitutils.hpp cdpf.hpp types.hpp dpf.hpp prg.hpp aes.hpp cdpf.tcc
+cdpf.o: bitutils.hpp cdpf.hpp mpcio.hpp types.hpp coroutine.hpp dpf.hpp
+cdpf.o: prg.hpp aes.hpp cdpf.tcc

+ 3 - 0
bitutils.hpp

@@ -58,6 +58,9 @@ inline __m128i set_lsb(const __m128i & block, const bool val = true)
     return _mm_or_si128(clear_lsb(block, 0b01), lsb128_mask[val ? 0b01 : 0b00]);
 }
 
+// The following can probably be improved by someone who knows the SIMD
+// instruction sets better than I do.
+
 // Return the parity of the number of bits set in block; that is, 1 if
 // there are an odd number of bits set in block; 0 if even
 inline uint8_t parity(const __m128i & block)

+ 19 - 3
cdpf.cpp

@@ -3,6 +3,9 @@
 #include "cdpf.hpp"
 
 // Generate a pair of CDPFs with the given target value
+//
+// Cost:
+// 4*VALUE_BITS - 28 = 228 local AES operations
 std::tuple<CDPF,CDPF> CDPF::generate(value_t target, size_t &aes_ops)
 {
     CDPF dpf0, dpf1;
@@ -129,6 +132,9 @@ std::tuple<CDPF,CDPF> CDPF::generate(value_t target, size_t &aes_ops)
 }
 
 // Generate a pair of CDPFs with a random target value
+//
+// Cost:
+// 4*VALUE_BITS - 28 = 228 local AES operations
 std::tuple<CDPF,CDPF> CDPF::generate(size_t &aes_ops)
 {
     value_t target;
@@ -162,6 +168,10 @@ DPFnode CDPF::leaf(value_t input, size_t &aes_ops) const
 // equal to" just by adding the greater and equal outputs together.
 // Note also that you can compare two RegAS values A and B by
 // passing A-B here.
+//
+// Cost:
+// 1 word sent in 1 message
+// 3*VALUE_BITS - 22 = 170 local AES operations
 std::tuple<RegBS,RegBS,RegBS> CDPF::compare(MPCTIO &tio, yield_t &yield,
     RegAS x, size_t &aes_ops)
 {
@@ -169,7 +179,7 @@ std::tuple<RegBS,RegBS,RegBS> CDPF::compare(MPCTIO &tio, yield_t &yield,
     // The server does nothing in this protocol
     if (tio.player() < 2) {
         RegAS S_share = as_target - x;
-        tio.iostream_peer() << x;
+        tio.iostream_peer() << S_share;
         yield();
         RegAS peer_S_share;
         tio.iostream_peer() >> peer_S_share;
@@ -188,6 +198,9 @@ std::tuple<RegBS,RegBS,RegBS> CDPF::compare(MPCTIO &tio, yield_t &yield,
 // You can call this version directly if you already have S = target-x
 // reconstructed.  This routine is entirely local; no communication
 // is needed.
+//
+// Cost:
+// 3*VALUE_BITS - 22 = 170 local AES operations
 std::tuple<RegBS,RegBS,RegBS> CDPF::compare(value_t S, size_t &aes_ops)
 {
     RegBS gt, eq;
@@ -269,12 +282,15 @@ std::tuple<RegBS,RegBS,RegBS> CDPF::compare(value_t S, size_t &aes_ops)
     // and all the higher bits into gt.  Also pull out the bits strictly
     // below that for T in Tnode into gt.
 
-    // TODO...
+    nbits_t Spos = S & 0x7f;
+    eq.bshare = bit_at(Snode, Spos);
+    gt ^= parity_above(Snode, Spos);
+    gt ^= parity_below(Tnode, Spos);
 
     // Once we have gt and eq (which cannot both be 1), lt is just 1
     // exactly if they're both 0.
     RegBS lt;
-    lt.bshare = 1 ^ eq.bshare ^ gt.bshare;
+    lt.bshare = whichhalf ^ eq.bshare ^ gt.bshare;
 
     return std::make_tuple(lt, eq, gt);
 }

+ 22 - 9
cdpf.hpp

@@ -65,18 +65,18 @@
 // of the DPF.
 //
 // So at the end, we've computed a bit sharing of [x>0] with local
-// computation linear in the depth of the DPF (concretely, fewer than
-// 200 AES operations), and only a *single word* of communication in
-// each direction (exchanging the target{i}-x{i} values).  Of course,
-// this assumes you have one pair of these DPFs lying around, and you
-// have to use a fresh pair with a fresh random target value for each
+// computation linear in the depth of the DPF (concretely, 170 AES
+// operations), and only a *single word* of communication in each
+// direction (exchanging the target{i}-x{i} values).  Of course, this
+// assumes you have one pair of these DPFs lying around, and you have to
+// use a fresh pair with a fresh random target value for each
 // comparison, since revealing target-x for two different x's but the
 // same target leaks the difference of the x's. But in the 3-party
 // setting (or even the 2+1-party setting), you can just have the server
-// precompute a bunch of these pairs in advance, and hand bunches of the
-// first item in each pair to player 0 and the second item in each pair
-// to player 1, at preprocessing time (a single message from the server
-// to each of player 0 and player 1), and these DPFs are very fast to
+// at preprocessing time precompute a bunch of these pairs in advance,
+// and hand bunches of the first item in each pair to player 0 and the
+// second item in each pair to player 1 (a single message from the
+// server to each of player 0 and player 1). These DPFs are very fast to
 // compute, and very small (< 1KB each) to transmit and store.
 
 // See also dpf.hpp for the differences between these DPFs and the ones
@@ -95,9 +95,15 @@ struct CDPF : public DPF {
     DPFnode leaf_cwr;
 
     // Generate a pair of CDPFs with the given target value
+    //
+    // Cost:
+    // 4*VALUE_BITS - 28 = 228 local AES operations
     static std::tuple<CDPF,CDPF> generate(value_t target, size_t &aes_ops);
 
     // Generate a pair of CDPFs with a random target value
+    //
+    // Cost:
+    // 4*VALUE_BITS - 28 = 228 local AES operations
     static std::tuple<CDPF,CDPF> generate(size_t &aes_ops);
 
     // Descend from the parent of a leaf node to the leaf node
@@ -117,12 +123,19 @@ struct CDPF : public DPF {
     // equal to" just by adding the greater and equal outputs together.
     // Note also that you can compare two RegAS values A and B by
     // passing A-B here.
+    //
+    // Cost:
+    // 1 word sent in 1 message
+    // 3*VALUE_BITS - 22 = 170 local AES operations
     std::tuple<RegBS,RegBS,RegBS> compare(MPCTIO &tio, yield_t &yield,
         RegAS x, size_t &aes_ops);
 
     // You can call this version directly if you already have S = target-x
     // reconstructed.  This routine is entirely local; no communication
     // is needed.
+    //
+    // Cost:
+    // 3*VALUE_BITS - 22 = 170 local AES operations
     std::tuple<RegBS,RegBS,RegBS> compare(value_t S, size_t &aes_ops);
 
 };

+ 129 - 0
online.cpp

@@ -547,6 +547,132 @@ static void cdpf_test(MPCIO &mpcio, yield_t &yield,
     pool.join();
 }
 
+static int compare_test_one(MPCTIO &tio, yield_t &yield,
+    value_t target, value_t x)
+{
+    int player = tio.player();
+    size_t &aes_ops = tio.aes_ops();
+    int res = 1;
+    if (player == 2) {
+        // Create a CDPF pair with the given target
+        auto [dpf0, dpf1] = CDPF::generate(target, aes_ops);
+        // Send it and a share of x to the computational parties
+        RegAS x0, x1;
+        x0.randomize();
+        x1.set(x-x0.share());
+        tio.iostream_p0() << dpf0 << x0;
+        tio.iostream_p1() << dpf1 << x1;
+    } else {
+        CDPF dpf;
+        RegAS xsh;
+        tio.iostream_server() >> dpf >> xsh;
+        auto [lt, eq, gt] = dpf.compare(tio, yield, xsh, aes_ops);
+        printf("%016lx %016lx %d %d %d ", target, x, lt.bshare,
+            eq.bshare, gt.bshare);
+        // Check the answer
+        if (player == 1) {
+            tio.iostream_peer() << xsh << lt << eq << gt;
+        } else {
+            RegAS peer_xsh;
+            RegBS peer_lt, peer_eq, peer_gt;
+            tio.iostream_peer() >> peer_xsh >> peer_lt >> peer_eq >> peer_gt;
+            lt ^= peer_lt;
+            eq ^= peer_eq;
+            gt ^= peer_gt;
+            xsh += peer_xsh;
+            int lti = int(lt.bshare);
+            int eqi = int(eq.bshare);
+            int gti = int(gt.bshare);
+            x = xsh.share();
+            printf(": %d %d %d ", lti, eqi, gti);
+            bool signbit = (x >> 63);
+            if (lti + eqi + gti != 1) {
+                printf("INCONSISTENT");
+                res = 0;
+            } else if (x == 0 && eqi) {
+                printf("=");
+            } else if (!signbit && gti) {
+                printf(">");
+            } else if (signbit && lti) {
+                printf("<");
+            } else {
+                printf("INCORRECT");
+                res = 0;
+            }
+        }
+        printf("\n");
+    }
+    return res;
+}
+
+static int compare_test_target(MPCTIO &tio, yield_t &yield,
+    value_t target, value_t x)
+{
+    int res = 1;
+    res &= compare_test_one(tio, yield, target, x);
+    res &= compare_test_one(tio, yield, target, 0);
+    res &= compare_test_one(tio, yield, target, 1);
+    res &= compare_test_one(tio, yield, target, 15);
+    res &= compare_test_one(tio, yield, target, 16);
+    res &= compare_test_one(tio, yield, target, 17);
+    res &= compare_test_one(tio, yield, target, -1);
+    res &= compare_test_one(tio, yield, target, -15);
+    res &= compare_test_one(tio, yield, target, -16);
+    res &= compare_test_one(tio, yield, target, -17);
+    res &= compare_test_one(tio, yield, target, (value_t(1)<<63));
+    res &= compare_test_one(tio, yield, target, (value_t(1)<<63)+1);
+    res &= compare_test_one(tio, yield, target, (value_t(1)<<63)-1);
+    return res;
+}
+
+static void compare_test(MPCIO &mpcio, yield_t &yield,
+    const PRACOptions &opts, char **args)
+{
+    value_t target, x;
+    arc4random_buf(&target, sizeof(target));
+    arc4random_buf(&x, sizeof(x));
+
+    if (*args) {
+        target = strtoull(*args, NULL, 16);
+        ++args;
+    }
+    if (*args) {
+        x = strtoull(*args, NULL, 16);
+        ++args;
+    }
+
+    int num_threads = opts.num_threads;
+    boost::asio::thread_pool pool(num_threads);
+    for (int thread_num = 0; thread_num < num_threads; ++thread_num) {
+        boost::asio::post(pool, [&mpcio, &yield, thread_num, &target, &x] {
+            MPCTIO tio(mpcio, thread_num);
+            int res = 1;
+            res &= compare_test_target(tio, yield, target, x);
+            res &= compare_test_target(tio, yield, 0, x);
+            res &= compare_test_target(tio, yield, 1, x);
+            res &= compare_test_target(tio, yield, 15, x);
+            res &= compare_test_target(tio, yield, 16, x);
+            res &= compare_test_target(tio, yield, 17, x);
+            res &= compare_test_target(tio, yield, -1, x);
+            res &= compare_test_target(tio, yield, -15, x);
+            res &= compare_test_target(tio, yield, -16, x);
+            res &= compare_test_target(tio, yield, -17, x);
+            res &= compare_test_target(tio, yield, (value_t(1)<<63), x);
+            res &= compare_test_target(tio, yield, (value_t(1)<<63)+1, x);
+            res &= compare_test_target(tio, yield, (value_t(1)<<63)-1, x);
+            tio.send();
+            if (tio.player() == 0) {
+                if (res == 1) {
+                    printf("All tests passed!\n");
+                } else {
+                    printf("TEST FAILURES\n");
+                }
+            }
+        });
+    }
+    pool.join();
+}
+
 void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
 {
     // Run everything inside a coroutine so that simple tests don't have
@@ -586,6 +712,9 @@ void online_main(MPCIO &mpcio, const PRACOptions &opts, char **args)
             } else if (!strcmp(*args, "cdpftest")) {
                 ++args;
                 cdpf_test(mpcio, yield, opts, args);
+            } else if (!strcmp(*args, "cmptest")) {
+                ++args;
+                compare_test(mpcio, yield, opts, args);
             } else {
                 std::cerr << "Unknown mode " << *args << "\n";
             }