main.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. const int MAX_ALLOWED_VOTE = 2;
  11. // Initialize the classes we use
  12. void initialize_prsona_classes()
  13. {
  14. Scalar::init();
  15. PrsonaServer::init();
  16. PrsonaClient::init();
  17. }
  18. // Quick and dirty mean calculation (used for averaging timings)
  19. double mean(vector<double> xx)
  20. {
  21. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  22. }
  23. // Time how long it takes to make a proof of valid votes
  24. vector<double> make_votes(
  25. default_random_engine& generator,
  26. vector<vector<CurveBipoint>>& newEncryptedVotes,
  27. vector<vector<Proof>>& validVoteProofs,
  28. const vector<PrsonaClient>& users,
  29. const PrsonaServerEntity& servers,
  30. size_t numVotes)
  31. {
  32. vector<double> retval;
  33. uniform_int_distribution<int> voteDistribution(0, MAX_ALLOWED_VOTE);
  34. size_t numUsers = users.size();
  35. newEncryptedVotes.clear();
  36. for (size_t i = 0; i < numUsers; i++)
  37. {
  38. // Make the correct number of new votes, but shuffle where they go
  39. vector<Scalar> votes;
  40. vector<bool> replaces;
  41. for (size_t j = 0; j < numUsers; j++)
  42. {
  43. votes.push_back(Scalar(voteDistribution(generator)));
  44. replaces.push_back(j < numVotes);
  45. }
  46. shuffle(replaces.begin(), replaces.end(), generator);
  47. Proof ownerProof;
  48. Curvepoint shortTermPublicKey =
  49. users[i].get_short_term_public_key(ownerProof);
  50. vector<CurveBipoint> currEncryptedVotes =
  51. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  52. vector<Proof> currVoteProof;
  53. chrono::high_resolution_clock::time_point t0 =
  54. chrono::high_resolution_clock::now();
  55. currEncryptedVotes = users[i].make_votes(
  56. currVoteProof,
  57. ownerProof,
  58. currEncryptedVotes,
  59. votes,
  60. replaces);
  61. chrono::high_resolution_clock::time_point t1 =
  62. chrono::high_resolution_clock::now();
  63. newEncryptedVotes.push_back(currEncryptedVotes);
  64. validVoteProofs.push_back(currVoteProof);
  65. chrono::duration<double> time_span =
  66. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  67. retval.push_back(time_span.count());
  68. }
  69. return retval;
  70. }
  71. // Time how long it takes to validate a proof of valid votes
  72. vector<double> transmit_votes_to_servers(
  73. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  74. const vector<vector<Proof>>& validVoteProofs,
  75. const vector<PrsonaClient>& users,
  76. PrsonaServerEntity& servers)
  77. {
  78. vector<double> retval;
  79. size_t numUsers = users.size();
  80. size_t numServers = servers.get_num_servers();
  81. for (size_t i = 0; i < numUsers; i++)
  82. {
  83. Proof ownerProof;
  84. Curvepoint shortTermPublicKey =
  85. users[i].get_short_term_public_key(ownerProof);
  86. for (size_t j = 0; j < numServers; j++)
  87. {
  88. chrono::high_resolution_clock::time_point t0 =
  89. chrono::high_resolution_clock::now();
  90. servers.receive_vote(
  91. validVoteProofs[i],
  92. newEncryptedVotes[i],
  93. shortTermPublicKey,
  94. j);
  95. chrono::high_resolution_clock::time_point t1 =
  96. chrono::high_resolution_clock::now();
  97. chrono::duration<double> time_span =
  98. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  99. retval.push_back(time_span.count());
  100. }
  101. }
  102. return retval;
  103. }
  104. // Time how long it takes to do the operations associated with an epoch
  105. double epoch(PrsonaServerEntity& servers)
  106. {
  107. Proof unused;
  108. // Do the epoch server calculations
  109. chrono::high_resolution_clock::time_point t0 =
  110. chrono::high_resolution_clock::now();
  111. servers.epoch(unused);
  112. chrono::high_resolution_clock::time_point t1 =
  113. chrono::high_resolution_clock::now();
  114. // Return the timing of the epoch server calculations
  115. chrono::duration<double> time_span =
  116. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  117. return time_span.count();
  118. }
  119. // Time how long it takes each user to decrypt their new scores
  120. vector<double> transmit_epoch_updates(
  121. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  122. {
  123. vector<double> retval;
  124. size_t numUsers = users.size();
  125. for (size_t i = 0; i < numUsers; i++)
  126. {
  127. chrono::high_resolution_clock::time_point t0 =
  128. chrono::high_resolution_clock::now();
  129. servers.transmit_updates(users[i]);
  130. chrono::high_resolution_clock::time_point t1 =
  131. chrono::high_resolution_clock::now();
  132. chrono::duration<double> time_span =
  133. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  134. retval.push_back(time_span.count());
  135. }
  136. return retval;
  137. }
  138. // Test if the proof of reputation level is working as expected
  139. void test_reputation_proof(
  140. default_random_engine& generator,
  141. const PrsonaClient& a,
  142. const PrsonaClient& b)
  143. {
  144. bool flag;
  145. mpz_class aScore = a.get_score().toInt();
  146. int i = 0;
  147. while (i < aScore)
  148. i++;
  149. uniform_int_distribution<int> thresholdDistribution(0, i);
  150. Scalar goodThreshold(thresholdDistribution(generator));
  151. Scalar badThreshold(aScore + 1);
  152. Proof pi;
  153. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  154. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  155. flag = b.verify_reputation_proof(
  156. goodRepProof, shortTermPublicKey, goodThreshold);
  157. cout << "TEST VALID REPUTATION PROOF: "
  158. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  159. << endl;
  160. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  161. flag = b.verify_reputation_proof(
  162. badRepProof, shortTermPublicKey, badThreshold);
  163. cout << "TEST INVALID REPUTATION PROOF: "
  164. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  165. << endl << endl;
  166. }
  167. // Test if the proof of valid votes is working as expected
  168. void test_vote_proof(
  169. default_random_engine& generator,
  170. const PrsonaClient& user,
  171. PrsonaServerEntity& servers)
  172. {
  173. size_t numUsers = servers.get_num_clients();
  174. vector<Scalar> votes;
  175. vector<bool> replaces;
  176. bool flag;
  177. for (size_t i = 0; i < numUsers; i++)
  178. {
  179. votes.push_back(Scalar(1));
  180. replaces.push_back(true);
  181. }
  182. vector<Proof> validVoteProof;
  183. Proof ownerProof;
  184. Curvepoint shortTermPublicKey =
  185. user.get_short_term_public_key(ownerProof);
  186. vector<CurveBipoint> encryptedVotes =
  187. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  188. encryptedVotes =
  189. user.make_votes(
  190. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  191. flag = servers.receive_vote(
  192. validVoteProof, encryptedVotes, shortTermPublicKey);
  193. cout << "TEST REPLACE VOTE PROOF: "
  194. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  195. << endl;
  196. for (size_t i = 0; i < numUsers; i++)
  197. {
  198. replaces[i] = false;
  199. }
  200. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  201. encryptedVotes =
  202. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  203. encryptedVotes =
  204. user.make_votes(
  205. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  206. flag = servers.receive_vote(
  207. validVoteProof, encryptedVotes, shortTermPublicKey);
  208. cout << "TEST RERANDOMIZE VOTE PROOF: "
  209. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  210. << endl;
  211. for (size_t i = 0; i < numUsers; i++)
  212. {
  213. votes[i] = Scalar(3);
  214. replaces[i] = true;
  215. }
  216. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  217. encryptedVotes =
  218. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  219. encryptedVotes =
  220. user.make_votes(
  221. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  222. flag = servers.receive_vote(
  223. validVoteProof, encryptedVotes, shortTermPublicKey);
  224. cout << "TEST INVALID REPLACE VOTE PROOF: "
  225. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  226. << endl << endl;
  227. }
  228. int main(int argc, char *argv[])
  229. {
  230. initialize_prsona_classes();
  231. // Defaults
  232. size_t numServers = 2;
  233. size_t numUsers = 5;
  234. size_t numRounds = 3;
  235. size_t numVotesPerRound = 3;
  236. bool maliciousServers = false;
  237. bool maliciousClients = true;
  238. string seedStr = "seed";
  239. // Potentially accept command line inputs
  240. if (argc > 1)
  241. numServers = atoi(argv[1]);
  242. if (argc > 2)
  243. numUsers = atoi(argv[2]);
  244. if (argc > 3)
  245. numRounds = atoi(argv[3]);
  246. if (argc > 4)
  247. numVotesPerRound = atoi(argv[4]);
  248. cout << "Running the protocol with the following parameters: " << endl;
  249. cout << numServers << " PRSONA servers" << endl;
  250. cout << numUsers << " participants (voters/votees)" << endl;
  251. cout << numRounds << " epochs" << endl;
  252. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  253. << endl << endl;
  254. // Set malicious flags where necessary
  255. if (maliciousServers)
  256. {
  257. PrsonaServer::set_server_malicious();
  258. PrsonaClient::set_server_malicious();
  259. }
  260. if (maliciousClients)
  261. {
  262. PrsonaServer::set_client_malicious();
  263. PrsonaClient::set_client_malicious();
  264. }
  265. // Entities we operate with
  266. PrsonaServerEntity servers(numServers);
  267. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  268. Curvepoint elGamalBlindGenerator = servers.get_blinding_generator();
  269. cout << "Initialization: adding users to system" << endl << endl;
  270. vector<PrsonaClient> users;
  271. for (size_t i = 0; i < numUsers; i++)
  272. {
  273. PrsonaClient currUser(bgnPublicKey, elGamalBlindGenerator, &servers);
  274. servers.add_new_client(currUser);
  275. users.push_back(currUser);
  276. }
  277. // Seeded randomness for random votes used in epoch
  278. seed_seq seed(seedStr.begin(), seedStr.end());
  279. default_random_engine generator(seed);
  280. // Do the epoch operations
  281. for (size_t i = 0; i < numRounds; i++)
  282. {
  283. vector<double> timings;
  284. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  285. vector<vector<CurveBipoint>> newEncryptedVotes;
  286. vector<vector<Proof>> validVoteProofs;
  287. timings = make_votes(
  288. generator,
  289. newEncryptedVotes,
  290. validVoteProofs,
  291. users,
  292. servers,
  293. numVotesPerRound);
  294. cout << "Vote generation (with proofs): " << mean(timings)
  295. << " seconds per user" << endl;
  296. timings.clear();
  297. timings = transmit_votes_to_servers(
  298. newEncryptedVotes, validVoteProofs, users, servers);
  299. cout << "Vote validation: " << mean(timings)
  300. << " seconds per vote vector/server" << endl;
  301. timings.clear();
  302. timings.push_back(epoch(servers));
  303. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  304. timings.clear();
  305. timings = transmit_epoch_updates(users, servers);
  306. cout << "Transmit epoch updates: " << mean(timings)
  307. << " seconds per user" << endl << endl;
  308. }
  309. // Pick random users for our tests
  310. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  311. size_t user_a = userDistribution(generator);
  312. size_t user_b = user_a;
  313. while (user_b == user_a)
  314. user_b = userDistribution(generator);
  315. test_reputation_proof(generator, users[user_a], users[user_b]);
  316. test_vote_proof(generator, users[user_a], servers);
  317. return 0;
  318. }