main.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(vector<double> xx)
  18. {
  19. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  20. }
  21. void print_user_scores(const vector<PrsonaClient>& users)
  22. {
  23. std::cout << "<";
  24. for (size_t i = 0; i < users.size(); i++)
  25. {
  26. std::cout << users[i].get_score()
  27. << (i == users.size() - 1 ? ">" : " ");
  28. }
  29. std::cout << std::endl;
  30. }
  31. // Time how long it takes to make a proof of valid votes
  32. vector<double> make_votes(
  33. default_random_engine& generator,
  34. vector<vector<TwistBipoint>>& newEncryptedVotes,
  35. vector<vector<Proof>>& validVoteProofs,
  36. const vector<PrsonaClient>& users,
  37. const PrsonaServerEntity& servers,
  38. size_t numVotes)
  39. {
  40. vector<double> retval;
  41. uniform_int_distribution<int> voteDistribution(
  42. 0, PrsonaBase::get_max_allowed_vote());
  43. size_t numUsers = users.size();
  44. newEncryptedVotes.clear();
  45. for (size_t i = 0; i < numUsers; i++)
  46. {
  47. // Make the correct number of new votes, but shuffle where they go
  48. vector<Scalar> votes;
  49. vector<bool> replaces;
  50. for (size_t j = 0; j < numUsers; j++)
  51. {
  52. votes.push_back(Scalar(voteDistribution(generator)));
  53. replaces.push_back(j < numVotes);
  54. }
  55. shuffle(replaces.begin(), replaces.end(), generator);
  56. Proof baseProof;
  57. vector<Proof> fullProof;
  58. Twistpoint shortTermPublicKey =
  59. users[i].get_short_term_public_key();
  60. vector<TwistBipoint> currEncryptedVotes =
  61. servers.get_current_votes_by(baseProof, shortTermPublicKey);
  62. fullProof.push_back(baseProof);
  63. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  64. vector<Proof> currVoteProof;
  65. chrono::high_resolution_clock::time_point t0 =
  66. chrono::high_resolution_clock::now();
  67. currEncryptedVotes = users[i].make_votes(
  68. currVoteProof,
  69. fullProof,
  70. currEncryptedVotes,
  71. votes,
  72. replaces);
  73. chrono::high_resolution_clock::time_point t1 =
  74. chrono::high_resolution_clock::now();
  75. newEncryptedVotes.push_back(currEncryptedVotes);
  76. validVoteProofs.push_back(currVoteProof);
  77. chrono::duration<double> time_span =
  78. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  79. retval.push_back(time_span.count());
  80. }
  81. return retval;
  82. }
  83. // Time how long it takes to validate a proof of valid votes
  84. vector<double> transmit_votes_to_servers(
  85. const vector<vector<TwistBipoint>>& newEncryptedVotes,
  86. const vector<vector<Proof>>& validVoteProofs,
  87. const vector<PrsonaClient>& users,
  88. PrsonaServerEntity& servers)
  89. {
  90. vector<double> retval;
  91. size_t numUsers = users.size();
  92. size_t numServers = servers.get_num_servers();
  93. for (size_t i = 0; i < numUsers; i++)
  94. {
  95. Proof ownerProof;
  96. Twistpoint shortTermPublicKey =
  97. users[i].get_short_term_public_key(ownerProof);
  98. for (size_t j = 0; j < numServers; j++)
  99. {
  100. chrono::high_resolution_clock::time_point t0 =
  101. chrono::high_resolution_clock::now();
  102. servers.receive_vote(
  103. validVoteProofs[i],
  104. newEncryptedVotes[i],
  105. shortTermPublicKey,
  106. j);
  107. chrono::high_resolution_clock::time_point t1 =
  108. chrono::high_resolution_clock::now();
  109. chrono::duration<double> time_span =
  110. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  111. retval.push_back(time_span.count());
  112. }
  113. }
  114. return retval;
  115. }
  116. // Time how long it takes to do the operations associated with an epoch
  117. double epoch(PrsonaServerEntity& servers)
  118. {
  119. // Do the epoch server calculations
  120. chrono::high_resolution_clock::time_point t0 =
  121. chrono::high_resolution_clock::now();
  122. servers.epoch();
  123. chrono::high_resolution_clock::time_point t1 =
  124. chrono::high_resolution_clock::now();
  125. // Return the timing of the epoch server calculations
  126. chrono::duration<double> time_span =
  127. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  128. return time_span.count();
  129. }
  130. // Time how long it takes each user to decrypt their new scores
  131. vector<double> transmit_epoch_updates(
  132. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  133. {
  134. vector<double> retval;
  135. size_t numUsers = users.size();
  136. for (size_t i = 0; i < numUsers; i++)
  137. {
  138. chrono::high_resolution_clock::time_point t0 =
  139. chrono::high_resolution_clock::now();
  140. servers.transmit_updates(users[i]);
  141. chrono::high_resolution_clock::time_point t1 =
  142. chrono::high_resolution_clock::now();
  143. chrono::duration<double> time_span =
  144. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  145. retval.push_back(time_span.count());
  146. }
  147. return retval;
  148. }
  149. // Test if the proof of reputation level is working as expected
  150. void test_reputation_proof(
  151. default_random_engine& generator,
  152. const PrsonaServerEntity& servers,
  153. const PrsonaClient& a,
  154. const PrsonaClient& b)
  155. {
  156. bool flag;
  157. mpz_class aScore = a.get_score().toInt();
  158. int i = 0;
  159. while (i < aScore)
  160. i++;
  161. uniform_int_distribution<int> thresholdDistribution(0, i);
  162. Scalar goodThreshold(thresholdDistribution(generator));
  163. Scalar badThreshold(aScore + 1);
  164. Twistpoint shortTermPublicKey = a.get_short_term_public_key();
  165. vector<Proof> goodRepProof =
  166. a.generate_reputation_proof(goodThreshold, servers.get_num_clients());
  167. Proof baseProof;
  168. vector<Proof> fullProof;
  169. EGCiphertext currEncryptedScore =
  170. servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
  171. fullProof.push_back(baseProof);
  172. servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
  173. flag = b.verify_reputation_proof(
  174. goodRepProof,
  175. shortTermPublicKey,
  176. goodThreshold,
  177. fullProof,
  178. currEncryptedScore);
  179. cout << "TEST VALID REPUTATION PROOF: "
  180. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  181. << endl;
  182. vector<Proof> badRepProof =
  183. a.generate_reputation_proof(badThreshold, servers.get_num_clients());
  184. baseProof.clear();
  185. fullProof.clear();
  186. currEncryptedScore =
  187. servers.get_current_user_encrypted_tally(baseProof, shortTermPublicKey);
  188. fullProof.push_back(baseProof);
  189. servers.get_other_user_tally_commitments(fullProof, shortTermPublicKey);
  190. flag = b.verify_reputation_proof(
  191. badRepProof,
  192. shortTermPublicKey,
  193. goodThreshold,
  194. fullProof,
  195. currEncryptedScore);
  196. cout << "TEST INVALID REPUTATION PROOF: "
  197. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  198. << endl << endl;
  199. }
  200. // Test if the proof of valid votes is working as expected
  201. void test_vote_proof(
  202. default_random_engine& generator,
  203. const PrsonaClient& user,
  204. PrsonaServerEntity& servers)
  205. {
  206. size_t numUsers = servers.get_num_clients();
  207. vector<Scalar> votes;
  208. vector<bool> replaces;
  209. bool flag;
  210. for (size_t i = 0; i < numUsers; i++)
  211. {
  212. votes.push_back(Scalar(1));
  213. replaces.push_back(true);
  214. }
  215. vector<Proof> validVoteProof;
  216. Proof baseProof;
  217. vector<Proof> fullProof;
  218. Twistpoint shortTermPublicKey =
  219. user.get_short_term_public_key();
  220. vector<TwistBipoint> encryptedVotes =
  221. servers.get_current_votes_by(baseProof, shortTermPublicKey);
  222. fullProof.push_back(baseProof);
  223. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  224. encryptedVotes = user.make_votes(
  225. validVoteProof,
  226. fullProof,
  227. encryptedVotes,
  228. votes,
  229. replaces);
  230. flag = servers.receive_vote(
  231. validVoteProof, encryptedVotes, shortTermPublicKey);
  232. cout << "TEST REPLACE VOTE PROOF: "
  233. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  234. << endl;
  235. for (size_t i = 0; i < numUsers; i++)
  236. replaces[i] = false;
  237. baseProof.clear();
  238. fullProof.clear();
  239. encryptedVotes =
  240. servers.get_current_votes_by(baseProof, shortTermPublicKey);
  241. fullProof.push_back(baseProof);
  242. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  243. encryptedVotes = user.make_votes(
  244. validVoteProof,
  245. fullProof,
  246. encryptedVotes,
  247. votes,
  248. replaces);
  249. flag = servers.receive_vote(
  250. validVoteProof, encryptedVotes, shortTermPublicKey);
  251. cout << "TEST RERANDOMIZE VOTE PROOF: "
  252. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  253. << endl;
  254. for (size_t i = 0; i < numUsers; i++)
  255. {
  256. votes[i] = Scalar(3);
  257. replaces[i] = true;
  258. }
  259. baseProof.clear();
  260. fullProof.clear();
  261. encryptedVotes =
  262. servers.get_current_votes_by(baseProof, shortTermPublicKey);
  263. fullProof.push_back(baseProof);
  264. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  265. encryptedVotes = user.make_votes(
  266. validVoteProof,
  267. fullProof,
  268. encryptedVotes,
  269. votes,
  270. replaces);
  271. flag = servers.receive_vote(
  272. validVoteProof, encryptedVotes, shortTermPublicKey);
  273. cout << "TEST INVALID REPLACE VOTE PROOF: "
  274. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  275. << endl << endl;
  276. }
  277. void check_vote_matrix_updates()
  278. {
  279. size_t numServers = 2;
  280. size_t numUsers = 3;
  281. cout << "Testing how the vote matrix updates." << endl;
  282. PrsonaBase::set_client_malicious();
  283. // Entities we operate with
  284. PrsonaServerEntity servers(numServers);
  285. vector<Proof> elGamalBlindGeneratorProof;
  286. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  287. Twistpoint elGamalBlindGenerator =
  288. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  289. vector<PrsonaClient> users;
  290. for (size_t i = 0; i < numUsers; i++)
  291. {
  292. PrsonaClient currUser(
  293. elGamalBlindGeneratorProof,
  294. elGamalBlindGenerator,
  295. bgnPublicKey,
  296. numServers);
  297. users.push_back(currUser);
  298. servers.add_new_client(users[i]);
  299. }
  300. Proof pseudonymsProof;
  301. vector<Twistpoint> currentPseudonyms =
  302. servers.get_current_pseudonyms(pseudonymsProof);
  303. cout << "Making votes." << endl;
  304. for (size_t i = 0; i < numUsers; i++)
  305. {
  306. Twistpoint shortTermPublicKey =
  307. users[i].get_short_term_public_key();
  308. size_t myIndex =
  309. users[i].binary_search(currentPseudonyms, shortTermPublicKey);
  310. cout << "User " << i+1 << " has initial index " << myIndex << endl;
  311. vector<Scalar> votes;
  312. vector<bool> replaces;
  313. for (size_t j = 0; j < numUsers; j++)
  314. {
  315. if (j == myIndex)
  316. votes.push_back(Scalar(2));
  317. else if (j > myIndex)
  318. votes.push_back(Scalar(1));
  319. else
  320. votes.push_back(Scalar(0));
  321. replaces.push_back(true);
  322. }
  323. Proof baseProof;
  324. vector<Proof> fullProof;
  325. vector<TwistBipoint> currEncryptedVotes =
  326. servers.get_current_votes_by(baseProof, shortTermPublicKey);
  327. fullProof.push_back(baseProof);
  328. servers.get_other_vote_row_commitments(fullProof, shortTermPublicKey);
  329. vector<Proof> currVoteProof;
  330. currEncryptedVotes = users[i].make_votes(
  331. currVoteProof,
  332. fullProof,
  333. currEncryptedVotes,
  334. votes,
  335. replaces);
  336. servers.receive_vote(
  337. currVoteProof,
  338. currEncryptedVotes,
  339. shortTermPublicKey);
  340. cout << "User " << i+1 << " now has the following votes:" << endl;
  341. servers.print_current_votes_by(shortTermPublicKey);
  342. }
  343. servers.print_votes();
  344. epoch(servers);
  345. cout << "First epoch done." << endl;
  346. transmit_epoch_updates(users, servers);
  347. cout << "Updates given to users." << endl;
  348. servers.print_votes();
  349. for (size_t i = 0; i < numUsers; i++)
  350. {
  351. Proof ownerProof;
  352. Twistpoint shortTermPublicKey =
  353. users[i].get_short_term_public_key(ownerProof);
  354. cout << "User " << i+1 << " now has the following votes:" << endl;
  355. servers.print_current_votes_by(shortTermPublicKey);
  356. }
  357. }
  358. int main(int argc, char *argv[])
  359. {
  360. initialize_prsona_classes();
  361. // Defaults
  362. size_t numServers = 2;
  363. size_t numUsers = 5;
  364. size_t numRounds = 3;
  365. size_t numVotesPerRound = 3;
  366. bool maliciousServers = true;
  367. bool maliciousClients = true;
  368. string seedStr = "seed";
  369. // Potentially accept command line inputs
  370. if (argc > 1)
  371. numServers = atoi(argv[1]);
  372. if (argc > 2)
  373. numUsers = atoi(argv[2]);
  374. if (argc > 3)
  375. numRounds = atoi(argv[3]);
  376. if (argc > 4)
  377. numVotesPerRound = atoi(argv[4]);
  378. if (argc > 5)
  379. {
  380. bool setting = argv[5][0] == 't' || argv[5][0] == 'T';
  381. maliciousServers = setting;
  382. }
  383. if (argc > 6)
  384. {
  385. bool setting = argv[6][0] == 't' || argv[6][0] == 'T';
  386. maliciousClients = setting;
  387. }
  388. if (argc > 7)
  389. seedStr = argv[7];
  390. cout << "Running the protocol with the following parameters: " << endl;
  391. cout << numServers << " PRSONA servers" << endl;
  392. cout << numUsers << " participants (voters/votees)" << endl;
  393. cout << numRounds << " epochs" << endl;
  394. cout << numVotesPerRound << " new (random) votes by each user per epoch" << endl;
  395. cout << "Servers are set to " << (maliciousServers ? "MALICIOUS" : "HBC") << " security" << endl;
  396. cout << "Clients are set to " << (maliciousClients ? "MALICIOUS" : "HBC") << " security" << endl;
  397. cout << "Current randomness seed: \"" << seedStr << "\"" << endl;
  398. cout << endl;
  399. // Set malicious flags where necessary
  400. if (maliciousServers)
  401. PrsonaBase::set_server_malicious();
  402. if (maliciousClients)
  403. PrsonaBase::set_client_malicious();
  404. // Entities we operate with
  405. PrsonaServerEntity servers(numServers);
  406. vector<Proof> elGamalBlindGeneratorProof;
  407. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  408. Twistpoint elGamalBlindGenerator =
  409. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  410. cout << "Initialization: adding users to system" << endl << endl;
  411. vector<PrsonaClient> users;
  412. for (size_t i = 0; i < numUsers; i++)
  413. {
  414. PrsonaClient currUser(
  415. elGamalBlindGeneratorProof,
  416. elGamalBlindGenerator,
  417. bgnPublicKey,
  418. numServers);
  419. users.push_back(currUser);
  420. servers.add_new_client(users[i]);
  421. }
  422. // Seeded randomness for random votes used in epoch
  423. seed_seq seed(seedStr.begin(), seedStr.end());
  424. default_random_engine generator(seed);
  425. // Do the epoch operations
  426. for (size_t i = 0; i < numRounds; i++)
  427. {
  428. vector<double> timings;
  429. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  430. vector<vector<TwistBipoint>> newEncryptedVotes;
  431. vector<vector<Proof>> validVoteProofs;
  432. timings = make_votes(
  433. generator,
  434. newEncryptedVotes,
  435. validVoteProofs,
  436. users,
  437. servers,
  438. numVotesPerRound);
  439. cout << "Vote generation (with proofs): " << mean(timings)
  440. << " seconds per user" << endl;
  441. timings.clear();
  442. timings = transmit_votes_to_servers(
  443. newEncryptedVotes, validVoteProofs, users, servers);
  444. cout << "Vote validation: " << mean(timings)
  445. << " seconds per vote vector/server" << endl;
  446. timings.clear();
  447. timings.push_back(epoch(servers));
  448. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  449. timings.clear();
  450. timings = transmit_epoch_updates(users, servers);
  451. cout << "Transmit epoch updates: " << mean(timings)
  452. << " seconds per user" << endl << endl;
  453. }
  454. // Pick random users for our tests
  455. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  456. size_t user_a = userDistribution(generator);
  457. size_t user_b = user_a;
  458. while (user_b == user_a)
  459. user_b = userDistribution(generator);
  460. test_reputation_proof(generator, servers, users[user_a], users[user_b]);
  461. test_vote_proof(generator, users[user_a], servers);
  462. return 0;
  463. }