Browse Source

Added servers checking each other when they add a new user to not allow funny business

tristangurtler 3 years ago
parent
commit
ae49c1f1cf

+ 8 - 3
prsona/inc/EGCiphertext.hpp

@@ -1,9 +1,14 @@
 #ifndef __EG_CIPHERTEXT_HPP
 #define __EG_CIPHERTEXT_HPP
 
-struct EGCiphertext {    
-    Curvepoint mask;
-    Curvepoint encryptedMessage;
+#include "Curvepoint.hpp"
+
+class EGCiphertext {
+    public:
+        Curvepoint mask;
+        Curvepoint encryptedMessage;
+
+        bool operator==(const EGCiphertext& other) const;
 };
 
 #endif

+ 1 - 0
prsona/inc/client.hpp

@@ -27,6 +27,7 @@ class PrsonaClient : public PrsonaBase {
 
         // BASIC PUBLIC SYSTEM INFO GETTERS
         Curvepoint get_short_term_public_key() const;
+        Curvepoint get_short_term_public_key(bool unused) const;
         Curvepoint get_short_term_public_key(Proof &pi) const;
 
         // SERVER INTERACTIONS

+ 8 - 1
prsona/inc/server.hpp

@@ -103,9 +103,16 @@ class PrsonaServer : public PrsonaBase {
             Proof& pi, const Curvepoint& nextGenerator);
 
         // DATA MAINTENANCE
+        bool import_new_user_update(
+            const std::vector<Proof>& pi,
+            const std::vector<TwistBipoint>& otherPreviousVoteTallies,
+            const std::vector<Curvepoint>& otherCurrentPseudonyms,
+            const std::vector<EGCiphertext>& otherCurrentUserEncryptedTallies,
+            const std::vector<std::vector<CurveBipoint>>& otherVoteMatrix
+        );
         void import_updates(
             const Proof& pi,
-            const std::vector<TwistBipoint>& otherPreviousVoteTally,
+            const std::vector<TwistBipoint>& otherPreviousVoteTallies,
             const std::vector<Curvepoint>& otherCurrentPseudonyms,
             const std::vector<EGCiphertext>& otherCurrentUserEncryptedTallies,
             const std::vector<std::vector<CurveBipoint>>& otherVoteMatrix

+ 6 - 0
prsona/src/EGCiphertext.cpp

@@ -0,0 +1,6 @@
+#include "EGCiphertext.hpp"
+
+bool EGCiphertext::operator==(const EGCiphertext& other) const
+{
+    return mask == other.mask && encryptedMessage == other.encryptedMessage;
+}

+ 2 - 0
prsona/src/base.cpp

@@ -128,6 +128,8 @@ size_t PrsonaBase::binary_search(
             lo = mid + 1;
         else if (index == list[mid])
             return mid;
+        else if (mid == lo)
+            return lo;
         else hi = mid - 1;
     }
 

+ 15 - 0
prsona/src/client.cpp

@@ -38,6 +38,15 @@ Curvepoint PrsonaClient::get_short_term_public_key() const
     return currentFreshGenerator * longTermPrivateKey;
 }
 
+Curvepoint PrsonaClient::get_short_term_public_key(bool unused) const
+{
+    std::cout << "g^r: " << std::hex << currentFreshGenerator << std::dec << std::endl;
+    std::cout << "ltsk: " << longTermPrivateKey << std::endl;
+    std::cout << "stpk: " << std::hex << currentFreshGenerator * longTermPrivateKey << std::dec << std::endl;
+
+    return currentFreshGenerator * longTermPrivateKey;
+}
+
 Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
 {
     pi = generate_ownership_proof();
@@ -93,7 +102,10 @@ bool PrsonaClient::receive_fresh_generator(
     const std::vector<Proof>& pi, const Curvepoint& freshGenerator)
 {
     if (!verify_generator_proof(pi, freshGenerator, servers->get_num_servers()))
+    {
+        std::cerr << "Issue verifying fresh generator proof." << std::endl;
         return false;
+    }
 
     currentFreshGenerator = freshGenerator;
     return true;
@@ -184,6 +196,9 @@ bool PrsonaClient::receive_new_user_data(const std::vector<Proof>& mainProof)
 
     currentEncryptedScore = userEncryptedScore;
     currentScore = decrypt_score(userEncryptedScore);
+
+    // std::cout << "g^r: " << std::hex << currentFreshGenerator << std::dec << std::endl;
+
     return true;
 }
 

+ 2 - 1
prsona/src/main.cpp

@@ -61,6 +61,7 @@ vector<double> make_votes(
         }
         shuffle(replaces.begin(), replaces.end(), generator);
 
+        
         Proof ownerProof;
         Curvepoint shortTermPublicKey =
             users[i].get_short_term_public_key(ownerProof);
@@ -326,8 +327,8 @@ int main(int argc, char *argv[])
             elGamalBlindGenerator,
             bgnPublicKey,
             &servers);
-        servers.add_new_client(currUser);
         users.push_back(currUser);
+        servers.add_new_client(users[i]);
     }
 
     // Seeded randomness for random votes used in epoch

+ 82 - 0
prsona/src/server.cpp

@@ -84,6 +84,15 @@ std::vector<CurveBipoint> PrsonaServer::get_current_votes_by(
     size_t voteSubmitter = binary_search(shortTermPublicKey);
     retval = voteMatrix[voteSubmitter];
 
+    // if (currentPseudonyms[voteSubmitter] != shortTermPublicKey)
+    // {
+    //     std::cout << "Query user: " << std::hex
+    //         << shortTermPublicKey << std::dec << std::endl;
+    //     for (size_t i = 0; i < currentPseudonyms.size(); i++)
+    //         std::cout << "User " << i + 1 << " of " << currentPseudonyms.size()
+    //             << ": " << std::hex << currentPseudonyms[i] << std::dec << std::endl;
+    // }
+
     pi = generate_valid_vote_row_proof();
     return retval;
 }
@@ -401,6 +410,79 @@ void PrsonaServer::break_down_midway_pseudonyms(
  * DATA MAINTENANCE
  */
 
+bool PrsonaServer::import_new_user_update(
+    const std::vector<Proof>& pi,
+    const std::vector<TwistBipoint>& otherPreviousVoteTallies,
+    const std::vector<Curvepoint>& otherCurrentPseudonyms,
+    const std::vector<EGCiphertext>& otherCurrentUserEncryptedTallies,
+    const std::vector<std::vector<CurveBipoint>>& otherVoteMatrix)
+{
+    size_t newIndex = 0;
+    if (!currentPseudonyms.empty())
+        while (otherCurrentPseudonyms[newIndex] == currentPseudonyms[newIndex])
+            newIndex++;
+
+    Curvepoint shortTermPublicKey = otherCurrentPseudonyms[newIndex];
+
+    bool flag = verify_proof_of_added_user(
+        pi,
+        currentFreshGenerator,
+        shortTermPublicKey,
+        elGamalBlindGenerator,
+        bgnSystem.get_public_key().get_bipoint_curvegen(),
+        bgnSystem.get_public_key().get_bipoint_curve_subgroup_gen(),
+        bgnSystem.get_public_key().get_bipoint_twistgen(),
+        bgnSystem.get_public_key().get_bipoint_twist_subgroup_gen(),
+        newIndex,
+        otherCurrentUserEncryptedTallies[newIndex],
+        otherPreviousVoteTallies[newIndex],
+        otherVoteMatrix);
+
+    if (!flag)
+    {
+        std::cerr << "Other server added new user invalidly, aborting." << std::endl;
+        return false;
+    }
+
+    for (size_t i = 0; i < otherCurrentPseudonyms.size(); i++)
+    {
+        if (i == newIndex)
+            continue;
+
+        size_t otherI = (i > newIndex ? i - 1 : i);
+
+        flag = flag && otherCurrentPseudonyms[i] ==
+                currentPseudonyms[otherI];
+        flag = flag && otherCurrentUserEncryptedTallies[i] ==
+                currentUserEncryptedTallies[otherI];
+        flag = flag && otherPreviousVoteTallies[i] ==
+                previousVoteTallies[otherI];
+
+        for (size_t j = 0; j < otherCurrentPseudonyms.size(); j++)
+        {
+            if (j == newIndex)
+                continue;
+
+            size_t otherJ = (j > newIndex ? j - 1 : j);
+            flag = flag && otherVoteMatrix[i][j] ==
+                    voteMatrix[otherI][otherJ];
+        }
+    }
+
+    if (!flag)
+    {
+        std::cerr << "Other server illicitly changed other value during new user add." << std::endl;
+        return false;
+    }
+
+    previousVoteTallies = otherPreviousVoteTallies;
+    currentPseudonyms = otherCurrentPseudonyms;
+    currentUserEncryptedTallies = otherCurrentUserEncryptedTallies;
+    voteMatrix = otherVoteMatrix;
+
+    return true;
+}
+
 void PrsonaServer::import_updates(
     const Proof& pi,
     const std::vector<TwistBipoint>& otherPreviousVoteTallies,

+ 2 - 5
prsona/src/serverEntity.cpp

@@ -264,15 +264,12 @@ void PrsonaServerEntity::add_new_client(PrsonaClient& newUser, size_t which)
         currentPseudonyms,
         currentUserEncryptedTallies,
         voteMatrix);
-    
-    // FIX ME!!!!!!!
-    Proof unused("PROOF");
 
     for (size_t j = 1; j < servers[which].get_num_servers(); j++)
     {
         size_t index = (which + j) % servers[which].get_num_servers();
-        servers[index].import_updates(
-            unused,
+        servers[index].import_new_user_update(
+            proofOfCorrectAddition,
             previousVoteTally,
             currentPseudonyms,
             currentUserEncryptedTallies,