Browse Source

Abstract verified encryption into a gadget

Ian Goldberg 4 years ago
parent
commit
801df4ad64
1 changed files with 133 additions and 77 deletions
  1. 133 77
      verifenc.cpp

+ 133 - 77
verifenc.cpp

@@ -8,14 +8,139 @@
 using namespace libsnark;
 using namespace std;
 
-int main(int argc, char **argv)
-{
-  enum {
+typedef enum {
     MODE_NONE,
     MODE_PRIV,
     MODE_PUB,
     MODE_CONST
-  } mode = MODE_NONE;
+} Mode;
+
+template<typename FieldT>
+class verified_encryption_gadget : public gadget<FieldT> {
+private:
+  const size_t numbits;
+  FieldT curve_b, Gx, Gy, Hx, Hy;
+  pb_variable<FieldT> r;
+  pb_variable<FieldT> xsquared, ysquared;
+  pb_variable_array<FieldT> kbits, rbits;
+  pb_variable<FieldT> elgx, elgy;
+  pb_linear_combination<FieldT> x;
+  vector<packing_gadget<FieldT> > packers;
+  vector<ec_constant_scalarmul_vec_gadget<FieldT> > constmuls;
+  vector<ec_scalarmul_vec_gadget<FieldT> > muls;
+  vector<ec_add_gadget<FieldT> > adders;
+
+public:
+  const Mode mode;
+  const pb_variable<FieldT> C1x, C1y, C2x, C2y, Kx, Ky;
+  const pb_variable<FieldT> Px, Py, k, s, y;
+  const pb_variable_array<FieldT> Ptable;
+
+  verified_encryption_gadget(protoboard<FieldT> &pb,
+              Mode mode,
+              const pb_variable<FieldT> &C1x,
+              const pb_variable<FieldT> &C1y,
+              const pb_variable<FieldT> &C2x,
+              const pb_variable<FieldT> &C2y,
+              const pb_variable<FieldT> &Kx,
+              const pb_variable<FieldT> &Ky,
+              const pb_variable<FieldT> &Px,
+              const pb_variable<FieldT> &Py,
+              const pb_variable_array<FieldT> &Ptable,
+              const pb_variable<FieldT> &k,
+              const pb_variable<FieldT> &s,
+              const pb_variable<FieldT> &y) :
+    gadget<FieldT>(pb, "verified_encryption_gadget"),
+    // Curve parameters and generators
+    numbits(FieldT::num_bits),
+    curve_b("7950939520449436327800262930799465135910802758673292356620796789196167463969"),
+    Gx(0), Gy("11977228949870389393715360594190192321220966033310912010610740966317727761886"),
+    Hx(1), Hy("21803877843449984883423225223478944275188924769286999517937427649571474907279"),
+    mode(mode), C1x(C1x), C1y(C1y), C2x(C2x), C2y(C2y),
+    Kx(Kx), Ky(Ky), Px(Px), Py(Py), Ptable(Ptable),
+    k(k), s(s), y(y)
+  {
+    r.allocate(pb, "r");
+    xsquared.allocate(pb, "xsquared");
+    ysquared.allocate(pb, "ysquared");
+    kbits.allocate(pb, numbits-8, "kbits");
+    rbits.allocate(pb, numbits, "rbits");
+
+    // The unpacking gadgets to turn k and r into bits
+    packers.emplace_back(pb, kbits, k);
+    packers.emplace_back(pb, rbits, r);
+
+    // The El Gamal first component r*G
+    constmuls.emplace_back(pb, C1x, C1y, rbits, Gx, Gy);
+
+    // The El Gamal intermediate value r*P
+    elgx.allocate(pb, "elgx");
+    elgy.allocate(pb, "elgy");
+    if (mode == MODE_CONST) {
+        constmuls.emplace_back(pb, elgx, elgy, rbits, Hx, Hy);
+    } else {
+        muls.emplace_back(pb, elgx, elgy, rbits, Px, Py, Ptable, mode == MODE_PRIV, true);
+    }
+
+    // The El Gamal second component r*P + M
+    x.assign(pb, k * 256 + s);
+    adders.emplace_back(pb, C2x, C2y, elgx, elgy, x, y);
+
+    // The generated public key k*G
+    constmuls.emplace_back(pb, Kx, Ky, kbits, Gx, Gy);
+  }
+
+  void generate_r1cs_constraints()
+  {
+    // Prove (256*k+s,y) is on the curve
+    this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(y, y, ysquared));
+    this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(k * 256 + s, k * 256 + s, xsquared));
+    this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(xsquared - 3, k * 256 + s, ysquared - curve_b));
+
+    for (auto&& gadget : packers) {
+        gadget.generate_r1cs_constraints(true);
+    }
+
+    for (auto&& gadget : constmuls) {
+        gadget.generate_r1cs_constraints();
+    }
+
+    for (auto&& gadget : muls) {
+        gadget.generate_r1cs_constraints();
+    }
+
+    for (auto&& gadget : adders) {
+        gadget.generate_r1cs_constraints();
+    }
+  }
+
+  void generate_r1cs_witness()
+  {
+    this->pb.val(r) = FieldT::random_element();
+    this->pb.val(xsquared) = (this->pb.val(k) * 256 + this->pb.val(s)).squared();
+    this->pb.val(ysquared) = this->pb.val(y).squared();
+    x.evaluate(this->pb);
+    for (auto&& gadget : packers) {
+        gadget.generate_r1cs_witness_from_packed();
+    }
+
+    for (auto&& gadget : constmuls) {
+        gadget.generate_r1cs_witness();
+    }
+
+    for (auto&& gadget : muls) {
+        gadget.generate_r1cs_witness();
+    }
+
+    for (auto&& gadget : adders) {
+        gadget.generate_r1cs_witness();
+    }
+  }
+};
+
+int main(int argc, char **argv)
+{
+  Mode mode = MODE_NONE;
 
   if (argc == 2) {
     if (!strcmp(argv[1], "priv")) {
@@ -53,11 +178,10 @@ int main(int argc, char **argv)
   pb_variable_array<FieldT> Ptable;
   pb_variable<FieldT> k, s, y, r;
 
+  const size_t numbits = FieldT::num_bits;
 
   // Allocate variables
 
-  size_t numbits = FieldT::num_bits;
-
   // Public outputs:
 
   // El Gamal encryption of k under public key P (or H if MODE_CONST)
@@ -89,8 +213,6 @@ int main(int argc, char **argv)
   // s and y are such that M = (256*k+s,y) is a point on the curve
   s.allocate(pb, "s");
   y.allocate(pb, "y");
-  // r is the randomness for the El Gamal encryption
-  r.allocate(pb, "r");
 
   // This sets up the protoboard variables so that the first n of them
   // represent the public input and the rest is private input
@@ -104,56 +226,8 @@ int main(int argc, char **argv)
   }
 
   // Initialize the gadgets
-
-  // Curve parameters and generators
-  FieldT curve_b("7950939520449436327800262930799465135910802758673292356620796789196167463969");
-  FieldT Gx(0), Gy("11977228949870389393715360594190192321220966033310912010610740966317727761886");
-  FieldT Hx(1), Hy("21803877843449984883423225223478944275188924769286999517937427649571474907279");
-
-  // Prove (256*k+s,y) is on the curve
-  pb_variable<FieldT> xsquared, ysquared;
-  xsquared.allocate(pb, "xsquared");
-  ysquared.allocate(pb, "ysquared");
-  pb.add_r1cs_constraint(r1cs_constraint<FieldT>(y, y, ysquared));
-  pb.add_r1cs_constraint(r1cs_constraint<FieldT>(k * 256 + s, k * 256 + s, xsquared));
-  pb.add_r1cs_constraint(r1cs_constraint<FieldT>(xsquared - 3, k * 256 + s, ysquared - curve_b));
-
-  // The unpacking gadgets to turn k and r into bits
-  pb_variable_array<FieldT> kbits, rbits;
-  kbits.allocate(pb, numbits-8, "kbits");
-  rbits.allocate(pb, numbits, "rbits");
-  packing_gadget<FieldT> kpacker(pb, kbits, k);
-  packing_gadget<FieldT> rpacker(pb, rbits, r);
-  kpacker.generate_r1cs_constraints(true);
-  rpacker.generate_r1cs_constraints(true);
-
-  // The El Gamal first component r*G
-  ec_constant_scalarmul_vec_gadget<FieldT> C1gadget(pb, C1x, C1y, rbits, Gx, Gy);
-  C1gadget.generate_r1cs_constraints();
-
-  // The El Gamal intermediate value r*P
-  pb_variable<FieldT> elgx, elgy;
-  elgx.allocate(pb, "elgx");
-  elgy.allocate(pb, "elgy");
-
-  gadget<FieldT> *ElGgadgetp = NULL;
-  if (mode == MODE_CONST) {
-      ElGgadgetp = new ec_constant_scalarmul_vec_gadget<FieldT> (pb, elgx, elgy, rbits, Hx, Hy);
-      (static_cast<ec_constant_scalarmul_vec_gadget<FieldT>*>(ElGgadgetp))->generate_r1cs_constraints();
-  } else {
-      ElGgadgetp = new ec_scalarmul_vec_gadget<FieldT> (pb, elgx, elgy, rbits, Px, Py, Ptable, mode == MODE_PRIV, true);
-      (static_cast<ec_scalarmul_vec_gadget<FieldT>*>(ElGgadgetp))->generate_r1cs_constraints();
-  }
-
-  // The El Gamal second component r*P + M
-  pb_linear_combination<FieldT> x;
-  x.assign(pb, k * 256 + s);
-  ec_add_gadget<FieldT> ElGfinal(pb, C2x, C2y, elgx, elgy, x, y);
-  ElGfinal.generate_r1cs_constraints();
-
-  // The generated public key k*G
-  ec_constant_scalarmul_vec_gadget<FieldT> Kgadget(pb, Kx, Ky, kbits, Gx, Gy);
-  Kgadget.generate_r1cs_constraints();
+  verified_encryption_gadget<FieldT> venc(pb, mode, C1x, C1y, C2x, C2y, Kx, Ky, Px, Py, Ptable, k, s, y);
+  venc.generate_r1cs_constraints();
 
   const r1cs_constraint_system<FieldT> constraint_system = pb.get_constraint_system();
 
@@ -171,26 +245,8 @@ int main(int argc, char **argv)
   pb.val(k) = FieldT("31329510635628557928212225120518124937732397714111203844965919301557399521");
   pb.val(s) = FieldT(1);
   pb.val(y) = FieldT("4364798287654239504994818950156019747851405522689486598132350453516910863367");
-  pb.val(r) = FieldT::random_element();
-
-  pb.val(xsquared) = (pb.val(k) * 256 + pb.val(s)).squared();
-  pb.val(ysquared) = pb.val(y).squared();
-
-  kpacker.generate_r1cs_witness_from_packed();
-  rpacker.generate_r1cs_witness_from_packed();
-
-  C1gadget.generate_r1cs_witness();
-  if (mode == MODE_CONST) {
-      (static_cast<ec_constant_scalarmul_vec_gadget<FieldT>*>(ElGgadgetp))->generate_r1cs_witness();
-  } else {
-      (static_cast<ec_scalarmul_vec_gadget<FieldT>*>(ElGgadgetp))->generate_r1cs_witness();
-  }
-  delete ElGgadgetp;
-
-  x.evaluate(pb);
-  ElGfinal.generate_r1cs_witness();
 
-  Kgadget.generate_r1cs_witness();
+  venc.generate_r1cs_witness();
 
   const r1cs_gg_ppzksnark_proof<default_r1cs_gg_ppzksnark_pp> proof = r1cs_gg_ppzksnark_prover<default_r1cs_gg_ppzksnark_pp>(keypair.pk, pb.primary_input(), pb.auxiliary_input());