verifenc.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. #include <stdlib.h>
  2. #include <iostream>
  3. #include <fstream>
  4. #include "ecgadget.hpp"
  5. #include "scalarmul.hpp"
  6. using namespace libsnark;
  7. using namespace std;
  8. typedef enum {
  9. MODE_NONE,
  10. MODE_PRIV,
  11. MODE_PUB,
  12. MODE_CONST
  13. } Mode;
  14. // If a is a quadratic residue, set sqrt_a to one of its square roots
  15. // (-sqrt_a will be the other) and return true. Otherwise return false.
  16. template<typename FieldT>
  17. bool sqrt_if_possible(FieldT &sqrt_a, const FieldT &a)
  18. {
  19. // A modification of the Tonelli-Shanks implementation from libff
  20. // to catch the case when you encounter a nonresidue
  21. const FieldT one = FieldT::one();
  22. size_t v = FieldT::s;
  23. FieldT z = FieldT::nqr_to_t;
  24. FieldT w = a^FieldT::t_minus_1_over_2;
  25. FieldT x = a * w;
  26. FieldT b = x * w; // b = a^t
  27. while (b != one)
  28. {
  29. size_t m = 0;
  30. FieldT b2m = b;
  31. while (b2m != one)
  32. {
  33. /* invariant: b2m = b^(2^m) after entering this loop */
  34. b2m = b2m.squared();
  35. m += 1;
  36. }
  37. if (m == v) {
  38. // Not a quadratic residue
  39. return false;
  40. }
  41. int j = v-m-1;
  42. w = z;
  43. while (j > 0)
  44. {
  45. w = w.squared();
  46. --j;
  47. } // w = z^2^(v-m-1)
  48. z = w.squared();
  49. b = b * z;
  50. x = x * w;
  51. v = m;
  52. }
  53. sqrt_a = x;
  54. return true;
  55. }
  56. template<typename FieldT>
  57. class verified_encryption_gadget : public gadget<FieldT> {
  58. private:
  59. const size_t numbits;
  60. FieldT curve_b, Gx, Gy, Hx, Hy;
  61. pb_variable<FieldT> r;
  62. pb_variable<FieldT> xsquared, ysquared;
  63. pb_variable_array<FieldT> kbits, rbits;
  64. pb_variable<FieldT> elgx, elgy;
  65. pb_linear_combination<FieldT> x;
  66. pb_variable<FieldT> s, y;
  67. vector<packing_gadget<FieldT> > packers;
  68. vector<ec_constant_scalarmul_vec_gadget<FieldT> > constmuls;
  69. vector<ec_scalarmul_vec_gadget<FieldT> > muls;
  70. vector<ec_add_gadget<FieldT> > adders;
  71. public:
  72. const Mode mode;
  73. const pb_variable<FieldT> C1x, C1y, C2x, C2y, Kx, Ky;
  74. const pb_variable<FieldT> Px, Py;
  75. const pb_variable_array<FieldT> Ptable;
  76. const pb_variable<FieldT> k;
  77. verified_encryption_gadget(protoboard<FieldT> &pb,
  78. Mode mode,
  79. const pb_variable<FieldT> &C1x,
  80. const pb_variable<FieldT> &C1y,
  81. const pb_variable<FieldT> &C2x,
  82. const pb_variable<FieldT> &C2y,
  83. const pb_variable<FieldT> &Kx,
  84. const pb_variable<FieldT> &Ky,
  85. const pb_variable<FieldT> &Px,
  86. const pb_variable<FieldT> &Py,
  87. const pb_variable_array<FieldT> &Ptable,
  88. const pb_variable<FieldT> &k) :
  89. gadget<FieldT>(pb, "verified_encryption_gadget"),
  90. // Curve parameters and generators
  91. numbits(FieldT::num_bits),
  92. curve_b("7950939520449436327800262930799465135910802758673292356620796789196167463969"),
  93. Gx(0), Gy("11977228949870389393715360594190192321220966033310912010610740966317727761886"),
  94. Hx(1), Hy("21803877843449984883423225223478944275188924769286999517937427649571474907279"),
  95. mode(mode), C1x(C1x), C1y(C1y), C2x(C2x), C2y(C2y),
  96. Kx(Kx), Ky(Ky), Px(Px), Py(Py), Ptable(Ptable), k(k)
  97. {
  98. s.allocate(pb, "s");
  99. y.allocate(pb, "y");
  100. r.allocate(pb, "r");
  101. xsquared.allocate(pb, "xsquared");
  102. ysquared.allocate(pb, "ysquared");
  103. kbits.allocate(pb, numbits-8, "kbits");
  104. rbits.allocate(pb, numbits, "rbits");
  105. // The unpacking gadgets to turn k and r into bits
  106. packers.emplace_back(pb, kbits, k);
  107. packers.emplace_back(pb, rbits, r);
  108. // The El Gamal first component r*G
  109. constmuls.emplace_back(pb, C1x, C1y, rbits, Gx, Gy);
  110. // The El Gamal intermediate value r*P
  111. elgx.allocate(pb, "elgx");
  112. elgy.allocate(pb, "elgy");
  113. if (mode == MODE_CONST) {
  114. constmuls.emplace_back(pb, elgx, elgy, rbits, Hx, Hy);
  115. } else {
  116. muls.emplace_back(pb, elgx, elgy, rbits, Px, Py, Ptable, mode == MODE_PRIV, true);
  117. }
  118. // The El Gamal second component r*P + M
  119. x.assign(pb, k * 256 + s);
  120. adders.emplace_back(pb, C2x, C2y, elgx, elgy, x, y);
  121. // The generated public key k*G
  122. constmuls.emplace_back(pb, Kx, Ky, kbits, Gx, Gy);
  123. }
  124. void generate_r1cs_constraints()
  125. {
  126. // Prove (256*k+s,y) is on the curve
  127. this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(y, y, ysquared));
  128. this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(k * 256 + s, k * 256 + s, xsquared));
  129. this->pb.add_r1cs_constraint(r1cs_constraint<FieldT>(xsquared - 3, k * 256 + s, ysquared - curve_b));
  130. for (auto&& gadget : packers) {
  131. gadget.generate_r1cs_constraints(true);
  132. }
  133. for (auto&& gadget : constmuls) {
  134. gadget.generate_r1cs_constraints();
  135. }
  136. for (auto&& gadget : muls) {
  137. gadget.generate_r1cs_constraints();
  138. }
  139. for (auto&& gadget : adders) {
  140. gadget.generate_r1cs_constraints();
  141. }
  142. }
  143. void find_s_y(const FieldT &kval, FieldT &sval, FieldT &yval)
  144. {
  145. FieldT s_candidate = 0;
  146. while (true) {
  147. FieldT x_candidate = kval*256+s_candidate;
  148. FieldT ysq_candidate = (x_candidate.squared() - 3)*x_candidate + curve_b;
  149. if (sqrt_if_possible<FieldT>(yval, ysq_candidate)) {
  150. sval = s_candidate;
  151. return;
  152. }
  153. s_candidate += 1;
  154. }
  155. }
  156. void generate_r1cs_witness()
  157. {
  158. // Find an s and y such that x^3 - 3*x + b = y^2, where x = 256*k + s
  159. FieldT sval, yval;
  160. find_s_y(this->pb.val(k), sval, yval);
  161. this->pb.val(s) = sval;
  162. this->pb.val(y) = yval;
  163. this->pb.val(r) = FieldT::random_element();
  164. this->pb.val(xsquared) = (this->pb.val(k) * 256 + this->pb.val(s)).squared();
  165. this->pb.val(ysquared) = this->pb.val(y).squared();
  166. x.evaluate(this->pb);
  167. for (auto&& gadget : packers) {
  168. gadget.generate_r1cs_witness_from_packed();
  169. }
  170. for (auto&& gadget : constmuls) {
  171. gadget.generate_r1cs_witness();
  172. }
  173. for (auto&& gadget : muls) {
  174. gadget.generate_r1cs_witness();
  175. }
  176. for (auto&& gadget : adders) {
  177. gadget.generate_r1cs_witness();
  178. }
  179. }
  180. };
  181. int main(int argc, char **argv)
  182. {
  183. Mode mode = MODE_NONE;
  184. size_t numverifencs = 1;
  185. if (argc == 2 || argc == 3) {
  186. if (!strcmp(argv[1], "priv")) {
  187. mode = MODE_PRIV;
  188. } else if (!strcmp(argv[1], "pub")) {
  189. mode = MODE_PUB;
  190. } else if (!strcmp(argv[1], "const")) {
  191. mode = MODE_CONST;
  192. }
  193. if (argc == 3) {
  194. numverifencs = atoi(argv[2]);
  195. }
  196. }
  197. if (mode == MODE_NONE || numverifencs < 1) {
  198. cerr << "Usage: " << argv[0] << " mode n" << endl << endl;
  199. cerr << "Where mode is one of:" << endl;
  200. cerr << " priv: use private Ptable" << endl;
  201. cerr << " pub: use public Ptable" << endl;
  202. cerr << " const: use constant public key (no Ptable)" << endl << endl;
  203. cerr << "and where n is the number of verifencs in the circuit" << endl;
  204. exit(1);
  205. }
  206. // Initialize the curve parameters
  207. default_r1cs_gg_ppzksnark_pp::init_public_params();
  208. init_curveparams();
  209. typedef libff::Fr<default_r1cs_gg_ppzksnark_pp> FieldT;
  210. // Create protoboard
  211. libff::start_profiling();
  212. cout << "Keypair" << endl;
  213. protoboard<FieldT> pb;
  214. pb_variable<FieldT> C1x[numverifencs], C1y[numverifencs];
  215. pb_variable<FieldT> C2x[numverifencs], C2y[numverifencs];
  216. pb_variable<FieldT> Kx[numverifencs], Ky[numverifencs];
  217. pb_variable<FieldT> Px[numverifencs], Py[numverifencs];
  218. pb_variable_array<FieldT> Ptable[numverifencs];
  219. pb_variable<FieldT> k[numverifencs];
  220. const size_t numbits = FieldT::num_bits;
  221. // Allocate variables
  222. // Public outputs:
  223. for (size_t i = 0; i < numverifencs; ++i) {
  224. // El Gamal encryption of k under public key P (or H if MODE_CONST)
  225. // C1 = r*G, C2 = r*P + M (where M=(256*k+s,y))
  226. C1x[i].allocate(pb, "C1x");
  227. C1y[i].allocate(pb, "C1y");
  228. C2x[i].allocate(pb, "C2x");
  229. C2y[i].allocate(pb, "C2y");
  230. // Public key corresponding to private key k
  231. // K = k*G
  232. Kx[i].allocate(pb, "Kx");
  233. Ky[i].allocate(pb, "Ky");
  234. // Public inputs:
  235. // The public key P (if not MODE_CONST)
  236. if (mode != MODE_CONST) {
  237. Px[i].allocate(pb, "Px");
  238. Py[i].allocate(pb, "Py");
  239. }
  240. }
  241. if (mode != MODE_CONST) {
  242. for (size_t i = 0; i < numverifencs; ++i) {
  243. // The Ptable might be public or private, according to the mode
  244. Ptable[i].allocate(pb, 2*numbits, "Ptable");
  245. }
  246. }
  247. for (size_t i = 0; i < numverifencs; ++i) {
  248. // Private inputs:
  249. // k is a 246-bit random number
  250. k[i].allocate(pb, "k");
  251. }
  252. // This sets up the protoboard variables so that the first n of them
  253. // represent the public input and the rest is private input
  254. if (mode == MODE_PRIV) {
  255. pb.set_input_sizes(8*numverifencs);
  256. } else if (mode == MODE_PUB) {
  257. pb.set_input_sizes(8*numverifencs+2*numbits*numverifencs);
  258. } else if (mode == MODE_CONST) {
  259. pb.set_input_sizes(6*numverifencs);
  260. }
  261. // Initialize the gadgets
  262. vector<verified_encryption_gadget<FieldT> > vencs;
  263. for (size_t i = 0; i < numverifencs; ++i) {
  264. vencs.emplace_back(pb, mode, C1x[i], C1y[i], C2x[i], C2y[i], Kx[i], Ky[i], Px[i], Py[i], Ptable[i], k[i]);
  265. }
  266. for (auto&& gadget : vencs) {
  267. gadget.generate_r1cs_constraints();
  268. }
  269. const r1cs_constraint_system<FieldT> constraint_system = pb.get_constraint_system();
  270. const r1cs_gg_ppzksnark_keypair<default_r1cs_gg_ppzksnark_pp> keypair = r1cs_gg_ppzksnark_generator<default_r1cs_gg_ppzksnark_pp>(constraint_system);
  271. // Add witness values
  272. cout << "Prover" << endl;
  273. if (mode != MODE_CONST) {
  274. FieldT curve_b("7950939520449436327800262930799465135910802758673292356620796789196167463969");
  275. for (size_t i = 0; i < numverifencs; ++i) {
  276. // A variable base point P
  277. FieldT x, y, ysq;
  278. do {
  279. x = FieldT::random_element();
  280. ysq = (x.squared() - 3)*x + curve_b;
  281. } while (!sqrt_if_possible<FieldT>(y, ysq));
  282. pb.val(Px[i]) = x;
  283. pb.val(Py[i]) = y;
  284. }
  285. }
  286. gmp_randstate_t randstate;
  287. gmp_randinit_default(randstate);
  288. FieldT seed = FieldT::random_element();
  289. mpz_t seed_mpz;
  290. mpz_init(seed_mpz);
  291. seed.mont_repr.to_mpz(seed_mpz);
  292. gmp_randseed(randstate, seed_mpz);
  293. mpz_clear(seed_mpz);
  294. for (size_t i = 0; i < numverifencs; ++i) {
  295. mpz_t kval;
  296. mpz_init(kval);
  297. mpz_urandomb(kval, randstate, 246);
  298. pb.val(k[i]) = FieldT(kval);
  299. mpz_clear(kval);
  300. }
  301. libff::enter_block("PROVER TIME");
  302. for (auto&& gadget : vencs) {
  303. gadget.generate_r1cs_witness();
  304. }
  305. const r1cs_gg_ppzksnark_proof<default_r1cs_gg_ppzksnark_pp> proof = r1cs_gg_ppzksnark_prover<default_r1cs_gg_ppzksnark_pp>(keypair.pk, pb.primary_input(), pb.auxiliary_input());
  306. libff::leave_block("PROVER TIME");
  307. cout << "Verifier" << endl;
  308. libff::enter_block("VERIFIER TIME");
  309. bool verified = r1cs_gg_ppzksnark_verifier_strong_IC<default_r1cs_gg_ppzksnark_pp>(keypair.vk, pb.primary_input(), proof);
  310. libff::leave_block("VERIFIER TIME");
  311. cout << "Number of R1CS constraints: " << constraint_system.num_constraints() << endl;
  312. cout << "Primary (public) input length: " << pb.primary_input().size() << endl;
  313. // cout << "Primary (public) input: " << pb.primary_input() << endl;
  314. cout << "Auxiliary (private) input length: " << pb.auxiliary_input().size() << endl;
  315. // cout << "Auxiliary (private) input: " << pb.auxiliary_input() << endl;
  316. cout << "Verification status: " << verified << endl;
  317. ofstream pkfile(string("pk_verifenc_") + argv[1] + "_" + to_string(numverifencs));
  318. pkfile << keypair.pk;
  319. pkfile.close();
  320. ofstream vkfile(string("vk_verifenc_") + argv[1] + "_" + to_string(numverifencs));
  321. vkfile << keypair.vk;
  322. vkfile.close();
  323. ofstream pffile(string("proof_verifenc_") + argv[1] + "_" + to_string(numverifencs));
  324. pffile << proof;
  325. pffile.close();
  326. return 0;
  327. }