Browse Source

Constructing DPFs: the flag correction bits

Ian Goldberg 1 year ago
parent
commit
cfb460526a
3 changed files with 88 additions and 34 deletions
  1. 1 1
      preproc.cpp
  2. 85 31
      rdpf.cpp
  3. 2 2
      types.hpp

+ 1 - 1
preproc.cpp

@@ -82,7 +82,7 @@ void preprocessing_comp(MPCIO &mpcio, int num_threads, char **args)
                         coroutines.emplace_back(
                             [&](yield_t &yield) {
                                 RegXS ri;
-                                ri.randomize();
+                                ri.randomize(type);
                                 RDPF rdpf(tio, yield, ri, type);
                             });
                     }

+ 85 - 31
rdpf.cpp

@@ -6,6 +6,7 @@
 #include "aes.hpp"
 #include "prg.hpp"
 
+#ifdef DPF_DEBUG
 static void dump_node(DPFnode node, const char *label = NULL)
 {
     if (label) printf("%s: ", label);
@@ -20,6 +21,7 @@ static void dump_level(DPFnode *nodes, size_t num, const char *label = NULL)
     }
     printf("\n");
 }
+#endif
 
 // Construct a DPF with the given (XOR-shared) target location, and
 // of the given depth, to be used for random-access memory reads and
@@ -45,10 +47,15 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
     nextlevel[0] = seed;
 
     // Construct each intermediate level
