Browse Source

reputation proofs fully working now

tristangurtler 3 years ago
parent
commit
80171710c0
6 changed files with 212 additions and 148 deletions
  1. 3 3
      prsona/inc/client.hpp
  2. 0 1
      prsona/inc/proof.hpp
  3. 1 0
      prsona/inc/serverEntity.hpp
  4. 142 105
      prsona/src/client.cpp
  5. 41 4
      prsona/src/main.cpp
  6. 25 35
      prsona/src/serverEntity.cpp

+ 3 - 3
prsona/inc/client.hpp

@@ -33,12 +33,11 @@ class PrsonaClient {
         // SERVER INTERACTIONS
         std::vector<CurveBipoint> make_votes(
             Proof& pi,
+            const std::vector<CurveBipoint>& currentEncryptedVotes,
             const std::vector<Scalar>& vote,
             const std::vector<bool>& replace);
         void receive_fresh_generator(const Curvepoint& freshGenerator);
         void receive_vote_tally(const Proof& pi, const EGCiphertext& score);
-        void receive_encrypted_votes(
-            const Proof& pi, const std::vector<CurveBipoint>& votes);
         
         // REPUTATION PROOFS
         std::vector<Proof> generate_reputation_proof(
@@ -48,6 +47,8 @@ class PrsonaClient {
             const Curvepoint& shortTermPublicKey,
             const Scalar& threshold) const;
 
+        Scalar get_score() const;
+
     private:
         // Constants for clients
         static Curvepoint EL_GAMAL_GENERATOR;
@@ -69,7 +70,6 @@ class PrsonaClient {
         // Things bound to this user (but change regularly)
         EGCiphertext currentEncryptedScore;
         Scalar currentScore;
-        std::vector<CurveBipoint> currentEncryptedVotes;
 
         // Things related to making decryption more efficient
         std::unordered_map<Curvepoint, Scalar, CurvepointHash>

+ 0 - 1
prsona/inc/proof.hpp

@@ -14,7 +14,6 @@
 struct Proof {    
     std::string basic;
     std::vector<Curvepoint> partialUniversals;
-    std::vector<Curvepoint> initParts;
     std::vector<Scalar> challengeParts;
     std::vector<Scalar> responseParts;
 };

+ 1 - 0
prsona/inc/serverEntity.hpp

@@ -16,6 +16,7 @@ class PrsonaServerEntity {
         BGNPublicKey get_bgn_public_key() const;
         Curvepoint get_blinding_generator() const;
         Curvepoint get_fresh_generator() const;
+        size_t get_num_clients() const;
 
         // ENCRYPTED DATA GETTERS
         std::vector<CurveBipoint> get_current_votes_by(

+ 142 - 105
prsona/src/client.cpp

@@ -28,6 +28,11 @@ mpz_class log2(mpz_class x)
     return retval;
 }
 
+mpz_class bit(mpz_class x)
+{
+    return x > 0 ? 1 : 0;
+}
+
 /********************
  * PUBLIC FUNCTIONS *
  ********************/
@@ -84,29 +89,37 @@ Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
  * SERVER INTERACTIONS
  */
 
-// Generate a new vote vector to give to the servers
-// (@replace controls which votes are actually being updated and which are not)
+/* Generate a new vote vector to give to the servers
+ * @replaces controls which votes are actually being updated and which are not
+ *
+ * You may really want to make currentEncryptedVotes a member variable, 
+ * but it doesn't behave correctly when adding new clients after this one. */
 std::vector<CurveBipoint> PrsonaClient::make_votes(
     Proof& pi,
-    const std::vector<Scalar>& vote,
-    const std::vector<bool>& replace)
+    const std::vector<CurveBipoint>& currentEncryptedVotes,
+    const std::vector<Scalar>& votes,
+    const std::vector<bool>& replaces)
 {
     std::vector<CurveBipoint> retval;
 
-    for (size_t i = 0; i < vote.size(); i++)
+    if (!verify_valid_votes_proof(pi, currentEncryptedVotes))
+    {
+        std::cerr << "Could not verify proof of valid votes." << std::endl;
+        return retval;
+    }
+
+    for (size_t i = 0; i < votes.size(); i++)
     {
         CurveBipoint currScore;
-        if (replace[i])
-            serverPublicKey.encrypt(currScore, vote[i]);
+        if (replaces[i])
+            serverPublicKey.encrypt(currScore, votes[i]);
         else
             currScore = serverPublicKey.rerandomize(currentEncryptedVotes[i]);
 
         retval.push_back(currScore);
     }
 
-    currentEncryptedVotes = retval;
-
-    pi = generate_vote_proof(retval, vote);
+    pi = generate_vote_proof(retval, votes);
     return retval;
 }
 
@@ -130,32 +143,21 @@ void PrsonaClient::receive_vote_tally(
     decrypt_score(score);
 }
 
-// Receive a new encrypted vote vector from the servers (each epoch)
-void PrsonaClient::receive_encrypted_votes(
-    const Proof& pi, const std::vector<CurveBipoint>& votes)
-{
-    if (!verify_valid_votes_proof(pi, votes))
-    {
-        std::cerr << "Could not verify proof of valid votes." << std::endl;
-        return;
-    }
-
-    currentEncryptedVotes = votes;
-}
-
 /*
  * REPUTATION PROOFS
  */
 
-// TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
+// A pretty straightforward range proof (generation)
 std::vector<Proof> PrsonaClient::generate_reputation_proof(
     const Scalar& threshold) const
 {
     std::vector<Proof> retval;
 
+    // Don't even try if the user asks to make an illegitimate proof
     if (threshold > currentScore)
         return retval;
 
+    // Base case
     if (!CLIENT_IS_MALICIOUS)
     {
         Proof currProof;
@@ -165,78 +167,79 @@ std::vector<Proof> PrsonaClient::generate_reputation_proof(
         return retval;
     }
 
+    // We really have two consecutive proofs in a junction.
+    // The first is to prove that we are the stpk we claim we are
     retval.push_back(generate_ownership_proof());
 
+    // The value we're actually using in our proof
     mpz_class proofVal = currentScore.curveSub(threshold).toInt();
-    mpz_class proofBits = log2(currentEncryptedVotes.size() * MAX_ALLOWED_VOTE - threshold.toInt());
+    // Top of the range in our proof determined by what scores are even possible
+    mpz_class proofBits =
+        log2(
+            servers->get_num_clients() * MAX_ALLOWED_VOTE -
+            threshold.toInt());
+    
+    // Don't risk a situation that would divulge our private key
+    if (proofBits <= 1)
+        proofBits = 2;
 
+    // This seems weird, but remember our base is A_t^r, not g^t
     std::vector<Scalar> masksPerBit;
-    masksPerBit.push_back(Scalar());
+    masksPerBit.push_back(inversePrivateKey);
     for (size_t i = 1; i < proofBits; i++)
     {
         Scalar currMask;
         currMask.set_random();
 
         masksPerBit.push_back(currMask);
-        masksPerBit[0] = masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
+        masksPerBit[0] =
+            masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
     }
 
+    // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
     for (size_t i = 0; i < proofBits; i++)
     {
         Proof currProof;
-        std::stringstream oracleInput;
-        oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR;
+        Curvepoint g, h, c, c_a, c_b;
+        g = currentEncryptedScore.mask;
+        h = EL_GAMAL_BLIND_GENERATOR;
     
-        mpz_class currBit = proofVal & (1 << i);
+        mpz_class currBit = bit(proofVal & (1 << i));
+        Scalar a, s, t, m, r;
+        a.set_random();
+        s.set_random();
+        t.set_random();
+        m = Scalar(currBit);
+        r = masksPerBit[i];
         
-        Curvepoint currentCommitment = currentFreshGenerator * masksPerBit[i] + EL_GAMAL_BLIND_GENERATOR * Scalar(currBit);
-        currProof.partialUniversals.push_back(currentCommitment);
-        oracleInput << currentCommitment;
+        c = g * r + h * m;
+        currProof.partialUniversals.push_back(c);
 
-        if (currBit)
-        {
-            Scalar u_0, c, c_0, c_1, z_0, z_1;
-            u_0.set_random();
-            c_1.set_random();
-            z_1.set_random();
-
-            Curvepoint U_0 = currentFreshGenerator * u_0;
-            Curvepoint U_1 = currentFreshGenerator * z_1 - currentCommitment * c_1 + EL_GAMAL_BLIND_GENERATOR;
-            currProof.initParts.push_back(U_0);
-            currProof.initParts.push_back(U_1);
-            oracleInput << U_0 << U_1;
-
-            c = oracle(oracleInput.str());
-            c_0 = c.curveSub(c_1);
-            z_0 = c_0.curveMult(masksPerBit[i]).curveAdd(u_0);
-
-            currProof.challengeParts.push_back(c_0);
-            currProof.challengeParts.push_back(c_1);
-            currProof.responseParts.push_back(z_0);
-            currProof.responseParts.push_back(z_1);
-        }
-        else
-        {
-            Scalar u_1, c, c_0, c_1, z_0, z_1;
-            u_1.set_random();
-            c_0.set_random();
-            z_0.set_random();
-
-            Curvepoint U_0 = currentFreshGenerator * z_0 - currentCommitment * c_0;
-            Curvepoint U_1 = currentFreshGenerator * u_1;
-            currProof.initParts.push_back(U_0);
-            currProof.initParts.push_back(U_1);
-            oracleInput << U_0 << U_1;
-
-            c = oracle(oracleInput.str());
-            c_1 = c.curveSub(c_0);
-            z_1 = c_1.curveMult(masksPerBit[i]).curveAdd(u_1);
-
-            currProof.challengeParts.push_back(c_0);
-            currProof.challengeParts.push_back(c_1);
-            currProof.responseParts.push_back(z_0);
-            currProof.responseParts.push_back(z_1);
-        }
+        c_a = g * s + h * a;
+
+        Scalar am = a.curveMult(m);
+        c_b = g * t + h * am;
+
+        std::stringstream oracleInput;
+        oracleInput << g << h << c << c_a << c_b;
+
+        Scalar x = oracle(oracleInput.str());
+        currProof.challengeParts.push_back(x);
+
+        Scalar f, z_a, z_b;
+        Scalar mx = m.curveMult(x);
+        f = mx.curveAdd(a);
+
+        Scalar rx = r.curveMult(x);
+        z_a = rx.curveAdd(s);
+
+        Scalar x_f = x.curveSub(f);
+        Scalar r_x_f = r.curveMult(x_f);
+        z_b = r_x_f.curveAdd(t);
+
+        currProof.responseParts.push_back(f);
+        currProof.responseParts.push_back(z_a);
+        currProof.responseParts.push_back(z_b);
 
         retval.push_back(currProof);
     }
@@ -244,56 +247,90 @@ std::vector<Proof> PrsonaClient::generate_reputation_proof(
     return retval;
 }
 
-// TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
+// A pretty straightforward range proof (verification)
 bool PrsonaClient::verify_reputation_proof(
     const std::vector<Proof>& pi,
     const Curvepoint& shortTermPublicKey,
     const Scalar& threshold) const
 {
+    // Reject outright if there's no proof to check
     if (pi.empty())
+    {
+        std::cerr << "Proof was empty, aborting." << std::endl;
         return false;
+    }
 
+    // Base case
     if (!CLIENT_IS_MALICIOUS)
         return pi[0].basic == "PROOF";
 
+    // User should be able to prove they are who they say they are
     if (!verify_ownership_proof(pi[0], shortTermPublicKey))
+    {
+        std::cerr << "Schnorr proof failed, aborting." << std::endl;
+        return false;
+    }
+
+    // Get the encrypted score in question from the servers
+    Proof serverProof;
+    EGCiphertext encryptedScore =
+        servers->get_current_tally(serverProof, shortTermPublicKey);
+
+    // Rough for the prover but if the server messes up,
+    // no way to prove the thing anyways
+    if (!verify_valid_tally_proof(serverProof, encryptedScore))
+    {
+        std::cerr << "Server error prevented proof from working, aborting." << std::endl;
         return false;
+    }
 
+    // X is the thing we're going to be checking in on throughout
+    // to try to get our score commitment back in the end.
     Curvepoint X;
     for (size_t i = 1; i < pi.size(); i++)
     {
-        X = X + pi[i].partialUniversals[0] * Scalar(1 << (i - 1));
+        Curvepoint c, g, h;
+        c = pi[i].partialUniversals[0];
+        g = encryptedScore.mask;
+        h = EL_GAMAL_BLIND_GENERATOR;
 
-        std::stringstream oracleInput;
-        oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR << pi[i].partialUniversals[0];
-        oracleInput << pi[i].initParts[0] << pi[i].initParts[1];
+        X = X + c * Scalar(1 << (i - 1));
 
-        Scalar c = oracle(oracleInput.str());
+        Scalar x, f, z_a, z_b;
+        x = pi[i].challengeParts[0];
+        f = pi[i].responseParts[0];
+        z_a = pi[i].responseParts[1];
+        z_b = pi[i].responseParts[2];
 
-        if (c != pi[i].challengeParts[0] + pi[i].challengeParts[1])
-            return false;
+        // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
+        Curvepoint c_a, c_b;
+        c_a = g * z_a + h * f - c * x;
+        Scalar x_f = x.curveSub(f);
+        c_b = g * z_b - c * x_f;
 
-        if (currentFreshGenerator * pi[i].responseParts[0] != pi[i].initParts[0] + pi[i].partialUniversals[0] * pi[i].challengeParts[0])
-            return false;
+        std::stringstream oracleInput;
+        oracleInput << g << h << c << c_a << c_b;
 
-        if (currentFreshGenerator * pi[i].responseParts[1] != pi[i].initParts[1] + pi[i].partialUniversals[0] * pi[i].challengeParts[1] - EL_GAMAL_BLIND_GENERATOR)
+        if (oracle(oracleInput.str()) != pi[i].challengeParts[0])
+        {
+            std::cerr << "0 or 1 proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
             return false;
+        }
     }
 
-    Proof serverProof;
-    EGCiphertext encryptedScore = servers->get_current_tally(serverProof, shortTermPublicKey);
-
-    if (!verify_valid_tally_proof(serverProof, encryptedScore))
-        return false;
-
     Scalar negThreshold;
     negThreshold = Scalar(0).curveSub(threshold);
 
-    Curvepoint scoreCommitment = encryptedScore.encryptedMessage + EL_GAMAL_BLIND_GENERATOR * negThreshold;
-    if (X != scoreCommitment)
-        return false;
+    Curvepoint scoreCommitment =
+        encryptedScore.encryptedMessage +
+        EL_GAMAL_BLIND_GENERATOR * negThreshold;
+    
+    return X == scoreCommitment;
+}
 
-    return true;
+Scalar PrsonaClient::get_score() const
+{
+    return currentScore;
 }
 
 /*********************
@@ -364,7 +401,7 @@ Proof PrsonaClient::generate_ownership_proof() const
     Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
 
     retval.basic = "PROOF";
-    retval.initParts.push_back(u);
+    retval.challengeParts.push_back(c);
     retval.responseParts.push_back(z);
 
     return retval;
@@ -377,15 +414,15 @@ bool PrsonaClient::verify_ownership_proof(
     if (!CLIENT_IS_MALICIOUS)
         return pi.basic == "PROOF";
 
-    Curvepoint u = pi.initParts[0];
+    Scalar c = pi.challengeParts[0];
+    Scalar z = pi.responseParts[0];
+
+    Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
 
     std::stringstream oracleInput;
     oracleInput << currentFreshGenerator << shortTermPublicKey << u;
-    Scalar c = oracle(oracleInput.str());
-
-    Scalar z = pi.responseParts[0];
-
-    return (currentFreshGenerator * z) == (shortTermPublicKey * c + u);
+    
+    return c == oracle(oracleInput.str());
 }
 
 /*

+ 41 - 4
prsona/src/main.cpp

@@ -145,8 +145,10 @@ double epoch(
         // Make the actual votes to give to the servers
         Proof pi;
         Curvepoint shortTermPublicKey = users[i].get_short_term_public_key(pi);
-        vector<CurveBipoint> encryptedVotes = users[i].make_votes(
-                                                pi, votes, replace);
+        vector<CurveBipoint> encryptedVotes =
+            servers.get_current_votes_by(pi, shortTermPublicKey);
+        encryptedVotes = users[i].make_votes(
+                            pi, encryptedVotes, votes, replace);
 
         // Give the servers these new votes
         servers.receive_vote(pi, encryptedVotes, shortTermPublicKey);
@@ -169,6 +171,32 @@ double epoch(
     return time_span.count();
 }
 
+void reputation_proof_attempt(default_random_engine& generator, const PrsonaClient& a, const PrsonaClient& b)
+{
+    mpz_class aScore = a.get_score().toInt();
+    int i = 0;
+    while (i < aScore)
+        i++;
+
+    uniform_int_distribution<int> thresholdDistribution(0, i);
+    Scalar threshold(thresholdDistribution(generator));
+
+    cout << "User A's score:            " << aScore << endl;
+    cout << "User A's chosen threshold: " << threshold << endl;
+
+    Proof pi;
+    Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
+    vector<Proof> repProof = a.generate_reputation_proof(threshold);
+    if (b.verify_reputation_proof(repProof, shortTermPublicKey, threshold))
+    {
+        cout << "User A proved their reputation to user B!" << endl;
+    }
+    else
+    {
+        cout << "User A failed to prove their reputation to user B!" << endl;
+    }
+}
+
 int main(int argc, char *argv[])
 {
     initialize_prsona_classes();
@@ -179,7 +207,7 @@ int main(int argc, char *argv[])
     size_t numRounds = 3;
     size_t numVotesPerRound = 3;
     bool maliciousServers = false;
-    bool maliciousUsers = false;
+    bool maliciousClients = true;
     string seedStr = "seed";
 
     // Potentially accept command line inputs
@@ -205,7 +233,7 @@ int main(int argc, char *argv[])
         PrsonaServer::set_server_malicious();
         PrsonaClient::set_server_malicious();
     }
-    if (maliciousUsers)
+    if (maliciousClients)
     {
         PrsonaServer::set_client_malicious();
         PrsonaClient::set_client_malicious();
@@ -235,5 +263,14 @@ int main(int argc, char *argv[])
         cout << "Server computation: " << timing << " seconds" << endl;
     }
 
+    uniform_int_distribution<int> userDistribution(0, numUsers - 1);
+    int user_a = userDistribution(generator);
+    int user_b = user_a;
+    while (user_b == user_a)
+        user_b = userDistribution(generator);
+
+    cout << "Attempting a proof of reputation" << endl;
+    reputation_proof_attempt(generator, users[user_a], users[user_b]);
+
     return 0;
 }

+ 25 - 35
prsona/src/serverEntity.cpp

@@ -59,6 +59,11 @@ Curvepoint PrsonaServerEntity::get_fresh_generator() const
     return retval;
 }
 
+size_t PrsonaServerEntity::get_num_clients() const
+{
+    return servers[0].currentPseudonyms.size();
+}
+
 /*
  * ENCRYPTED DATA GETTERS
  */
@@ -139,8 +144,8 @@ void PrsonaServerEntity::receive_vote(
     const std::vector<CurveBipoint>& votes,
     const Curvepoint& shortTermPublicKey)
 {
-    for (size_t j = 0; j < servers.size(); j++)
-        servers[j].receive_vote(pi, votes, shortTermPublicKey);
+    for (size_t i = 0; i < servers.size(); i++)
+        servers[i].receive_vote(pi, votes, shortTermPublicKey);
 }
 
 // After tallying scores and new vote matrix,
@@ -158,10 +163,6 @@ void PrsonaServerEntity::transmit_updates(PrsonaClient& currUser) const
 
     EGCiphertext score = get_current_tally(proofOfScore, shortTermPublicKey);
     currUser.receive_vote_tally(proofOfScore, score);
-
-    std::vector<CurveBipoint> encryptedVotes = get_current_votes_by(
-            proofOfCorrectVotes, shortTermPublicKey);
-    currUser.receive_encrypted_votes(proofOfCorrectVotes, encryptedVotes);
 }
 
 /*
@@ -189,16 +190,13 @@ void PrsonaServerEntity::epoch(Proof& pi)
             currentUserEncryptedTallies,
             currentTallyProofs,
             voteMatrix);
-        if (i < servers.size() - 1)
-        {
-            servers[i + 1].import_updates(
-                pi,
-                previousVoteTally,
-                currentPseudonyms,
-                currentUserEncryptedTallies,
-                currentTallyProofs,
-                voteMatrix);
-        }
+        servers[(i + 1) % servers.size()].import_updates(
+            pi,
+            previousVoteTally,
+            currentPseudonyms,
+            currentUserEncryptedTallies,
+            currentTallyProofs,
+            voteMatrix);
     }
 
     /* Imagine that server 0 is encrypting these, then would do a ZKP that it
@@ -206,13 +204,6 @@ void PrsonaServerEntity::epoch(Proof& pi)
      * knows. Everyone else then adds a mask and proves they added a secret mask
      * to the committed values. */
     currentUserEncryptedTallies = tally_scores(currentTallyProofs, nextGenerator);
-    servers[0].import_updates(
-        pi,
-        previousVoteTally,
-        currentPseudonyms,
-        currentUserEncryptedTallies,
-        currentTallyProofs,
-        voteMatrix);
     
     // go from A_0.5 to A_1
     for (size_t i = 0; i < servers.size(); i++)
@@ -224,20 +215,17 @@ void PrsonaServerEntity::epoch(Proof& pi)
             currentUserEncryptedTallies,
             currentTallyProofs,
             voteMatrix);
-        if (i < servers.size() - 1)
-        {
-            servers[i + 1].import_updates(
-                pi,
-                previousVoteTally,
-                currentPseudonyms,
-                currentUserEncryptedTallies,
-                currentTallyProofs,
-                voteMatrix);
-        }
+        servers[(i + 1) % servers.size()].import_updates(
+            pi,
+            previousVoteTally,
+            currentPseudonyms,
+            currentUserEncryptedTallies,
+            currentTallyProofs,
+            voteMatrix);
     }
 
     // At the end, make sure all servers have same information
-    for (size_t i = 0; i < servers.size() - 1; i++)
+    for (size_t i = 1; i < servers.size() - 1; i++)
     {
         servers[i].import_updates(
             pi,
@@ -272,7 +260,7 @@ std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
     mpz_class maxScorePossibleThisRound =
         servers[0].get_max_possible_score(maxScoreProof).toInt() *
         MAX_ALLOWED_VOTE;
-    mpz_class topOfScoreRange = retval.size() * MAX_ALLOWED_VOTE;
+    mpz_class topOfScoreRange = decryptedTalliedScores.size() * MAX_ALLOWED_VOTE;
 
     for (size_t i = 0; i < decryptedTalliedScores.size(); i++)
     {
@@ -298,6 +286,8 @@ std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
             (PrsonaServer::get_blinding_generator() * decryptedTalliedScores[i]);
     }
 
+    servers[0].currentUserEncryptedTallies = retval;
+    servers[0].currentTallyProofs = tallyProofs;
     return retval;
 }