localMain.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(
  18. vector<double> xx)
  19. {
  20. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  21. }
  22. void print_user_scores(
  23. const vector<PrsonaClient>& users)
  24. {
  25. std::cout << "<";
  26. for (size_t i = 0; i < users.size(); i++)
  27. std::cout << users[i].get_score() << (i == users.size() - 1 ? ">" : " ");
  28. std::cout << std::endl;
  29. }
  30. bool test_proof_output(
  31. const vector<Proof>& pi)
  32. {
  33. vector<Proof> copy;
  34. stringstream buffer;
  35. for (size_t i = 0; i < pi.size(); i++)
  36. {
  37. Proof currProof;
  38. buffer << pi[i];
  39. buffer >> currProof;
  40. copy.push_back(currProof);
  41. }
  42. bool retval = true;
  43. for (size_t i = 0; i < pi.size(); i++)
  44. {
  45. if (!(copy[i] == pi[i]))
  46. cout << "FAILURE at index " << i+1 << " of " << pi.size() << endl;
  47. retval = retval && copy[i] == pi[i];
  48. }
  49. cout << "TEST PROOF OUTPUT: " << (retval ? "PASSED" : "FAILED") << endl;
  50. return retval;
  51. }
  52. // Time how long it takes to make a proof of valid votes
  53. vector<double> make_votes(
  54. default_random_engine& generator,
  55. vector<vector<TwistBipoint>>& newEncryptedVotes,
  56. vector<vector<Proof>>& validVoteProofs,
  57. const vector<PrsonaClient>& users,
  58. const PrsonaServerEntity& servers,
  59. size_t numVotes)
  60. {
  61. vector<double> retval;
  62. uniform_int_distribution<int> voteDistribution(0, PrsonaBase::get_max_allowed_vote());
  63. size_t numUsers = users.size();
  64. newEncryptedVotes.clear();
  65. for (size_t i = 0; i < numUsers; i++)
  66. {
  67. // Make the correct number of new votes, but shuffle where they go
  68. vector<Scalar> votes;
  69. vector<bool> replaces;
  70. for (size_t j = 0; j < numUsers; j++)
  71. {
  72. votes.push_back(Scalar(voteDistribution(generator)));
  73. replaces.push_back(j < numVotes);
  74. }
  75. shuffle(replaces.begin(), replaces.end(), generator);
  76. Proof baseProof;
  77. vector<Proof> fullProof;
  78. Twistpoint shortTermPublicKey = users[i].get_short_term_public_key();
  79. vector<TwistBipoint> currEncryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
  80. fullProof.push_back(baseProof);
  81. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  82. vector<Proof> currVoteProof;
  83. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  84. currEncryptedVotes = users[i].make_votes(currVoteProof, fullProof, currEncryptedVotes, votes, replaces);
  85. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  86. newEncryptedVotes.push_back(currEncryptedVotes);
  87. validVoteProofs.push_back(currVoteProof);
  88. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  89. retval.push_back(time_span.count());
  90. }
  91. return retval;
  92. }
  93. // Time how long it takes to validate a proof of valid votes
  94. vector<double> transmit_votes_to_servers(
  95. const vector<vector<TwistBipoint>>& newEncryptedVotes,
  96. const vector<vector<Proof>>& validVoteProofs,
  97. const vector<PrsonaClient>& users,
  98. PrsonaServerEntity& servers)
  99. {
  100. vector<double> retval;
  101. size_t numUsers = users.size();
  102. size_t numServers = servers.get_num_servers();
  103. for (size_t i = 0; i < numUsers; i++)
  104. {
  105. Proof ownerProof;
  106. Twistpoint shortTermPublicKey = users[i].get_short_term_public_key(ownerProof);
  107. for (size_t j = 0; j < numServers; j++)
  108. {
  109. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  110. servers.receive_vote(validVoteProofs[i], newEncryptedVotes[i], shortTermPublicKey, j);
  111. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  112. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  113. retval.push_back(time_span.count());
  114. }
  115. }
  116. return retval;
  117. }
  118. // Time how long it takes to do the operations associated with an epoch
  119. double epoch(
  120. PrsonaServerEntity& servers)
  121. {
  122. // Do the epoch server calculations
  123. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  124. servers.epoch();
  125. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  126. // Return the timing of the epoch server calculations
  127. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  128. return time_span.count();
  129. }
  130. // Time how long it takes each user to decrypt their new scores
  131. vector<double> transmit_epoch_updates(
  132. vector<PrsonaClient>& users,
  133. const PrsonaServerEntity& servers)
  134. {
  135. vector<double> retval;
  136. size_t numUsers = users.size();
  137. for (size_t i = 0; i < numUsers; i++)
  138. {
  139. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  140. servers.transmit_updates(users[i]);
  141. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  142. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  143. retval.push_back(time_span.count());
  144. }
  145. return retval;
  146. }
  147. // Test if the proof of reputation level is working as expected
  148. void test_reputation_proof(
  149. default_random_engine& generator,
  150. const PrsonaServerEntity& servers,
  151. const PrsonaClient& a,
  152. const PrsonaClient& b)
  153. {
  154. bool flag;
  155. mpz_class aScore = a.get_score().toInt();
  156. int i = 0;
  157. while (i < aScore)
  158. i++;
  159. uniform_int_distribution<int> thresholdDistribution(0, i);
  160. Scalar goodThreshold(thresholdDistribution(generator));
  161. Scalar badThreshold(aScore + 1);
  162. Twistpoint shortTermPublicKey = a.get_short_term_public_key();
  163. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold, servers.get_num_clients());
  164. Proof baseProof;
  165. vector<Proof> fullProof;
  166. EGCiphertext currEncryptedScore = servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
  167. fullProof.push_back(baseProof);
  168. servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
  169. flag = b.verify_reputation_proof(goodRepProof, shortTermPublicKey, goodThreshold, fullProof, currEncryptedScore);
  170. cout << "TEST VALID REPUTATION PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
  171. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold, servers.get_num_clients());
  172. baseProof.clear();
  173. fullProof.clear();
  174. currEncryptedScore = servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
  175. fullProof.push_back(baseProof);
  176. servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
  177. flag = b.verify_reputation_proof(badRepProof, shortTermPublicKey, goodThreshold, fullProof, currEncryptedScore);
  178. cout << "TEST INVALID REPUTATION PROOF: " << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" ) << endl << endl;
  179. }
  180. // Test if the proof of valid votes is working as expected
  181. void test_vote_proof(
  182. default_random_engine& generator,
  183. const PrsonaClient& user,
  184. PrsonaServerEntity& servers)
  185. {
  186. size_t numUsers = servers.get_num_clients();
  187. vector<Scalar> votes;
  188. vector<bool> replaces;
  189. bool flag;
  190. for (size_t i = 0; i < numUsers; i++)
  191. {
  192. votes.push_back(Scalar(1));
  193. replaces.push_back(true);
  194. }
  195. vector<Proof> validVoteProof;
  196. Proof baseProof;
  197. vector<Proof> fullProof;
  198. Twistpoint shortTermPublicKey = user.get_short_term_public_key();
  199. vector<TwistBipoint> encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
  200. fullProof.push_back(baseProof);
  201. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  202. encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
  203. flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
  204. cout << "TEST REPLACE VOTE PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
  205. for (size_t i = 0; i < numUsers; i++)
  206. replaces[i] = false;
  207. baseProof.clear();
  208. fullProof.clear();
  209. encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
  210. fullProof.push_back(baseProof);
  211. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  212. encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
  213. flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
  214. cout << "TEST RERANDOMIZE VOTE PROOF: " << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" ) << endl;
  215. for (size_t i = 0; i < numUsers; i++)
  216. {
  217. votes[i] = Scalar(3);
  218. replaces[i] = true;
  219. }
  220. baseProof.clear();
  221. fullProof.clear();
  222. encryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
  223. fullProof.push_back(baseProof);
  224. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  225. encryptedVotes = user.make_votes(validVoteProof, fullProof, encryptedVotes, votes, replaces);
  226. flag = servers.receive_vote(validVoteProof, encryptedVotes, shortTermPublicKey);
  227. cout << "TEST INVALID REPLACE VOTE PROOF: " << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" ) << endl << endl;
  228. }
  229. void check_vote_matrix_updates()
  230. {
  231. size_t numServers = 2;
  232. size_t numUsers = 3;
  233. cout << "Testing how the vote matrix updates." << endl;
  234. PrsonaBase::set_client_malicious();
  235. // Entities we operate with
  236. PrsonaServerEntity servers(numServers);
  237. vector<Proof> elGamalBlindGeneratorProof;
  238. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  239. Twistpoint elGamalBlindGenerator = servers.get_blinding_generator(elGamalBlindGeneratorProof);
  240. vector<PrsonaClient> users;
  241. for (size_t i = 0; i < numUsers; i++)
  242. {
  243. PrsonaClient currUser(elGamalBlindGeneratorProof, elGamalBlindGenerator, bgnPublicKey, numServers);
  244. users.push_back(currUser);
  245. servers.add_new_client(users[i]);
  246. }
  247. Proof pseudonymsProof;
  248. vector<Twistpoint> currentPseudonyms = servers.get_current_pseudonyms(pseudonymsProof);
  249. cout << "Making votes." << endl;
  250. for (size_t i = 0; i < numUsers; i++)
  251. {
  252. Twistpoint shortTermPublicKey = users[i].get_short_term_public_key();
  253. size_t myIndex = users[i].binary_search(currentPseudonyms, shortTermPublicKey);
  254. cout << "User " << i+1 << " has initial index " << myIndex << endl;
  255. vector<Scalar> votes;
  256. vector<bool> replaces;
  257. for (size_t j = 0; j < numUsers; j++)
  258. {
  259. if (j == myIndex)
  260. votes.push_back(Scalar(2));
  261. else if (j > myIndex)
  262. votes.push_back(Scalar(1));
  263. else
  264. votes.push_back(Scalar(0));
  265. replaces.push_back(true);
  266. }
  267. Proof baseProof;
  268. vector<Proof> fullProof;
  269. vector<TwistBipoint> currEncryptedVotes = servers.get_current_votes_by(baseProof, shortTermPublicKey);
  270. fullProof.push_back(baseProof);
  271. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  272. vector<Proof> currVoteProof;
  273. currEncryptedVotes = users[i].make_votes(currVoteProof, fullProof, currEncryptedVotes, votes, replaces);
  274. servers.receive_vote(currVoteProof, currEncryptedVotes, shortTermPublicKey);
  275. cout << "User " << i+1 << " now has the following votes:" << endl;
  276. servers.print_current_votes_by(shortTermPublicKey);
  277. }
  278. servers.print_votes();
  279. epoch(servers);
  280. cout << "First epoch done." << endl;
  281. transmit_epoch_updates(users, servers);
  282. cout << "Updates given to users." << endl;
  283. servers.print_votes();
  284. for (size_t i = 0; i < numUsers; i++)
  285. {
  286. Proof ownerProof;
  287. Twistpoint shortTermPublicKey = users[i].get_short_term_public_key(ownerProof);
  288. cout << "User " << i+1 << " now has the following votes:" << endl;
  289. servers.print_current_votes_by(shortTermPublicKey);
  290. }
  291. }
  292. int main(int argc, char *argv[])
  293. {
  294. initialize_prsona_classes();
  295. // Defaults
  296. size_t numServers = 2;
  297. size_t numUsers = 5;
  298. size_t numRounds = 3;
  299. size_t numVotesPerRound = 3;
  300. size_t lambda = 0;
  301. bool maliciousServers = true;
  302. bool maliciousClients = true;
  303. string seedStr = "seed";
  304. // Potentially accept command line inputs
  305. if (argc > 1)
  306. numServers = atoi(argv[1]);
  307. if (argc > 2)
  308. numUsers = atoi(argv[2]);
  309. if (argc > 3)
  310. numRounds = atoi(argv[3]);
  311. if (argc > 4)
  312. numVotesPerRound = atoi(argv[4]);
  313. if (argc > 5)
  314. lambda = atoi(argv[5]);
  315. if (argc > 6)
  316. {
  317. bool setting = argv[6][0] == 't' || argv[6][0] == 'T';
  318. maliciousServers = setting;
  319. }
  320. if (argc > 7)
  321. {
  322. bool setting = argv[7][0] == 't' || argv[7][0] == 'T';
  323. maliciousClients = setting;
  324. }
  325. if (argc > 8)
  326. seedStr = argv[8];
  327. cout << "Running the protocol with the following parameters: " << endl;
  328. cout << numServers << " PRSONA servers" << endl;
  329. cout << numUsers << " participants (voters/votees)" << endl;
  330. cout << numRounds << " epochs" << endl;
  331. cout << numVotesPerRound << " new (random) votes by each user per epoch" << endl;
  332. cout << "Proof batching " << (lambda > 0 ? "IS" : "is NOT") << " in use." << (lambda > 0 ? " Batch parameter: " : "");
  333. if (lambda > 0)
  334. cout << lambda;
  335. cout << endl;
  336. cout << "Servers are set to " << (maliciousServers ? "MALICIOUS" : "HBC") << " security" << endl;
  337. cout << "Clients are set to " << (maliciousClients ? "MALICIOUS" : "HBC") << " security" << endl;
  338. cout << "Current randomness seed: \"" << seedStr << "\"" << endl;
  339. cout << endl;
  340. // Set malicious flags where necessary
  341. if (maliciousServers)
  342. PrsonaBase::set_server_malicious();
  343. if (maliciousClients)
  344. PrsonaBase::set_client_malicious();
  345. if (lambda > 0)
  346. PrsonaBase::set_lambda(lambda);
  347. // Entities we operate with
  348. PrsonaServerEntity servers(numServers);
  349. vector<Proof> elGamalBlindGeneratorProof;
  350. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  351. Twistpoint elGamalBlindGenerator = servers.get_blinding_generator(elGamalBlindGeneratorProof);
  352. test_proof_output(elGamalBlindGeneratorProof);
  353. cout << "Initialization: adding users to system" << endl << endl;
  354. vector<PrsonaClient> users;
  355. for (size_t i = 0; i < numUsers; i++)
  356. {
  357. PrsonaClient currUser(elGamalBlindGeneratorProof, elGamalBlindGenerator, bgnPublicKey, numServers);
  358. users.push_back(currUser);
  359. servers.add_new_client(users[i]);
  360. }
  361. // Seeded randomness for random votes used in epoch
  362. seed_seq seed(seedStr.begin(), seedStr.end());
  363. default_random_engine generator(seed);
  364. // Do the epoch operations
  365. for (size_t i = 0; i < numRounds; i++)
  366. {
  367. vector<double> timings;
  368. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  369. vector<vector<TwistBipoint>> newEncryptedVotes;
  370. vector<vector<Proof>> validVoteProofs;
  371. timings = make_votes(generator, newEncryptedVotes, validVoteProofs, users, servers, numVotesPerRound);
  372. cout << "Vote generation (with proofs): " << mean(timings) << " seconds per user" << endl;
  373. timings.clear();
  374. timings = transmit_votes_to_servers(newEncryptedVotes, validVoteProofs, users, servers);
  375. cout << "Vote validation: " << mean(timings) << " seconds per vote vector/server" << endl;
  376. timings.clear();
  377. timings.push_back(epoch(servers));
  378. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  379. timings.clear();
  380. timings = transmit_epoch_updates(users, servers);
  381. cout << "Transmit epoch updates: " << mean(timings) << " seconds per user" << endl << endl;
  382. }
  383. // Pick random users for our tests
  384. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  385. size_t user_a = userDistribution(generator);
  386. size_t user_b = user_a;
  387. while (user_b == user_a)
  388. user_b = userDistribution(generator);
  389. test_reputation_proof(generator, servers, users[user_a], users[user_b]);
  390. test_vote_proof(generator, users[user_a], servers);
  391. return 0;
  392. }