-    while(level < depth - 1) {
+    while(level < depth) {
         delete[] curlevel;
         curlevel = nextlevel;
-        nextlevel = new DPFnode[1<<(level+1)];
+        // We don't need to store the last level
+        if (level < depth-1) {
+            nextlevel = new DPFnode[1<<(level+1)];
+        } else {
+            nextlevel = NULL;
+        }
         // Invariant: curlevel has 2^level elements; nextlevel has
         // 2^{level+1} elements
 
@@ -58,46 +65,93 @@ RDPF::RDPF(MPCTIO &tio, yield_t &yield,
         size_t curlevel_size = (size_t(1)<<level);
         DPFnode L = _mm_setzero_si128();
         DPFnode R = _mm_setzero_si128();
+        // The server doesn't need to do this computation, but it does
+        // need to execute mpc_reconstruct_choice so that it sends
+        // the AndTriples at the appropriate time.
         if (player < 2) {
             for(size_t i=0;i<curlevel_size;++i) {
-                prgboth(nextlevel[2*i], nextlevel[2*i+1], curlevel[i], aesops);
-                L = _mm_xor_si128(L, nextlevel[2*i]);
-                R = _mm_xor_si128(R, nextlevel[2*i+1]);
+                DPFnode lchild, rchild;
+                prgboth(lchild, rchild, curlevel[i], aesops);
+                L = _mm_xor_si128(L, lchild);
+                R = _mm_xor_si128(R, rchild);
+                if (nextlevel) {
+                    nextlevel[2*i] = lchild;
+                    nextlevel[2*i+1] = rchild;
+                }
             }
         }
+        // If we're going left (bs_choice = 0), we want the correction
+        // work to be the XOR of our right side and our peer's right
+        // side; if bs_choice = 1, it should be the XOR or our left side
+        // and our peer's left side.
+
+        // We have to ensure that the flag bits (the lsb) of the side
+        // that will end up the same be of course the same, but also
+        // that the flag bits (the lsb) of the side that will end up
+        // different _must_ be different.  That is, it's not enough for
+        // the nodes of the child selected by choice to be different as
+        // 128-bit values; they also have to be different in their lsb.
+
+        // Note that the XOR of our left and right child before and
+        // after applying the correction word won't change, since the
+        // correction word is applied to either both children or
+        // neither, depending on the value of the parent's flag. So in
+        // particular, the XOR of the flag bits won't change, and if our
+        // children's flag's XOR equals our peer's children's flag's
+        // XOR, then we won't have different flag bits even for the
+        // children that have different 128-bit values.
+
+        // So we compute our_parity = lsb(L^R)^player, and we XOR that
+        // into the R value in the correction word computation.  At the
+        // same time, we exchange these parity values to compute the
+        // combined parity, which we store in the DPF.  Then when the
+        // DPF is evaluated, if the parent's flag is set, not only apply
+        // the correction work to both children, but also apply the
+        // (combined) parity bit to just the right child.  Then for
+        // unequal nodes (where the flag bit is different), exactly one
+        // of the four children (two for P0 and two for P1) will have
+        // the parity bit applied, which will set the XOR of the lsb of
+        // those four nodes to just L0^R0^L1^R1^our_parity^peer_parity
+        // = 1 because everything cancels out except player (for which
+        // one player is 0 and the other is 1).
+
+        bool our_parity_bit = get_lsb(_mm_xor_si128(L,R)) ^ !!player;
+        DPFnode our_parity = lsb128_mask[our_parity_bit];
+
         DPFnode CW;
-        mpc_reconstruct_choice(tio, yield, CW, bs_choice, R, L);
+        bool peer_parity_bit;
+        // Exchange the parities and do mpc_reconstruct_choice at the
+        // same time (bundled into the same rounds)
+        std::vector<coro_t> coroutines;
+        coroutines.emplace_back(
+            [&](yield_t &yield) {
+                tio.queue_peer(&our_parity_bit, 1);
+                yield();
+                tio.recv_peer(&peer_parity_bit, 1);
+            });
+        coroutines.emplace_back(
+            [&](yield_t &yield) {
+                mpc_reconstruct_choice(tio, yield, CW, bs_choice,
+                    _mm_xor_si128(R,our_parity), L);
+            });
+        run_coroutines(yield, coroutines);
+        bool parity_bit = our_parity_bit ^ peer_parity_bit;
+        cfbits |= (size_t(parity_bit)<<level);
+        DPFnode CWR = _mm_xor_si128(CW,lsb128_mask[parity_bit]);
         if (player < 2) {
-            for(size_t i=0;i<curlevel_size;++i) {
-                bool flag = get_lsb(curlevel[i]);
-                nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
-                nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CW, flag);
+            if (nextlevel) {
+                for(size_t i=0;i<curlevel_size;++i) {
+                    bool flag = get_lsb(curlevel[i]);
+                    nextlevel[2*i] = xor_if(nextlevel[2*i], CW, flag);
+                    nextlevel[2*i+1] = xor_if(nextlevel[2*i+1], CWR, flag);
+                }
             }
-            printf("%d\n", bs_choice.bshare);
-            dump_level(nextlevel, curlevel_size<<1);
             cw.push_back(CW);
         }
 
         ++level;
     }
 
-    // We don't need to store the last level
-
-    AESkey prgkey;
-    __m128i key = _mm_set_epi64x(314159265, 271828182);
-    AES_128_Key_Expansion(prgkey, key);
-    __m128i left, right;
-    AES_ECB_encrypt(left, set_lsb(seed, 0), prgkey, aesops);
-    AES_ECB_encrypt(right, set_lsb(seed, 1), prgkey, aesops);
-
-    __m128i nleft, nright, oleft, oright;
-    prg(nleft, seed, 0, aesops);
-    prg(nright, seed, 1, aesops);
-    prgboth(oleft, oright, seed, aesops);
-    printf("left : "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&left)[15-i]); } printf("\n");
-    printf("nleft: "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&nleft)[15-i]); } printf("\n");
-    printf("oleft: "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&oleft)[15-i]); } printf("\n");
-    printf("rght : "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&right)[15-i]); } printf("\n");
-    printf("nrght: "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&nright)[15-i]); } printf("\n");
-    printf("orght: "); for(int i=0;i<16;++i) { printf("%02x", ((unsigned char *)&oright)[15-i]); } printf("\n");
+    delete[] curlevel;
+    delete[] nextlevel;
 }

+ 2 - 2
types.hpp

@@ -143,9 +143,9 @@ struct RegXS {
     }
 
     // Extract a bit share of bit bitnum of the XOR-shared register
-    inline RegBS bit(bit_t bitnum) const {
+    inline RegBS bit(nbits_t bitnum) const {
         RegBS bs;
-        bs.bshare = !!(xshare & (size_t(1)<<bitnum));
+        bs.bshare = !!(xshare & (value_t(1)<<bitnum));
         return bs;
     }
 };