Browse Source

Make explicit types for additive-shared, XOR-shared, and bit-shared registers

Ian Goldberg 1 year ago
parent
commit
b691a8a367
4 changed files with 210 additions and 80 deletions
  1. 30 33
      mpcops.cpp
  2. 7 12
      mpcops.hpp
  3. 50 16
      online.cpp
  4. 123 19
      types.hpp

+ 30 - 33
mpcops.cpp

@@ -1,10 +1,5 @@
 #include "mpcops.hpp"
 
-// as_ denotes additive shares
-// xs_ denotes xor shares
-// bs_ denotes a share of a single bit (which is effectively both an xor
-//     share and an additive share mod 2)
-
 // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
 // (shares are y0 and y1); compute additive shares of z = x*y =
 // (x0+x1)*(y0+y1). x, y, and z are each at most nbits bits long.
@@ -13,14 +8,14 @@
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_mul(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t as_x, value_t as_y,
+    RegAS &z, RegAS x, RegAS y,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
-    // Compute as_z to be an additive share of (x0*y1+y0*x1)
-    mpc_cross(tio, yield, as_z, as_x, as_y, nbits);
+    // Compute z to be an additive share of (x0*y1+y0*x1)
+    mpc_cross(tio, yield, z, x, y, nbits);
     // Add x0*y0 (the peer will add x1*y1)
-    as_z = (as_z + as_x * as_y) & mask;
+    z.ashare = (z.ashare + x.ashare * y.ashare) & mask;
 }
 
 // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
