Browse Source

Make valid vote proof a little more efficient

tristangurtler 4 years ago
parent
commit
12b3a8dc96
2 changed files with 46 additions and 72 deletions
  1. 32 49
      prsona/src/client.cpp
  2. 14 23
      prsona/src/server.cpp

+ 32 - 49
prsona/src/client.cpp

@@ -532,13 +532,11 @@ std::vector<Proof> PrsonaClient::generate_vote_proof(
         {
             Proof currProof;
 
-            Scalar c_r, z_r, a, b, s_1, s_2, t_1, t_2;
+            Scalar c_r, z_r, a, s, t_1, t_2;
             c_r.set_random();
             z_r.set_random();
             a.set_random();
-            b.set_random();
-            s_1.set_random();
-            s_2.set_random();
+            s.set_random();
             t_1.set_random();
             t_2.set_random();
 
@@ -548,49 +546,42 @@ std::vector<Proof> PrsonaClient::generate_vote_proof(
                 newEncryptedVotes[i] * c_r;
 
             CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * a +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * s_1;
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * s;
 
-            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * b +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * s_2;
-
-            Scalar power = (a.curveAdd(b)).curveMult(votes[i].curveMult(votes[i]));
+            Scalar power = (a.curveAdd(a)).curveMult(votes[i].curveMult(votes[i]));
             power =
-                power.curveSub((a.curveAdd(a).curveAdd(b)).curveMult(votes[i]));
-            CurveBipoint C_c = serverPublicKey.get_bipoint_curvegen() * power +
+                power.curveSub((a.curveAdd(a).curveAdd(a)).curveMult(votes[i]));
+            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * power +
                 serverPublicKey.get_bipoint_curve_subgroup_gen() * t_1;
-            currProof.partialUniversals.push_back(C_c[0]);
-            currProof.partialUniversals.push_back(C_c[1]);
+            currProof.partialUniversals.push_back(C_b[0]);
+            currProof.partialUniversals.push_back(C_b[1]);
 
-            CurveBipoint C_d =
+            CurveBipoint C_c =
                 serverPublicKey.get_bipoint_curvegen() *
-                    a.curveMult(b.curveMult(votes[i])) +
+                    a.curveMult(a.curveMult(votes[i])) +
                 serverPublicKey.get_bipoint_curve_subgroup_gen() * t_2;
 
-            oracleInput << U << C_a << C_b << C_c << C_d;
+            oracleInput << U << C_a << C_b << C_c;
 
             Scalar c = oracle(oracleInput.str());
             Scalar c_n = c.curveSub(c_r);
             currProof.challengeParts.push_back(c_r);
             currProof.challengeParts.push_back(c_n);
 
-            Scalar f_1 = (votes[i].curveMult(c_n)).curveAdd(a);
-            Scalar f_2 = (votes[i].curveMult(c_n)).curveAdd(b);
-            Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s_1);
-            Scalar z_nb = (seeds[i].curveMult(c_n)).curveAdd(s_2);
+            Scalar f = (votes[i].curveMult(c_n)).curveAdd(a);
+            Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s);
 
             Scalar t_1_c_n_t_2 = (t_1.curveMult(c_n)).curveAdd(t_2);
-            Scalar f_1_c_n = f_1.curveSub(c_n);
-            Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
-            Scalar z_nc = 
-                (seeds[i].curveMult(f_1_c_n).curveMult(c_n_f_2)).curveAdd(
+            Scalar f_c_n = f.curveSub(c_n);
+            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+            Scalar z_nb = 
+                (seeds[i].curveMult(f_c_n).curveMult(c_n2_f)).curveAdd(
                     t_1_c_n_t_2);
 
             currProof.responseParts.push_back(z_r);
-            currProof.responseParts.push_back(f_1);
-            currProof.responseParts.push_back(f_2);
+            currProof.responseParts.push_back(f);
             currProof.responseParts.push_back(z_na);
             currProof.responseParts.push_back(z_nb);
-            currProof.responseParts.push_back(z_nc);
 
             retval.push_back(currProof);
         }
@@ -599,43 +590,37 @@ std::vector<Proof> PrsonaClient::generate_vote_proof(
             Proof currProof;
 
             Scalar u, commitmentLambda_1, commitmentLambda_2,
-                c_n, z_na, z_nb, z_nc, f_1, f_2;
+                c_n, z_na, z_nb, f;
             u.set_random();
             commitmentLambda_1.set_random();
             commitmentLambda_2.set_random();
             c_n.set_random();
             z_na.set_random();
             z_nb.set_random();
-            z_nc.set_random();
-            f_1.set_random();
-            f_2.set_random();
+            f.set_random();
 
             CurveBipoint U =
                 serverPublicKey.get_bipoint_curve_subgroup_gen() * u;
 
-            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * f_1 +
+            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * f +
                 serverPublicKey.get_bipoint_curve_subgroup_gen() * z_na -
                 newEncryptedVotes[i] * c_n;
 
-            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * f_2 +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nb -
-                newEncryptedVotes[i] * c_n;
-
-            CurveBipoint C_c = 
+            CurveBipoint C_b = 
                 serverPublicKey.get_bipoint_curvegen() * commitmentLambda_1 +
                 serverPublicKey.get_bipoint_curve_subgroup_gen() *
                     commitmentLambda_2;
-            currProof.partialUniversals.push_back(C_c[0]);
-            currProof.partialUniversals.push_back(C_c[1]);
+            currProof.partialUniversals.push_back(C_b[0]);
+            currProof.partialUniversals.push_back(C_b[1]);
 
-            Scalar f_1_c_n = f_1.curveSub(c_n);
-            Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
-            CurveBipoint C_d =
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nc -
-                newEncryptedVotes[i] * f_1_c_n.curveMult(c_n_f_2) -
-                C_c * c_n;
+            Scalar f_c_n = f.curveSub(c_n);
+            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+            CurveBipoint C_c =
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nb -
+                newEncryptedVotes[i] * f_c_n.curveMult(c_n2_f) -
+                C_b * c_n;
 
-            oracleInput << U << C_a << C_b << C_c << C_d;
+            oracleInput << U << C_a << C_b << C_c;
 
             Scalar c = oracle(oracleInput.str());
             Scalar c_r = c.curveSub(c_n);
@@ -644,11 +629,9 @@ std::vector<Proof> PrsonaClient::generate_vote_proof(
 
             Scalar z_r = u.curveAdd(c_r.curveMult(seeds[i]));
             currProof.responseParts.push_back(z_r);
-            currProof.responseParts.push_back(f_1);
-            currProof.responseParts.push_back(f_2);
+            currProof.responseParts.push_back(f);
             currProof.responseParts.push_back(z_na);
             currProof.responseParts.push_back(z_nb);
-            currProof.responseParts.push_back(z_nc);
 
             retval.push_back(currProof);
         }

+ 14 - 23
prsona/src/server.cpp

@@ -204,10 +204,7 @@ bool PrsonaServer::receive_vote(
     std::vector<CurveBipoint> oldVotes = voteMatrix[voteSubmitter];
 
     if (!verify_vote_proof(pi, oldVotes, newVotes, shortTermPublicKey))
-    {
-        std::cerr << "Could not verify votes." << std::endl;
         return false;
-    }
 
     voteMatrix[voteSubmitter] = newVotes;
     return true;
@@ -575,44 +572,38 @@ bool PrsonaServer::verify_vote_proof(
     for (size_t i = 1; i < pi.size(); i++)
     {
         size_t voteIndex = i - 1;
-        Curvepoint C_c_0, C_c_1;
-        C_c_0 = pi[i].partialUniversals[0];
-        C_c_1 = pi[i].partialUniversals[1];
+        Curvepoint C_b_0, C_b_1;
+        C_b_0 = pi[i].partialUniversals[0];
+        C_b_1 = pi[i].partialUniversals[1];
 
         CurveBipoint g, h;
         g = bgn_system.get_public_key().get_bipoint_curvegen();
         h = bgn_system.get_public_key().get_bipoint_curve_subgroup_gen();
-        CurveBipoint C_c(C_c_0, C_c_1);
+        CurveBipoint C_b(C_b_0, C_b_1);
 
-        Scalar c_r, c_n, z_r, f_1, f_2, z_na, z_nb, z_nc;
+        Scalar c_r, c_n, z_r, f, z_na, z_nb;
         c_r = pi[i].challengeParts[0];
         c_n = pi[i].challengeParts[1];
 
         z_r  = pi[i].responseParts[0];
-        f_1  = pi[i].responseParts[1];
-        f_2  = pi[i].responseParts[2];
-        z_na = pi[i].responseParts[3];
-        z_nb = pi[i].responseParts[4];
-        z_nc = pi[i].responseParts[5];
+        f  = pi[i].responseParts[1];
+        z_na = pi[i].responseParts[2];
+        z_nb = pi[i].responseParts[3];
 
-        CurveBipoint U, C_a, C_b, C_d;
+        CurveBipoint U, C_a, C_c;
         U = h * z_r + oldVotes[voteIndex] * c_r - newVotes[voteIndex] * c_r;
-        C_a = g * f_1 + h * z_na - newVotes[voteIndex] * c_n;
-        C_b = g * f_2 + h * z_nb - newVotes[voteIndex] * c_n;
+        C_a = g * f + h * z_na - newVotes[voteIndex] * c_n;
 
-        Scalar f_1_c_n = f_1.curveSub(c_n);
-        Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
-        C_d = h * z_nc - newVotes[voteIndex] * f_1_c_n.curveMult(c_n_f_2) - C_c * c_n;
+        Scalar f_c_n = f.curveSub(c_n);
+        Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+        C_c = h * z_nb - newVotes[voteIndex] * f_c_n.curveMult(c_n2_f) - C_b * c_n;
 
         std::stringstream oracleInput;
         oracleInput << g << h << oldVotes[voteIndex] << newVotes[voteIndex]
-            << U << C_a << C_b << C_c << C_d;
+            << U << C_a << C_b << C_c;
 
         if (oracle(oracleInput.str()) != c_r.curveAdd(c_n))
-        {
-            std::cerr << "Valid vote proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
             return false;
-        }
     }
 
     return true;