Browse Source

Prefetch nodeselecttriples

We prefetch all the nodeselecttriples we'll need for RDPF construction
in a single round, rather than creating them on the fly, saving one
round per level.
Ian Goldberg 1 year ago
parent
commit
2b0fc57cfd
3 changed files with 70 additions and 29 deletions
  1. 57 29
      mpcio.cpp
  2. 10 0
      mpcio.hpp
  3. 3 0
      rdpf.tcc

+ 57 - 29
mpcio.cpp

@@ -460,7 +460,8 @@ MPCTIO::MPCTIO(MPCIO &mpcio, int thread_num, int num_threads) :
 #ifdef VERBOSE_COMMS
         round_num(0),
 #endif
-        last_andtriple_bits_remaining(0)
+        last_andtriple_bits_remaining(0),
+        remaining_nodesselecttriples(0)
 {
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
@@ -726,44 +727,71 @@ MultTriple MPCTIO::andtriple(yield_t &yield)
     return val;
 }
 
-SelectTriple<DPFnode> MPCTIO::nodeselecttriple(yield_t &yield)
+void MPCTIO::request_nodeselecttriples(yield_t &yield, size_t num)
 {
-    SelectTriple<DPFnode> val;
     if (mpcio.player < 2) {
         MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
         if (mpcpio.mode != MODE_ONLINE) {
-            uint8_t Xbyte;
             yield();
-            recv_server(&Xbyte, sizeof(Xbyte));
-            val.X = Xbyte & 1;
-            recv_server(&val.Y, sizeof(val.Y));
-            recv_server(&val.Z, sizeof(val.Z));
+            for (size_t i=0; i<num; ++i) {
+                SelectTriple<DPFnode> v;
+                uint8_t Xbyte;
+                recv_server(&Xbyte, sizeof(Xbyte));
+                v.X = Xbyte & 1;
+                recv_server(&v.Y, sizeof(v.Y));
+                recv_server(&v.Z, sizeof(v.Z));
+                queued_nodeselecttriples.push_back(v);
+            }
+            remaining_nodesselecttriples += num;
         } else {
             std::cerr << "Attempted to read SelectTriple<DPFnode> in online phase\n";
         }
     } else if (mpcio.mode != MODE_ONLINE) {
-        // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
-        // (X0*Y1 ^ Y0*X1) = (Z0^Z1)
-        bit_t X0, X1;
-        DPFnode Y0, Z0, Y1, Z1;
-        X0 = arc4random() & 1;
-        arc4random_buf(&Y0, sizeof(Y0));
-        arc4random_buf(&Z0, sizeof(Z0));
-        X1 = arc4random() & 1;
-        arc4random_buf(&Y1, sizeof(Y1));
-        DPFnode X0ext, X1ext;
-        // Sign-extend X0 and X1 (so that 0 -> 0000...0 and
-        // 1 -> 1111...1)
-        X0ext = if128_mask[X0];
-        X1ext = if128_mask[X1];
-        Z1 = ((X0ext & Y1) ^ (X1ext & Y0)) ^ Z0;
-        queue_p0(&X0, sizeof(X0));
-        queue_p0(&Y0, sizeof(Y0));
-        queue_p0(&Z0, sizeof(Z0));
-        queue_p1(&X1, sizeof(X1));
-        queue_p1(&Y1, sizeof(Y1));
-        queue_p1(&Z1, sizeof(Z1));
+        for (size_t i=0; i<num; ++i) {
+            // Create triples (X0,Y0,Z0),(X1,Y1,Z1) such that
+            // (X0*Y1 ^ Y0*X1) = (Z0^Z1)
+            bit_t X0, X1;
+            DPFnode Y0, Z0, Y1, Z1;
+            X0 = arc4random() & 1;
+            arc4random_buf(&Y0, sizeof(Y0));
+            arc4random_buf(&Z0, sizeof(Z0));
+            X1 = arc4random() & 1;
+            arc4random_buf(&Y1, sizeof(Y1));
+            DPFnode X0ext, X1ext;
+            // Sign-extend X0 and X1 (so that 0 -> 0000...0 and
+            // 1 -> 1111...1)
+            X0ext = if128_mask[X0];
+            X1ext = if128_mask[X1];
+            Z1 = ((X0ext & Y1) ^ (X1ext & Y0)) ^ Z0;
+            queue_p0(&X0, sizeof(X0));
+            queue_p0(&Y0, sizeof(Y0));
+            queue_p0(&Z0, sizeof(Z0));
+            queue_p1(&X1, sizeof(X1));
+            queue_p1(&Y1, sizeof(Y1));
+            queue_p1(&Z1, sizeof(Z1));
+        }
         yield();
+        remaining_nodesselecttriples += num;
+    }
+}
+
+SelectTriple<DPFnode> MPCTIO::nodeselecttriple(yield_t &yield)
+{
+    SelectTriple<DPFnode> val;
+    if (remaining_nodesselecttriples == 0) {
+        request_nodeselecttriples(yield, 1);
+    }
+    if (mpcio.player < 2) {
+        MPCPeerIO &mpcpio = static_cast<MPCPeerIO&>(mpcio);
+        if (mpcpio.mode != MODE_ONLINE) {
+            val = queued_nodeselecttriples.front();
+            queued_nodeselecttriples.pop_front();
+            --remaining_nodesselecttriples;
+        } else {
+            std::cerr << "Attempted to read SelectTriple<DPFnode> in online phase\n";
+        }
+    } else if (mpcio.mode != MODE_ONLINE) {
+        --remaining_nodesselecttriples;
     }
     return val;
 }

+ 10 - 0
mpcio.hpp

@@ -353,6 +353,15 @@ class MPCTIO {
     AndTriple last_andtriple;
     nbits_t last_andtriple_bits_remaining;
 
+    // We allow for prefetching of SelectTriple<DPFnode>s to save one
+    // network round per level when constructing RDPFs
+    std::deque<SelectTriple<DPFnode>> queued_nodeselecttriples;
+    // For P0 and P1, it should always be the case that
+    // remaining_nodesselecttriples equals
+    // queued_nodeselecttriples.size().  P2 does not store anything in
+    // queued_nodeselecttriples, however.
+    size_t remaining_nodesselecttriples;
+
 public:
     MPCTIO(MPCIO &mpcio, int thread_num, int num_threads = 1);
 
@@ -425,6 +434,7 @@ public:
     MultTriple multtriple(yield_t &yield);
     HalfTriple halftriple(yield_t &yield, bool tally=true);
     AndTriple andtriple(yield_t &yield);
+    void request_nodeselecttriples(yield_t &yield, size_t num);
     SelectTriple<DPFnode> nodeselecttriple(yield_t &yield);
     SelectTriple<value_t> valselecttriple(yield_t &yield);
     SelectTriple<bit_t> bitselecttriple(yield_t &yield);

+ 3 - 0
rdpf.tcc

@@ -857,6 +857,9 @@ RDPF<WIDTH>::RDPF(MPCTIO &tio, yield_t &yield,
 
     li.resize(incremental ? depth : 1);
 
+    // Prefetch the right number of nodeselecttriples
+    tio.request_nodeselecttriples(yield, incremental ? 2*depth-1 : depth);
+
     // Construct each intermediate level
     while(level < depth) {
         LeafNode *leaflevel = NULL;