123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- #include <iostream>
- #include <algorithm>
- #include <random>
- #include <chrono>
- #include "BGN.hpp"
- #include "client.hpp"
- #include "server.hpp"
- #include "serverEntity.hpp"
- using namespace std;
- // Initialize the classes we use
- void initialize_prsona_classes()
- {
- Scalar::init();
- PrsonaBase::init();
- }
- // Quick and dirty mean calculation (used for averaging timings)
- double mean(
- vector<double> xx)
- {
- return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
- }
- void print_user_scores(
- const vector<PrsonaClient>& users)
- {
- std::cout << "<";
- for (size_t i = 0; i < users.size(); i++)
- std::cout << users[i].get_score() << (i == users.size() - 1 ? ">" : " ");
- std::cout << std::endl;
- }
- bool test_proof_output(
- const vector<Proof>& pi)
- {
- vector<Proof> copy;
- stringstream buffer;
- for (size_t i = 0; i < pi.size(); i++)
- {
- Proof currProof;
- buffer << pi[i];
- buffer >> currProof;
- copy.push_back(currProof);
- }
- bool retval = true;
- for (size_t i = 0; i < pi.size(); i++)
- {
- if (!(copy[i] == pi[i]))
- cout << "FAILURE at index " << i+1 << " of " << pi.size() << endl;
- retval = retval && copy[i] == pi[i];
- }
- cout << "TEST PROOF OUTPUT: " << (retval ? "PASSED" : "FAILED") << endl;
- return retval;
- }
- // Time how long it takes to make a proof of valid votes
- vector<double> make_votes(
- default_random_engine& generator,
- vector<vector<TwistBipoint>>& newEncryptedVotes,
- vector<vector<Proof>>& validVoteProofs,
- const vector<PrsonaClient>& users,
- const PrsonaServerEntity& servers,
- size_t numVotes)
- {
- vector<double> retval;
- uniform_int_distribution<int> voteDistribution(0, PrsonaBase::get_max_allowed_vote());
- size_t numUsers = users.size();
- newEncryptedVotes.clear();
- for (size_t i = 0; i < numUsers; i++)
- {
- // Make the correct number of new votes, but shuffle where they go
- vector<Scalar> votes;
- vector<bool> replaces;
- for (size_t j = 0; j < numUsers; j++)
- {
- votes.push_back(Scalar(voteDistribution(generator)));
- replaces.push_back(j < numVotes);
- }
- shuffle(replaces.begin(), replaces.end(), generator);
-
- Proof baseProof;
- vector<Proof> fullProof;
- Twistpoint shortTermPublicKey = users[i].get_short_term_public_key();
- vector<TwistBipoint> currEncryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
-
- vector<Proof> currVoteProof;
- chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
- currEncryptedVotes = users[i].make_votes(currVoteProof, fullProof, currEncryptedVotes, votes, replaces);
- chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
- newEncryptedVotes.push_back(currEncryptedVotes);
- validVoteProofs.push_back(currVoteProof);
- chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
- retval.push_back(time_span.count());
- }
- return retval;
- }
- // Time how long it takes to validate a proof of valid votes
- vector<double> transmit_votes_to_servers(
- const vector<vector<TwistBipoint>>& newEncryptedVotes,
- const vector<vector<Proof>>& validVoteProofs,
- const vector<PrsonaClient>& users,
- PrsonaServerEntity& servers)
- {
- vector<double> retval;
- size_t numUsers = users.size();
- size_t numServers = servers.get_num_servers();
- for (size_t i = 0; i < numUsers; i++)
- {
- Proof ownerProof;
- Twistpoint shortTermPublicKey = users[i].get_short_term_public_key(ownerProof);
- for (size_t j = 0; j < numServers; j++)
- {
- chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
- servers.receive_vote(validVoteProofs[i], newEncryptedVotes[i], shortTermPublicKey, j);
- chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
- chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
- retval.push_back(time_span.count());
- }
-
- }
- return retval;
- }
- // Time how long it takes to do the operations associated with an epoch
- double epoch(
- PrsonaServerEntity& servers)
- {
- // Do the epoch server calculations
- chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
- servers.epoch();
- chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
- // Return the timing of the epoch server calculations
- chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
- return time_span.count();
- }
- // Time how long it takes each user to decrypt their new scores
- vector<double> transmit_epoch_updates(
- vector<PrsonaClient>& users,
- const PrsonaServerEntity& servers)
- {
- vector<double> retval;
- size_t numUsers = users.size();
- for (size_t i = 0; i < numUsers; i++)
- {
- chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
- servers.transmit_updates(users[i]);
- chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
- chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
- retval.push_back(time_span.count());
- }
- return retval;
- }
- // Test if the proof of reputation level is working as expected
- void test_reputation_proof(
- default_random_engine& generator,
- const PrsonaServerEntity& servers,
- const PrsonaClient& a,
- const PrsonaClient& b)
- {
- bool flag;
- mpz_class aScore = a.get_score().toInt();
- size_t intScore = 0;
- while (intScore < aScore)
- intScore++;
- intScore = (intScore == 0 ? 1 : intScore);
- uniform_int_distribution<size_t> thresholdDistribution(0, intScore-1);
- Scalar goodThreshold(thresholdDistribution(generator));
- Scalar badThreshold(aScore + 1);
- Twistpoint shortTermPublicKey = a.get_short_term_public_key();
- vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold, servers.get_num_clients());
- Proof baseProof;
- vector<Proof> fullProof;
- EGCiphertext currEncryptedScore = servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
- flag = b.verify_reputation_proof(goodRepProof, shortTermPublicKey, goodThreshold, fullProof, currEncryptedScore);
- cout << "TEST VALID REPUTATION PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
-
- vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold, servers.get_num_clients());
- baseProof.clear();
- fullProof.clear();
- currEncryptedScore = servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
- flag = b.verify_reputation_proof(badRepProof, shortTermPublicKey, badThreshold, fullProof, currEncryptedScore);
- cout << "TEST INVALID REPUTATION PROOF: " << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" ) << endl << endl;
- }
- // Test if the proof of valid votes is working as expected
- void test_vote_proof(
- default_random_engine& generator,
- const PrsonaClient& user,
- PrsonaServerEntity& servers)
- {
- size_t numUsers = servers.get_num_clients();
- vector<Scalar> votes;
- vector<bool> replaces;
- bool flag;
- for (size_t i = 0; i < numUsers; i++)
- {
- votes.push_back(Scalar(1));
- replaces.push_back(true);
- }
- vector<Proof> validVoteProof;
- Proof baseProof;
- vector<Proof> fullProof;
- Twistpoint shortTermPublicKey = user.get_short_term_public_key();
- vector<TwistBipoint> encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
- encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
- flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
- cout << "TEST REPLACE VOTE PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
- for (size_t i = 0; i < numUsers; i++)
- replaces[i] = false;
- baseProof.clear();
- fullProof.clear();
- encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
- encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
- flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
- cout << "TEST RERANDOMIZE VOTE PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
- for (size_t i = 0; i < numUsers; i++)
- {
- votes[i] = Scalar(3);
- replaces[i] = true;
- }
- baseProof.clear();
- fullProof.clear();
- encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
- encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
- flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
- cout << "TEST INVALID REPLACE VOTE PROOF: " << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" ) << endl << endl;
- }
- void check_vote_matrix_updates()
- {
- size_t numServers = 2;
- size_t numUsers = 3;
- cout << "Testing how the vote matrix updates." << endl;
- PrsonaBase::set_client_malicious();
- // Entities we operate with
- PrsonaServerEntity servers(numServers);
- vector<Proof> elGamalBlindGeneratorProof;
- BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
- Twistpoint elGamalBlindGenerator = servers.get_blinding_generator(elGamalBlindGeneratorProof);
- vector<PrsonaClient> users;
- for (size_t i = 0; i < numUsers; i++)
- {
- PrsonaClient currUser(elGamalBlindGeneratorProof, elGamalBlindGenerator, bgnPublicKey, numServers);
- users.push_back(currUser);
- servers.add_new_client(users[i]);
- }
- Proof pseudonymsProof;
- vector<Twistpoint> currentPseudonyms = servers.get_current_pseudonyms(pseudonymsProof);
- cout << "Making votes." << endl;
- for (size_t i = 0; i < numUsers; i++)
- {
- Twistpoint shortTermPublicKey = users[i].get_short_term_public_key();
- size_t myIndex = users[i].binary_search(currentPseudonyms, shortTermPublicKey);
- cout << "User " << i+1 << " has initial index " << myIndex << endl;
- vector<Scalar> votes;
- vector<bool> replaces;
- for (size_t j = 0; j < numUsers; j++)
- {
- if (j == myIndex)
- votes.push_back(Scalar(2));
- else if (j > myIndex)
- votes.push_back(Scalar(1));
- else
- votes.push_back(Scalar(0));
- replaces.push_back(true);
- }
- Proof baseProof;
- vector<Proof> fullProof;
- vector<TwistBipoint> currEncryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
- fullProof.push_back(baseProof);
- servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
- vector<Proof> currVoteProof;
- currEncryptedVotes = users[i].make_votes(currVoteProof, fullProof, currEncryptedVotes, votes, replaces);
- servers.receive_vote(currVoteProof, currEncryptedVotes, shortTermPublicKey);
- cout << "User " << i+1 << " now has the following votes:" << endl;
- servers.print_current_votes_by(shortTermPublicKey);
- }
-
- servers.print_votes();
-
- epoch(servers);
- cout << "First epoch done." << endl;
-
- transmit_epoch_updates(users, servers);
- cout << "Updates given to users." << endl;
- servers.print_votes();
- for (size_t i = 0; i < numUsers; i++)
- {
- Proof ownerProof;
- Twistpoint shortTermPublicKey = users[i].get_short_term_public_key(ownerProof);
- cout << "User " << i+1 << " now has the following votes:" << endl;
- servers.print_current_votes_by(shortTermPublicKey);
- }
- }
- int main(int argc, char *argv[])
- {
- initialize_prsona_classes();
- // Defaults
- size_t numServers = 2;
- size_t numUsers = 5;
- size_t numRounds = 3;
- size_t numVotesPerRound = 3;
- size_t lambda = 0;
- bool maliciousServers = true;
- bool maliciousClients = true;
- string seedStr = "seed";
- // Potentially accept command line inputs
- if (argc > 1)
- numServers = atoi(argv[1]);
- if (argc > 2)
- numUsers = atoi(argv[2]);
- if (argc > 3)
- numRounds = atoi(argv[3]);
- if (argc > 4)
- numVotesPerRound = atoi(argv[4]);
- if (argc > 5)
- lambda = atoi(argv[5]);
- if (argc > 6)
- maliciousServers = argv[6][0] == 't' || argv[6][0] == 'T';
- if (argc > 7)
- maliciousClients = argv[7][0] == 't' || argv[7][0] == 'T';
- if (argc > 8)
- seedStr = argv[8];
- cout << "Running the protocol with the following parameters: " << endl;
- cout << numServers << " PRSONA servers" << endl;
- cout << numUsers << " participants (voters/votees)" << endl;
- cout << numRounds << " epochs" << endl;
- cout << numVotesPerRound << " new (random) votes by each user per epoch" << endl;
- cout << "Proof batching " << (lambda > 0 ? "IS" : "is NOT") << " in use." << (lambda > 0 ? " Batch parameter: " : "");
- if (lambda > 0)
- cout << lambda;
- cout << endl;
- cout << "Servers are set to " << (maliciousServers ? "MALICIOUS" : "HBC") << " security" << endl;
- cout << "Clients are set to " << (maliciousClients ? "MALICIOUS" : "HBC") << " security" << endl;
- cout << "Current randomness seed: \"" << seedStr << "\"" << endl;
- cout << endl;
- // Set malicious flags where necessary
- if (maliciousServers)
- PrsonaBase::set_server_malicious();
- if (maliciousClients)
- PrsonaBase::set_client_malicious();
- if (lambda > 0)
- PrsonaBase::set_lambda(lambda);
- // Entities we operate with
- PrsonaServerEntity servers(numServers);
- vector<Proof> elGamalBlindGeneratorProof;
- BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
- Twistpoint elGamalBlindGenerator = servers.get_blinding_generator(elGamalBlindGeneratorProof);
- test_proof_output(elGamalBlindGeneratorProof);
- cout << "Initialization: adding users to system" << endl << endl;
- vector<PrsonaClient> users;
- for (size_t i = 0; i < numUsers; i++)
- {
- PrsonaClient currUser(elGamalBlindGeneratorProof, elGamalBlindGenerator, bgnPublicKey, numServers);
- users.push_back(currUser);
- servers.add_new_client(users[i]);
- }
- // Seeded randomness for random votes used in epoch
- seed_seq seed(seedStr.begin(), seedStr.end());
- default_random_engine generator(seed);
- // Do the epoch operations
- for (size_t i = 0; i < numRounds; i++)
- {
- vector<double> timings;
- cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
-
- vector<vector<TwistBipoint>> newEncryptedVotes;
- vector<vector<Proof>> validVoteProofs;
- timings = make_votes(generator, newEncryptedVotes, validVoteProofs, users, servers, numVotesPerRound);
-
- cout << "Vote generation (with proofs): " << mean(timings) << " seconds per user" << endl;
- timings.clear();
- timings = transmit_votes_to_servers(newEncryptedVotes, validVoteProofs, users, servers);
- cout << "Vote validation: " << mean(timings) << " seconds per vote vector/server" << endl;
- timings.clear();
- timings.push_back(epoch(servers));
-
- cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
- timings.clear();
- timings = transmit_epoch_updates(users, servers);
- cout << "Transmit epoch updates: " << mean(timings) << " seconds per user" << endl << endl;
- }
- // Pick random users for our tests
- uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
- size_t user_a = userDistribution(generator);
- size_t user_b = user_a;
- while (user_b == user_a)
- user_b = userDistribution(generator);
- test_reputation_proof(generator, servers, users[user_a], users[user_b]);
- test_vote_proof(generator, users[user_a], servers);
- return 0;
- }
|