Browse Source

Added proofs of valid votes by users, as well as more timing data to be printed when the program is run

tristangurtler 3 years ago
parent
commit
d6215063bb
7 changed files with 612 additions and 219 deletions
  1. 16 9
      prsona/inc/client.hpp
  2. 17 11
      prsona/inc/server.hpp
  3. 9 3
      prsona/inc/serverEntity.hpp
  4. 183 28
      prsona/src/client.cpp
  5. 243 142
      prsona/src/main.cpp
  6. 110 21
      prsona/src/server.cpp
  7. 34 5
      prsona/src/serverEntity.cpp

+ 16 - 9
prsona/inc/client.hpp

@@ -20,10 +20,11 @@ class PrsonaClient {
         // CONSTRUCTORS
         PrsonaClient(
             const BGNPublicKey& serverPublicKey,
+            const Curvepoint& elGamalBlindGenerator,
             const PrsonaServerEntity *servers);
 
         // SETUP FUNCTIONS
-        static void init(const Curvepoint& elGamalBlindGenerator);
+        static void init();
         static void set_server_malicious();
         static void set_client_malicious();
 
@@ -32,10 +33,12 @@ class PrsonaClient {
 
         // SERVER INTERACTIONS
         std::vector<CurveBipoint> make_votes(
-            Proof& pi,
-            const std::vector<CurveBipoint>& currentEncryptedVotes,
-            const std::vector<Scalar>& vote,
-            const std::vector<bool>& replace);
+            std::vector<Proof>& validVoteProof,
+            const Proof& serverProof,
+            const std::vector<CurveBipoint>& oldEncryptedVotes,
+            const std::vector<Scalar>& votes,
+            const std::vector<bool>& replaces
+        ) const;
         void receive_fresh_generator(const Curvepoint& freshGenerator);
         void receive_vote_tally(const Proof& pi, const EGCiphertext& score);
         
@@ -52,12 +55,12 @@ class PrsonaClient {
     private:
         // Constants for clients
         static Curvepoint EL_GAMAL_GENERATOR;
-        static Curvepoint EL_GAMAL_BLIND_GENERATOR;
         static bool SERVER_IS_MALICIOUS;
         static bool CLIENT_IS_MALICIOUS;
         
         // Things bound to the servers permanently
         const BGNPublicKey serverPublicKey;
+        const Curvepoint elGamalBlindGenerator;
         const PrsonaServerEntity *servers;
 
         // Things bound to the servers (but change regularly)
@@ -98,9 +101,13 @@ class PrsonaClient {
             const Proof& pi, const std::vector<CurveBipoint>& votes) const;
 
         // PROOF GENERATION
-        Proof generate_vote_proof(
-            const std::vector<CurveBipoint>& encryptedVotes,
-            const std::vector<Scalar>& vote) const;
+        std::vector<Proof> generate_vote_proof(
+            const std::vector<bool>& replaces,
+            const std::vector<CurveBipoint>& oldEncryptedVotes,
+            const std::vector<CurveBipoint>& newEncryptedVotes,
+            const std::vector<Scalar>& seeds,
+            const std::vector<Scalar>& votes
+        ) const;
 }; 
 
 #endif

+ 17 - 11
prsona/inc/server.hpp

@@ -16,12 +16,12 @@ class PrsonaServer {
         PrsonaServer(const BGN& other_bgn);
 
         // SETUP FUNCTIONS
-        static Curvepoint init();
+        static void init();
         static void set_server_malicious();
         static void set_client_malicious();
 
         // BASIC PUBLIC SYSTEM INFO GETTERS
-        static Curvepoint get_blinding_generator();
+        Curvepoint get_blinding_generator() const;
         BGNPublicKey get_bgn_public_key() const;
         
         // FRESH GENERATOR CALCULATION
@@ -41,23 +41,23 @@ class PrsonaServer {
             const Proof& proofOfValidKey,
             Proof& proofOfValidAddition,
             const Curvepoint& shortTermPublicKey);
-        void receive_vote(
-            const Proof& pi,
-            const std::vector<CurveBipoint>& votes,
+        bool receive_vote(
+            const std::vector<Proof>& pi,
+            const std::vector<CurveBipoint>& newVotes,
             const Curvepoint& shortTermPublicKey);
 
     private:
         // Constants for servers
         static Curvepoint EL_GAMAL_GENERATOR;
-        static Curvepoint EL_GAMAL_BLIND_GENERATOR;
         static Scalar SCALAR_N;
         static Scalar DEFAULT_TALLY;
         static Scalar DEFAULT_VOTE;
         static bool SERVER_IS_MALICIOUS;
         static bool CLIENT_IS_MALICIOUS;
 
-        // Identical between all servers
+        // Identical between all servers (but collaboratively constructed)
         BGN bgn_system;
+        Curvepoint elGamalBlindGenerator;
 
         // Private; different for each server
         Scalar currentSeed;
@@ -85,6 +85,9 @@ class PrsonaServer {
         // CONSTRUCTOR HELPERS
         const BGN& get_bgn_details() const;
         void initialize_fresh_generator(const Curvepoint& firstGenerator);
+        Curvepoint add_rand_seed_to_generator(
+            const Curvepoint& currGenerator) const;
+        void set_EG_blind_generator(const Curvepoint& currGenerator);
         
         // SCORE TALLYING
         std::vector<Scalar> tally_scores(std::vector<Proof>& tallyProofs);
@@ -129,16 +132,19 @@ class PrsonaServer {
         // BINARY SEARCH
         size_t binary_search(const Curvepoint& index) const;
 
-        // PROOF VERIFICATION
-        bool verify_valid_key_proof(
+        // CLIENT PROOF VERIFICATION
+        bool verify_ownership_proof(
             const Proof& pi,
             const Curvepoint& shortTermPublicKey
         ) const;
         bool verify_vote_proof(
-            const Proof& pi,
-            const std::vector<CurveBipoint>& votes,
+            const std::vector<Proof>& pi,
+            const std::vector<CurveBipoint>& oldVotes,
+            const std::vector<CurveBipoint>& newVotes,
             const Curvepoint& shortTermPublicKey
         ) const;
+
+        // SERVER PROOF VERIFICATION
         bool verify_update_proof(
             const Proof& pi
         ) const;

+ 9 - 3
prsona/inc/serverEntity.hpp

@@ -17,6 +17,7 @@ class PrsonaServerEntity {
         Curvepoint get_blinding_generator() const;
         Curvepoint get_fresh_generator() const;
         size_t get_num_clients() const;
+        size_t get_num_servers() const;
 
         // ENCRYPTED DATA GETTERS
         std::vector<CurveBipoint> get_current_votes_by(
@@ -26,10 +27,15 @@ class PrsonaServerEntity {
 
         // CLIENT INTERACTIONS
         void add_new_client(PrsonaClient& newUser);
-        void receive_vote(
-            const Proof& pi,
-            const std::vector<CurveBipoint>& votes,
+        bool receive_vote(
+            const std::vector<Proof>& pi,
+            const std::vector<CurveBipoint>& newVotes,
             const Curvepoint& shortTermPublicKey);
+        bool receive_vote(
+            const std::vector<Proof>& pi,
+            const std::vector<CurveBipoint>& newVotes,
+            const Curvepoint& shortTermPublicKey,
+            size_t which);
         void transmit_updates(PrsonaClient& currUser) const;
 
         // EPOCH

+ 183 - 28
prsona/src/client.cpp

@@ -11,7 +11,6 @@ const int MAX_ALLOWED_VOTE = 2;
  * (or at least, with g++, whenever it would execute is not at a useful time)
  * so we have an init() function to actually put the correct values in them. */
 Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
-Curvepoint PrsonaClient::EL_GAMAL_BLIND_GENERATOR = Curvepoint();
 bool PrsonaClient::SERVER_IS_MALICIOUS = false;
 bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
 
@@ -43,15 +42,17 @@ mpz_class bit(mpz_class x)
 
 PrsonaClient::PrsonaClient(
     const BGNPublicKey& serverPublicKey,
+    const Curvepoint& elGamalBlindGenerator,
     const PrsonaServerEntity* servers)
     : serverPublicKey(serverPublicKey),
+        elGamalBlindGenerator(elGamalBlindGenerator),
         servers(servers),
         max_checked(0)
 {
     longTermPrivateKey.set_random();
     inversePrivateKey = longTermPrivateKey.curveInverse();
 
-    decryption_memoizer[EL_GAMAL_BLIND_GENERATOR * max_checked] = max_checked;
+    decryption_memoizer[elGamalBlindGenerator * max_checked] = max_checked;
 }
 
 /*
@@ -59,10 +60,9 @@ PrsonaClient::PrsonaClient(
  */
 
 // Must be called once before any usage of this class
-void PrsonaClient::init(const Curvepoint& elGamalBlindGenerator)
+void PrsonaClient::init()
 {
     EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
-    EL_GAMAL_BLIND_GENERATOR = elGamalBlindGenerator;
 }
 
 void PrsonaClient::set_server_malicious()
@@ -95,32 +95,37 @@ Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
  * You may really want to make currentEncryptedVotes a member variable, 
  * but it doesn't behave correctly when adding new clients after this one. */
 std::vector<CurveBipoint> PrsonaClient::make_votes(
-    Proof& pi,
-    const std::vector<CurveBipoint>& currentEncryptedVotes,
+    std::vector<Proof>& validVoteProof,
+    const Proof& serverProof,
+    const std::vector<CurveBipoint>& oldEncryptedVotes,
     const std::vector<Scalar>& votes,
-    const std::vector<bool>& replaces)
+    const std::vector<bool>& replaces) const
 {
-    std::vector<CurveBipoint> retval;
+    std::vector<Scalar> seeds(oldEncryptedVotes.size());
+    std::vector<CurveBipoint> newEncryptedVotes(oldEncryptedVotes.size());
 
-    if (!verify_valid_votes_proof(pi, currentEncryptedVotes))
+    if (!verify_valid_votes_proof(serverProof, oldEncryptedVotes))
     {
         std::cerr << "Could not verify proof of valid votes." << std::endl;
-        return retval;
+        return newEncryptedVotes;
     }
 
     for (size_t i = 0; i < votes.size(); i++)
     {
-        CurveBipoint currScore;
         if (replaces[i])
-            serverPublicKey.encrypt(currScore, votes[i]);
+        {
+            newEncryptedVotes[i] = serverPublicKey.encrypt(seeds[i], votes[i]);
+        }
         else
-            currScore = serverPublicKey.rerandomize(currentEncryptedVotes[i]);
-
-        retval.push_back(currScore);
+        {
+            newEncryptedVotes[i] =
+                serverPublicKey.rerandomize(seeds[i], oldEncryptedVotes[i]);
+        }
     }
 
-    pi = generate_vote_proof(retval, votes);
-    return retval;
+    validVoteProof = generate_vote_proof(
+        replaces, oldEncryptedVotes, newEncryptedVotes, seeds, votes);
+    return newEncryptedVotes;
 }
 
 // Get a new fresh generator (happens at initialization and during each epoch)
@@ -202,7 +207,7 @@ std::vector<Proof> PrsonaClient::generate_reputation_proof(
         Proof currProof;
         Curvepoint g, h, c, c_a, c_b;
         g = currentEncryptedScore.mask;
-        h = EL_GAMAL_BLIND_GENERATOR;
+        h = elGamalBlindGenerator;
     
         mpz_class currBit = bit(proofVal & (1 << i));
         Scalar a, s, t, m, r;
@@ -292,7 +297,7 @@ bool PrsonaClient::verify_reputation_proof(
         Curvepoint c, g, h;
         c = pi[i].partialUniversals[0];
         g = encryptedScore.mask;
-        h = EL_GAMAL_BLIND_GENERATOR;
+        h = elGamalBlindGenerator;
 
         X = X + c * Scalar(1 << (i - 1));
 
@@ -323,7 +328,7 @@ bool PrsonaClient::verify_reputation_proof(
 
     Curvepoint scoreCommitment =
         encryptedScore.encryptedMessage +
-        EL_GAMAL_BLIND_GENERATOR * negThreshold;
+        elGamalBlindGenerator * negThreshold;
     
     return X == scoreCommitment;
 }
@@ -360,12 +365,12 @@ void PrsonaClient::decrypt_score(const EGCiphertext& score)
 
     // If not, iterate until we find it (adding everything to the memoization)
     max_checked++;
-    Curvepoint decryptionCandidate = EL_GAMAL_BLIND_GENERATOR * max_checked;
+    Curvepoint decryptionCandidate = elGamalBlindGenerator * max_checked;
     while (decryptionCandidate != hashedDecrypted)
     {
         decryption_memoizer[decryptionCandidate] = max_checked;
 
-        decryptionCandidate = decryptionCandidate + EL_GAMAL_BLIND_GENERATOR;
+        decryptionCandidate = decryptionCandidate + elGamalBlindGenerator;
         max_checked++;
     }
     decryption_memoizer[decryptionCandidate] = max_checked;
@@ -486,18 +491,168 @@ bool PrsonaClient::verify_valid_votes_proof(
  * PROOF GENERATION
  */
 
-Proof PrsonaClient::generate_vote_proof(
-    const std::vector<CurveBipoint>& encryptedVotes,
-    const std::vector<Scalar>& vote) const
+std::vector<Proof> PrsonaClient::generate_vote_proof(
+    const std::vector<bool>& replaces,
+    const std::vector<CurveBipoint>& oldEncryptedVotes,
+    const std::vector<CurveBipoint>& newEncryptedVotes,
+    const std::vector<Scalar>& seeds,
+    const std::vector<Scalar>& votes) const
 {
-    Proof retval;
+    std::vector<Proof> retval;
 
+    // Base case
     if (!CLIENT_IS_MALICIOUS)
     {
-        retval.basic = "PROOF";
+        Proof currProof;
+        currProof.basic = "PROOF";
+        retval.push_back(currProof);
+        
         return retval;
     }
 
-    retval.basic = "PROOF";
+    // The first need is to prove that we are the stpk we claim we are
+    retval.push_back(generate_ownership_proof());
+
+    // Then, we iterate over all votes for the proofs that they are correct
+    for (size_t i = 0; i < replaces.size(); i++)
+    {
+        std::stringstream oracleInput;
+        oracleInput << serverPublicKey.get_bipoint_curvegen()
+            << serverPublicKey.get_bipoint_curve_subgroup_gen()
+            << oldEncryptedVotes[i] << newEncryptedVotes[i];
+        
+        /* This proof structure is documented in my notes.
+         * It's inspired by the proof in Fig. 1 at
+         * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
+         * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
+         *
+         * The rerandomization part is just a slight variation on an
+         * ordinary Schnorr proof, so that part's less scary. */
+        if (replaces[i])     // CASE: Make new vote
+        {
+            Proof currProof;
+
+            Scalar c_r, z_r, a, b, s_1, s_2, t_1, t_2;
+            c_r.set_random();
+            z_r.set_random();
+            a.set_random();
+            b.set_random();
+            s_1.set_random();
+            s_2.set_random();
+            t_1.set_random();
+            t_2.set_random();
+
+            CurveBipoint U =
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_r +
+                oldEncryptedVotes[i] * c_r -
+                newEncryptedVotes[i] * c_r;
+
+            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * a +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * s_1;
+
+            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * b +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * s_2;
+
+            Scalar power = (a.curveAdd(b)).curveMult(votes[i].curveMult(votes[i]));
+            power =
+                power.curveSub((a.curveAdd(a).curveAdd(b)).curveMult(votes[i]));
+            CurveBipoint C_c = serverPublicKey.get_bipoint_curvegen() * power +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * t_1;
+            currProof.partialUniversals.push_back(C_c[0]);
+            currProof.partialUniversals.push_back(C_c[1]);
+
+            CurveBipoint C_d =
+                serverPublicKey.get_bipoint_curvegen() *
+                    a.curveMult(b.curveMult(votes[i])) +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * t_2;
+
+            oracleInput << U << C_a << C_b << C_c << C_d;
+
+            Scalar c = oracle(oracleInput.str());
+            Scalar c_n = c.curveSub(c_r);
+            currProof.challengeParts.push_back(c_r);
+            currProof.challengeParts.push_back(c_n);
+
+            Scalar f_1 = (votes[i].curveMult(c_n)).curveAdd(a);
+            Scalar f_2 = (votes[i].curveMult(c_n)).curveAdd(b);
+            Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s_1);
+            Scalar z_nb = (seeds[i].curveMult(c_n)).curveAdd(s_2);
+
+            Scalar t_1_c_n_t_2 = (t_1.curveMult(c_n)).curveAdd(t_2);
+            Scalar f_1_c_n = f_1.curveSub(c_n);
+            Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
+            Scalar z_nc = 
+                (seeds[i].curveMult(f_1_c_n).curveMult(c_n_f_2)).curveAdd(
+                    t_1_c_n_t_2);
+
+            currProof.responseParts.push_back(z_r);
+            currProof.responseParts.push_back(f_1);
+            currProof.responseParts.push_back(f_2);
+            currProof.responseParts.push_back(z_na);
+            currProof.responseParts.push_back(z_nb);
+            currProof.responseParts.push_back(z_nc);
+
+            retval.push_back(currProof);
+        }
+        else                // CASE: Rerandomize existing vote
+        {
+            Proof currProof;
+
+            Scalar u, commitmentLambda_1, commitmentLambda_2,
+                c_n, z_na, z_nb, z_nc, f_1, f_2;
+            u.set_random();
+            commitmentLambda_1.set_random();
+            commitmentLambda_2.set_random();
+            c_n.set_random();
+            z_na.set_random();
+            z_nb.set_random();
+            z_nc.set_random();
+            f_1.set_random();
+            f_2.set_random();
+
+            CurveBipoint U =
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * u;
+
+            CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * f_1 +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_na -
+                newEncryptedVotes[i] * c_n;
+
+            CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * f_2 +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nb -
+                newEncryptedVotes[i] * c_n;
+
+            CurveBipoint C_c = 
+                serverPublicKey.get_bipoint_curvegen() * commitmentLambda_1 +
+                serverPublicKey.get_bipoint_curve_subgroup_gen() *
+                    commitmentLambda_2;
+            currProof.partialUniversals.push_back(C_c[0]);
+            currProof.partialUniversals.push_back(C_c[1]);
+
+            Scalar f_1_c_n = f_1.curveSub(c_n);
+            Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
+            CurveBipoint C_d =
+                serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nc -
+                newEncryptedVotes[i] * f_1_c_n.curveMult(c_n_f_2) -
+                C_c * c_n;
+
+            oracleInput << U << C_a << C_b << C_c << C_d;
+
+            Scalar c = oracle(oracleInput.str());
+            Scalar c_r = c.curveSub(c_n);
+            currProof.challengeParts.push_back(c_r);
+            currProof.challengeParts.push_back(c_n);
+
+            Scalar z_r = u.curveAdd(c_r.curveMult(seeds[i]));
+            currProof.responseParts.push_back(z_r);
+            currProof.responseParts.push_back(f_1);
+            currProof.responseParts.push_back(f_2);
+            currProof.responseParts.push_back(z_na);
+            currProof.responseParts.push_back(z_nb);
+            currProof.responseParts.push_back(z_nc);
+
+            retval.push_back(currProof);
+        }
+    }
+
     return retval;
 }

+ 243 - 142
prsona/src/main.cpp

@@ -15,155 +15,125 @@ using namespace std;
 
 const int MAX_ALLOWED_VOTE = 2;
 
-// int argparse(
-//     int argc,
-//     char *argv[],
-//     size_t &numServers,
-//     size_t &numClients,
-//     size_t &numRounds,
-//     size_t &numVotes,
-//     bool &maliciousServers,
-//     bool &maliciousUsers,
-//     string &seedStr) 
-// {
-//     string config_file;
-
-//     // Declare a group of options that will be 
-//     // allowed only on command line
-//     po::options_description generic("General options");
-//     generic.add_options()
-//         ("help", "produce this help message")
-//         ("config,c", po::value<string>(&config_file)->default_value(""),
-//               "name of a configuration file")
-//         ;
-
-//     // Declare a group of options that will be 
-//     // allowed both on command line and in
-//     // config file
-//     po::options_description config("Configuration");
-//     config.add_options()
-//         ("malicious-servers,M",
-//                "presence of this flag indicates servers will operate in malicious model")
-//         ("malicious-users,U",
-//                "presence of this flag indicates users will operate in malicious model")
-//         ("seed", po::value<string>(&seedStr)->default_value("default"),
-//             "the random seed to use for this test")
-//         ;
-
-//     // Hidden options, will be allowed both on command line and
-//     // in config file, but will not be shown to the user.
-//     po::options_description hidden("Hidden options");
-//     hidden.add_options()
-//         ("number-servers,S", po::value<size_t>(&numServers)->default_value(2), 
-//               "number of servers in test")
-//         ("number-users,N", po::value<size_t>(&numClients)->default_value(5), 
-//               "number of users in test")
-//         ("number-rounds,R", po::value<size_t>(&numRounds)->default_value(3), 
-//               "number of rounds to perform in test")
-//         ("number-votes,V", po::value<size_t>(&numVotes)->default_value(3), 
-//               "number of votes each user makes per round during test")
-//         ;
-    
-//     po::options_description cmdline_options;
-//     cmdline_options.add(generic).add(config).add(hidden);
-
-//     po::options_description config_file_options;
-//     config_file_options.add(config).add(hidden);
-
-//     po::options_description visible("Allowed options");
-//     visible.add(generic).add(config);
-    
-//     po::positional_options_description p;
-//     p.add("number-servers", 1);
-//     p.add("number-users", 1);
-//     p.add("number-rounds", 1);
-//     p.add("number-votes", 1);
-    
-//     po::variables_map vm;
-//     store(po::command_line_parser(argc, argv).
-//           options(cmdline_options).positional(p).run(), vm);
-//     notify(vm);
-
-//     if (!config_file.empty())
-//     {
-//         ifstream config(config_file.c_str());
-//         if (!config)
-//         {
-//             cerr << "Cannot open config file: " << config_file << "\n";
-//             return 2;
-//         }
-//         else
-//         {
-//             store(parse_config_file(config, config_file_options), vm);
-//             notify(vm);
-//         }
-//     }
-
-//     if (vm.count("help"))
-//     {
-//         cout << visible << endl;
-//         return 1;
-//     }
-
-//     maliciousServers = vm.count("malicious-servers");
-//     maliciousUsers = vm.count("malicious-users");
-
-//     return 0;
-// }
-
 // Initialize the classes we use
 void initialize_prsona_classes()
 {
     Scalar::init();
-    Curvepoint elGamalBlindGenerator = PrsonaServer::init();
-    PrsonaClient::init(elGamalBlindGenerator);
+    PrsonaServer::init();
+    PrsonaClient::init();
+}
+
+// Quick and dirty mean calculation (used for averaging timings)
+double mean(vector<double> xx)
+{
+    return accumulate(xx.begin(), xx.end(), 0.0) / xx.size(); 
 }
 
-// Do an epoch (including votes, etc.), and return the timing to print out
-double epoch(
+// Time how long it takes to make a proof of valid votes
+vector<double> make_votes(
     default_random_engine& generator,
-    PrsonaServerEntity& servers,
-    vector<PrsonaClient>& users,
+    vector<vector<CurveBipoint>>& newEncryptedVotes,
+    vector<vector<Proof>>& validVoteProofs,
+    const vector<PrsonaClient>& users,
+    const PrsonaServerEntity& servers,
     size_t numVotes)
 {
-    Proof unused;
+    vector<double> retval;
     uniform_int_distribution<int> voteDistribution(0, MAX_ALLOWED_VOTE);
     size_t numUsers = users.size();
+    newEncryptedVotes.clear();
 
     for (size_t i = 0; i < numUsers; i++)
     {
         // Make the correct number of new votes, but shuffle where they go
         vector<Scalar> votes;
-        vector<bool> replace;
+        vector<bool> replaces;
         for (size_t j = 0; j < numUsers; j++)
         {
             votes.push_back(Scalar(voteDistribution(generator)));
-            replace.push_back(j < numVotes);
+            replaces.push_back(j < numVotes);
         }
-        shuffle(replace.begin(), replace.end(), generator);
-
-        // Make the actual votes to give to the servers
-        Proof pi;
-        Curvepoint shortTermPublicKey = users[i].get_short_term_public_key(pi);
-        vector<CurveBipoint> encryptedVotes =
-            servers.get_current_votes_by(pi, shortTermPublicKey);
-        encryptedVotes = users[i].make_votes(
-                            pi, encryptedVotes, votes, replace);
-
-        // Give the servers these new votes
-        servers.receive_vote(pi, encryptedVotes, shortTermPublicKey);
+        shuffle(replaces.begin(), replaces.end(), generator);
+
+        Proof ownerProof;
+        Curvepoint shortTermPublicKey =
+            users[i].get_short_term_public_key(ownerProof);
+        vector<CurveBipoint> currEncryptedVotes =
+            servers.get_current_votes_by(ownerProof, shortTermPublicKey);
+        vector<Proof> currVoteProof;
+        
+        chrono::high_resolution_clock::time_point t0 =
+            chrono::high_resolution_clock::now();
+        currEncryptedVotes = users[i].make_votes(
+                                currVoteProof,
+                                ownerProof,
+                                currEncryptedVotes,
+                                votes,
+                                replaces);
+        chrono::high_resolution_clock::time_point t1 =
+            chrono::high_resolution_clock::now();
+
+        newEncryptedVotes.push_back(currEncryptedVotes);
+        validVoteProofs.push_back(currVoteProof);
+
+        chrono::duration<double> time_span =
+            chrono::duration_cast<chrono::duration<double>>(t1 - t0);
+        retval.push_back(time_span.count());
     }
 
+    return retval;
+}
+
+// Time how long it takes to validate a proof of valid votes
+vector<double> transmit_votes_to_servers(
+    const vector<vector<CurveBipoint>>& newEncryptedVotes,
+    const vector<vector<Proof>>& validVoteProofs,
+    const vector<PrsonaClient>& users,
+    PrsonaServerEntity& servers)
+{
+    vector<double> retval;
+    size_t numUsers = users.size();
+    size_t numServers = servers.get_num_servers();
+
+    for (size_t i = 0; i < numUsers; i++)
+    {
+        Proof ownerProof;
+        Curvepoint shortTermPublicKey =
+            users[i].get_short_term_public_key(ownerProof);
+
+        for (size_t j = 0; j < numServers; j++)
+        {
+            chrono::high_resolution_clock::time_point t0 =
+                chrono::high_resolution_clock::now();
+            servers.receive_vote(
+                validVoteProofs[i],
+                newEncryptedVotes[i],
+                shortTermPublicKey,
+                j);
+            chrono::high_resolution_clock::time_point t1 =
+                chrono::high_resolution_clock::now();
+
+            chrono::duration<double> time_span =
+                chrono::duration_cast<chrono::duration<double>>(t1 - t0);
+            retval.push_back(time_span.count());
+        }
+            
+    }
+
+    return retval;
+}
+
+// Time how long it takes to do the operations associated with an epoch
+double epoch(PrsonaServerEntity& servers)
+{
+    Proof unused;
+
     // Do the epoch server calculations
     chrono::high_resolution_clock::time_point t0 =
         chrono::high_resolution_clock::now();
     servers.epoch(unused);
     chrono::high_resolution_clock::time_point t1 =
         chrono::high_resolution_clock::now();
-    
-    // Transmit the results of the epoch to each user
-    for (size_t i = 0; i < numUsers; i++)
-        servers.transmit_updates(users[i]);
 
     // Return the timing of the epoch server calculations
     chrono::duration<double> time_span =
@@ -171,8 +141,36 @@ double epoch(
     return time_span.count();
 }
 
-void reputation_proof_attempt(default_random_engine& generator, const PrsonaClient& a, const PrsonaClient& b)
+// Time how long it takes each user to decrypt their new scores
+vector<double> transmit_epoch_updates(
+    vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
 {
+    vector<double> retval;
+    size_t numUsers = users.size();
+
+    for (size_t i = 0; i < numUsers; i++)
+    {
+        chrono::high_resolution_clock::time_point t0 =
+            chrono::high_resolution_clock::now();
+        servers.transmit_updates(users[i]);
+        chrono::high_resolution_clock::time_point t1 =
+            chrono::high_resolution_clock::now();
+
+        chrono::duration<double> time_span =
+            chrono::duration_cast<chrono::duration<double>>(t1 - t0);
+        retval.push_back(time_span.count());
+    }
+
+    return retval;
+}
+
+// Test if the proof of reputation level is working as expected
+void test_reputation_proof(
+    default_random_engine& generator,
+    const PrsonaClient& a,
+    const PrsonaClient& b)
+{
+    bool flag;
     mpz_class aScore = a.get_score().toInt();
     int i = 0;
     while (i < aScore)
@@ -182,23 +180,92 @@ void reputation_proof_attempt(default_random_engine& generator, const PrsonaClie
     Scalar goodThreshold(thresholdDistribution(generator));
     Scalar badThreshold(aScore + 1);
 
-    cout << "User A's score: " << aScore << endl;
-    cout << "User A's chosen good threshold: " << goodThreshold << endl;
-    cout << "User A's chosen bad threshold: " << badThreshold << endl;
-
     Proof pi;
     Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
     vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
-    cout << "TEST VALID PROOF:   "
-        << (b.verify_reputation_proof(goodRepProof, shortTermPublicKey, goodThreshold) ?
-            "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
+    flag = b.verify_reputation_proof(
+                goodRepProof, shortTermPublicKey, goodThreshold);
+    cout << "TEST VALID REPUTATION PROOF:     "
+        << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
         << endl;
     
     vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
-    cout << "TEST INVALID PROOF: "
-        << (b.verify_reputation_proof(badRepProof, shortTermPublicKey, badThreshold) ?
-            "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
+    flag = b.verify_reputation_proof(
+                badRepProof, shortTermPublicKey, badThreshold);
+    cout << "TEST INVALID REPUTATION PROOF:   "
+        << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
+        << endl << endl;
+}
+
+// Test if the proof of valid votes is working as expected
+void test_vote_proof(
+    default_random_engine& generator,
+    const PrsonaClient& user,
+    PrsonaServerEntity& servers)
+{
+    size_t numUsers = servers.get_num_clients();
+    vector<Scalar> votes;
+    vector<bool> replaces;
+    bool flag;
+
+    for (size_t i = 0; i < numUsers; i++)
+    {
+        votes.push_back(Scalar(1));
+        replaces.push_back(true); 
+    }
+
+    vector<Proof> validVoteProof;
+    Proof ownerProof;
+    Curvepoint shortTermPublicKey =
+        user.get_short_term_public_key(ownerProof);
+    vector<CurveBipoint> encryptedVotes =
+        servers.get_current_votes_by(ownerProof, shortTermPublicKey);
+    encryptedVotes =
+        user.make_votes(
+            validVoteProof, ownerProof, encryptedVotes, votes, replaces);
+
+    flag = servers.receive_vote(
+                validVoteProof, encryptedVotes, shortTermPublicKey);
+    cout << "TEST REPLACE VOTE PROOF:         "
+        << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
+        << endl;
+
+    for (size_t i = 0; i < numUsers; i++)
+    {
+        replaces[i] = false;
+    }
+
+    shortTermPublicKey = user.get_short_term_public_key(ownerProof);
+    encryptedVotes = 
+        servers.get_current_votes_by(ownerProof, shortTermPublicKey);
+    encryptedVotes =
+        user.make_votes(
+            validVoteProof, ownerProof, encryptedVotes, votes, replaces);
+
+    flag = servers.receive_vote(
+                validVoteProof, encryptedVotes, shortTermPublicKey);
+    cout << "TEST RERANDOMIZE VOTE PROOF:     "
+        << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
         << endl;
+
+    for (size_t i = 0; i < numUsers; i++)
+    {
+        votes[i] = Scalar(3);
+        replaces[i] = true;
+    }
+
+    shortTermPublicKey = user.get_short_term_public_key(ownerProof);
+    encryptedVotes = 
+        servers.get_current_votes_by(ownerProof, shortTermPublicKey);
+    encryptedVotes =
+        user.make_votes(
+            validVoteProof, ownerProof, encryptedVotes, votes, replaces);
+
+    flag = servers.receive_vote(
+                validVoteProof, encryptedVotes, shortTermPublicKey);
+    cout << "TEST INVALID REPLACE VOTE PROOF: "
+        << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
+        << endl << endl;
 }
 
 int main(int argc, char *argv[])
@@ -229,7 +296,7 @@ int main(int argc, char *argv[])
     cout << numUsers << " participants (voters/votees)" << endl;
     cout << numRounds << " epochs" << endl;
     cout << numVotesPerRound << " new (random) votes by each user per epoch"
-        << endl;
+        << endl << endl;
 
     // Set malicious flags where necessary
     if (maliciousServers)
@@ -246,12 +313,13 @@ int main(int argc, char *argv[])
     // Entities we operate with
     PrsonaServerEntity servers(numServers);
     BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
+    Curvepoint elGamalBlindGenerator = servers.get_blinding_generator();
 
-    cout << "Initialization: adding users to system" << endl;
+    cout << "Initialization: adding users to system" << endl << endl;
     vector<PrsonaClient> users;
     for (size_t i = 0; i < numUsers; i++)
     {
-        PrsonaClient currUser(bgnPublicKey, &servers);
+        PrsonaClient currUser(bgnPublicKey, elGamalBlindGenerator, &servers);
         servers.add_new_client(currUser);
         users.push_back(currUser);
     }
@@ -260,21 +328,54 @@ int main(int argc, char *argv[])
     seed_seq seed(seedStr.begin(), seedStr.end());
     default_random_engine generator(seed);
 
+    // Do the epoch operations
     for (size_t i = 0; i < numRounds; i++)
     {
+        vector<double> timings;
+
         cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
-        double timing = epoch(generator, servers, users, numVotesPerRound);
-        cout << "Server computation: " << timing << " seconds" << endl;
+        
+        vector<vector<CurveBipoint>> newEncryptedVotes;
+        vector<vector<Proof>> validVoteProofs;
+        timings = make_votes(
+            generator,
+            newEncryptedVotes,
+            validVoteProofs,
+            users,
+            servers,
+            numVotesPerRound);
+        
+        cout << "Vote generation (with proofs): " << mean(timings)
+            << " seconds per user" << endl;
+        timings.clear();
+
+        timings = transmit_votes_to_servers(
+            newEncryptedVotes, validVoteProofs, users, servers);
+
+        cout << "Vote validation: " << mean(timings)
+            << " seconds per vote vector/server" << endl;
+        timings.clear();
+
+        timings.push_back(epoch(servers));
+        
+        cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
+        timings.clear();
+
+        timings = transmit_epoch_updates(users, servers);
+
+        cout << "Transmit epoch updates: " << mean(timings)
+            << " seconds per user" << endl << endl;
     }
 
-    uniform_int_distribution<int> userDistribution(0, numUsers - 1);
-    int user_a = userDistribution(generator);
-    int user_b = user_a;
+    // Pick random users for our tests
+    uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
+    size_t user_a = userDistribution(generator);
+    size_t user_b = user_a;
     while (user_b == user_a)
         user_b = userDistribution(generator);
 
-    cout << "Attempting a proof of reputation" << endl;
-    reputation_proof_attempt(generator, users[user_a], users[user_b]);
+    test_reputation_proof(generator, users[user_a], users[user_b]);
+    test_vote_proof(generator, users[user_a], servers);
 
     return 0;
 }

+ 110 - 21
prsona/src/server.cpp

@@ -11,7 +11,6 @@ const int MAX_ALLOWED_VOTE = 2;
  * (or at least, with g++, whenever it would execute is not at a useful time)
  * so we have an init() function to actually put the correct values in them. */
 Curvepoint PrsonaServer::EL_GAMAL_GENERATOR = Curvepoint();
-Curvepoint PrsonaServer::EL_GAMAL_BLIND_GENERATOR = Curvepoint();
 Scalar PrsonaServer::SCALAR_N = Scalar();
 Scalar PrsonaServer::DEFAULT_TALLY = Scalar();
 Scalar PrsonaServer::DEFAULT_VOTE = Scalar();
@@ -44,18 +43,15 @@ PrsonaServer::PrsonaServer(const BGN& other_bgn)
  */
 
 // Must be called once before any usage of this class
-Curvepoint PrsonaServer::init()
+void PrsonaServer::init()
 {
     Scalar lambda;
     lambda.set_random();
 
     EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
-    EL_GAMAL_BLIND_GENERATOR = EL_GAMAL_GENERATOR * lambda;
     SCALAR_N = Scalar(bn_n);
     DEFAULT_TALLY = Scalar(1);
     DEFAULT_VOTE = Scalar(1);
-
-    return EL_GAMAL_BLIND_GENERATOR;
 }
 
 // Call this (once) if using malicious-security servers
@@ -74,9 +70,9 @@ void PrsonaServer::set_client_malicious()
  * BASIC PUBLIC SYSTEM INFO GETTERS
  */
 
-Curvepoint PrsonaServer::get_blinding_generator()
+Curvepoint PrsonaServer::get_blinding_generator() const
 {
-    return EL_GAMAL_BLIND_GENERATOR;
+    return elGamalBlindGenerator;
 }
 
 BGNPublicKey PrsonaServer::get_bgn_public_key() const
@@ -150,7 +146,7 @@ void PrsonaServer::add_new_client(
     Proof& proofOfValidAddition,
     const Curvepoint& shortTermPublicKey)
 {
-    if (!verify_valid_key_proof(proofOfValidKey, shortTermPublicKey))
+    if (!verify_ownership_proof(proofOfValidKey, shortTermPublicKey))
     {
         std::cerr << "Could not verify proof of valid key." << std::endl;
         return;
@@ -199,19 +195,22 @@ void PrsonaServer::add_new_client(
 }
 
 // Receive a new vote row from a user (identified by short term public key).
-void PrsonaServer::receive_vote(
-    const Proof& pi,
-    const std::vector<CurveBipoint>& votes,
+bool PrsonaServer::receive_vote(
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& newVotes,
     const Curvepoint& shortTermPublicKey)
 {
-    if (!verify_vote_proof(pi, votes, shortTermPublicKey))
+    size_t voteSubmitter = binary_search(shortTermPublicKey);
+    std::vector<CurveBipoint> oldVotes = voteMatrix[voteSubmitter];
+
+    if (!verify_vote_proof(pi, oldVotes, newVotes, shortTermPublicKey))
     {
         std::cerr << "Could not verify votes." << std::endl;
-        return;
+        return false;
     }
 
-    size_t voteSubmitter = binary_search(shortTermPublicKey);
-    voteMatrix[voteSubmitter] = votes;
+    voteMatrix[voteSubmitter] = newVotes;
+    return true;
 }
 
 /*********************
@@ -232,6 +231,22 @@ void PrsonaServer::initialize_fresh_generator(const Curvepoint& firstGenerator)
     currentFreshGenerator = firstGenerator;
 }
 
+// To calculate the blind generator for ElGamal, start from the base generator,
+// then have every server call this function on it iteratively (in any order).
+Curvepoint PrsonaServer::add_rand_seed_to_generator(
+    const Curvepoint& currGenerator) const
+{
+    Scalar lambda;
+    lambda.set_random();
+
+    return currGenerator + EL_GAMAL_GENERATOR * lambda;
+}
+
+void PrsonaServer::set_EG_blind_generator(const Curvepoint& currGenerator)
+{
+    elGamalBlindGenerator = currGenerator;
+}
+
 /*
  * SCORE TALLYING
  */
@@ -508,25 +523,99 @@ size_t PrsonaServer::binary_search(const Curvepoint& index) const
  * PROOF VERIFICATION
  */
 
-bool PrsonaServer::verify_valid_key_proof(
+bool PrsonaServer::verify_ownership_proof(
     const Proof& pi,
     const Curvepoint& shortTermPublicKey) const
 {
     if (!CLIENT_IS_MALICIOUS)
         return pi.basic == "PROOF";
 
-    return pi.basic == "PROOF";
+    Scalar c = pi.challengeParts[0];
+    Scalar z = pi.responseParts[0];
+
+    Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
+
+    std::stringstream oracleInput;
+    oracleInput << currentFreshGenerator << shortTermPublicKey << u;
+    
+    return c == oracle(oracleInput.str());
 }
 
 bool PrsonaServer::verify_vote_proof(
-    const Proof& pi,
-    const std::vector<CurveBipoint>& votes,
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& oldVotes,
+    const std::vector<CurveBipoint>& newVotes,
     const Curvepoint& shortTermPublicKey) const
 {
+    // Reject outright if there's no proof to check
+    if (pi.empty())
+    {
+        std::cerr << "Proof was empty, aborting." << std::endl;
+        return false;
+    }
+
+    // Base case
     if (!CLIENT_IS_MALICIOUS)
-        return pi.basic == "PROOF";
+        return pi[0].basic == "PROOF";
 
-    return pi.basic == "PROOF";
+    // User should be able to prove they are who they say they are
+    if (!verify_ownership_proof(pi[0], shortTermPublicKey))
+    {
+        std::cerr << "Schnorr proof failed, aborting." << std::endl;
+        return false;
+    }
+
+    /* This proof structure is documented in my notes.
+     * It's inspired by the proof in Fig. 1 at
+     * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
+     * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
+     *
+     * The rerandomization part is just a slight variation on an
+     * ordinary Schnorr proof, so that part's less scary. */
+    for (size_t i = 1; i < pi.size(); i++)
+    {
+        size_t voteIndex = i - 1;
+        Curvepoint C_c_0, C_c_1;
+        C_c_0 = pi[i].partialUniversals[0];
+        C_c_1 = pi[i].partialUniversals[1];
+
+        CurveBipoint g, h;
+        g = bgn_system.get_public_key().get_bipoint_curvegen();
+        h = bgn_system.get_public_key().get_bipoint_curve_subgroup_gen();
+        CurveBipoint C_c(C_c_0, C_c_1);
+
+        Scalar c_r, c_n, z_r, f_1, f_2, z_na, z_nb, z_nc;
+        c_r = pi[i].challengeParts[0];
+        c_n = pi[i].challengeParts[1];
+
+        z_r  = pi[i].responseParts[0];
+        f_1  = pi[i].responseParts[1];
+        f_2  = pi[i].responseParts[2];
+        z_na = pi[i].responseParts[3];
+        z_nb = pi[i].responseParts[4];
+        z_nc = pi[i].responseParts[5];
+
+        CurveBipoint U, C_a, C_b, C_d;
+        U = h * z_r + oldVotes[voteIndex] * c_r - newVotes[voteIndex] * c_r;
+        C_a = g * f_1 + h * z_na - newVotes[voteIndex] * c_n;
+        C_b = g * f_2 + h * z_nb - newVotes[voteIndex] * c_n;
+
+        Scalar f_1_c_n = f_1.curveSub(c_n);
+        Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
+        C_d = h * z_nc - newVotes[voteIndex] * f_1_c_n.curveMult(c_n_f_2) - C_c * c_n;
+
+        std::stringstream oracleInput;
+        oracleInput << g << h << oldVotes[voteIndex] << newVotes[voteIndex]
+            << U << C_a << C_b << C_c << C_d;
+
+        if (oracle(oracleInput.str()) != c_r.curveAdd(c_n))
+        {
+            std::cerr << "Valid vote proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
+            return false;
+        }
+    }
+
+    return true;
 }
 
 bool PrsonaServer::verify_update_proof(

+ 34 - 5
prsona/src/serverEntity.cpp

@@ -34,6 +34,14 @@ PrsonaServerEntity::PrsonaServerEntity(size_t numServers)
     Curvepoint firstGenerator = get_fresh_generator();
     for (size_t i = 0; i < numServers; i++)
         servers[i].initialize_fresh_generator(firstGenerator);
+
+    // It's important that no server knows the DLOG between g and h for ElGamal,
+    // so have each server collaborate to make h.
+    Curvepoint blindGenerator = PrsonaServer::EL_GAMAL_GENERATOR;
+    for (size_t i = 0; i < numServers; i++)
+        blindGenerator = servers[i].add_rand_seed_to_generator(blindGenerator);
+    for (size_t i = 0; i < numServers; i++)
+        servers[i].set_EG_blind_generator(blindGenerator);
 }
 
 /*
@@ -64,6 +72,11 @@ size_t PrsonaServerEntity::get_num_clients() const
     return servers[0].currentPseudonyms.size();
 }
 
+size_t PrsonaServerEntity::get_num_servers() const
+{
+    return servers.size();
+}
+
 /*
  * ENCRYPTED DATA GETTERS
  */
@@ -139,13 +152,29 @@ void PrsonaServerEntity::add_new_client(PrsonaClient& newUser)
 }
 
 // Receive a new vote row from a user (identified by short term public key).
-void PrsonaServerEntity::receive_vote(
-    const Proof& pi,
-    const std::vector<CurveBipoint>& votes,
+bool PrsonaServerEntity::receive_vote(
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& newVotes,
     const Curvepoint& shortTermPublicKey)
 {
+    bool retval = true;
+
     for (size_t i = 0; i < servers.size(); i++)
-        servers[i].receive_vote(pi, votes, shortTermPublicKey);
+    {
+        retval =
+            retval && servers[i].receive_vote(pi, newVotes, shortTermPublicKey);
+    }
+
+    return retval;
+}
+
+bool PrsonaServerEntity::receive_vote(
+    const std::vector<Proof>& pi,
+    const std::vector<CurveBipoint>& newVotes,
+    const Curvepoint& shortTermPublicKey,
+    size_t which)
+{
+    return servers[which].receive_vote(pi, newVotes, shortTermPublicKey);
 }
 
 // After tallying scores and new vote matrix,
@@ -283,7 +312,7 @@ std::vector<EGCiphertext> PrsonaServerEntity::tally_scores(
         retval[i].mask = servers[0].currentPseudonyms[i] * currMask;
         retval[i].encryptedMessage =
             (nextGenerator * currMask) +
-            (PrsonaServer::get_blinding_generator() * decryptedTalliedScores[i]);
+            (servers[0].get_blinding_generator() * decryptedTalliedScores[i]);
     }
 
     servers[0].currentUserEncryptedTallies = retval;