@@ -31,7 +26,7 @@ void mpc_mul(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_cross(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t as_x, value_t as_y,
+    RegAS &z, RegAS x, RegAS y,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
@@ -39,8 +34,8 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     auto [X, Y, Z] = tio.triple();
 
     // Send x+X and y+Y
-    value_t blind_x = (as_x + X) & mask;
-    value_t blind_y = (as_y + Y) & mask;
+    value_t blind_x = (x.ashare + X) & mask;
+    value_t blind_y = (y.ashare + Y) & mask;
 
     tio.queue_peer(&blind_x, nbytes);
     tio.queue_peer(&blind_y, nbytes);
@@ -52,7 +47,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
     tio.recv_peer(&peer_blind_x, nbytes);
     tio.recv_peer(&peer_blind_y, nbytes);
 
-    as_z = ((as_x * peer_blind_y) - (Y * peer_blind_x) + Z) & mask;
+    z.ashare = ((x.ashare * peer_blind_y) - (Y * peer_blind_x) + Z) & mask;
 }
 
 // P0 holds the (complete) value x, P1 holds the (complete) value y;
@@ -64,7 +59,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
 // 1 word sent in 1 message
 // consumes 1 HalfTriple
 void mpc_valuemul(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t x,
+    RegAS &z, value_t x,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
@@ -83,9 +78,9 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
     tio.recv_peer(&peer_blind_y, nbytes);
 
     if (tio.player() == 0) {
-        as_z = ((x * peer_blind_y) + Z) & mask;
+        z.ashare = ((x * peer_blind_y) + Z) & mask;
     } else if (tio.player() == 1) {
-        as_z = ((-X * peer_blind_y) + Z) & mask;
+        z.ashare = ((-X * peer_blind_y) + Z) & mask;
     }
 }
 
@@ -98,17 +93,19 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_flagmult(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, bit_t bs_f, value_t as_y,
+    RegAS &z, RegBS f, RegAS y,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
 
     // Compute additive shares of [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0
-    value_t bs_fval = value_t(bs_f);
-    mpc_cross(tio, yield, as_z, (1-2*bs_fval)*as_y, bs_fval, nbits);
+    value_t bs_fval = value_t(f.bshare);
+    RegAS fval;
+    fval.ashare = bs_fval;
+    mpc_cross(tio, yield, z, y*(1-2*bs_fval), fval, nbits);
 
     // Add f0*y0 (and the peer will add f1*y1)
-    as_z = (as_z + bs_fval*as_y) & mask;
+    z.ashare = (z.ashare + bs_fval*y.ashare) & mask;
 
     // Now the shares add up to:
     // [(1-2*f0)*y0]*f1 + [(1-2*f1)*y1]*f0 + f0*y0 + f1*y1
@@ -125,14 +122,14 @@ void mpc_flagmult(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_select(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
+    RegAS &z, RegBS f, RegAS x, RegAS y,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
 
     // The desired result is z = x + f * (y-x)
-    mpc_flagmult(tio, yield, as_z, bs_f, as_y-as_x, nbits);
-    as_z = (as_z + as_x) & mask;
+    mpc_flagmult(tio, yield, z, f, y-x, nbits);
+    z.ashare = (z.ashare + x.ashare) & mask;
 }
 
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
@@ -145,17 +142,17 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_oswap(MPCTIO &tio, yield_t &yield,
-    value_t &as_x, value_t &as_y, bit_t bs_f,
+    RegAS &x, RegAS &y, RegBS f,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
 
     // Let s = f*(y-x).  Then the desired result is
     // x <- x + s, y <- y - s.
-    value_t as_s;
-    mpc_flagmult(tio, yield, as_s, bs_f, as_y-as_x, nbits);
-    as_x = (as_x + as_s) & mask;
-    as_y = (as_y - as_s) & mask;
+    RegAS s;
+    mpc_flagmult(tio, yield, s, f, y-x, nbits);
+    x.ashare = (x.ashare + s.ashare) & mask;
+    y.ashare = (y.ashare - s.ashare) & mask;
 }
 
 // P0 and P1 hold XOR shares of x. Compute additive shares of the same
@@ -165,7 +162,7 @@ void mpc_oswap(MPCTIO &tio, yield_t &yield,
 // nbits-1 words sent in 1 message
 // consumes nbits-1 HalfTriples
 void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
-    value_t &as_x, value_t xs_x,
+    RegAS &as_x, RegXS xs_x,
     nbits_t nbits)
 {
     const value_t mask = MASKBITS(nbits);
@@ -188,18 +185,18 @@ void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
     // message, then yield, so that all of their messages get sent at
     // once, then each will read their results.
 
-    value_t as_bitand[nbits-1];
+    RegAS as_bitand[nbits-1];
     std::vector<coro_t> coroutines;
     for (nbits_t i=0; i<nbits-1; ++i) {
         coroutines.emplace_back(
             [&](yield_t &yield) {
-                mpc_valuemul(tio, yield, as_bitand[i], (xs_x>>i)&1, nbits);
+                mpc_valuemul(tio, yield, as_bitand[i], (xs_x.xshare>>i)&1, nbits);
             });
     }
     run_coroutines(yield, coroutines);
     value_t as_C = 0;
     for (nbits_t i=0; i<nbits-1; ++i) {
-        as_C += (as_bitand[i]<<(i+1));
+        as_C += (as_bitand[i].ashare<<(i+1));
     }
-    as_x = (xs_x - as_C) & mask;
+    as_x.ashare = (xs_x.xshare - as_C) & mask;
 }

+ 7 - 12
mpcops.hpp

@@ -5,11 +5,6 @@
 #include "mpcio.hpp"
 #include "coroutine.hpp"
 
-// as_ denotes additive shares
-// xs_ denotes xor shares
-// bs_ denotes a share of a single bit (which is effectively both an xor
-//     share and an additive share mod 2)
-
 // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
 // (shares are y0 and y1); compute additive shares of z = x*y =
 // (x0+x1)*(y0+y1). x, y, and z are each at most nbits bits long.
@@ -18,7 +13,7 @@
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_mul(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t as_x, value_t as_y,
+    RegAS &z, RegAS x, RegAS y,
     nbits_t nbits = VALUE_BITS);
 
 // P0 and P1 both hold additive shares of x (shares are x0 and x1) and y
@@ -29,7 +24,7 @@ void mpc_mul(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_cross(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t as_x, value_t as_y,
+    RegAS &z, RegAS x, RegAS y,
     nbits_t nbits = VALUE_BITS);
 
 // P0 holds the (complete) value x, P1 holds the (complete) value y;
@@ -41,7 +36,7 @@ void mpc_cross(MPCTIO &tio, yield_t &yield,
 // 1 word sent in 1 message
 // consumes 1 HalfTriple
 void mpc_valuemul(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, value_t x,
+    RegAS &z, value_t x,
     nbits_t nbits = VALUE_BITS);
 
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
@@ -53,7 +48,7 @@ void mpc_valuemul(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_flagmult(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, bit_t bs_f, value_t as_y,
+    RegAS &z, RegBS f, RegAS y,
     nbits_t nbits = VALUE_BITS);
 
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
@@ -65,7 +60,7 @@ void mpc_flagmult(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_select(MPCTIO &tio, yield_t &yield,
-    value_t &as_z, bit_t bs_f, value_t as_x, value_t as_y,
+    RegAS &z, RegBS f, RegAS x, RegAS y,
     nbits_t nbits = VALUE_BITS);
 
 // P0 and P1 hold bit shares f0 and f1 of the single bit f, and additive
@@ -78,7 +73,7 @@ void mpc_select(MPCTIO &tio, yield_t &yield,
 // 2 words sent in 1 message
 // consumes 1 MultTriple
 void mpc_oswap(MPCTIO &tio, yield_t &yield,
-    value_t &as_x, value_t &as_y, bit_t bs_f,
+    RegAS &x, RegAS &y, RegBS f,
     nbits_t nbits = VALUE_BITS);
 
 // P0 and P1 hold XOR shares of x. Compute additive shares of the same
@@ -88,7 +83,7 @@ void mpc_oswap(MPCTIO &tio, yield_t &yield,
 // nbits-1 words sent in 1 message
 // consumes nbits-1 HalfTriples
 void mpc_xs_to_as(MPCTIO &tio, yield_t &yield,
-    value_t &as_x, value_t xs_x,
+    RegAS &as_x, RegXS xs_x,
     nbits_t nbits = VALUE_BITS);
 
 #endif

+ 50 - 16
online.cpp

@@ -12,18 +12,31 @@ static void online_test(MPCIO &mpcio, int num_threads, char **args)
         nbits = atoi(*args);
     }
 
-    size_t memsize = 13;
+    size_t memsize = 9;
 
     MPCTIO tio(mpcio, 0);
     bool is_server = (mpcio.player == 2);
 
-    value_t *A = new value_t[memsize];
+    RegAS *A = new RegAS[memsize];
+    value_t V;
+    RegBS F0, F1;
+    RegXS X;
 
     if (!is_server) {
-        arc4random_buf(A, memsize*sizeof(value_t));
-        A[5] &= 1;
-        A[8] &= 1;
-        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+        A[0].randomize();
+        A[1].randomize();
+        F0.randomize();
+        A[4].randomize();
+        F1.randomize();
+        A[6].randomize();
+        A[7].randomize();
+        X.randomize();
+        arc4random_buf(&V, sizeof(V));
+        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
+        printf("V  : %016lX\n", V);
+        printf("F0 : %01X\n", F0.bshare);
+        printf("F1 : %01X\n", F1.bshare);
+        printf("X  : %016lX\n", X.xshare);
     }
     std::vector<coro_t> coroutines;
     coroutines.emplace_back(
@@ -32,38 +45,59 @@ static void online_test(MPCIO &mpcio, int num_threads, char **args)
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_valuemul(tio, yield, A[4], A[3], nbits);
+            mpc_valuemul(tio, yield, A[3], V, nbits);
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_flagmult(tio, yield, A[7], A[5], A[6], nbits);
+            mpc_flagmult(tio, yield, A[5], F0, A[4], nbits);
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_oswap(tio, yield, A[9], A[10], A[8], nbits);
+            mpc_oswap(tio, yield, A[6], A[7], F1, nbits);
         });
     coroutines.emplace_back(
         [&](yield_t &yield) {
-            mpc_xs_to_as(tio, yield, A[12], A[11], nbits);
+            mpc_xs_to_as(tio, yield, A[8], X, nbits);
         });
     run_coroutines(tio, coroutines);
     if (!is_server) {
         printf("\n");
-        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i]);
+        printf("A:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, A[i].ashare);
     }
 
     // Check the answers
     if (mpcio.player == 1) {
-        tio.queue_peer(A, memsize*sizeof(value_t));
+        tio.queue_peer(A, memsize*sizeof(RegAS));
+        tio.queue_peer(&V, sizeof(V));
+        tio.queue_peer(&F0, sizeof(RegBS));
+        tio.queue_peer(&F1, sizeof(RegBS));
+        tio.queue_peer(&X, sizeof(RegXS));
         tio.send();
     } else if (mpcio.player == 0) {
-        value_t *B = new value_t[memsize];
+        RegAS *B = new RegAS[memsize];
+        RegBS BF0, BF1;
+        RegXS BX;
+        value_t BV;
         value_t *S = new value_t[memsize];
-        tio.recv_peer(B, memsize*sizeof(value_t));
-        for(size_t i=0; i<memsize; ++i) S[i] = A[i]+B[i];
+        bit_t SF0, SF1;
+        value_t SX;
+        tio.recv_peer(B, memsize*sizeof(RegAS));
+        tio.recv_peer(&BV, sizeof(BV));
+        tio.recv_peer(&BF0, sizeof(RegBS));
+        tio.recv_peer(&BF1, sizeof(RegBS));
+        tio.recv_peer(&BX, sizeof(RegXS));
+        for(size_t i=0; i<memsize; ++i) S[i] = A[i].ashare+B[i].ashare;
+        SF0 = F0.bshare ^ BF0.bshare;
+        SF1 = F1.bshare ^ BF1.bshare;
+        SX = X.xshare ^ BX.xshare;
         printf("S:\n"); for (size_t i=0; i<memsize; ++i) printf("%3lu: %016lX\n", i, S[i]);
+        printf("SF0: %01X\n", SF0);
+        printf("SF1: %01X\n", SF1);
+        printf("SX : %016lX\n", SX);
         printf("\n%016lx\n", S[0]*S[1]-S[2]);
-        printf("%016lx\n", (A[3]*B[3])-S[4]);
+        printf("%016lx\n", (V*BV)-S[3]);
+        printf("%016lx\n", (SF0*S[4])-S[5]);
+        printf("%016lx\n", S[8]-SX);
         delete[] B;
         delete[] S;
     }

+ 123 - 19
types.hpp

@@ -3,6 +3,7 @@
 
 #include <tuple>
 #include <cstdint>
+#include <bsd/stdlib.h> // arc4random_buf
 
 // The number of bits in an MPC secret-shared memory word
 
@@ -10,7 +11,9 @@
 #define VALUE_BITS 64
 #endif
 
-// Values in MPC secret-shared memory are of this type
+// Values in MPC secret-shared memory are of this type.
+// This is the type of the underlying shared value, not the types of the
+// shares themselves.
 
 #if VALUE_BITS == 64
 using value_t = uint64_t;
@@ -20,6 +23,125 @@ using value_t = uint32_t;
 #error "Unsupported value of VALUE_BITS"
 #endif
 
+// Secret-shared bits are of this type.  Note that it is standards
+// compliant to treat a bool as an unsigned integer type with values 0
+// and 1.
+
+using bit_t = bool;
+
+// Counts of the number of bits in a value are of this type, which must
+// be large enough to store the _value_ VALUE_BITS
+using nbits_t = uint8_t;
+
+// Convert a number of bits to the number of bytes required to store (or
+// more to the point, send) them.
+#define BITBYTES(nbits) (((nbits)+7)>>3)
+
+// A mask of this many bits; the test is to prevent 1<<nbits from
+// overflowing if nbits == VALUE_BITS
+#define MASKBITS(nbits) (((nbits) < VALUE_BITS) ? (value_t(1)<<(nbits))-1 : ~0)
+
+// The type of a register holding an additive share of a value
+struct RegAS {
+    value_t ashare;
+
+    // Set each side's share to a random value nbits bits long
+    inline void randomize(size_t nbits = VALUE_BITS) {
+        value_t mask = MASKBITS(nbits);
+        arc4random_buf(&ashare, sizeof(ashare));
+        ashare &= mask;
+    }
+
+    inline RegAS &operator+=(RegAS &rhs) {
+        this->ashare += rhs.ashare;
+        return *this;
+    }
+
+    inline RegAS operator+(RegAS &rhs) const {
+        RegAS res = *this;
+        res += rhs;
+        return res;
+    }
+
+    inline RegAS &operator-=(RegAS &rhs) {
+        this->ashare -= rhs.ashare;
+        return *this;
+    }
+
+    inline RegAS operator-(RegAS &rhs) const {
+        RegAS res = *this;
+        res -= rhs;
+        return res;
+    }
+
+    inline RegAS &operator*=(value_t rhs) {
+        this->ashare *= rhs;
+        return *this;
+    }
+
+    inline RegAS operator*(value_t rhs) const {
+        RegAS res = *this;
+        res *= rhs;
+        return res;
+    }
+
+    inline RegAS &operator&=(value_t mask) {
+        this->ashare &= mask;
+        return *this;
+    }
+
+    inline RegAS operator&(value_t mask) const {
+        RegAS res = *this;
+        res &= mask;
+        return res;
+    }
+};
+
+// The type of a register holding an XOR share of a value
+struct RegXS {
+    value_t xshare;
+
+    // Set each side's share to a random value nbits bits long
+    inline void randomize(size_t nbits = VALUE_BITS) {
+        value_t mask = MASKBITS(nbits);
+        arc4random_buf(&xshare, sizeof(xshare));
+        xshare &= mask;
+    }
+
+    inline RegXS &operator^=(RegXS &rhs) {
+        this->xshare ^= rhs.xshare;
+        return *this;
+    }
+
+    inline RegXS operator^(RegXS &rhs) const {
+        RegXS res = *this;
+        res ^= rhs;
+        return res;
+    }
+
+    inline RegXS &operator&=(value_t mask) {
+        this->xshare &= mask;
+        return *this;
+    }
+
+    inline RegXS operator&(value_t mask) const {
+        RegXS res = *this;
+        res &= mask;
+        return res;
+    }
+};
+
+// The type of a register holding a bit share
+struct RegBS {
+    bit_t bshare;
+
+    // Set each side's share to a random bit
+    inline void randomize() {
+        arc4random_buf(&bshare, sizeof(bshare));
+        bshare &= 1;
+    }
+};
+
 // The _maximum_ number of bits in an MPC address; the actual size of
 // the memory will typically be set at runtime, but it cannot exceed
 // this value.  It is more efficient (in terms of communication) in some
@@ -43,24 +165,6 @@ using address_t = uint64_t;
 #error "VALUE_BITS must be at least as large as ADDRESS_MAX_BITS"
 #endif
 
-// Secret-shared bits are of this type.  Note that it is standards
-// compliant to treat a bool as an unsigned integer type with values 0
-// and 1.
-
-using bit_t = bool;
-
-// Counts of the number of bits in a value are of this type, which must
-// be large enough to store the _value_ VALUE_BITS
-using nbits_t = uint8_t;
-
-// Convert a number of bits to the number of bytes required to store (or
-// more to the point, send) them.
-#define BITBYTES(nbits) (((nbits)+7)>>3)
-
-// A mask of this many bits; the test is to prevent 1<<nbits from
-// overflowing if nbits == VALUE_BITS
-#define MASKBITS(nbits) (((nbits) < VALUE_BITS) ? (value_t(1)<<(nbits))-1 : ~0)
-
 // A multiplication triple is a triple (X0,Y0,Z0) held by P0 (and
 // correspondingly (X1,Y1,Z1) held by P1), with all values random,
 // but subject to the relation that X0*Y1 + Y0*X1 = Z0+Z1