Bladeren bron

fixed user tally batched proofs

tristangurtler 3 jaren geleden
bovenliggende
commit
65db225dec
2 gewijzigde bestanden met toevoegingen van 459 en 1 verwijderingen
  1. 54 0
      prsona/inc/base.hpp
  2. 405 1
      prsona/src/base.cpp

+ 54 - 0
prsona/inc/base.hpp

@@ -316,6 +316,60 @@ class PrsonaBase {
             const std::vector<std::vector<Twistpoint>>& userTallySeedCommits
         ) const;
 
+        std::vector<Proof> generate_unbatched_user_tally_proofs(
+            const std::vector<std::vector<Scalar>>& permutations,
+            const Scalar& power,
+            const Twistpoint& nextGenerator,
+            const std::vector<std::vector<Scalar>>& permutationSeeds,
+            const std::vector<std::vector<Scalar>>& userTallySeeds,
+            const std::vector<Twistpoint>& currPseudonyms,
+            const std::vector<Twistpoint>& userTallyMasks,
+            const std::vector<Twistpoint>& userTallyMessages,
+            const std::vector<std::vector<Twistpoint>>& permutationCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallySeedCommits
+        ) const;
+
+        bool verify_unbatched_user_tally_proofs(
+            const std::vector<Proof>& pi,
+            const Twistpoint& nextGenerator,
+            const std::vector<Twistpoint>& currPseudonyms,
+            const std::vector<Twistpoint>& userTallyMasks,
+            const std::vector<Twistpoint>& userTallyMessages,
+            const std::vector<std::vector<Twistpoint>>& permutationCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallySeedCommits
+        ) const;
+
+        std::vector<Proof> generate_batched_user_tally_proofs(
+            const std::vector<std::vector<Scalar>>& permutations,
+            const Scalar& power,
+            const Twistpoint& nextGenerator,
+            const std::vector<std::vector<Scalar>>& permutationSeeds,
+            const std::vector<std::vector<Scalar>>& userTallySeeds,
+            const std::vector<Twistpoint>& currPseudonyms,
+            const std::vector<Twistpoint>& userTallyMasks,
+            const std::vector<Twistpoint>& userTallyMessages,
+            const std::vector<std::vector<Twistpoint>>& permutationCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallySeedCommits
+        ) const;
+
+        bool verify_batched_user_tally_proofs(
+            const std::vector<Proof>& pi,
+            const Twistpoint& nextGenerator,
+            const std::vector<Twistpoint>& currPseudonyms,
+            const std::vector<Twistpoint>& userTallyMasks,
+            const std::vector<Twistpoint>& userTallyMessages,
+            const std::vector<std::vector<Twistpoint>>& permutationCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+            const std::vector<std::vector<Twistpoint>>& userTallySeedCommits
+        ) const;
+
         // SERVER AGREEMENT PROOFS
         Proof generate_valid_vote_row_proof(
             const std::vector<TwistBipoint>& commitment

+ 405 - 1
prsona/src/base.cpp

@@ -1803,6 +1803,43 @@ std::vector<Proof> PrsonaBase::generate_user_tally_proofs(
     const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
     const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
     const std::vector<std::vector<Twistpoint>>& userTallySeedCommits) const
+{
+    if (LAMBDA > 0)
+        return generate_batched_user_tally_proofs(permutations, power, nextGenerator, permutationSeeds, userTallySeeds, currPseudonyms, userTallyMasks, userTallyMessages, permutationCommits, userTallyMaskCommits, userTallyMessageCommits, userTallySeedCommits);
+    else
+        return generate_unbatched_user_tally_proofs(permutations, power, nextGenerator, permutationSeeds, userTallySeeds, currPseudonyms, userTallyMasks, userTallyMessages, permutationCommits, userTallyMaskCommits, userTallyMessageCommits, userTallySeedCommits);
+}
+
+bool PrsonaBase::verify_user_tally_proofs(
+    const std::vector<Proof>& pi,
+    const Twistpoint& nextGenerator,
+    const std::vector<Twistpoint>& currPseudonyms,
+    const std::vector<Twistpoint>& userTallyMasks,
+    const std::vector<Twistpoint>& userTallyMessages,
+    const std::vector<std::vector<Twistpoint>>& permutationCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallySeedCommits) const
+{
+    if (LAMBDA > 0)
+        return verify_batched_user_tally_proofs(pi, nextGenerator, currPseudonyms, userTallyMasks, userTallyMessages, permutationCommits, userTallyMaskCommits, userTallyMessageCommits, userTallySeedCommits);
+    else
+        return verify_unbatched_user_tally_proofs(pi, nextGenerator, currPseudonyms, userTallyMasks, userTallyMessages, permutationCommits, userTallyMaskCommits, userTallyMessageCommits, userTallySeedCommits);
+}
+
+std::vector<Proof> PrsonaBase::generate_unbatched_user_tally_proofs(
+    const std::vector<std::vector<Scalar>>& permutations,
+    const Scalar& power,
+    const Twistpoint& nextGenerator,
+    const std::vector<std::vector<Scalar>>& permutationSeeds,
+    const std::vector<std::vector<Scalar>>& userTallySeeds,
+    const std::vector<Twistpoint>& currPseudonyms,
+    const std::vector<Twistpoint>& userTallyMasks,
+    const std::vector<Twistpoint>& userTallyMessages,
+    const std::vector<std::vector<Twistpoint>>& permutationCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallySeedCommits) const
 {
     std::vector<Proof> retval;
 
@@ -1928,7 +1965,7 @@ std::vector<Proof> PrsonaBase::generate_user_tally_proofs(
     return retval;
 }
 
-bool PrsonaBase::verify_user_tally_proofs(
+bool PrsonaBase::verify_unbatched_user_tally_proofs(
     const std::vector<Proof>& pi,
     const Twistpoint& nextGenerator,
     const std::vector<Twistpoint>& currPseudonyms,
@@ -2038,6 +2075,373 @@ bool PrsonaBase::verify_user_tally_proofs(
     return true;
 }
 
+std::vector<Proof> PrsonaBase::generate_batched_user_tally_proofs(
+    const std::vector<std::vector<Scalar>>& permutations,
+    const Scalar& power,
+    const Twistpoint& nextGenerator,
+    const std::vector<std::vector<Scalar>>& permutationSeeds,
+    const std::vector<std::vector<Scalar>>& userTallySeeds,
+    const std::vector<Twistpoint>& currPseudonyms,
+    const std::vector<Twistpoint>& userTallyMasks,
+    const std::vector<Twistpoint>& userTallyMessages,
+    const std::vector<std::vector<Twistpoint>>& permutationCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallySeedCommits) const
+{
+    std::vector<Proof> retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.push_back(Proof("PROOF"));
+        return retval;
+    }
+
+    Proof first;
+    retval.push_back(first);
+    
+    Twistpoint g = EL_GAMAL_GENERATOR;
+    Twistpoint h = elGamalBlindGenerator;
+
+    std::stringstream oracleInput;
+    oracleInput << g << h << nextGenerator;
+
+    for (size_t i = 0; i < currPseudonyms.size(); i++)
+        oracleInput << currPseudonyms[i];
+
+    for (size_t i = 0; i < userTallyMasks.size(); i++)
+        oracleInput << userTallyMasks[i];
+
+    for (size_t i = 0; i < userTallyMessages.size(); i++)
+        oracleInput << userTallyMessages[i];
+
+    for (size_t i = 0; i < permutationCommits.size(); i++)
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+            oracleInput << permutationCommits[i][j];
+
+    for (size_t i = 0; i < userTallyMaskCommits.size(); i++)
+        for (size_t j = 0; j < userTallyMaskCommits[i].size(); j++)
+            oracleInput << userTallyMaskCommits[i][j];
+
+    for (size_t i = 0; i < userTallyMessageCommits.size(); i++)
+        for (size_t j = 0; j < userTallyMessageCommits[i].size(); j++)
+            oracleInput << userTallyMessageCommits[i][j];
+
+    for (size_t i = 0; i < userTallySeedCommits.size(); i++)
+        for (size_t j = 0; j < userTallySeedCommits[i].size(); j++)
+            oracleInput << userTallySeedCommits[i][j];
+
+    std::vector<Scalar> b1;
+    std::vector<Scalar> b2;
+    std::vector<Scalar> t1;
+    std::vector<Scalar> t2;
+
+    for (size_t i = 0; i < permutationCommits.size(); i++)
+    {
+        Proof currProof;
+
+        Scalar currb1;
+        Scalar currb2;
+        Scalar currt1;
+        Scalar currt2;
+
+        Twistpoint U1, U4, U6, U7;
+        std::vector<Twistpoint> U2, U3, U5;
+
+        currb1.set_random();
+        currb2.set_random();
+        currt1.set_random();
+        currt2.set_random();
+
+        U1 = g * currb2 + h * currt1;
+        U4 = currPseudonyms[i] * (currb1 * currb2 * currt2);
+        U6 = nextGenerator * (currb2 * currt2);
+        U7 = g * currt2;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            Twistpoint currU2, currU3, currU5;
+            
+            Scalar U2MaskMult = power * currb2 + currb1 * permutations[i][j];
+            Scalar U2PseudMult = 
+                power * currt2 * permutations[i][j] +
+                power * currb2 * userTallySeeds[j][i] +
+                currb1 * permutations[i][j] * userTallySeeds[j][i];
+            Scalar U2hMult = currt2;
+
+            Scalar U3MaskMult = currb1 * currb2;
+            Scalar U3PseudMult = power * currb2 * currt2 + 
+                currb1 * currt2 * permutations[i][j] +
+                currb1 * currb2 * userTallySeeds[j][i];
+            Scalar U3hMult = Scalar(0);
+
+            Scalar U5MessageMult = currb2;
+            Scalar U5GenMult = currb2 * userTallySeeds[j][i] + currt2 * permutations[i][j];
+            Scalar U5hMult = currt2;
+            for (size_t k = 0; k < permutationCommits[i].size(); k++)
+            {
+                if (k == j)
+                    continue;
+
+                std::stringstream xkOracle;
+                xkOracle << permutationCommits[i][k] << userTallyMaskCommits[k][i] << userTallyMessageCommits[k][i] << userTallySeedCommits[k][i];
+                Scalar x_k = oracle(xkOracle.str(), LAMBDA);
+
+                for (size_t l = 0; l < permutationCommits[i].size(); l++)
+                {
+                    if (l == j)
+                        continue;
+                    if (l == k)
+                        continue;
+
+                    std::stringstream xlOracle;
+                    xlOracle << permutationCommits[i][l] << userTallyMaskCommits[l][i] << userTallyMessageCommits[l][i] << userTallySeedCommits[l][i];
+                    Scalar x_l = oracle(xlOracle.str(), LAMBDA);
+
+                    U3MaskMult = U3MaskMult +
+                        power * permutations[i][j] * x_k * x_l;
+                    U3PseudMult = U3PseudMult +
+                        power * permutations[i][j] * userTallySeeds[k][i] * x_k * x_l;
+                    U3hMult = U3hMult +
+                        userTallySeeds[j][i] * x_k * x_l;
+                }
+
+                U2MaskMult = U2MaskMult +
+                    Scalar(2) * power * permutations[i][j] * x_k +
+                    power * permutations[i][k] * x_k;
+                U2PseudMult = U2PseudMult +
+                    power * permutations[i][j] * userTallySeeds[j][i] * x_k +
+                    power * permutations[i][k] * userTallySeeds[j][i] * x_k +
+                    power * permutations[i][j] * userTallySeeds[k][i] * x_k;
+                U2hMult = U2hMult +
+                    Scalar(2) * userTallySeeds[j][i] * x_k +
+                    userTallySeeds[k][i] * x_k;
+
+                U3MaskMult = U3MaskMult +
+                    power * currb2 * x_k +
+                    currb1 * permutations[i][j] * x_k;
+                U3PseudMult = U3PseudMult +
+                    power * currt2 * permutations[i][j] * x_k +
+                    power * currb2 * userTallySeeds[j][i] * x_k +
+                    currb1 * permutations[i][j] * userTallySeeds[k][i] * x_k;
+                U3hMult = U3hMult +
+                    currt2 * x_k;
+
+                U5MessageMult = U5MessageMult +
+                    permutations[i][j] * x_k;
+                U5GenMult = U5GenMult + 
+                    permutations[i][j] * userTallySeeds[k][i] * x_k;
+                U5hMult = U5hMult +
+                    userTallySeeds[j][i] * x_k;
+            }
+            currU2 = userTallyMasks[i] * U2MaskMult +
+                currPseudonyms[i] * U2PseudMult +
+                h * U2hMult;
+
+            currU3 = userTallyMasks[i] * U3MaskMult +
+                currPseudonyms[i] * U3PseudMult +
+                h * U3hMult;
+
+            currU5 = userTallyMessages[i] * U5MessageMult +
+                nextGenerator * U5GenMult +
+                h * U5hMult;
+        
+            U2.push_back(currU2);
+            U3.push_back(currU3);
+            U5.push_back(currU5);
+        }
+
+        oracleInput << U1;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            oracleInput << U2[j];
+            currProof.curvepointUniversals.push_back(U2[j]);
+        }
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            oracleInput << U3[j];
+            currProof.curvepointUniversals.push_back(U3[j]);
+        }
+        oracleInput << U4;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            oracleInput << U5[j];
+            currProof.curvepointUniversals.push_back(U5[j]);
+        }
+        oracleInput << U6 << U7;
+
+        b1.push_back(currb1);
+        b2.push_back(currb2);
+        t1.push_back(currt1);
+        t2.push_back(currt2);
+
+        retval.push_back(currProof);
+    }
+
+    Scalar x = oracle(oracleInput.str());
+    retval[0].challengeParts.push_back(x);
+
+    for (size_t i = 0; i < permutationCommits.size(); i++)
+    {
+        size_t piIndex = i + 1;
+
+        Scalar f1 = b1[i];
+        Scalar f2 = b2[i];
+        Scalar z1 = t1[i];
+        Scalar z2 = t2[i];
+
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            std::stringstream currOracle;
+            currOracle << permutationCommits[i][j] << userTallyMaskCommits[j][i] << userTallyMessageCommits[j][i] << userTallySeedCommits[j][i];
+            Scalar currx = oracle(currOracle.str(), LAMBDA);
+
+            f1 = f1 + power * currx;
+            f2 = f2 + permutations[i][j] * currx;
+            z1 = z1 + permutationSeeds[i][j] * currx;
+            z2 = z2 + userTallySeeds[j][i] * currx;
+        }
+
+        retval[piIndex].responseParts.push_back(f1);
+        retval[piIndex].responseParts.push_back(f2);
+        retval[piIndex].responseParts.push_back(z1);
+        retval[piIndex].responseParts.push_back(z2);
+    }
+
+    return retval;
+}
+
+bool PrsonaBase::verify_batched_user_tally_proofs(
+    const std::vector<Proof>& pi,
+    const Twistpoint& nextGenerator,
+    const std::vector<Twistpoint>& currPseudonyms,
+    const std::vector<Twistpoint>& userTallyMasks,
+    const std::vector<Twistpoint>& userTallyMessages,
+    const std::vector<std::vector<Twistpoint>>& permutationCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMaskCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallyMessageCommits,
+    const std::vector<std::vector<Twistpoint>>& userTallySeedCommits) const
+{
+    if (pi.empty())
+        return false;
+
+    if (!SERVER_IS_MALICIOUS)
+        return pi[0].hbc == "PROOF";
+    
+    Twistpoint g = EL_GAMAL_GENERATOR;
+    Twistpoint h = elGamalBlindGenerator;
+
+    std::stringstream oracleInput;
+    oracleInput << g << h << nextGenerator;
+
+    for (size_t i = 0; i < currPseudonyms.size(); i++)
+        oracleInput << currPseudonyms[i];
+
+    for (size_t i = 0; i < userTallyMasks.size(); i++)
+        oracleInput << userTallyMasks[i];
+
+    for (size_t i = 0; i < userTallyMessages.size(); i++)
+        oracleInput << userTallyMessages[i];
+
+    for (size_t i = 0; i < permutationCommits.size(); i++)
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+            oracleInput << permutationCommits[i][j];
+
+    for (size_t i = 0; i < userTallyMaskCommits.size(); i++)
+        for (size_t j = 0; j < userTallyMaskCommits[i].size(); j++)
+            oracleInput << userTallyMaskCommits[i][j];
+
+    for (size_t i = 0; i < userTallyMessageCommits.size(); i++)
+        for (size_t j = 0; j < userTallyMessageCommits[i].size(); j++)
+            oracleInput << userTallyMessageCommits[i][j];
+
+    for (size_t i = 0; i < userTallySeedCommits.size(); i++)
+        for (size_t j = 0; j < userTallySeedCommits[i].size(); j++)
+            oracleInput << userTallySeedCommits[i][j];
+
+    Scalar x = pi[0].challengeParts[0];
+
+    for (size_t i = 0; i < permutationCommits.size(); i++)
+    {
+        Scalar sum_of_sub_xs = Scalar(0);
+        std::vector<Scalar> sub_xs;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            std::stringstream currOracle;
+            currOracle << permutationCommits[i][j] << userTallyMaskCommits[j][i] << userTallyMessageCommits[j][i] << userTallySeedCommits[j][i];
+            sub_xs.push_back(oracle(currOracle.str(), LAMBDA));
+            sum_of_sub_xs = sum_of_sub_xs + sub_xs[j];
+        }
+
+        size_t piIndex = i + 1;
+
+        std::vector<Twistpoint> U2(pi[piIndex].curvepointUniversals.begin(), pi[piIndex].curvepointUniversals.begin() + permutationCommits.size());
+        std::vector<Twistpoint> U3(pi[piIndex].curvepointUniversals.begin() + permutationCommits.size(), pi[piIndex].curvepointUniversals.begin() + (2 * permutationCommits.size()));
+        std::vector<Twistpoint> U5(pi[piIndex].curvepointUniversals.begin() + (2 * permutationCommits.size()), pi[piIndex].curvepointUniversals.end());
+
+        Scalar f1 = pi[piIndex].responseParts[0];
+        Scalar f2 = pi[piIndex].responseParts[1];
+        Scalar z1 = pi[piIndex].responseParts[2];
+        Scalar z2 = pi[piIndex].responseParts[3];
+
+        Twistpoint U1 = g * f2 + h * z1;
+        Twistpoint U4 = userTallyMasks[i] * (f1 * f2 * sum_of_sub_xs) +
+            currPseudonyms[i] * (f1 * f2 * z2) +
+            h * (z2 * sum_of_sub_xs * sum_of_sub_xs);
+        Twistpoint U6 = userTallyMessages[i] * (f2 * sum_of_sub_xs) +
+            nextGenerator * (f2 * z2) +
+            h * (z2 * sum_of_sub_xs);
+        Twistpoint U7 = g * z2;
+
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+        {
+            U1 = U1 - permutationCommits[i][j] * sub_xs[j];
+            
+            U4 = U4 -
+                userTallyMaskCommits[j][i] * (sub_xs[j] * sub_xs[j] * sub_xs[j]) -
+                U2[j] * (sub_xs[j] * sub_xs[j]) -
+                U3[j] * sub_xs[j];
+
+            U6 = U6 -
+                userTallyMessageCommits[j][i] * (sub_xs[j] * sub_xs[j]) -
+                U5[j] * sub_xs[j];
+
+            U7 = U7 - userTallySeedCommits[j][i] * sub_xs[j];
+        }
+
+        oracleInput << U1;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+            oracleInput << U2[j];
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+            oracleInput << U3[j];
+        oracleInput << U4;
+        for (size_t j = 0; j < permutationCommits[i].size(); j++)
+            oracleInput << U5[j];
+        oracleInput << U6 << U7;
+    }
+
+    if (x != oracle(oracleInput.str()))
+    {
+        std::cerr << "User tallies not generated by permutation matrix." << std::endl;
+        return false;
+    }
+
+    for (size_t i = 0; i < userTallySeedCommits.size(); i++)
+    {
+        Twistpoint sum = userTallySeedCommits[i][0];
+
+        for (size_t j = 1; j < userTallySeedCommits[i].size(); j++)
+            sum = sum + userTallySeedCommits[i][j];
+
+        if (sum != Twistpoint())
+        {
+            std::cerr << "seed commits did not sum to 0, aborting." << std::endl;
+            return false;
+        }
+    }
+
+    return true;
+}
+
 /*
  * SERVER AGREEMENT PROOFS
  */