Ian Goldberg пре 1 година
родитељ
комит
706253a51f
3 измењених фајлова са 84 додато и 10 уклоњено
  1. 46 0
      mpcops.cpp
  2. 11 0
      mpcops.hpp
  3. 27 10
      online.cpp

+ 46 - 0
mpcops.cpp

@@ -135,6 +135,52 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
     z.ashare = (z.ashare + x.ashare) & mask;
 }
 
+// P0 and P1 hold bit shares f0 and f1 of the single bit f, and XOR
+// shares of the values x and y; compute XOR shares of z, where z = x if
+// f=0 and z = y if f=1.  x, y, and z are each at most nbits bits long.
+//
+// Cost:
+// 2 words sent in 1 message
+// consumes 1 SelectTriple
+void mpc_select(MPCTIO &tio, yield_t &yield,
+    RegXS &z, RegBS f, RegXS x, RegXS y,
+    nbits_t nbits)
+{
+    const value_t mask = MASKBITS(nbits);
+    size_t nbytes = BITBYTES(nbits);
+    // Sign-extend f (so 0 -> 0000...0; 1 -> 1111...1)
+    value_t fext = (-value_t(f.bshare)) & mask;
+
+    // Compute XOR shares of f & (x ^ y)
+    auto [X, Y, Z] = tio.valselecttriple(yield);
+
+    bit_t blind_f = f.bshare ^ X;
+    value_t d = (x.xshare ^ y.xshare) & mask;
+    value_t blind_d = (d ^ Y) & mask;
+
+    // Send the blinded values
+    tio.queue_peer(&blind_f, sizeof(blind_f));
+    tio.queue_peer(&blind_d, nbytes);
+
+    yield();
+
+    // Read the peer's values
+    bit_t peer_blind_f = 0;
+    value_t peer_blind_d;
+    tio.recv_peer(&peer_blind_f, sizeof(peer_blind_f));
+    peer_blind_f &= 1;
+    tio.recv_peer(&peer_blind_d, nbytes);
+    peer_blind_d &= mask;
+
+    // Compute our share of f ? x : y = (f * (x ^ y))^x
+    value_t peer_blind_fext = -value_t(peer_blind_f);
+    value_t zshare =
+            (fext & peer_blind_d) ^ (Y & peer_blind_fext) ^
+            (fext & d) ^ (Z ^ x.xshare);
+
+    z.set(zshare & mask);
+}
+
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
 // shares of the values x and y. Obliviously swap x and y; that is,
 // replace x and y with new additive sharings of x and y respectively

+ 11 - 0
mpcops.hpp

@@ -65,6 +65,17 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
     RegAS &z, RegBS f, RegAS x, RegAS y,
     nbits_t nbits = VALUE_BITS);
 
+// P0 and P1 hold bit shares f0 and f1 of the single bit f, and XOR
+// shares of the values x and y; compute XOR shares of z, where z = x if
+// f=0 and z = y if f=1.  x, y, and z are each at most nbits bits long.
+//
+// Cost:
+// 2 words sent in 1 message
+// consumes 1 SelectTriple
+void mpc_select(MPCTIO &tio, yield_t &yield,
+    RegXS &z, RegBS f, RegXS x, RegXS y,
+    nbits_t nbits = VALUE_BITS);
+
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
 // shares of the values x and y. Obliviously swap x and y; that is,
 // replace x and y with new additive sharings of x and y respectively

+ 27 - 10
online.cpp

@@ -17,12 +17,14 @@ static void online_test(MPCIO &mpcio,
         nbits = atoi(*args);
     }
 
-    size_t memsize = 9;
+    size_t as_memsize = 9;
+    size_t xs_memsize = 3;
 
     MPCTIO tio(mpcio, 0);
     bool is_server = (mpcio.player == 2);
 
-    RegAS *A = new RegAS[memsize];
+    RegAS *A = new RegAS[as_memsize];
+    RegXS *AX = new RegXS[xs_memsize];
     value_t V;
     RegBS F0, F1;
     RegXS X;
