Browse Source

Refactoring: put proof code (and global variables) in one shared class the other things inherit from

tristangurtler 3 years ago
parent
commit
f7faf80c5e

+ 149 - 0
prsona/inc/base.hpp

@@ -0,0 +1,149 @@
+#ifndef __PRSONA_BASE_HPP
+#define __PRSONA_BASE_HPP
+
+#include <vector>
+
+#include "Curvepoint.hpp"
+#include "Bipoint.hpp"
+#include "Scalar.hpp"
+
+#include "EGCiphertext.hpp"
+#include "proof.hpp"
+
+class PrsonaBase {
+    public:
+        static size_t MAX_ALLOWED_VOTE;
+        
+        // SETUP FUNCTIONS
+        static void init();
+        static void set_server_malicious();
+        static void set_client_malicious();
+
+        // CONST GETTERS
+        static size_t get_max_allowed_vote();
+        Curvepoint get_blinding_generator() const;
+        Curvepoint get_blinding_generator(std::vector<Proof>& pi) const;
+
+    protected:
+        // Essentially constants, true for both servers and clients
+        static Curvepoint EL_GAMAL_GENERATOR;
+        static Scalar SCALAR_N;
+        static Scalar DEFAULT_TALLY;
+        static Scalar DEFAULT_VOTE;
+
+        static bool SERVER_IS_MALICIOUS;
+        static bool CLIENT_IS_MALICIOUS;
+        
+        std::vector<Proof> elGamalBlindGeneratorProof;
+        Curvepoint elGamalBlindGenerator;
+
+        // PRIVATE ELEMENT SETTER
+        bool set_EG_blind_generator(
+            const std::vector<Proof>& pi,
+            const Curvepoint& currGenerator,
+            size_t numServers);
+
+        // SCHNORR PROOFS
+        Proof schnorr_generation(
+            const Curvepoint& generator,
+            const Curvepoint& commitment,
+            const Scalar& log
+        ) const;
+
+        bool schnorr_verification(
+            const Curvepoint& generator,
+            const Curvepoint& commitment,
+            const Scalar& c,
+            const Scalar& z
+        ) const;
+
+        // OWNERSHIP PROOFS
+        Proof generate_ownership_proof(
+            const Curvepoint& generator,
+            const Curvepoint& commitment,
+            const Scalar& log
+        ) const;
+
+        bool verify_ownership_proof(
+            const Proof& pi,
+            const Curvepoint& generator,
+            const Curvepoint& commitment
+        ) const;
+
+        // ITERATED SCHNORR PROOFS
+        Proof add_to_generator_proof(
+            const Curvepoint& currGenerator, 
+            const Scalar& seed
+        ) const;
+
+        bool verify_generator_proof(
+            const std::vector<Proof>& pi,
+            const Curvepoint& currGenerator,
+            size_t numServers
+        ) const;
+
+        // REPUTATION PROOFS
+        std::vector<Proof> generate_reputation_proof(
+            const Proof& ownershipProof,
+            const EGCiphertext& commitment,
+            const Scalar& currentScore,
+            const Scalar& threshold,
+            const Scalar& inverseKey,
+            size_t numClients
+        ) const;
+
+        bool verify_reputation_proof(
+            const std::vector<Proof>& pi,
+            const Curvepoint& generator,
+            const Curvepoint& owner,
+            const EGCiphertext& commitment,
+            const Scalar& threshold
+        ) const;
+
+        // VALID VOTE PROOFS
+        std::vector<Proof> generate_vote_proof(
+            const Proof& ownershipProof,
+            const CurveBipoint& g,
+            const CurveBipoint& h,
+            const std::vector<bool>& replaces,
+            const std::vector<CurveBipoint>& oldEncryptedVotes,
+            const std::vector<CurveBipoint>& newEncryptedVotes,
+            const std::vector<Scalar>& seeds,
+            const std::vector<Scalar>& votes
+        ) const;
+
+        bool verify_vote_proof(
+            const CurveBipoint& g,
+            const CurveBipoint& h,
+            const std::vector<Proof>& pi,
+            const std::vector<CurveBipoint>& oldEncryptedVotes,
+            const std::vector<CurveBipoint>& newEncryptedVotes,
+            const Curvepoint& freshGenerator,
+            const Curvepoint& owner
+        ) const;
+
+        // EPOCH PROOFS
+        bool verify_update_proof(
+            const Proof& pi
+        ) const;
+
+        // SERVER AGREEMENT PROOFS
+        Proof generate_valid_default_tally_proof() const;
+        Proof generate_valid_fresh_generator_proof() const;
+        Proof generate_votes_valid_proof() const;
+        Proof generate_proof_of_added_user() const;
+        Proof generate_score_proof() const;
+        Proof generate_proof_of_correct_tally() const;
+        Proof generate_proof_of_correct_sum() const;
+        Proof generate_proof_of_shuffle() const;
+        Proof generate_valid_pseudonyms_proof() const;
+
+        bool verify_valid_tally_proof(const Proof& pi) const;
+        bool verify_score_proof(const Proof& pi) const;
+        bool verify_default_tally_proof(const Proof& pi) const;
+        bool verify_default_votes_proof(const Proof& pi) const;
+        bool verify_valid_votes_proof(const Proof& pi) const;
+        bool verify_valid_pseudonyms_proof(const Proof& pi) const;
+};
+
+#endif

+ 20 - 31
prsona/inc/client.hpp

@@ -8,6 +8,7 @@
 #include "Scalar.hpp"
 #include "BGN.hpp"
 
+#include "base.hpp"
 #include "EGCiphertext.hpp"
 #include "proof.hpp"
 
@@ -15,18 +16,14 @@
 // which is needed in some proofs
 class PrsonaServerEntity;
 
-class PrsonaClient {
+class PrsonaClient : public PrsonaBase {
     public:
         // CONSTRUCTORS
         PrsonaClient(
             const BGNPublicKey& serverPublicKey,
+            const std::vector<Proof>& generatorProof,
             const Curvepoint& elGamalBlindGenerator,
-            const PrsonaServerEntity *servers);
-
-        // SETUP FUNCTIONS
-        static void init();
-        static void set_server_malicious();
-        static void set_client_malicious();
+            const PrsonaServerEntity* servers);
 
         // BASIC PUBLIC SYSTEM INFO GETTERS
         Curvepoint get_short_term_public_key(Proof &pi) const;
@@ -39,28 +36,35 @@ class PrsonaClient {
             const std::vector<Scalar>& votes,
             const std::vector<bool>& replaces
         ) const;
-        void receive_fresh_generator(const Curvepoint& freshGenerator);
+        bool receive_fresh_generator(
+            const std::vector<Proof>& pi, const Curvepoint& freshGenerator);
         void receive_vote_tally(const Proof& pi, const EGCiphertext& score);
-        
+
         // REPUTATION PROOFS
         std::vector<Proof> generate_reputation_proof(
-            const Scalar& threshold) const;
+            const Scalar& threshold
+        ) const;
         bool verify_reputation_proof(
             const std::vector<Proof>& pi,
             const Curvepoint& shortTermPublicKey,
-            const Scalar& threshold) const;
+            const Scalar& threshold
+        ) const;
 
+        // NEEDED FOR TESTING PROOFS
         Scalar get_score() const;
 
+    protected:
+        // REQUIRED BY BASE CLASS
+        EGCiphertext get_current_tally(
+            Proof& pi, const Curvepoint& shortTermPublicKey) const;
+
     private:
         // Constants for clients
-        static Curvepoint EL_GAMAL_GENERATOR;
         static bool SERVER_IS_MALICIOUS;
         static bool CLIENT_IS_MALICIOUS;
         
         // Things bound to the servers permanently
         const BGNPublicKey serverPublicKey;
-        const Curvepoint elGamalBlindGenerator;
         const PrsonaServerEntity *servers;
 
         // Things bound to the servers (but change regularly)
@@ -84,30 +88,15 @@ class PrsonaClient {
 
         // OWNERSHIP OF STPK PROOFS
         Proof generate_ownership_proof() const;
-        bool verify_ownership_proof(
-            const Proof& pi, const Curvepoint& shortTermPublicKey) const;
-
-        // PROOF VERIFICATION
-        bool verify_score_proof(const Proof& pi) const;
-        bool verify_generator_proof(
-            const Proof& pi, const Curvepoint& generator) const;
-        bool verify_default_tally_proof(
-            const Proof& pi, const EGCiphertext& generator) const;
-        bool verify_valid_tally_proof(
-            const Proof& pi, const EGCiphertext& score) const;
-        bool verify_default_votes_proof(
-            const Proof& pi, const std::vector<CurveBipoint>& votes) const;
-        bool verify_valid_votes_proof(
-            const Proof& pi, const std::vector<CurveBipoint>& votes) const;
-
-        // PROOF GENERATION
+
+        // VALID VOTE PROOFS
         std::vector<Proof> generate_vote_proof(
             const std::vector<bool>& replaces,
             const std::vector<CurveBipoint>& oldEncryptedVotes,
             const std::vector<CurveBipoint>& newEncryptedVotes,
             const std::vector<Scalar>& seeds,
             const std::vector<Scalar>& votes
-        ) const;
+        ) const;        
 }; 
 
 #endif

+ 9 - 5
prsona/inc/proof.hpp

@@ -11,11 +11,15 @@
 #include "Scalar.hpp"
 #include "Curvepoint.hpp"
 
