main.cpp 15 KB


  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(vector<double> xx)
  18. {
  19. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  20. }
  21. void print_user_scores(const vector<PrsonaClient>& users)
  22. {
  23. std::cout << "<";
  24. for (size_t i = 0; i < users.size(); i++)
  25. {
  26. std::cout << users[i].get_score()
  27. << (i == users.size() - 1 ? ">" : " ");
  28. }
  29. std::cout << std::endl;
  30. }
  31. // Time how long it takes to make a proof of valid votes
  32. vector<double> make_votes(
  33. default_random_engine& generator,
  34. vector<vector<CurveBipoint>>& newEncryptedVotes,
  35. vector<vector<Proof>>& validVoteProofs,
  36. const vector<PrsonaClient>& users,
  37. const PrsonaServerEntity& servers,
  38. size_t numVotes)
  39. {
  40. vector<double> retval;
  41. uniform_int_distribution<int> voteDistribution(
  42. 0, PrsonaBase::get_max_allowed_vote());
  43. size_t numUsers = users.size();
  44. newEncryptedVotes.clear();
  45. for (size_t i = 0; i < numUsers; i++)
  46. {
  47. // Make the correct number of new votes, but shuffle where they go
  48. vector<Scalar> votes;
  49. vector<bool> replaces;
  50. for (size_t j = 0; j < numUsers; j++)
  51. {
  52. votes.push_back(Scalar(voteDistribution(generator)));
  53. replaces.push_back(j < numVotes);
  54. }
  55. shuffle(replaces.begin(), replaces.end(), generator);
  56. Proof ownerProof;
  57. Curvepoint shortTermPublicKey =
  58. users[i].get_short_term_public_key(ownerProof);
  59. vector<CurveBipoint> currEncryptedVotes =
  60. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  61. vector<Proof> currVoteProof;
  62. chrono::high_resolution_clock::time_point t0 =
  63. chrono::high_resolution_clock::now();
  64. currEncryptedVotes = users[i].make_votes(
  65. currVoteProof,
  66. ownerProof,
  67. currEncryptedVotes,
  68. votes,
  69. replaces);
  70. chrono::high_resolution_clock::time_point t1 =
  71. chrono::high_resolution_clock::now();
  72. newEncryptedVotes.push_back(currEncryptedVotes);
  73. validVoteProofs.push_back(currVoteProof);
  74. chrono::duration<double> time_span =
  75. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  76. retval.push_back(time_span.count());
  77. }
  78. return retval;
  79. }
  80. // Time how long it takes to validate a proof of valid votes
  81. vector<double> transmit_votes_to_servers(
  82. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  83. const vector<vector<Proof>>& validVoteProofs,
  84. const vector<PrsonaClient>& users,
  85. PrsonaServerEntity& servers)
  86. {
  87. vector<double> retval;
  88. size_t numUsers = users.size();
  89. size_t numServers = servers.get_num_servers();
  90. for (size_t i = 0; i < numUsers; i++)
  91. {
  92. Proof ownerProof;
  93. Curvepoint shortTermPublicKey =
  94. users[i].get_short_term_public_key(ownerProof);
  95. for (size_t j = 0; j < numServers; j++)
  96. {
  97. chrono::high_resolution_clock::time_point t0 =
  98. chrono::high_resolution_clock::now();
  99. servers.receive_vote(
  100. validVoteProofs[i],
  101. newEncryptedVotes[i],
  102. shortTermPublicKey,
  103. j);
  104. chrono::high_resolution_clock::time_point t1 =
  105. chrono::high_resolution_clock::now();
  106. chrono::duration<double> time_span =
  107. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  108. retval.push_back(time_span.count());
  109. }
  110. }
  111. return retval;
  112. }
  113. // Time how long it takes to do the operations associated with an epoch
  114. double epoch(PrsonaServerEntity& servers)
  115. {
  116. // Do the epoch server calculations
  117. chrono::high_resolution_clock::time_point t0 =
  118. chrono::high_resolution_clock::now();
  119. servers.epoch();
  120. chrono::high_resolution_clock::time_point t1 =
  121. chrono::high_resolution_clock::now();
  122. // Return the timing of the epoch server calculations
  123. chrono::duration<double> time_span =
  124. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  125. return time_span.count();
  126. }
  127. // Time how long it takes each user to decrypt their new scores
  128. vector<double> transmit_epoch_updates(
  129. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  130. {
  131. vector<double> retval;
  132. size_t numUsers = users.size();
  133. for (size_t i = 0; i < numUsers; i++)
  134. {
  135. chrono::high_resolution_clock::time_point t0 =
  136. chrono::high_resolution_clock::now();
  137. servers.transmit_updates(users[i]);
  138. chrono::high_resolution_clock::time_point t1 =
  139. chrono::high_resolution_clock::now();
  140. chrono::duration<double> time_span =
  141. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  142. retval.push_back(time_span.count());
  143. }
  144. return retval;
  145. }
  146. // Test if the proof of reputation level is working as expected
  147. void test_reputation_proof(
  148. default_random_engine& generator,
  149. const PrsonaClient& a,
  150. const PrsonaClient& b)
  151. {
  152. bool flag;
  153. mpz_class aScore = a.get_score().toInt();
  154. int i = 0;
  155. while (i < aScore)
  156. i++;
  157. uniform_int_distribution<int> thresholdDistribution(0, i);
  158. Scalar goodThreshold(thresholdDistribution(generator));
  159. Scalar badThreshold(aScore + 1);
  160. Proof pi;
  161. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  162. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  163. flag = b.verify_reputation_proof(
  164. goodRepProof, shortTermPublicKey, goodThreshold);
  165. cout << "TEST VALID REPUTATION PROOF: "
  166. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  167. << endl;
  168. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  169. flag = b.verify_reputation_proof(
  170. badRepProof, shortTermPublicKey, badThreshold);
  171. cout << "TEST INVALID REPUTATION PROOF: "
  172. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  173. << endl << endl;
  174. }
  175. // Test if the proof of valid votes is working as expected
  176. void test_vote_proof(
  177. default_random_engine& generator,
  178. const PrsonaClient& user,
  179. PrsonaServerEntity& servers)
  180. {
  181. size_t numUsers = servers.get_num_clients();
  182. vector<Scalar> votes;
  183. vector<bool> replaces;
  184. bool flag;
  185. for (size_t i = 0; i < numUsers; i++)
  186. {
  187. votes.push_back(Scalar(1));
  188. replaces.push_back(true);
  189. }
  190. vector<Proof> validVoteProof;
  191. Proof ownerProof;
  192. Curvepoint shortTermPublicKey =
  193. user.get_short_term_public_key(ownerProof);
  194. vector<CurveBipoint> encryptedVotes =
  195. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  196. encryptedVotes =
  197. user.make_votes(
  198. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  199. flag = servers.receive_vote(
  200. validVoteProof, encryptedVotes, shortTermPublicKey);
  201. cout << "TEST REPLACE VOTE PROOF: "
  202. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  203. << endl;
  204. for (size_t i = 0; i < numUsers; i++)
  205. {
  206. replaces[i] = false;
  207. }
  208. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  209. encryptedVotes =
  210. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  211. encryptedVotes =
  212. user.make_votes(
  213. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  214. flag = servers.receive_vote(
  215. validVoteProof, encryptedVotes, shortTermPublicKey);
  216. cout << "TEST RERANDOMIZE VOTE PROOF: "
  217. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  218. << endl;
  219. for (size_t i = 0; i < numUsers; i++)
  220. {
  221. votes[i] = Scalar(3);
  222. replaces[i] = true;
  223. }
  224. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  225. encryptedVotes =
  226. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  227. encryptedVotes =
  228. user.make_votes(
  229. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  230. flag = servers.receive_vote(
  231. validVoteProof, encryptedVotes, shortTermPublicKey);
  232. cout << "TEST INVALID REPLACE VOTE PROOF: "
  233. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  234. << endl << endl;
  235. }
  236. void check_vote_matrix_updates()
  237. {
  238. size_t numServers = 2;
  239. size_t numUsers = 3;
  240. cout << "Testing how the vote matrix updates." << endl;
  241. PrsonaBase::set_client_malicious();
  242. // Entities we operate with
  243. PrsonaServerEntity servers(numServers);
  244. vector<Proof> elGamalBlindGeneratorProof;
  245. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  246. Curvepoint elGamalBlindGenerator =
  247. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  248. vector<PrsonaClient> users;
  249. for (size_t i = 0; i < numUsers; i++)
  250. {
  251. PrsonaClient currUser(
  252. elGamalBlindGeneratorProof,
  253. elGamalBlindGenerator,
  254. bgnPublicKey,
  255. &servers);
  256. users.push_back(currUser);
  257. servers.add_new_client(users[i]);
  258. }
  259. Proof pseudonymsProof;
  260. vector<Curvepoint> currentPseudonyms =
  261. servers.get_current_pseudonyms(pseudonymsProof);
  262. cout << "Making votes." << endl;
  263. for (size_t i = 0; i < numUsers; i++)
  264. {
  265. Proof ownerProof;
  266. Curvepoint shortTermPublicKey =
  267. users[i].get_short_term_public_key(ownerProof);
  268. size_t myIndex =
  269. users[i].binary_search(currentPseudonyms, shortTermPublicKey);
  270. cout << "User " << i+1 << " has initial index " << myIndex << endl;
  271. vector<Scalar> votes;
  272. vector<bool> replaces;
  273. for (size_t j = 0; j < numUsers; j++)
  274. {
  275. if (j == myIndex)
  276. votes.push_back(Scalar(2));
  277. else if (j > myIndex)
  278. votes.push_back(Scalar(1));
  279. else
  280. votes.push_back(Scalar(0));
  281. replaces.push_back(true);
  282. }
  283. vector<CurveBipoint> currEncryptedVotes =
  284. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  285. vector<Proof> currVoteProof;
  286. currEncryptedVotes = users[i].make_votes(
  287. currVoteProof,
  288. ownerProof,
  289. currEncryptedVotes,
  290. votes,
  291. replaces);
  292. servers.receive_vote(
  293. currVoteProof,
  294. currEncryptedVotes,
  295. shortTermPublicKey);
  296. cout << "User " << i+1 << " now has the following votes:" << endl;
  297. servers.print_current_votes_by(shortTermPublicKey);
  298. }
  299. servers.print_votes();
  300. epoch(servers);
  301. cout << "First epoch done." << endl;
  302. transmit_epoch_updates(users, servers);
  303. cout << "Updates given to users." << endl;
  304. servers.print_votes();
  305. for (size_t i = 0; i < numUsers; i++)
  306. {
  307. Proof ownerProof;
  308. Curvepoint shortTermPublicKey =
  309. users[i].get_short_term_public_key(ownerProof);
  310. cout << "User " << i+1 << " now has the following votes:" << endl;
  311. servers.print_current_votes_by(shortTermPublicKey);
  312. }
  313. }
  314. int main(int argc, char *argv[])
  315. {
  316. initialize_prsona_classes();
  317. // Defaults
  318. size_t numServers = 2;
  319. size_t numUsers = 5;
  320. size_t numRounds = 3;
  321. size_t numVotesPerRound = 3;
  322. bool maliciousServers = true;
  323. bool maliciousClients = true;
  324. string seedStr = "seed";
  325. // Potentially accept command line inputs
  326. if (argc > 1)
  327. numServers = atoi(argv[1]);
  328. if (argc > 2)
  329. numUsers = atoi(argv[2]);
  330. if (argc > 3)
  331. numRounds = atoi(argv[3]);
  332. if (argc > 4)
  333. numVotesPerRound = atoi(argv[4]);
  334. cout << "Running the protocol with the following parameters: " << endl;
  335. cout << numServers << " PRSONA servers" << endl;
  336. cout << numUsers << " participants (voters/votees)" << endl;
  337. cout << numRounds << " epochs" << endl;
  338. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  339. << endl << endl;
  340. // Set malicious flags where necessary
  341. if (maliciousServers)
  342. PrsonaBase::set_server_malicious();
  343. if (maliciousClients)
  344. PrsonaBase::set_client_malicious();
  345. // Entities we operate with
  346. PrsonaServerEntity servers(numServers);
  347. vector<Proof> elGamalBlindGeneratorProof;
  348. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  349. Curvepoint elGamalBlindGenerator =
  350. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  351. cout << "Initialization: adding users to system" << endl << endl;
  352. vector<PrsonaClient> users;
  353. for (size_t i = 0; i < numUsers; i++)
  354. {
  355. PrsonaClient currUser(
  356. elGamalBlindGeneratorProof,
  357. elGamalBlindGenerator,
  358. bgnPublicKey,
  359. &servers);
  360. users.push_back(currUser);
  361. servers.add_new_client(users[i]);
  362. }
  363. // Seeded randomness for random votes used in epoch
  364. seed_seq seed(seedStr.begin(), seedStr.end());
  365. default_random_engine generator(seed);
  366. // Do the epoch operations
  367. for (size_t i = 0; i < numRounds; i++)
  368. {
  369. vector<double> timings;
  370. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  371. vector<vector<CurveBipoint>> newEncryptedVotes;
  372. vector<vector<Proof>> validVoteProofs;
  373. timings = make_votes(
  374. generator,
  375. newEncryptedVotes,
  376. validVoteProofs,
  377. users,
  378. servers,
  379. numVotesPerRound);
  380. cout << "Vote generation (with proofs): " << mean(timings)
  381. << " seconds per user" << endl;
  382. timings.clear();
  383. timings = transmit_votes_to_servers(
  384. newEncryptedVotes, validVoteProofs, users, servers);
  385. cout << "Vote validation: " << mean(timings)
  386. << " seconds per vote vector/server" << endl;
  387. timings.clear();
  388. timings.push_back(epoch(servers));
  389. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  390. timings.clear();
  391. timings = transmit_epoch_updates(users, servers);
  392. cout << "Transmit epoch updates: " << mean(timings)
  393. << " seconds per user" << endl << endl;
  394. }
  395. // Pick random users for our tests
  396. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  397. size_t user_a = userDistribution(generator);
  398. size_t user_b = user_a;
  399. while (user_b == user_a)
  400. user_b = userDistribution(generator);
  401. test_reputation_proof(generator, users[user_a], users[user_b]);
  402. test_vote_proof(generator, users[user_a], servers);
  403. return 0;
  404. }