@@ -36,8 +38,11 @@ static void online_test(MPCIO &mpcio,
         A[6].randomize();
         A[7].randomize();
         X.randomize();
+        AX[0].randomize();
+        AX[1].randomize();
         arc4random_buf(&V, sizeof(V));
-        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
+        printf("A:\n"); for (size_t i=0; i<as_memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
+        printf("AX:\n"); for (size_t i=0; i<xs_memsize; ++i) printf("%3lu: %016lX\n", i, AX[i].xshare);
         printf("V  : %016lX\n", V);
         printf("F0 : %01X\n", F0.bshare);
         printf("F1 : %01X\n", F1.bshare);
@@ -64,38 +69,49 @@ static void online_test(MPCIO &mpcio,
         [&tio, &A, &X, nbits](yield_t &yield) {
             mpc_xs_to_as(tio, yield, A[8], X, nbits);
         });
+    coroutines.emplace_back(
+        [&tio, &AX, &F0, nbits](yield_t &yield) {
+            mpc_select(tio, yield, AX[2], F0, AX[0], AX[1], nbits);
+        });
     run_coroutines(tio, coroutines);
     if (!is_server) {
         printf("\n");
-        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
+        printf("A:\n"); for (size_t i=0; i<as_memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
+        printf("AX:\n"); for (size_t i=0; i<xs_memsize; ++i) printf("%3lu: %016lX\n", i, AX[i].xshare);
     }
 
     // Check the answers
     if (mpcio.player == 1) {
-        tio.queue_peer(A, memsize*sizeof(RegAS));
+        tio.queue_peer(A, as_memsize*sizeof(RegAS));
+        tio.queue_peer(AX, xs_memsize*sizeof(RegXS));
         tio.queue_peer(&V, sizeof(V));
         tio.queue_peer(&F0, sizeof(RegBS));
         tio.queue_peer(&F1, sizeof(RegBS));
         tio.queue_peer(&X, sizeof(RegXS));
         tio.send();
     } else if (mpcio.player == 0) {
-        RegAS *B = new RegAS[memsize];
+        RegAS *B = new RegAS[as_memsize];
+        RegXS *BAX = new RegXS[xs_memsize];
         RegBS BF0, BF1;
         RegXS BX;
         value_t BV;
-        value_t *S = new value_t[memsize];
+        value_t *S = new value_t[as_memsize];
+        value_t *Y = new value_t[xs_memsize];
         bit_t SF0, SF1;
         value_t SX;
-        tio.recv_peer(B, memsize*sizeof(RegAS));
+        tio.recv_peer(B, as_memsize*sizeof(RegAS));
+        tio.recv_peer(BAX, xs_memsize*sizeof(RegXS));
         tio.recv_peer(&BV, sizeof(BV));
         tio.recv_peer(&BF0, sizeof(RegBS));
         tio.recv_peer(&BF1, sizeof(RegBS));
         tio.recv_peer(&BX, sizeof(RegXS));
-        for(size_t i=0; i<memsize; ++i) S[i] = A[i].ashare+B[i].ashare;
+        for(size_t i=0; i<as_memsize; ++i) S[i] = A[i].ashare+B[i].ashare;
+        for(size_t i=0; i<xs_memsize; ++i) Y[i] = AX[i].xshare^BAX[i].xshare;
         SF0 = F0.bshare ^ BF0.bshare;
         SF1 = F1.bshare ^ BF1.bshare;
         SX = X.xshare ^ BX.xshare;
-        printf("S:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, S[i]);
+        printf("S:\n"); for (size_t i=0; i<as_memsize; ++i) printf("%3lu: %016lX\n", i, S[i]);
+        printf("Y:\n"); for (size_t i=0; i<xs_memsize; ++i) printf("%3lu: %016lX\n", i, Y[i]);
         printf("SF0: %01X\n", SF0);
         printf("SF1: %01X\n", SF1);
         printf("SX : %016lX\n", SX);
@@ -108,6 +124,7 @@ static void online_test(MPCIO &mpcio,
     }
 
     delete[] A;
+    delete[] AX;
 }
 
 static void lamport_test(MPCIO &mpcio,