-struct Proof {    
-    std::string basic;
-    std::vector<Curvepoint> partialUniversals;
-    std::vector<Scalar> challengeParts;
-    std::vector<Scalar> responseParts;
+class Proof {
+    public:
+        Proof();
+        Proof(std::string basic);
+        
+        std::string basic;
+        std::vector<Curvepoint> partialUniversals;
+        std::vector<Scalar> challengeParts;
+        std::vector<Scalar> responseParts;
 };
 
 Scalar oracle(const std::string& input);

+ 21 - 60
prsona/inc/server.hpp

@@ -6,28 +6,28 @@
 #include "BGN.hpp"
 #include "Curvepoint.hpp"
 #include "Bipoint.hpp"
+
+#include "base.hpp"
 #include "EGCiphertext.hpp"
 #include "proof.hpp"
 
-class PrsonaServer {
+class PrsonaServer : public PrsonaBase {
     public:
         // CONSTRUCTORS
-        PrsonaServer();
-        PrsonaServer(const BGN& other_bgn);
-
-        // SETUP FUNCTIONS
-        static void init();
-        static void set_server_malicious();
-        static void set_client_malicious();
+        PrsonaServer(size_t numServers);
+        PrsonaServer(size_t numServers, const BGN& other_bgn);
 
         // BASIC PUBLIC SYSTEM INFO GETTERS
-        Curvepoint get_blinding_generator() const;
         BGNPublicKey get_bgn_public_key() const;
+        size_t get_num_clients() const;
+        size_t get_num_servers() const;
         
         // FRESH GENERATOR CALCULATION
         Curvepoint add_curr_seed_to_generator(
+            std::vector<Proof>& pi,
             const Curvepoint& currGenerator) const;
         Curvepoint add_next_seed_to_generator(
+            std::vector<Proof>& pi,
             const Curvepoint& currGenerator) const;
 
         // ENCRYPTED DATA GETTERS
@@ -35,6 +35,7 @@ class PrsonaServer {
             Proof& pi, const Curvepoint& shortTermPublicKey) const;
         EGCiphertext get_current_tally(
             Proof& pi, const Curvepoint& shortTermPublicKey) const;
+        std::vector<Curvepoint> get_current_pseudonyms(Proof& pi) const;
 
         // CLIENT INTERACTIONS
         void add_new_client(
@@ -47,17 +48,11 @@ class PrsonaServer {
             const Curvepoint& shortTermPublicKey);
 
     private:
-        // Constants for servers
-        static Curvepoint EL_GAMAL_GENERATOR;
-        static Scalar SCALAR_N;
-        static Scalar DEFAULT_TALLY;
-        static Scalar DEFAULT_VOTE;
-        static bool SERVER_IS_MALICIOUS;
-        static bool CLIENT_IS_MALICIOUS;
+        // constants for servers
+        const size_t numServers;
 
         // Identical between all servers (but collaboratively constructed)
-        BGN bgn_system;
-        Curvepoint elGamalBlindGenerator;
+        BGN bgnSystem;
 
         // Private; different for each server
         Scalar currentSeed;
@@ -84,10 +79,15 @@ class PrsonaServer {
 
         // CONSTRUCTOR HELPERS
         const BGN& get_bgn_details() const;
-        void initialize_fresh_generator(const Curvepoint& firstGenerator);
+        bool initialize_fresh_generator(
+            const std::vector<Proof>& pi,
+            const Curvepoint& firstGenerator);
         Curvepoint add_rand_seed_to_generator(
+            std::vector<Proof>& pi,
             const Curvepoint& currGenerator) const;
-        void set_EG_blind_generator(const Curvepoint& currGenerator);
+        bool set_EG_blind_generator(
+            const std::vector<Proof>& pi,
+            const Curvepoint& currGenerator);
         
         // SCORE TALLYING
         std::vector<Scalar> tally_scores(std::vector<Proof>& tallyProofs);
@@ -132,52 +132,13 @@ class PrsonaServer {
         // BINARY SEARCH
         size_t binary_search(const Curvepoint& index) const;
 
-        // CLIENT PROOF VERIFICATION
-        bool verify_ownership_proof(
-            const Proof& pi,
-            const Curvepoint& shortTermPublicKey
-        ) const;
+        // VALID VOTE PROOFS
         bool verify_vote_proof(
             const std::vector<Proof>& pi,
             const std::vector<CurveBipoint>& oldVotes,
             const std::vector<CurveBipoint>& newVotes,
             const Curvepoint& shortTermPublicKey
         ) const;
-
-        // SERVER PROOF VERIFICATION
-        bool verify_update_proof(
-            const Proof& pi
-        ) const;
-
-        // PROOF GENERATION
-        Proof generate_valid_default_tally_proof(
-            const EGCiphertext& newUserEncryptedTally,
-            const Scalar& mask
-        ) const;
-        Proof generate_valid_fresh_generator_proof(
-            const Proof& pi
-        ) const;
-        Proof generate_votes_valid_proof(
-            const std::vector<CurveBipoint>& votes,
-            const Curvepoint& voter
-        ) const;
-        Proof generate_proof_of_added_user(
-            const Curvepoint& shortTermPublicKey
-        ) const;
-        Proof generate_score_proof(
-            const EGCiphertext& score
-        ) const;
-        Proof generate_proof_of_correct_tally(
-            const Quadripoint& BGNEncryptedTally,
-            const Scalar& decryptedTally
-        ) const;
-        Proof generate_proof_of_correct_sum(
-            const TwistBipoint& BGNEncryptedSum,
-            const Scalar& decryptedSum
-        ) const;
-        Proof generate_proof_of_shuffle(
-            const std::vector<size_t>& shuffle_order
-        ) const;
 }; 
 
 #endif

+ 23 - 2
prsona/inc/serverEntity.hpp

@@ -14,19 +14,35 @@ class PrsonaServerEntity {
 
         // BASIC PUBLIC SYSTEM INFO GETTERS
         BGNPublicKey get_bgn_public_key() const;
+        BGNPublicKey get_bgn_public_key(size_t which) const;
         Curvepoint get_blinding_generator() const;
+        Curvepoint get_blinding_generator(size_t which) const;
+        Curvepoint get_blinding_generator(std::vector<Proof>& pi) const;
+        Curvepoint get_blinding_generator(
+            std::vector<Proof>& pi, size_t which) const;
         Curvepoint get_fresh_generator() const;
+        Curvepoint get_fresh_generator(size_t which) const;
+        Curvepoint get_fresh_generator(std::vector<Proof>& pi) const;
+        Curvepoint get_fresh_generator(
+            std::vector<Proof>& pi, size_t which) const;
         size_t get_num_clients() const;
+        size_t get_num_clients(size_t which) const;
         size_t get_num_servers() const;
+        size_t get_num_servers(size_t which) const;
 
         // ENCRYPTED DATA GETTERS
         std::vector<CurveBipoint> get_current_votes_by(
             Proof& pi, const Curvepoint& shortTermPublicKey) const;
+        std::vector<CurveBipoint> get_current_votes_by(
+            Proof& pi, const Curvepoint& shortTermPublicKey, size_t which) const;
         EGCiphertext get_current_tally(
             Proof& pi, const Curvepoint& shortTermPublicKey) const;
+        EGCiphertext get_current_tally(
+            Proof& pi, const Curvepoint& shortTermPublicKey, size_t which) const;
 
         // CLIENT INTERACTIONS
         void add_new_client(PrsonaClient& newUser);
+        void add_new_client(PrsonaClient& newUser, size_t which);
         bool receive_vote(
             const std::vector<Proof>& pi,
             const std::vector<CurveBipoint>& newVotes,
@@ -37,19 +53,24 @@ class PrsonaServerEntity {
             const Curvepoint& shortTermPublicKey,
             size_t which);
         void transmit_updates(PrsonaClient& currUser) const;
+        void transmit_updates(PrsonaClient& currUser, size_t which) const;
 
         // EPOCH
         void epoch(Proof& pi);
+        void epoch(Proof& pi, size_t which);
 
     private:
         std::vector<PrsonaServer> servers;
 
         // SCORE TALLYING
         std::vector<EGCiphertext> tally_scores(
-            std::vector<Proof>& tallyProofs, const Curvepoint& nextGenerator);
+            std::vector<Proof>& tallyProofs,
+            const Curvepoint& nextGenerator,
+            size_t which);
         
         // BINARY SEARCH
-        size_t binary_search(const Curvepoint& index) const;
+        size_t binary_search(
+            const Curvepoint& shortTermPublicKey, size_t which) const;
 }; 
 
 #endif

+ 842 - 0
prsona/src/base.cpp

@@ -0,0 +1,842 @@
+#include <iostream>
+
+#include "base.hpp"
+
+extern const scalar_t bn_n;
+extern const curvepoint_fp_t bn_curvegen;
+
+/* These lines need to be here so these static variables are defined,
+ * but in C++ putting code here doesn't actually execute
+ * (or at least, with g++, whenever it would execute is not at a useful time)
+ * so we have an init() function to actually put the correct values in them. */
+Curvepoint PrsonaBase::EL_GAMAL_GENERATOR = Curvepoint();
+Scalar PrsonaBase::SCALAR_N = Scalar();
+Scalar PrsonaBase::DEFAULT_TALLY = Scalar();
+Scalar PrsonaBase::DEFAULT_VOTE = Scalar();
+
+bool PrsonaBase::SERVER_IS_MALICIOUS = false;
+bool PrsonaBase::CLIENT_IS_MALICIOUS = false;
+size_t PrsonaBase::MAX_ALLOWED_VOTE = 2;
+
+// Quick and dirty function to calculate ceil(log base 2) with mpz_class
+mpz_class log2(mpz_class x)
+{
+    mpz_class retval = 0;
+    while (x > 0)
+    {
+        retval++;
+        x = x >> 1;
+    }
+
+    return retval;
+}
+
+mpz_class bit(mpz_class x)
+{
+    return x > 0 ? 1 : 0;
+}
+
+/********************
+ * PUBLIC FUNCTIONS *
+ ********************/
+
+/*
+ * SETUP FUNCTIONS
+ */
+
+// Must be called once before any usage of this class
+void PrsonaBase::init()
+{
+    EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
+    SCALAR_N = Scalar(bn_n);
+    DEFAULT_TALLY = Scalar(1);
+    DEFAULT_VOTE = Scalar(1);
+}
+
+// Call this (once) if using malicious-security servers
+void PrsonaBase::set_server_malicious()
+{
+    SERVER_IS_MALICIOUS = true;
+}
+
+// Call this (once) if using malicious-security clients
+void PrsonaBase::set_client_malicious()
+{
+    CLIENT_IS_MALICIOUS = true;
+}
+
+/*
+ * CONST GETTERS
+ */
+
+size_t PrsonaBase::get_max_allowed_vote()
+{
+    return MAX_ALLOWED_VOTE;
+}
+
+Curvepoint PrsonaBase::get_blinding_generator() const
+{
+    return elGamalBlindGenerator;
+}
+
+Curvepoint PrsonaBase::get_blinding_generator(std::vector<Proof>& pi) const
+{
+    pi = elGamalBlindGeneratorProof;
+
+    return elGamalBlindGenerator;
+}
+
+/***********************
+ * PROTECTED FUNCTIONS *
+ ***********************/
+
+/*
+ * PRIVATE ELEMENT SETTER
+ */
+
+bool PrsonaBase::set_EG_blind_generator(
+    const std::vector<Proof>& pi,
+    const Curvepoint& currGenerator,
+    size_t numServers)
+{
+    if (!verify_generator_proof(pi, currGenerator, numServers))
+        return false;
+
+    elGamalBlindGeneratorProof = pi;
+    elGamalBlindGenerator = currGenerator;
+    return true;
+}
+
+/*
+ * SCHNORR PROOFS
+ */
+
+Proof PrsonaBase::schnorr_generation(
+    const Curvepoint& generator,
+    const Curvepoint& commitment,
+    const Scalar& log) const
+{
+    Proof retval;
+
+    std::stringstream oracleInput;
+    
+    Scalar r;
+    r.set_random();
+    
+    Curvepoint U = generator * r;
+    oracleInput << generator << commitment << U;
+
+    Scalar c = oracle(oracleInput.str());
+    Scalar z = r.curveAdd(c.curveMult(log));
+
+    retval.challengeParts.push_back(c);
+    retval.responseParts.push_back(z);
+
+    return retval;
+}
+
+bool PrsonaBase::schnorr_verification(
+    const Curvepoint& generator,
+    const Curvepoint& commitment,
+    const Scalar& c,
+    const Scalar& z) const
+{
+    Curvepoint U = generator * z - commitment * c;
+
+    std::stringstream oracleInput;
+    oracleInput << generator << commitment << U;
+    
+    return c == oracle(oracleInput.str());
+}
+
+/*
+ * OWNERSHIP PROOFS
+ */
+
+// Prove ownership of the short term public key
+Proof PrsonaBase::generate_ownership_proof(
+    const Curvepoint& generator,
+    const Curvepoint& commitment,
+    const Scalar& log) const
+{
+    if (!CLIENT_IS_MALICIOUS)
+    {
+        Proof retval;
+        retval.basic = "PROOF";
+
+        return retval;
+    }
+
+    return schnorr_generation(generator, commitment, log);
+}
+
+bool PrsonaBase::verify_ownership_proof(
+    const Proof& pi,
+    const Curvepoint& generator,
+    const Curvepoint& commitment) const
+{
+    if (!CLIENT_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    Scalar c = pi.challengeParts[0];
+    Scalar z = pi.responseParts[0];
+
+    return schnorr_verification(generator, commitment, c, z);
+}
+
+/*
+ * ITERATED SCHNORR PROOFS
+ */
+
+Proof PrsonaBase::add_to_generator_proof(
+    const Curvepoint& currGenerator, 
+    const Scalar& seed) const
+{
+    Proof retval;
+    if (!CLIENT_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    Curvepoint nextGenerator = currGenerator * seed;
+    retval = schnorr_generation(currGenerator, nextGenerator, seed);
+
+    retval.partialUniversals.push_back(currGenerator);
+    return retval;
+}
+
+bool PrsonaBase::verify_generator_proof(
+    const std::vector<Proof>& pi,
+    const Curvepoint& currGenerator,
+    size_t numServers) const
+{
+    if (pi.size() != numServers || numServers == 0)
+        return false;
+
+    bool retval = true;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        for (size_t i = 0; i < pi.size(); i++)
+            retval = retval && pi[i].basic == "PROOF";
+
+        return retval;
+    }
+
+    if (pi[0].partialUniversals[0] != EL_GAMAL_GENERATOR)
+        return false;
+
+    for (size_t i = 0; i < pi.size(); i++)
+    {
+        Curvepoint generator = pi[i].partialUniversals[0];
+        Curvepoint commitment = (i == pi.size() - 1 ?
+                                    currGenerator :
+                                    pi[i + 1].partialUniversals[0]);
+        Scalar c = pi[i].challengeParts[0];
+        Scalar z = pi[i].responseParts[0];
+
+        retval = retval && 
+            schnorr_verification(generator, commitment, c, z);
+        if (!retval)
+            std::cerr << "Error in index " << i+1 << " of " << pi.size() << std::endl;
+    }
+    
+    return retval;
+}
+
+/*
+ * REPUTATION PROOFS
+ */
+
+// A pretty straightforward range proof (generation)
+std::vector<Proof> PrsonaBase::generate_reputation_proof(
+    const Proof& ownershipProof,
+    const EGCiphertext& commitment,
+    const Scalar& currentScore,
+    const Scalar& threshold,
+    const Scalar& inverseKey,
+    size_t numClients) const
+{
+    std::vector<Proof> retval;
+
+    // Base case
+    if (!CLIENT_IS_MALICIOUS)
+    {
+        retval.push_back(Proof("PROOF"));
+
+        return retval;
+    }
+
+    // Don't even try if the user asks to make an illegitimate proof
+    if (threshold.toInt() > (numClients * MAX_ALLOWED_VOTE))
+        return retval;
+
+    // We really have two consecutive proofs in a junction.
+    // The first is to prove that we are the stpk we claim we are
+    retval.push_back(ownershipProof);
+
+    // The value we're actually using in our proof
+    mpz_class proofVal = currentScore.curveSub(threshold).toInt();
+    // Top of the range in our proof determined by what scores are even possible
+    mpz_class proofBits =
+        log2(numClients * MAX_ALLOWED_VOTE - threshold.toInt());
+    
+    // Don't risk a situation that would divulge our private key
+    if (proofBits <= 1)
+        proofBits = 2;
+
+    // This seems weird, but remember our base is A_t^r, not g^t
+    std::vector<Scalar> masksPerBit;
+    masksPerBit.push_back(inverseKey);
+    for (size_t i = 1; i < proofBits; i++)
+    {
+        Scalar currMask;
+        currMask.set_random();
+
+        masksPerBit.push_back(currMask);
+        masksPerBit[0] =
+            masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
+    }
+
+    // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
+    for (size_t i = 0; i < proofBits; i++)
+    {
+        Proof currProof;
+        Curvepoint g, h, c, c_a, c_b;
+        g = commitment.mask;
+        h = elGamalBlindGenerator;
+    
+        mpz_class currBit = bit(proofVal & (1 << i));
+        Scalar a, s, t, m, r;
+        a.set_random();
+        s.set_random();
+        t.set_random();
+        m = Scalar(currBit);
+        r = masksPerBit[i];
+        
+        c = g * r + h * m;
+        currProof.partialUniversals.push_back(c);
+
+        c_a = g * s + h * a;
+
+        Scalar am = a.curveMult(m);
+        c_b = g * t + h * am;
+
+        std::stringstream oracleInput;
+        oracleInput << g << h << c << c_a << c_b;
+
+        Scalar x = oracle(oracleInput.str());
+        currProof.challengeParts.push_back(x);
+
+        Scalar f, z_a, z_b;
+        Scalar mx = m.curveMult(x);
+        f = mx.curveAdd(a);
+
+        Scalar rx = r.curveMult(x);
+        z_a = rx.curveAdd(s);
+
+        Scalar x_f = x.curveSub(f);
+        Scalar r_x_f = r.curveMult(x_f);
+        z_b = r_x_f.curveAdd(t);
+
+        currProof.responseParts.push_back(f);
+        currProof.responseParts.push_back(z_a);
+        currProof.responseParts.push_back(z_b);
+
+        retval.push_back(currProof);
+    }
+
+    return retval;
+}
+
+// A pretty straightforward range proof (verification)
+bool PrsonaBase::verify_reputation_proof(
+    const std::vector<Proof>& pi,
+    const Curvepoint& generator,
+    const Curvepoint& owner,
+    const EGCiphertext& commitment,
+    const Scalar& threshold) const
+{
+    // Reject outright if there's no proof to check
+    if (pi.empty())
+    {
+        std::cerr << "Proof was empty, aborting." << std::endl;
+        return false;
+    }
+
+    // If the range is so big that it wraps around mod n,
+    // there's a chance the user actually made a proof for a very low reputation
+    if (pi.size() > 256)
+    {
+        std::cerr << "Proof was too big, prover could have cheated." << std::endl;
+        return false;
+    }
+
+    if (!CLIENT_IS_MALICIOUS)
+        return pi[0].basic == "PROOF";
+
+    Scalar c, z;
+    c = pi[0].challengeParts[0];
+    z = pi[0].responseParts[0];
+
+    // User should be able to prove they are who they say they are
+    if (!schnorr_verification(generator, owner, c, z))
+    {
+        std::cerr << "Schnorr proof failed, aborting." << std::endl;
+        return false;
+    }
+
+    // X is the thing we're going to be checking in on throughout
+    // to try to get our score commitment back in the end.
+    Curvepoint X;
+    for (size_t i = 1; i < pi.size(); i++)
+    {
+        Curvepoint c, g, h;
+        c = pi[i].partialUniversals[0];
+        g = commitment.mask;
+        h = elGamalBlindGenerator;
+
+        X = X + c * Scalar(1 << (i - 1));
+
+        Scalar x, f, z_a, z_b;
+        x = pi[i].challengeParts[0];
+        f = pi[i].responseParts[0];
+        z_a = pi[i].responseParts[1];
+        z_b = pi[i].responseParts[2];
+
+        // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
+        Curvepoint c_a, c_b;
+        c_a = g * z_a + h * f - c * x;
+        Scalar x_f = x.curveSub(f);
+        c_b = g * z_b - c * x_f;
+
+        std::stringstream oracleInput;
+        oracleInput << g << h << c << c_a << c_b;
+
+        if (oracle(oracleInput.str()) != pi[i].challengeParts[0])
+        {
+            std::cerr << "0 or 1 proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
+            return false;
+        }
+    }
+
+    Scalar negThreshold;
+    negThreshold = Scalar(0).curveSub(threshold);
+
+    Curvepoint scoreCommitment =
+        commitment.encryptedMessage +
+        elGamalBlindGenerator * negThreshold;
+    
+    return X == scoreCommitment;
+}
+
+/*
+ * VALID VOTE PROOFS
+ */
+
+std::vector<Proof> PrsonaBase::generate_vote_proof(
+    const Proof& ownershipProof,
+    const CurveBipoint& g,
+    const CurveBipoint& h,
+    const std::vector<bool>& replaces,
+    const std::vector<CurveBipoint>& oldEncryptedVotes,
+    const std::vector<CurveBipoint>& newEncryptedVotes,
+    const std::vector<Scalar>& seeds,
+    const std::vector<Scalar>& votes) const
+{
+    std::vector<Proof> retval;
+
+    // Base case
+    if (!CLIENT_IS_MALICIOUS)
+    {
+        retval.push_back(Proof("PROOF"));
+        
+        return retval;
+    }
+
+    // The first need is to prove that we are the stpk we claim we are
+    retval.push_back(ownershipProof);
+
+    // Then, we iterate over all votes for the proofs that they are correct
+    for (size_t i = 0; i < replaces.size(); i++)
+    {
+        std::stringstream oracleInput;
+        oracleInput << g << h << oldEncryptedVotes[i] << newEncryptedVotes[i];
+        
+        /* This proof structure is documented in my notes.
+         * It's inspired by the proof in Fig. 1 at
+         * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
+         * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
+         *
+         * The rerandomization part is just a slight variation on an
+         * ordinary Schnorr proof, so that part's less scary. */
+        if (replaces[i])     // CASE: Make new vote
+        {
+            Proof currProof;
+
+            Scalar c_r, z_r, a, s, t_1, t_2;
+            c_r.set_random();
+            z_r.set_random();
+            a.set_random();
+            s.set_random();
+            t_1.set_random();
+            t_2.set_random();
+
+            CurveBipoint U = h * z_r +
+                                oldEncryptedVotes[i] * c_r -
+                                newEncryptedVotes[i] * c_r;
+
+            CurveBipoint C_a = g * a + h * s;
+
+            Scalar power = (a.curveAdd(a)).curveMult(votes[i].curveMult(votes[i]));
+            power =
+                power.curveSub((a.curveAdd(a).curveAdd(a)).curveMult(votes[i]));
+            CurveBipoint C_b = g * power + h * t_1;
+            currProof.partialUniversals.push_back(C_b[0]);
+            currProof.partialUniversals.push_back(C_b[1]);
+
+            CurveBipoint C_c = g * a.curveMult(a.curveMult(votes[i])) +
+                                h * t_2;
+
+            oracleInput << U << C_a << C_b << C_c;
+
+            Scalar c = oracle(oracleInput.str());
+            Scalar c_n = c.curveSub(c_r);
+            currProof.challengeParts.push_back(c_r);
+            currProof.challengeParts.push_back(c_n);
+
+            Scalar f = (votes[i].curveMult(c_n)).curveAdd(a);
+            Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s);
+
+            Scalar t_1_c_n_t_2 = (t_1.curveMult(c_n)).curveAdd(t_2);
+            Scalar f_c_n = f.curveSub(c_n);
+            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+            Scalar z_nb = 
+                (seeds[i].curveMult(f_c_n).curveMult(c_n2_f)).curveAdd(
+                    t_1_c_n_t_2);
+
+            currProof.responseParts.push_back(z_r);
+            currProof.responseParts.push_back(f);
+            currProof.responseParts.push_back(z_na);
+            currProof.responseParts.push_back(z_nb);
+
+            retval.push_back(currProof);
+        }
+        else                // CASE: Rerandomize existing vote
+        {
+            Proof currProof;
+
+            Scalar u, commitmentLambda_1, commitmentLambda_2,
+                c_n, z_na, z_nb, f;
+            u.set_random();
+            commitmentLambda_1.set_random();
+            commitmentLambda_2.set_random();
+            c_n.set_random();
+            z_na.set_random();
+            z_nb.set_random();
+            f.set_random();
+
+            CurveBipoint U = h * u;
+
+            CurveBipoint C_a = g * f +
+                h * z_na -
+                newEncryptedVotes[i] * c_n;
+
+            CurveBipoint C_b = g * commitmentLambda_1 + h * commitmentLambda_2;
+            currProof.partialUniversals.push_back(C_b[0]);
+            currProof.partialUniversals.push_back(C_b[1]);
+
+            Scalar f_c_n = f.curveSub(c_n);
+            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+            CurveBipoint C_c =
+                h * z_nb -
+                newEncryptedVotes[i] * f_c_n.curveMult(c_n2_f) -
+                C_b * c_n;
+
+            oracleInput << U << C_a << C_b << C_c;
+
+            Scalar c = oracle(oracleInput.str());
+            Scalar c_r = c.curveSub(c_n);
+            currProof.challengeParts.push_back(c_r);
+            currProof.challengeParts.push_back(c_n);
+
+            Scalar z_r = u.curveAdd(c_r.curveMult(seeds[i]));
+            currProof.responseParts.push_back(z_r);
+            currProof.responseParts.push_back(f);
+            currProof.responseParts.push_back(z_na);
+            currProof.responseParts.push_back(z_nb);
+
+            retval.push_back(currProof);
+        }
+    }
+
+    return retval;
+}
+
+bool PrsonaBase::verify_vote_proof(
+    const CurveBipoint& g,
+    const CurveBipoint& h,
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& oldEncryptedVotes,
+    const std::vector<CurveBipoint>& newEncryptedVotes,
+    const Curvepoint& freshGenerator,
+    const Curvepoint& owner) const
+{
+    // Reject outright if there's no proof to check
+    if (pi.empty())
+    {
+        std::cerr << "Proof was empty, aborting." << std::endl;
+        return false;
+    }
+
+    // Base case
+    if (!CLIENT_IS_MALICIOUS)
+        return pi[0].basic == "PROOF";
+
+    // User should be able to prove they are who they say they are
+    if (!verify_ownership_proof(pi[0], freshGenerator, owner))
+    {
+        std::cerr << "Schnorr proof failed, aborting." << std::endl;
+        return false;
+    }
+
+    /* This proof structure is documented in my notes.
+     * It's inspired by the proof in Fig. 1 at
+     * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
+     * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
+     *
+     * The rerandomization part is just a slight variation on an
+     * ordinary Schnorr proof, so that part's less scary. */
+    for (size_t i = 1; i < pi.size(); i++)
+    {
+        size_t voteIndex = i - 1;
+        Curvepoint C_b_0, C_b_1;
+        C_b_0 = pi[i].partialUniversals[0];
+        C_b_1 = pi[i].partialUniversals[1];
+
+        CurveBipoint C_b(C_b_0, C_b_1);
+
+        Scalar c_r, c_n, z_r, f, z_na, z_nb;
+        c_r = pi[i].challengeParts[0];
+        c_n = pi[i].challengeParts[1];
+
+        z_r  = pi[i].responseParts[0];
+        f  = pi[i].responseParts[1];
+        z_na = pi[i].responseParts[2];
+        z_nb = pi[i].responseParts[3];
+
+        CurveBipoint U, C_a, C_c;
+        U = h * z_r +
+            oldEncryptedVotes[voteIndex] * c_r -
+            newEncryptedVotes[voteIndex] * c_r;
+        C_a = g * f + h * z_na - newEncryptedVotes[voteIndex] * c_n;
+
+        Scalar f_c_n = f.curveSub(c_n);
+        Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
+        C_c = h * z_nb -
+            newEncryptedVotes[voteIndex] * f_c_n.curveMult(c_n2_f) -
+            C_b * c_n;
+
+        std::stringstream oracleInput;
+        oracleInput << g << h
+            << oldEncryptedVotes[voteIndex] << newEncryptedVotes[voteIndex]
+            << U << C_a << C_b << C_c;
+
+        if (oracle(oracleInput.str()) != c_r.curveAdd(c_n))
+            return false;
+    }
+
+    return true;
+}
+
+/*
+ * EPOCH PROOFS
+ */
+
+bool PrsonaBase::verify_update_proof(
+    const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+/*
+ * SERVER AGREEMENT PROOFS
+ */
+
+Proof PrsonaBase::generate_valid_default_tally_proof() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_valid_fresh_generator_proof() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_votes_valid_proof() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_proof_of_added_user() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_score_proof() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_proof_of_correct_tally() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_proof_of_correct_sum() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_proof_of_shuffle() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+Proof PrsonaBase::generate_valid_pseudonyms_proof() const
+{
+    Proof retval;
+
+    if (!SERVER_IS_MALICIOUS)
+    {
+        retval.basic = "PROOF";
+        return retval;
+    }
+
+    retval.basic = "PROOF";
+    return retval;
+}
+
+bool PrsonaBase::verify_valid_tally_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+bool PrsonaBase::verify_score_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+bool PrsonaBase::verify_default_tally_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+bool PrsonaBase::verify_default_votes_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+bool PrsonaBase::verify_valid_votes_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}
+
+bool PrsonaBase::verify_valid_pseudonyms_proof(const Proof& pi) const
+{
+    if (!SERVER_IS_MALICIOUS)
+        return pi.basic == "PROOF";
+
+    return pi.basic == "PROOF";
+}

+ 39 - 451
prsona/src/client.cpp

@@ -3,35 +3,6 @@
 #include "client.hpp"
 #include "serverEntity.hpp"
 
-extern const curvepoint_fp_t bn_curvegen;
-const int MAX_ALLOWED_VOTE = 2;
-
-/* These lines need to be here so these static variables are defined,
- * but in C++ putting code here doesn't actually execute
- * (or at least, with g++, whenever it would execute is not at a useful time)
- * so we have an init() function to actually put the correct values in them. */
-Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
-bool PrsonaClient::SERVER_IS_MALICIOUS = false;
-bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
-
-// Quick and dirty function to calculate ceil(log base 2) with mpz_class
-mpz_class log2(mpz_class x)
-{
-    mpz_class retval = 0;
-    while (x > 0)
-    {
-        retval++;
-        x = x >> 1;
-    }
-
-    return retval;
-}
-
-mpz_class bit(mpz_class x)
-{
-    return x > 0 ? 1 : 0;
-}
-
 /********************
  * PUBLIC FUNCTIONS *
  ********************/
@@ -42,39 +13,22 @@ mpz_class bit(mpz_class x)
 
 PrsonaClient::PrsonaClient(
     const BGNPublicKey& serverPublicKey,
+    const std::vector<Proof>& generatorProof,
     const Curvepoint& elGamalBlindGenerator,
     const PrsonaServerEntity* servers)
     : serverPublicKey(serverPublicKey),
-        elGamalBlindGenerator(elGamalBlindGenerator),
         servers(servers),
         max_checked(0)
 {
+    set_EG_blind_generator(
+        generatorProof, elGamalBlindGenerator, servers->get_num_servers());
+
     longTermPrivateKey.set_random();
     inversePrivateKey = longTermPrivateKey.curveInverse();
 
     decryption_memoizer[elGamalBlindGenerator * max_checked] = max_checked;
 }
 
-/*
- * SETUP FUNCTIONS
- */
-
-// Must be called once before any usage of this class
-void PrsonaClient::init()
-{
-    EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
-}
-
-void PrsonaClient::set_server_malicious()
-{
-    SERVER_IS_MALICIOUS = true;
-}
-
-void PrsonaClient::set_client_malicious()
-{
-    CLIENT_IS_MALICIOUS = true;
-}
-
 /*
  * BASIC PUBLIC SYSTEM INFO GETTERS
  */
@@ -104,7 +58,7 @@ std::vector<CurveBipoint> PrsonaClient::make_votes(
     std::vector<Scalar> seeds(oldEncryptedVotes.size());
     std::vector<CurveBipoint> newEncryptedVotes(oldEncryptedVotes.size());
 
-    if (!verify_valid_votes_proof(serverProof, oldEncryptedVotes))
+    if (!verify_valid_votes_proof(serverProof))
     {
         std::cerr << "Could not verify proof of valid votes." << std::endl;
         return newEncryptedVotes;
@@ -129,16 +83,21 @@ std::vector<CurveBipoint> PrsonaClient::make_votes(
 }
 
 // Get a new fresh generator (happens at initialization and during each epoch)
-void PrsonaClient::receive_fresh_generator(const Curvepoint& freshGenerator)
+bool PrsonaClient::receive_fresh_generator(
+    const std::vector<Proof>& pi, const Curvepoint& freshGenerator)
 {
+    if (!verify_generator_proof(pi, freshGenerator, servers->get_num_servers()))
+        return false;
+
     currentFreshGenerator = freshGenerator;
+    return true;
 }
 
 // Receive a new encrypted score from the servers (each epoch)
 void PrsonaClient::receive_vote_tally(
     const Proof& pi, const EGCiphertext& score)
 {
-    if (!verify_valid_tally_proof(pi, score))
+    if (!verify_valid_tally_proof(pi))
     {
         std::cerr << "Could not verify proof of valid tally." << std::endl;
         return;
@@ -156,181 +115,34 @@ void PrsonaClient::receive_vote_tally(
 std::vector<Proof> PrsonaClient::generate_reputation_proof(
     const Scalar& threshold) const
 {
-    std::vector<Proof> retval;
-
-    // Don't even try if the user asks to make an illegitimate proof
-    if (threshold.toInt() > (servers->get_num_clients() * MAX_ALLOWED_VOTE))
-        return retval;
-
-    // Base case
-    if (!CLIENT_IS_MALICIOUS)
-    {
-        Proof currProof;
-        currProof.basic = "PROOF";
-
-        retval.push_back(currProof);
-        return retval;
-    }
-
-    // We really have two consecutive proofs in a junction.
-    // The first is to prove that we are the stpk we claim we are
-    retval.push_back(generate_ownership_proof());
-
-    // The value we're actually using in our proof
-    mpz_class proofVal = currentScore.curveSub(threshold).toInt();
-    // Top of the range in our proof determined by what scores are even possible
-    mpz_class proofBits =
-        log2(
-            servers->get_num_clients() * MAX_ALLOWED_VOTE -
-            threshold.toInt());
+    Proof ownershipProof = generate_ownership_proof();
     
-    // Don't risk a situation that would divulge our private key
-    if (proofBits <= 1)
-        proofBits = 2;
-
-    // This seems weird, but remember our base is A_t^r, not g^t
-    std::vector<Scalar> masksPerBit;
-    masksPerBit.push_back(inversePrivateKey);
-    for (size_t i = 1; i < proofBits; i++)
-    {
-        Scalar currMask;
-        currMask.set_random();
-
-        masksPerBit.push_back(currMask);
-        masksPerBit[0] =
-            masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
-    }
-
-    // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
-    for (size_t i = 0; i < proofBits; i++)
-    {
-        Proof currProof;
-        Curvepoint g, h, c, c_a, c_b;
-        g = currentEncryptedScore.mask;
-        h = elGamalBlindGenerator;
-    
-        mpz_class currBit = bit(proofVal & (1 << i));
-        Scalar a, s, t, m, r;
-        a.set_random();
-        s.set_random();
-        t.set_random();
-        m = Scalar(currBit);
-        r = masksPerBit[i];
-        
-        c = g * r + h * m;
-        currProof.partialUniversals.push_back(c);
-
-        c_a = g * s + h * a;
-
-        Scalar am = a.curveMult(m);
-        c_b = g * t + h * am;
-
-        std::stringstream oracleInput;
-        oracleInput << g << h << c << c_a << c_b;
-
-        Scalar x = oracle(oracleInput.str());
-        currProof.challengeParts.push_back(x);
-
-        Scalar f, z_a, z_b;
-        Scalar mx = m.curveMult(x);
-        f = mx.curveAdd(a);
-
-        Scalar rx = r.curveMult(x);
-        z_a = rx.curveAdd(s);
-
-        Scalar x_f = x.curveSub(f);
-        Scalar r_x_f = r.curveMult(x_f);
-        z_b = r_x_f.curveAdd(t);
-
-        currProof.responseParts.push_back(f);
-        currProof.responseParts.push_back(z_a);
-        currProof.responseParts.push_back(z_b);
-
-        retval.push_back(currProof);
-    }
-
-    return retval;
+    return PrsonaBase::generate_reputation_proof(
+        ownershipProof,
+        currentEncryptedScore,
+        currentScore,
+        threshold,
+        inversePrivateKey,
+        servers->get_num_clients());
 }
 
-// A pretty straightforward range proof (verification)
 bool PrsonaClient::verify_reputation_proof(
     const std::vector<Proof>& pi,
     const Curvepoint& shortTermPublicKey,
     const Scalar& threshold) const
 {
-    // Reject outright if there's no proof to check
-    if (pi.empty())
-    {
-        std::cerr << "Proof was empty, aborting." << std::endl;
-        return false;
-    }
-
-    // Base case
-    if (!CLIENT_IS_MALICIOUS)
-        return pi[0].basic == "PROOF";
-
-    // User should be able to prove they are who they say they are
-    if (!verify_ownership_proof(pi[0], shortTermPublicKey))
-    {
-        std::cerr << "Schnorr proof failed, aborting." << std::endl;
-        return false;
-    }
-
-    // Get the encrypted score in question from the servers
     Proof serverProof;
     EGCiphertext encryptedScore =
         servers->get_current_tally(serverProof, shortTermPublicKey);
 
-    // Rough for the prover but if the server messes up,
-    // no way to prove the thing anyways
-    if (!verify_valid_tally_proof(serverProof, encryptedScore))
+    if (!verify_valid_tally_proof(serverProof))
     {
-        std::cerr << "Server error prevented proof from working, aborting." << std::endl;
+        std::cerr << "Error getting score from server, aborting." << std::endl;
         return false;
     }
 
-    // X is the thing we're going to be checking in on throughout
-    // to try to get our score commitment back in the end.
-    Curvepoint X;
-    for (size_t i = 1; i < pi.size(); i++)
-    {
-        Curvepoint c, g, h;
-        c = pi[i].partialUniversals[0];
-        g = encryptedScore.mask;
-        h = elGamalBlindGenerator;
-
-        X = X + c * Scalar(1 << (i - 1));
-
-        Scalar x, f, z_a, z_b;
-        x = pi[i].challengeParts[0];
-        f = pi[i].responseParts[0];
-        z_a = pi[i].responseParts[1];
-        z_b = pi[i].responseParts[2];
-
-        // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
-        Curvepoint c_a, c_b;
-        c_a = g * z_a + h * f - c * x;
-        Scalar x_f = x.curveSub(f);
-        c_b = g * z_b - c * x_f;
-
-        std::stringstream oracleInput;
-        oracleInput << g << h << c << c_a << c_b;
-
-        if (oracle(oracleInput.str()) != pi[i].challengeParts[0])
-        {
-            std::cerr << "0 or 1 proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
-            return false;
-        }
-    }
-
-    Scalar negThreshold;
-    negThreshold = Scalar(0).curveSub(threshold);
-
-    Curvepoint scoreCommitment =
-        encryptedScore.encryptedMessage +
-        elGamalBlindGenerator * negThreshold;
-    
-    return X == scoreCommitment;
+    return PrsonaBase::verify_reputation_proof(
+        pi, currentFreshGenerator, shortTermPublicKey, encryptedScore, threshold);
 }
 
 Scalar PrsonaClient::get_score() const
@@ -383,112 +195,17 @@ void PrsonaClient::decrypt_score(const EGCiphertext& score)
  * OWNERSHIP PROOFS
  */
 
-// Very basic Schnorr proof (generation)
+// Prove ownership of the short term public key
 Proof PrsonaClient::generate_ownership_proof() const
 {
-    Proof retval;
-    if (!CLIENT_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    std::stringstream oracleInput;
-    
-    Scalar r;
-    r.set_random();
-
     Curvepoint shortTermPublicKey = currentFreshGenerator * longTermPrivateKey;
-    Curvepoint u = currentFreshGenerator * r;
-    oracleInput << currentFreshGenerator << shortTermPublicKey << u;
-
-    Scalar c = oracle(oracleInput.str());
-    Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
-
-    retval.basic = "PROOF";
-    retval.challengeParts.push_back(c);
-    retval.responseParts.push_back(z);
-
-    return retval;
-}
-
-// Very basic Schnorr proof (verification)
-bool PrsonaClient::verify_ownership_proof(
-    const Proof& pi, const Curvepoint& shortTermPublicKey) const
-{
-    if (!CLIENT_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    Scalar c = pi.challengeParts[0];
-    Scalar z = pi.responseParts[0];
-
-    Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
-
-    std::stringstream oracleInput;
-    oracleInput << currentFreshGenerator << shortTermPublicKey << u;
-    
-    return c == oracle(oracleInput.str());
-}
-
-/*
- * PROOF VERIFICATION
- */
-
-bool PrsonaClient::verify_score_proof(const Proof& pi) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
 
-bool PrsonaClient::verify_generator_proof(
-    const Proof& pi, const Curvepoint& generator) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
-
-bool PrsonaClient::verify_default_tally_proof(
-    const Proof& pi, const EGCiphertext& score) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
-
-bool PrsonaClient::verify_valid_tally_proof(
-    const Proof& pi, const EGCiphertext& score) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
-
-bool PrsonaClient::verify_default_votes_proof(
-    const Proof& pi, const std::vector<CurveBipoint>& votes) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
-
-bool PrsonaClient::verify_valid_votes_proof(
-    const Proof& pi, const std::vector<CurveBipoint>& votes) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
+    return PrsonaBase::generate_ownership_proof(
+        currentFreshGenerator, shortTermPublicKey, longTermPrivateKey);
 }
 
 /*
- * PROOF GENERATION
+ * VALID VOTE PROOFS
  */
 
 std::vector<Proof> PrsonaClient::generate_vote_proof(
@@ -498,144 +215,15 @@ std::vector<Proof> PrsonaClient::generate_vote_proof(
     const std::vector<Scalar>& seeds,
     const std::vector<Scalar>& votes) const
 {
-    std::vector<Proof> retval;
-
-    // Base case
-    if (!CLIENT_IS_MALICIOUS)
-    {
-        Proof currProof;
-        currProof.basic = "PROOF";
-        retval.push_back(currProof);
-        
-        return retval;
-    }
-
-    // The first need is to prove that we are the stpk we claim we are
-    retval.push_back(generate_ownership_proof());
-
-    // Then, we iterate over all votes for the proofs that they are correct
-    for (size_t i = 0; i < replaces.size(); i++)
-    {
-        std::stringstream oracleInput;
-        oracleInput << serverPublicKey.get_bipoint_curvegen()
-            << serverPublicKey.get_bipoint_curve_subgroup_gen()
-            << oldEncryptedVotes[i] << newEncryptedVotes[i];
-        
-        /* This proof structure is documented in my notes.
-         * It's inspired by the proof in Fig. 1 at
-         * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
-         * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
-         *
-         * The rerandomization part is just a slight variation on an
-         * ordinary Schnorr proof, so that part's less scary. */
-        if (replaces[i])     // CASE: Make new vote
-        {
-            Proof currProof;
-
-            Scalar c_r, z_r, a, s, t_1, t_2;
-            c_r.set_random();
-            z_r.set_random();
-            a.set_random();
-            s.set_random();
-            t_1.set_random();
-            t_2.set_random();
-
-            CurveBipoint U =
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_r +
-                oldEncryptedVotes[i] * c_r -
-                newEncryptedVotes[i] * c_r;
-
-            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * a +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * s;
-
-            Scalar power = (a.curveAdd(a)).curveMult(votes[i].curveMult(votes[i]));
-            power =
-                power.curveSub((a.curveAdd(a).curveAdd(a)).curveMult(votes[i]));
-            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * power +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * t_1;
-            currProof.partialUniversals.push_back(C_b[0]);
-            currProof.partialUniversals.push_back(C_b[1]);
-
-            CurveBipoint C_c =
-                serverPublicKey.get_bipoint_curvegen() *
-                    a.curveMult(a.curveMult(votes[i])) +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * t_2;
-
-            oracleInput << U << C_a << C_b << C_c;
-
-            Scalar c = oracle(oracleInput.str());
-            Scalar c_n = c.curveSub(c_r);
-            currProof.challengeParts.push_back(c_r);
-            currProof.challengeParts.push_back(c_n);
-
-            Scalar f = (votes[i].curveMult(c_n)).curveAdd(a);
-            Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s);
-
-            Scalar t_1_c_n_t_2 = (t_1.curveMult(c_n)).curveAdd(t_2);
-            Scalar f_c_n = f.curveSub(c_n);
-            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
-            Scalar z_nb = 
-                (seeds[i].curveMult(f_c_n).curveMult(c_n2_f)).curveAdd(
-                    t_1_c_n_t_2);
-
-            currProof.responseParts.push_back(z_r);
-            currProof.responseParts.push_back(f);
-            currProof.responseParts.push_back(z_na);
-            currProof.responseParts.push_back(z_nb);
-
-            retval.push_back(currProof);
-        }
-        else                // CASE: Rerandomize existing vote
-        {
-            Proof currProof;
-
-            Scalar u, commitmentLambda_1, commitmentLambda_2,
-                c_n, z_na, z_nb, f;
-            u.set_random();
-            commitmentLambda_1.set_random();
-            commitmentLambda_2.set_random();
-            c_n.set_random();
-            z_na.set_random();
-            z_nb.set_random();
-            f.set_random();
-
-            CurveBipoint U =
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * u;
-
-            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * f +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_na -
-                newEncryptedVotes[i] * c_n;
-
-            CurveBipoint C_b = 
-                serverPublicKey.get_bipoint_curvegen() * commitmentLambda_1 +
-                serverPublicKey.get_bipoint_curve_subgroup_gen() *
-                    commitmentLambda_2;
-            currProof.partialUniversals.push_back(C_b[0]);
-            currProof.partialUniversals.push_back(C_b[1]);
-
-            Scalar f_c_n = f.curveSub(c_n);
-            Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
-            CurveBipoint C_c =
-                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nb -
-                newEncryptedVotes[i] * f_c_n.curveMult(c_n2_f) -
-                C_b * c_n;
-
-            oracleInput << U << C_a << C_b << C_c;
-
-            Scalar c = oracle(oracleInput.str());
-            Scalar c_r = c.curveSub(c_n);
-            currProof.challengeParts.push_back(c_r);
-            currProof.challengeParts.push_back(c_n);
-
-            Scalar z_r = u.curveAdd(c_r.curveMult(seeds[i]));
-            currProof.responseParts.push_back(z_r);
-            currProof.responseParts.push_back(f);
-            currProof.responseParts.push_back(z_na);
-            currProof.responseParts.push_back(z_nb);
-
-            retval.push_back(currProof);
-        }
-    }
-
-    return retval;
+    Proof pi = generate_ownership_proof();
+
+    return PrsonaBase::generate_vote_proof(
+        pi,
+        serverPublicKey.get_bipoint_curvegen(),
+        serverPublicKey.get_bipoint_curve_subgroup_gen(),
+        replaces,
+        oldEncryptedVotes,
+        newEncryptedVotes,
+        seeds,
+        votes);
 }

+ 14 - 16
prsona/src/main.cpp

@@ -10,14 +10,11 @@
 
 using namespace std;
 
-const int MAX_ALLOWED_VOTE = 2;
-
 // Initialize the classes we use
 void initialize_prsona_classes()
 {
     Scalar::init();
-    PrsonaServer::init();
-    PrsonaClient::init();
+    PrsonaBase::init();
 }
 
 // Quick and dirty mean calculation (used for averaging timings)
@@ -36,7 +33,8 @@ vector<double> make_votes(
     size_t numVotes)
 {
     vector<double> retval;
-    uniform_int_distribution<int> voteDistribution(0, MAX_ALLOWED_VOTE);
+    uniform_int_distribution<int> voteDistribution(
+        0, PrsonaBase::get_max_allowed_vote());
     size_t numUsers = users.size();
     newEncryptedVotes.clear();
 
@@ -274,7 +272,7 @@ int main(int argc, char *argv[])
     size_t numUsers = 5;
     size_t numRounds = 3;
     size_t numVotesPerRound = 3;
-    bool maliciousServers = false;
+    bool maliciousServers = true;
     bool maliciousClients = true;
     string seedStr = "seed";
 
@@ -297,26 +295,26 @@ int main(int argc, char *argv[])
 
     // Set malicious flags where necessary
     if (maliciousServers)
-    {
-        PrsonaServer::set_server_malicious();
-        PrsonaClient::set_server_malicious();
-    }
+        PrsonaBase::set_server_malicious();
     if (maliciousClients)
-    {
-        PrsonaServer::set_client_malicious();
-        PrsonaClient::set_client_malicious();
-    }
+        PrsonaBase::set_client_malicious();
 
     // Entities we operate with
     PrsonaServerEntity servers(numServers);
+    vector<Proof> elGamalBlindGeneratorProof;
     BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
-    Curvepoint elGamalBlindGenerator = servers.get_blinding_generator();
+    Curvepoint elGamalBlindGenerator =
+        servers.get_blinding_generator(elGamalBlindGeneratorProof);
 
     cout << "Initialization: adding users to system" << endl << endl;
     vector<PrsonaClient> users;
     for (size_t i = 0; i < numUsers; i++)
     {
-        PrsonaClient currUser(bgnPublicKey, elGamalBlindGenerator, &servers);
+        PrsonaClient currUser(
+            bgnPublicKey,
+            elGamalBlindGeneratorProof,
+            elGamalBlindGenerator,
+            &servers);
         servers.add_new_client(currUser);
         users.push_back(currUser);
     }

+ 7 - 0
prsona/src/proof.cpp

@@ -40,3 +40,10 @@ Scalar oracle(const std::string& input)
 
     return output;
 }
+
+Proof::Proof()
+{ /* Do nothing */ }
+
+Proof::Proof(std::string basic)
+: basic(basic)
+{ /* Do nothing */ }

+ 72 - 290
prsona/src/server.cpp

@@ -2,21 +2,6 @@
 
 #include "server.hpp"
 
-extern const curvepoint_fp_t bn_curvegen;
-extern const scalar_t bn_n;
-const int MAX_ALLOWED_VOTE = 2;
-
-/* These lines need to be here so these static variables are defined,
- * but in C++ putting code here doesn't actually execute
- * (or at least, with g++, whenever it would execute is not at a useful time)
- * so we have an init() function to actually put the correct values in them. */
-Curvepoint PrsonaServer::EL_GAMAL_GENERATOR = Curvepoint();
-Scalar PrsonaServer::SCALAR_N = Scalar();
-Scalar PrsonaServer::DEFAULT_TALLY = Scalar();
-Scalar PrsonaServer::DEFAULT_VOTE = Scalar();
-bool PrsonaServer::SERVER_IS_MALICIOUS = false;
-bool PrsonaServer::CLIENT_IS_MALICIOUS = false;
-
 /********************
  * PUBLIC FUNCTIONS *
  ********************/
@@ -26,58 +11,36 @@ bool PrsonaServer::CLIENT_IS_MALICIOUS = false;
  */
 
 // Used to generate the first server; instantiates BGN for the first time
-PrsonaServer::PrsonaServer()
+PrsonaServer::PrsonaServer(size_t numServers)
+: numServers(numServers)
 {
     currentSeed.set_random();
 }
 
 // Used for all other servers, so they have the same BGN parameters
-PrsonaServer::PrsonaServer(const BGN& other_bgn)
-: bgn_system(other_bgn)
+PrsonaServer::PrsonaServer(size_t numServers, const BGN& otherBgn)
+: numServers(numServers), bgnSystem(otherBgn)
 {
     currentSeed.set_random();
 }
 
 /*
- * SETUP FUNCTIONS
+ * BASIC PUBLIC SYSTEM INFO GETTERS
  */
 
-// Must be called once before any usage of this class
-void PrsonaServer::init()
+BGNPublicKey PrsonaServer::get_bgn_public_key() const
 {
-    Scalar lambda;
-    lambda.set_random();
-
-    EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
-    SCALAR_N = Scalar(bn_n);
-    DEFAULT_TALLY = Scalar(1);
-    DEFAULT_VOTE = Scalar(1);
+    return bgnSystem.get_public_key();
 }
 
-// Call this (once) if using malicious-security servers
-void PrsonaServer::set_server_malicious()
+size_t PrsonaServer::get_num_clients() const
 {
-    SERVER_IS_MALICIOUS = true;
+    return currentPseudonyms.size();
 }
 
-// Call this (once) if using malicious-security clients
-void PrsonaServer::set_client_malicious()
+size_t PrsonaServer::get_num_servers() const
 {
-    CLIENT_IS_MALICIOUS = true;
-}
-
-/*
- * BASIC PUBLIC SYSTEM INFO GETTERS
- */
-
-Curvepoint PrsonaServer::get_blinding_generator() const
-{
-    return elGamalBlindGenerator;
-}
-
-BGNPublicKey PrsonaServer::get_bgn_public_key() const
-{
-    return bgn_system.get_public_key();
+    return numServers;
 }
 
 /*
@@ -87,16 +50,22 @@ BGNPublicKey PrsonaServer::get_bgn_public_key() const
 // To calculate the current epoch's generator, start from the base generator,
 // then have every server call this function on it iteratively (in any order).
 Curvepoint PrsonaServer::add_curr_seed_to_generator(
+    std::vector<Proof>& pi,
     const Curvepoint& currGenerator) const
 {
+    pi.push_back(add_to_generator_proof(currGenerator, currentSeed));
+
     return currGenerator * currentSeed;
 }
 
 // To calculate the next epoch's generator, start from the base generator,
 // then have every server call this function on it iteratively (in any order).
 Curvepoint PrsonaServer::add_next_seed_to_generator(
+    std::vector<Proof>& pi,
     const Curvepoint& currGenerator) const
 {
+    pi.push_back(add_to_generator_proof(currGenerator, nextSeed));
+
     return currGenerator * nextSeed;
 }
 
@@ -115,7 +84,7 @@ std::vector<CurveBipoint> PrsonaServer::get_current_votes_by(
     size_t voteSubmitter = binary_search(shortTermPublicKey);
     retval = voteMatrix[voteSubmitter];
 
-    pi = generate_votes_valid_proof(retval, shortTermPublicKey);
+    pi = generate_votes_valid_proof();
     return retval;
 }
 
@@ -134,6 +103,12 @@ EGCiphertext PrsonaServer::get_current_tally(
     return retval;
 }
 
+std::vector<Curvepoint> PrsonaServer::get_current_pseudonyms(Proof& pi) const
+{    
+    pi = generate_valid_pseudonyms_proof();
+    return currentPseudonyms;
+}
+
 /*
  * CLIENT INTERACTIONS
  */
@@ -146,7 +121,8 @@ void PrsonaServer::add_new_client(
     Proof& proofOfValidAddition,
     const Curvepoint& shortTermPublicKey)
 {
-    if (!verify_ownership_proof(proofOfValidKey, shortTermPublicKey))
+    if (!verify_ownership_proof(
+            proofOfValidKey, currentFreshGenerator, shortTermPublicKey))
     {
         std::cerr << "Could not verify proof of valid key." << std::endl;
         return;
@@ -157,7 +133,7 @@ void PrsonaServer::add_new_client(
     // The first epoch's score for a new user will be low,
     // but will typically converge on an average score quickly
     TwistBipoint encryptedDefaultTally;
-    bgn_system.encrypt(encryptedDefaultTally, DEFAULT_TALLY);
+    bgnSystem.encrypt(encryptedDefaultTally, DEFAULT_TALLY);
     previousVoteTallies.push_back(encryptedDefaultTally);
 
     Scalar mask;
@@ -168,20 +144,19 @@ void PrsonaServer::add_new_client(
         currentFreshGenerator * mask +
         get_blinding_generator() * DEFAULT_TALLY;
     currentUserEncryptedTallies.push_back(newUserEncryptedTally);
-    currentTallyProofs.push_back(
-        generate_valid_default_tally_proof(newUserEncryptedTally, mask));
+    currentTallyProofs.push_back(generate_valid_default_tally_proof());
 
     // Users are defaulted to casting a neutral vote for others.
     CurveBipoint encryptedDefaultVote, encryptedSelfVote;
-    bgn_system.encrypt(encryptedDefaultVote, DEFAULT_VOTE);
-    bgn_system.encrypt(encryptedSelfVote, Scalar(MAX_ALLOWED_VOTE));
+    bgnSystem.encrypt(encryptedDefaultVote, DEFAULT_VOTE);
+    bgnSystem.encrypt(encryptedSelfVote, Scalar(MAX_ALLOWED_VOTE));
     std::vector<CurveBipoint> newRow;
     for (size_t i = 0; i < voteMatrix.size(); i++)
     {
-        encryptedDefaultVote = bgn_system.rerandomize(encryptedDefaultVote);
+        encryptedDefaultVote = bgnSystem.rerandomize(encryptedDefaultVote);
         voteMatrix[i].push_back(encryptedDefaultVote);
 
-        encryptedDefaultVote = bgn_system.rerandomize(encryptedDefaultVote);
+        encryptedDefaultVote = bgnSystem.rerandomize(encryptedDefaultVote);
         newRow.push_back(encryptedDefaultVote);
     }
     // Because we are adding the new user to the end (and then sorting it),
@@ -191,7 +166,7 @@ void PrsonaServer::add_new_client(
 
     order_data(proofOfValidAddition);
 
-    proofOfValidAddition = generate_proof_of_added_user(shortTermPublicKey);
+    proofOfValidAddition = generate_proof_of_added_user();
 }
 
 // Receive a new vote row from a user (identified by short term public key).
@@ -220,28 +195,42 @@ bool PrsonaServer::receive_vote(
 
 const BGN& PrsonaServer::get_bgn_details() const
 {
-    return bgn_system;
+    return bgnSystem;
 }
 
-void PrsonaServer::initialize_fresh_generator(const Curvepoint& firstGenerator)
+bool PrsonaServer::initialize_fresh_generator(
+    const std::vector<Proof>& pi,
+    const Curvepoint& firstGenerator)
 {
+    if (!verify_generator_proof(pi, firstGenerator, numServers))
+    {
+        std::cerr << "Could not verify generator proof, aborting." << std::endl;
+        return false;
+    }
+
     currentFreshGenerator = firstGenerator;
+    return true;
 }
 
 // To calculate the blind generator for ElGamal, start from the base generator,
 // then have every server call this function on it iteratively (in any order).
 Curvepoint PrsonaServer::add_rand_seed_to_generator(
+    std::vector<Proof>& pi,
     const Curvepoint& currGenerator) const
 {
     Scalar lambda;
     lambda.set_random();
 
-    return currGenerator + EL_GAMAL_GENERATOR * lambda;
+    pi.push_back(add_to_generator_proof(currGenerator, lambda));
+
+    return currGenerator * lambda;
 }
 
-void PrsonaServer::set_EG_blind_generator(const Curvepoint& currGenerator)
+bool PrsonaServer::set_EG_blind_generator(
+    const std::vector<Proof>& pi,
+    const Curvepoint& currGenerator)
 {
-    elGamalBlindGenerator = currGenerator;
+    return PrsonaBase::set_EG_blind_generator(pi, currGenerator, numServers);
 }
 
 /*
@@ -265,7 +254,7 @@ std::vector<Scalar> PrsonaServer::tally_scores(std::vector<Proof>& tallyProofs)
         for (size_t j = 0; j < previousVoteTallies.size(); j++)
         {
             Quadripoint curr =
-                bgn_system.homomorphic_multiplication_no_rerandomize(
+                bgnSystem.homomorphic_multiplication_no_rerandomize(
                     voteMatrix[j][i], previousVoteTallies[j]);
 
             weightedVotes.push_back(curr);
@@ -276,15 +265,13 @@ std::vector<Scalar> PrsonaServer::tally_scores(std::vector<Proof>& tallyProofs)
         for (size_t j = 1; j < weightedVotes.size(); j++)
         {
             currEncryptedTally =
-                bgn_system.homomorphic_addition_no_rerandomize(
+                bgnSystem.homomorphic_addition_no_rerandomize(
                     currEncryptedTally, weightedVotes[j]);
         }
 
         // DECRYPT
-        decryptedTallies.push_back(bgn_system.decrypt(currEncryptedTally));
-        tallyProofs.push_back(
-            generate_proof_of_correct_tally(
-                currEncryptedTally, decryptedTallies[i]));
+        decryptedTallies.push_back(bgnSystem.decrypt(currEncryptedTally));
+        tallyProofs.push_back(generate_proof_of_correct_tally());
     }
 
     return decryptedTallies;
@@ -302,14 +289,14 @@ Scalar PrsonaServer::get_max_possible_score(Proof& pi)
     for (size_t i = 1; i < previousVoteTallies.size(); i++)
     {
         currEncryptedVal =
-            bgn_system.homomorphic_addition_no_rerandomize(
+            bgnSystem.homomorphic_addition_no_rerandomize(
                 currEncryptedVal, previousVoteTallies[i]);
     }
 
     // DECRYPT
-    Scalar retval = bgn_system.decrypt(currEncryptedVal);
+    Scalar retval = bgnSystem.decrypt(currEncryptedVal);
 
-    pi = generate_proof_of_correct_sum(currEncryptedVal, retval);
+    pi = generate_proof_of_correct_sum();
     return retval;
 }
 
@@ -407,9 +394,9 @@ void PrsonaServer::rerandomize_data()
     for (size_t i = 0; i < voteMatrix.size(); i++)
     {
         for (size_t j = 0; j < voteMatrix[0].size(); j++)
-            voteMatrix[i][j] = bgn_system.rerandomize(voteMatrix[i][j]);
+            voteMatrix[i][j] = bgnSystem.rerandomize(voteMatrix[i][j]);
 
-        bgn_system.rerandomize(previousVoteTallies[i]);
+        bgnSystem.rerandomize(previousVoteTallies[i]);
         if (!currentUserEncryptedTallies.empty())
         {
             Scalar rerandomizer;
@@ -486,7 +473,7 @@ std::vector<size_t> PrsonaServer::order_data(Proof& pi)
     currentTallyProofs = newTallyProofs;
     voteMatrix = newVoteMatrix;
 
-    pi = generate_proof_of_shuffle(retval);
+    pi = generate_proof_of_shuffle();
     return retval;
 }
 
@@ -517,227 +504,22 @@ size_t PrsonaServer::binary_search(const Curvepoint& index) const
 }
 
 /*
- * PROOF VERIFICATION
+ * VALID VOTE PROOFS
  */
 
-bool PrsonaServer::verify_ownership_proof(
-    const Proof& pi,
-    const Curvepoint& shortTermPublicKey) const
-{
-    if (!CLIENT_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    Scalar c = pi.challengeParts[0];
-    Scalar z = pi.responseParts[0];
-
-    Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
-
-    std::stringstream oracleInput;
-    oracleInput << currentFreshGenerator << shortTermPublicKey << u;
-    
-    return c == oracle(oracleInput.str());
-}
-
 bool PrsonaServer::verify_vote_proof(
     const std::vector<Proof>& pi,
     const std::vector<CurveBipoint>& oldVotes,
     const std::vector<CurveBipoint>& newVotes,
     const Curvepoint& shortTermPublicKey) const
 {
-    // Reject outright if there's no proof to check
-    if (pi.empty())
-    {
-        std::cerr << "Proof was empty, aborting." << std::endl;
-        return false;
-    }
-
-    // Base case
-    if (!CLIENT_IS_MALICIOUS)
-        return pi[0].basic == "PROOF";
-
-    // User should be able to prove they are who they say they are
-    if (!verify_ownership_proof(pi[0], shortTermPublicKey))
-    {
-        std::cerr << "Schnorr proof failed, aborting." << std::endl;
-        return false;
-    }
-
-    /* This proof structure is documented in my notes.
-     * It's inspired by the proof in Fig. 1 at
-     * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
-     * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
-     *
-     * The rerandomization part is just a slight variation on an
-     * ordinary Schnorr proof, so that part's less scary. */
-    for (size_t i = 1; i < pi.size(); i++)
-    {
-        size_t voteIndex = i - 1;
-        Curvepoint C_b_0, C_b_1;
-        C_b_0 = pi[i].partialUniversals[0];
-        C_b_1 = pi[i].partialUniversals[1];
-
-        CurveBipoint g, h;
-        g = bgn_system.get_public_key().get_bipoint_curvegen();
-        h = bgn_system.get_public_key().get_bipoint_curve_subgroup_gen();
-        CurveBipoint C_b(C_b_0, C_b_1);
-
-        Scalar c_r, c_n, z_r, f, z_na, z_nb;
-        c_r = pi[i].challengeParts[0];
-        c_n = pi[i].challengeParts[1];
-
-        z_r  = pi[i].responseParts[0];
-        f  = pi[i].responseParts[1];
-        z_na = pi[i].responseParts[2];
-        z_nb = pi[i].responseParts[3];
-
-        CurveBipoint U, C_a, C_c;
-        U = h * z_r + oldVotes[voteIndex] * c_r - newVotes[voteIndex] * c_r;
-        C_a = g * f + h * z_na - newVotes[voteIndex] * c_n;
-
-        Scalar f_c_n = f.curveSub(c_n);
-        Scalar c_n2_f = c_n.curveAdd(c_n).curveSub(f);
-        C_c = h * z_nb - newVotes[voteIndex] * f_c_n.curveMult(c_n2_f) - C_b * c_n;
-
-        std::stringstream oracleInput;
-        oracleInput << g << h << oldVotes[voteIndex] << newVotes[voteIndex]
-            << U << C_a << C_b << C_c;
-
-        if (oracle(oracleInput.str()) != c_r.curveAdd(c_n))
-            return false;
-    }
-
-    return true;
-}
-
-bool PrsonaServer::verify_update_proof(
-    const Proof& pi) const
-{
-    if (!SERVER_IS_MALICIOUS)
-        return pi.basic == "PROOF";
-
-    return pi.basic == "PROOF";
-}
-
-/*
- * PROOF GENERATION
- */
-
-Proof PrsonaServer::generate_valid_default_tally_proof(
-    const EGCiphertext& newUserEncryptedTally, const Scalar& mask) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_valid_fresh_generator_proof(
-    const Proof& oldProof) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_votes_valid_proof(
-    const std::vector<CurveBipoint>& votes, const Curvepoint& voter) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_proof_of_added_user(
-    const Curvepoint& shortTermPublicKey) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_score_proof(const EGCiphertext& score) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_proof_of_correct_tally(
-    const Quadripoint& BGNEncryptedTally,
-    const Scalar& decryptedTally) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_proof_of_correct_sum(
-    const TwistBipoint& BGNEncryptedSum, const Scalar& decryptedSum) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
-}
-
-Proof PrsonaServer::generate_proof_of_shuffle(
-    const std::vector<size_t>& shuffle_order) const
-{
-    Proof retval;
-
-    if (!SERVER_IS_MALICIOUS)
-    {
-        retval.basic = "PROOF";
-        return retval;
-    }
-
-    retval.basic = "PROOF";
-    return retval;
+    const BGNPublicKey& pubKey = bgnSystem.get_public_key();
+    return PrsonaBase::verify_vote_proof(
+        pubKey.get_bipoint_curvegen(),
+        pubKey.get_bipoint_curve_subgroup_gen(),
+        pi,
+        oldVotes,
+        newVotes,
+        currentFreshGenerator,
+        shortTermPublicKey);
 }

+ 191 - 65
prsona/src/serverEntity.cpp

@@ -2,8 +2,6 @@
 
 #include "serverEntity.hpp"
 
-const int MAX_ALLOWED_VOTE = 2;
-
 /********************
  * PUBLIC FUNCTIONS *
  ********************/
@@ -21,27 +19,33 @@ PrsonaServerEntity::PrsonaServerEntity(size_t numServers)
     }
 
     // Make the first server, which makes the BGN parameters
-    PrsonaServer firstServer;
+    PrsonaServer firstServer(numServers);
     servers.push_back(firstServer);
 
     // Make the rest of the servers, which take the BGN parameters
     const BGN& sharedBGN = firstServer.get_bgn_details();
     for (size_t i = 1; i < numServers; i++)
-        servers.push_back(PrsonaServer(sharedBGN));
+        servers.push_back(PrsonaServer(numServers, sharedBGN));
 
     // After all servers have made their seeds,
     // make sure they have the initial fresh generator
-    Curvepoint firstGenerator = get_fresh_generator();
+    std::vector<Proof> pi;
+    Curvepoint firstGenerator = get_fresh_generator(pi);
     for (size_t i = 0; i < numServers; i++)
-        servers[i].initialize_fresh_generator(firstGenerator);
+        servers[i].initialize_fresh_generator(pi, firstGenerator);
+
+    pi.clear();
 
     // It's important that no server knows the DLOG between g and h for ElGamal,
     // so have each server collaborate to make h.
     Curvepoint blindGenerator = PrsonaServer::EL_GAMAL_GENERATOR;
     for (size_t i = 0; i < numServers; i++)
-        blindGenerator = servers[i].add_rand_seed_to_generator(blindGenerator);
+    {
+        blindGenerator =
+            servers[i].add_rand_seed_to_generator(pi, blindGenerator);
+    }
     for (size_t i = 0; i < numServers; i++)
-        servers[i].set_EG_blind_generator(blindGenerator);
+        servers[i].set_EG_blind_generator(pi, blindGenerator);
 }
 
 /*
@@ -50,45 +54,131 @@ PrsonaServerEntity::PrsonaServerEntity(size_t numServers)
 
 BGNPublicKey PrsonaServerEntity::get_bgn_public_key() const
 {
-    return servers[0].get_bgn_public_key();
+    return get_bgn_public_key(0);
+}
+
+BGNPublicKey PrsonaServerEntity::get_bgn_public_key(size_t which) const
+{
+    return servers[which].get_bgn_public_key();
 }
 
 Curvepoint PrsonaServerEntity::get_blinding_generator() const
 {
-    return servers[0].get_blinding_generator();
+    return get_blinding_generator(0);
+}
+
+Curvepoint PrsonaServerEntity::get_blinding_generator(size_t which) const
+{
+    std::vector<Proof> pi;
+    Curvepoint retval = get_blinding_generator(pi, which);
+
+    if (!servers[which].verify_generator_proof(
+            pi, retval, servers[which].get_num_servers()))
+    {
+        std::cerr << "Error making the generator, aborting." << std::endl;
+        return Curvepoint();
+    }
+
+    return retval;
+}
+
+Curvepoint PrsonaServerEntity::get_blinding_generator(
+    std::vector<Proof>& pi) const
+{
+    return get_blinding_generator(pi, 0);
+}
+
+Curvepoint PrsonaServerEntity::get_blinding_generator(
+    std::vector<Proof>& pi, size_t which) const
+{
+    return servers[which].get_blinding_generator(pi);
 }
 
 Curvepoint PrsonaServerEntity::get_fresh_generator() const
+{
+    return get_fresh_generator(0);
+}
+
+Curvepoint PrsonaServerEntity::get_fresh_generator(size_t which) const
+{
+    std::vector<Proof> pi;
+    Curvepoint retval = get_fresh_generator(pi, which);
+
+    if (!servers[which].verify_generator_proof(
+            pi, retval, servers[which].get_num_servers()))
+    {
+        std::cerr << "Error making the generator, aborting." << std::endl;
+        return Curvepoint();
+    }
+
+    return retval;
+}
+
+Curvepoint PrsonaServerEntity::get_fresh_generator(
+    std::vector<Proof>& pi) const
+{
+    return get_fresh_generator(pi, 0);
+}
+
+Curvepoint PrsonaServerEntity::get_fresh_generator(
+    std::vector<Proof>& pi, size_t which) const
 {
     Curvepoint retval = PrsonaServer::EL_GAMAL_GENERATOR;
-    for (size_t j = 0; j < servers.size(); j++)
-        retval = servers[j].add_curr_seed_to_generator(retval);
+
+    pi.clear();   
+    for (size_t j = 0; j < servers[which].get_num_servers(); j++)
+    {
+        size_t index = (which + j) % servers[which].get_num_servers();
+        retval = servers[index].add_curr_seed_to_generator(pi, retval);
+    }
 
     return retval;
 }
 
 size_t PrsonaServerEntity::get_num_clients() const
 {
-    return servers[0].currentPseudonyms.size();
+    return get_num_clients(0);
+}
+
+size_t PrsonaServerEntity::get_num_clients(size_t which) const
+{
+    return servers[which].get_num_clients();
 }
 
 size_t PrsonaServerEntity::get_num_servers() const
 {
-    return servers.size();
+    return get_num_servers(0);
+}
+
+size_t PrsonaServerEntity::get_num_servers(size_t which) const
+{
+    return servers[which].get_num_servers();
 }
 
 /*
  * ENCRYPTED DATA GETTERS
  */
 
+std::vector<CurveBipoint> PrsonaServerEntity::get_current_votes_by(
+    Proof& pi, const Curvepoint& shortTermPublicKey) const
+{
+    return get_current_votes_by(pi, shortTermPublicKey, 0);
+}
+
 /* Call this in order to get the current encrypted votes cast by a given user
  * (who is identified by their short term public key).
  * In practice, this is intended for clients,
  * who need to know their current votes in order to rerandomize them. */
 std::vector<CurveBipoint> PrsonaServerEntity::get_current_votes_by(
+    Proof& pi, const Curvepoint& shortTermPublicKey, size_t which) const
+{
+    return servers[which].get_current_votes_by(pi, shortTermPublicKey);
+}
+
+EGCiphertext PrsonaServerEntity::get_current_tally(
     Proof& pi, const Curvepoint& shortTermPublicKey) const
 {
-    return servers[0].get_current_votes_by(pi, shortTermPublicKey);
+    return get_current_tally(pi, shortTermPublicKey, 0);
 }
 
 /* Call this in order to get the current encrypted tally of a given user
@@ -96,31 +186,38 @@ std::vector<CurveBipoint> PrsonaServerEntity::get_current_votes_by(
  * In practice, this is intended for clients, so that the servers vouch
  * for their ciphertexts being valid as part of their reputation proofs. */
 EGCiphertext PrsonaServerEntity::get_current_tally(
-    Proof& pi, const Curvepoint& shortTermPublicKey) const
+    Proof& pi, const Curvepoint& shortTermPublicKey, size_t which) const
 {
-    return servers[0].get_current_tally(pi, shortTermPublicKey);
+    return servers[which].get_current_tally(pi, shortTermPublicKey);
 }
 
 /*
  * CLIENT INTERACTIONS
  */
 
+void PrsonaServerEntity::add_new_client(PrsonaClient& newUser)
+{
+    add_new_client(newUser, 0);
+}
+
 /* Add a new client (who is identified only by their short term public key)
  * One server does the main work, then other servers import their (proven)
  * exported data. */
-void PrsonaServerEntity::add_new_client(PrsonaClient& newUser)
+void PrsonaServerEntity::add_new_client(PrsonaClient& newUser, size_t which)
 {
     Proof proofOfValidSTPK, proofOfCorrectAddition, proofOfValidVotes;
-    Curvepoint freshGenerator = get_fresh_generator();
+    std::vector<Proof> proofOfValidGenerator;
+    Curvepoint freshGenerator =
+        get_fresh_generator(proofOfValidGenerator, which);
 
     // Users can't actually announce a short term public key
     // if they don't know the fresh generator.
-    newUser.receive_fresh_generator(freshGenerator);
+    newUser.receive_fresh_generator(proofOfValidGenerator, freshGenerator);
     Curvepoint shortTermPublicKey = newUser.get_short_term_public_key(
                                         proofOfValidSTPK);
 
     // Do the actual work of adding the client to the first server
-    servers[0].add_new_client(
+    servers[which].add_new_client(
         proofOfValidSTPK, proofOfCorrectAddition, shortTermPublicKey);
 
     // Then, export the data to the rest of the servers
@@ -129,15 +226,16 @@ void PrsonaServerEntity::add_new_client(PrsonaClient& newUser)
     std::vector<EGCiphertext> currentUserEncryptedTallies;
     std::vector<Proof> currentTallyProofs;
     std::vector<std::vector<CurveBipoint>> voteMatrix;
-    servers[0].export_updates(
+    servers[which].export_updates(
         previousVoteTally,
         currentPseudonyms,
         currentUserEncryptedTallies,
         currentTallyProofs,
         voteMatrix);
-    for (size_t j = 1; j < servers.size(); j++)
+    for (size_t j = 1; j < servers[which].get_num_servers(); j++)
     {
-        servers[j].import_updates(
+        size_t index = (which + j) % servers[which].get_num_servers();
+        servers[index].import_updates(
             proofOfCorrectAddition,
             previousVoteTally,
             currentPseudonyms,
@@ -148,7 +246,7 @@ void PrsonaServerEntity::add_new_client(PrsonaClient& newUser)
 
     // Finally, give the user the information it needs
     // about its current tally and votes
-    transmit_updates(newUser);
+    transmit_updates(newUser, which);
 }
 
 // Receive a new vote row from a user (identified by short term public key).
@@ -156,41 +254,52 @@ bool PrsonaServerEntity::receive_vote(
     const std::vector<Proof>& pi,
     const std::vector<CurveBipoint>& newVotes,
     const Curvepoint& shortTermPublicKey)
+{
+    return receive_vote(pi, newVotes, shortTermPublicKey, 0);
+}
+
+bool PrsonaServerEntity::receive_vote(
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& newVotes,
+    const Curvepoint& shortTermPublicKey,
+    size_t which)
 {
     bool retval = true;
 
-    for (size_t i = 0; i < servers.size(); i++)
+    for (size_t i = 0; i < servers[which].get_num_servers(); i++)
     {
-        retval =
-            retval && servers[i].receive_vote(pi, newVotes, shortTermPublicKey);
+        size_t index = (i + which) % servers[which].get_num_servers();
+
+        retval = retval &&
+            servers[index].receive_vote(pi, newVotes, shortTermPublicKey);
     }
 
     return retval;
 }
 
-bool PrsonaServerEntity::receive_vote(
-    const std::vector<Proof>& pi,
-    const std::vector<CurveBipoint>& newVotes,
-    const Curvepoint& shortTermPublicKey,
-    size_t which)
+void PrsonaServerEntity::transmit_updates(PrsonaClient& currUser) const
 {
-    return servers[which].receive_vote(pi, newVotes, shortTermPublicKey);
+    transmit_updates(currUser, 0);
 }
 
 // After tallying scores and new vote matrix,
 // give those to a user for the new epoch
-void PrsonaServerEntity::transmit_updates(PrsonaClient& currUser) const
+void PrsonaServerEntity::transmit_updates(
+    PrsonaClient& currUser, size_t which) const
 {
     Proof proofOfValidSTPK, proofOfScore, proofOfCorrectVotes;
-    Curvepoint freshGenerator = get_fresh_generator();
+    std::vector<Proof> proofOfValidGenerator;
+    Curvepoint freshGenerator =
+        get_fresh_generator(proofOfValidGenerator, which);
 
     // Get users the next fresh generator so they can correctly
     // ask for their new scores and vote row
-    currUser.receive_fresh_generator(freshGenerator);
-    Curvepoint shortTermPublicKey = currUser.get_short_term_public_key(
-                                        proofOfValidSTPK);
+    currUser.receive_fresh_generator(proofOfValidGenerator, freshGenerator);
+    Curvepoint shortTermPublicKey =
+        currUser.get_short_term_public_key(proofOfValidSTPK);
 
-    EGCiphertext score = get_current_tally(proofOfScore, shortTermPublicKey);
+    EGCiphertext score =
+        get_current_tally(proofOfScore, shortTermPublicKey, which);
     currUser.receive_vote_tally(proofOfScore, score);
 }
 
@@ -198,8 +307,13 @@ void PrsonaServerEntity::transmit_updates(PrsonaClient& currUser) const
  * EPOCH
  */
 
-// Do the epoch process
 void PrsonaServerEntity::epoch(Proof& pi)
+{
+    epoch(pi, 0);
+}
+
+// Do the epoch process
+void PrsonaServerEntity::epoch(Proof& pi, size_t which)
 {
     Curvepoint nextGenerator = PrsonaServer::EL_GAMAL_GENERATOR;
     
@@ -210,16 +324,19 @@ void PrsonaServerEntity::epoch(Proof& pi)
     std::vector<std::vector<CurveBipoint>> voteMatrix;
 
     // go from A_0 to A_0.5
-    for (size_t i = 0; i < servers.size(); i++)
+    for (size_t i = 0; i < servers[which].get_num_servers(); i++)
     {
-        servers[i].build_up_midway_pseudonyms(pi, nextGenerator);
-        servers[i].export_updates(
+        size_t index = (which + i) % servers[which].get_num_servers();
+        size_t nextIndex = (which + i + 1) % servers[which].get_num_servers();
+
+        servers[index].build_up_midway_pseudonyms(pi, nextGenerator);
+        servers[index].export_updates(
             previousVoteTally,
             currentPseudonyms,
             currentUserEncryptedTallies,
             currentTallyProofs,
             voteMatrix);
-        servers[(i + 1) % servers.size()].import_updates(
+        servers[nextIndex].import_updates(
             pi,
             previousVoteTally,
             currentPseudonyms,
@@ -232,19 +349,23 @@ void PrsonaServerEntity::epoch(Proof& pi)
      * knows a secret mask and encrypted the correct value everyone else already
      * knows. Everyone else then adds a mask and proves they added a secret mask
      * to the committed values. */
-    currentUserEncryptedTallies = tally_scores(currentTallyProofs, nextGenerator);
+    currentUserEncryptedTallies =
+        tally_scores(currentTallyProofs, nextGenerator, which);
     
     // go from A_0.5 to A_1
-    for (size_t i = 0; i < servers.size(); i++)
+    for (size_t i = 0; i < servers[which].get_num_servers(); i++)
     {
-        servers[i].break_down_midway_pseudonyms(pi, nextGenerator);
-        servers[i].export_updates(
+        size_t index = (which + i) % servers[which].get_num_servers();
+        size_t nextIndex = (which + i + 1) % servers[which].get_num_servers();
+        
+        servers[index].break_down_midway_pseudonyms(pi, nextGenerator);
+        servers[index].export_updates(
             previousVoteTally,
             currentPseudonyms,
             currentUserEncryptedTallies,
             currentTallyProofs,
             voteMatrix);
-        servers[(i + 1) % servers.size()].import_updates(
+        servers[nextIndex].import_updates(
             pi,
             previousVoteTally,
             currentPseudonyms,
@@ -254,9 +375,11 @@ void PrsonaServerEntity::epoch(Proof& pi)
     }
 
     // At the end, make sure all servers have same information
-    for (size_t i = 1; i < servers.size() - 1; i++)
+    for (size_t i = 1; i < servers[which].get_num_servers() - 1; i++)
     {
-        servers[i].import_updates(
+        size_t index = (which + i) % servers[which].get_num_servers();
+
+        servers[index].import_updates(
             pi,
             previousVoteTally,
             currentPseudonyms,
@@ -280,16 +403,19 @@ void PrsonaServerEntity::epoch(Proof& pi)
  * We're treating it as if we are one server, so that server gets the updated
  * weights to be sent to all other servers for the next epoch. */
 std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
-    std::vector<Proof>& tallyProofs, const Curvepoint& nextGenerator)
+    std::vector<Proof>& tallyProofs,
+    const Curvepoint& nextGenerator,
+    size_t which)
 {
     std::vector<EGCiphertext> retval;
     Proof maxScoreProof;
-    std::vector<Scalar> decryptedTalliedScores = servers[0].tally_scores(
+    std::vector<Scalar> decryptedTalliedScores = servers[which].tally_scores(
                                                     tallyProofs);
     mpz_class maxScorePossibleThisRound =
-        servers[0].get_max_possible_score(maxScoreProof).toInt() *
-        MAX_ALLOWED_VOTE;
-    mpz_class topOfScoreRange = decryptedTalliedScores.size() * MAX_ALLOWED_VOTE;
+        servers[which].get_max_possible_score(maxScoreProof).toInt() *
+        PrsonaBase::get_max_allowed_vote();
+    mpz_class topOfScoreRange =
+        decryptedTalliedScores.size() * PrsonaBase::get_max_allowed_vote();
 
     for (size_t i = 0; i < decryptedTalliedScores.size(); i++)
     {
@@ -306,17 +432,17 @@ std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
 
         // Give the server the new weights,
         // to get passed around to the other servers
-        servers[0].bgn_system.encrypt(
-            servers[0].previousVoteTallies[i], decryptedTalliedScores[i]);
+        servers[which].bgnSystem.encrypt(
+            servers[which].previousVoteTallies[i], decryptedTalliedScores[i]);
         
-        retval[i].mask = servers[0].currentPseudonyms[i] * currMask;
+        retval[i].mask = servers[which].currentPseudonyms[i] * currMask;
         retval[i].encryptedMessage =
             (nextGenerator * currMask) +
-            (servers[0].get_blinding_generator() * decryptedTalliedScores[i]);
+            (servers[which].get_blinding_generator() * decryptedTalliedScores[i]);
     }
 
-    servers[0].currentUserEncryptedTallies = retval;
-    servers[0].currentTallyProofs = tallyProofs;
+    servers[which].currentUserEncryptedTallies = retval;
+    servers[which].currentTallyProofs = tallyProofs;
     return retval;
 }
 
@@ -326,7 +452,7 @@ std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
 
 // Completely normal binary search
 size_t PrsonaServerEntity::binary_search(
-    const Curvepoint& shortTermPublicKey) const
+    const Curvepoint& shortTermPublicKey, size_t which) const
 {
-    return servers[0].binary_search(shortTermPublicKey);
+    return servers[which].binary_search(shortTermPublicKey);
 }