main.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(vector<double> xx)
  18. {
  19. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  20. }
  21. void print_user_scores(const vector<PrsonaClient>& users)
  22. {
  23. std::cout << "<";
  24. for (size_t i = 0; i < users.size(); i++)
  25. {
  26. std::cout << users[i].get_score()
  27. << (i == users.size() - 1 ? ">" : " ");
  28. }
  29. std::cout << std::endl;
  30. }
  31. // Time how long it takes to make a proof of valid votes
  32. vector<double> make_votes(
  33. default_random_engine& generator,
  34. vector<vector<CurveBipoint>>& newEncryptedVotes,
  35. vector<vector<Proof>>& validVoteProofs,
  36. const vector<PrsonaClient>& users,
  37. const PrsonaServerEntity& servers,
  38. size_t numVotes)
  39. {
  40. vector<double> retval;
  41. uniform_int_distribution<int> voteDistribution(
  42. 0, PrsonaBase::get_max_allowed_vote());
  43. size_t numUsers = users.size();
  44. newEncryptedVotes.clear();
  45. for (size_t i = 0; i < numUsers; i++)
  46. {
  47. // Make the correct number of new votes, but shuffle where they go
  48. vector<Scalar> votes;
  49. vector<bool> replaces;
  50. for (size_t j = 0; j < numUsers; j++)
  51. {
  52. votes.push_back(Scalar(voteDistribution(generator)));
  53. replaces.push_back(j < numVotes);
  54. }
  55. shuffle(replaces.begin(), replaces.end(), generator);
  56. Proof ownerProof;
  57. Curvepoint shortTermPublicKey =
  58. users[i].get_short_term_public_key(ownerProof);
  59. vector<CurveBipoint> currEncryptedVotes =
  60. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  61. vector<Proof> currVoteProof;
  62. chrono::high_resolution_clock::time_point t0 =
  63. chrono::high_resolution_clock::now();
  64. currEncryptedVotes = users[i].make_votes(
  65. currVoteProof,
  66. ownerProof,
  67. currEncryptedVotes,
  68. votes,
  69. replaces);
  70. chrono::high_resolution_clock::time_point t1 =
  71. chrono::high_resolution_clock::now();
  72. newEncryptedVotes.push_back(currEncryptedVotes);
  73. validVoteProofs.push_back(currVoteProof);
  74. chrono::duration<double> time_span =
  75. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  76. retval.push_back(time_span.count());
  77. }
  78. return retval;
  79. }
  80. // Time how long it takes to validate a proof of valid votes
  81. vector<double> transmit_votes_to_servers(
  82. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  83. const vector<vector<Proof>>& validVoteProofs,
  84. const vector<PrsonaClient>& users,
  85. PrsonaServerEntity& servers)
  86. {
  87. vector<double> retval;
  88. size_t numUsers = users.size();
  89. size_t numServers = servers.get_num_servers();
  90. for (size_t i = 0; i < numUsers; i++)
  91. {
  92. Proof ownerProof;
  93. Curvepoint shortTermPublicKey =
  94. users[i].get_short_term_public_key(ownerProof);
  95. for (size_t j = 0; j < numServers; j++)
  96. {
  97. chrono::high_resolution_clock::time_point t0 =
  98. chrono::high_resolution_clock::now();
  99. servers.receive_vote(
  100. validVoteProofs[i],
  101. newEncryptedVotes[i],
  102. shortTermPublicKey,
  103. j);
  104. chrono::high_resolution_clock::time_point t1 =
  105. chrono::high_resolution_clock::now();
  106. chrono::duration<double> time_span =
  107. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  108. retval.push_back(time_span.count());
  109. }
  110. }
  111. return retval;
  112. }
  113. // Time how long it takes to do the operations associated with an epoch
  114. double epoch(PrsonaServerEntity& servers)
  115. {
  116. Proof unused;
  117. // Do the epoch server calculations
  118. chrono::high_resolution_clock::time_point t0 =
  119. chrono::high_resolution_clock::now();
  120. servers.epoch(unused);
  121. chrono::high_resolution_clock::time_point t1 =
  122. chrono::high_resolution_clock::now();
  123. // Return the timing of the epoch server calculations
  124. chrono::duration<double> time_span =
  125. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  126. return time_span.count();
  127. }
  128. // Time how long it takes each user to decrypt their new scores
  129. vector<double> transmit_epoch_updates(
  130. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  131. {
  132. vector<double> retval;
  133. size_t numUsers = users.size();
  134. for (size_t i = 0; i < numUsers; i++)
  135. {
  136. chrono::high_resolution_clock::time_point t0 =
  137. chrono::high_resolution_clock::now();
  138. servers.transmit_updates(users[i]);
  139. chrono::high_resolution_clock::time_point t1 =
  140. chrono::high_resolution_clock::now();
  141. chrono::duration<double> time_span =
  142. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  143. retval.push_back(time_span.count());
  144. }
  145. return retval;
  146. }
  147. // Test if the proof of reputation level is working as expected
  148. void test_reputation_proof(
  149. default_random_engine& generator,
  150. const PrsonaClient& a,
  151. const PrsonaClient& b)
  152. {
  153. bool flag;
  154. mpz_class aScore = a.get_score().toInt();
  155. int i = 0;
  156. while (i < aScore)
  157. i++;
  158. uniform_int_distribution<int> thresholdDistribution(0, i);
  159. Scalar goodThreshold(thresholdDistribution(generator));
  160. Scalar badThreshold(aScore + 1);
  161. Proof pi;
  162. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  163. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  164. flag = b.verify_reputation_proof(
  165. goodRepProof, shortTermPublicKey, goodThreshold);
  166. cout << "TEST VALID REPUTATION PROOF: "
  167. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  168. << endl;
  169. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  170. flag = b.verify_reputation_proof(
  171. badRepProof, shortTermPublicKey, badThreshold);
  172. cout << "TEST INVALID REPUTATION PROOF: "
  173. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  174. << endl << endl;
  175. }
  176. // Test if the proof of valid votes is working as expected
  177. void test_vote_proof(
  178. default_random_engine& generator,
  179. const PrsonaClient& user,
  180. PrsonaServerEntity& servers)
  181. {
  182. size_t numUsers = servers.get_num_clients();
  183. vector<Scalar> votes;
  184. vector<bool> replaces;
  185. bool flag;
  186. for (size_t i = 0; i < numUsers; i++)
  187. {
  188. votes.push_back(Scalar(1));
  189. replaces.push_back(true);
  190. }
  191. vector<Proof> validVoteProof;
  192. Proof ownerProof;
  193. Curvepoint shortTermPublicKey =
  194. user.get_short_term_public_key(ownerProof);
  195. vector<CurveBipoint> encryptedVotes =
  196. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  197. encryptedVotes =
  198. user.make_votes(
  199. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  200. flag = servers.receive_vote(
  201. validVoteProof, encryptedVotes, shortTermPublicKey);
  202. cout << "TEST REPLACE VOTE PROOF: "
  203. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  204. << endl;
  205. for (size_t i = 0; i < numUsers; i++)
  206. {
  207. replaces[i] = false;
  208. }
  209. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  210. encryptedVotes =
  211. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  212. encryptedVotes =
  213. user.make_votes(
  214. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  215. flag = servers.receive_vote(
  216. validVoteProof, encryptedVotes, shortTermPublicKey);
  217. cout << "TEST RERANDOMIZE VOTE PROOF: "
  218. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  219. << endl;
  220. for (size_t i = 0; i < numUsers; i++)
  221. {
  222. votes[i] = Scalar(3);
  223. replaces[i] = true;
  224. }
  225. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  226. encryptedVotes =
  227. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  228. encryptedVotes =
  229. user.make_votes(
  230. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  231. flag = servers.receive_vote(
  232. validVoteProof, encryptedVotes, shortTermPublicKey);
  233. cout << "TEST INVALID REPLACE VOTE PROOF: "
  234. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  235. << endl << endl;
  236. }
  237. int main(int argc, char *argv[])
  238. {
  239. initialize_prsona_classes();
  240. // Defaults
  241. size_t numServers = 2;
  242. size_t numUsers = 5;
  243. size_t numRounds = 3;
  244. size_t numVotesPerRound = 3;
  245. bool maliciousServers = true;
  246. bool maliciousClients = true;
  247. string seedStr = "seed";
  248. // Potentially accept command line inputs
  249. if (argc > 1)
  250. numServers = atoi(argv[1]);
  251. if (argc > 2)
  252. numUsers = atoi(argv[2]);
  253. if (argc > 3)
  254. numRounds = atoi(argv[3]);
  255. if (argc > 4)
  256. numVotesPerRound = atoi(argv[4]);
  257. cout << "Running the protocol with the following parameters: " << endl;
  258. cout << numServers << " PRSONA servers" << endl;
  259. cout << numUsers << " participants (voters/votees)" << endl;
  260. cout << numRounds << " epochs" << endl;
  261. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  262. << endl << endl;
  263. // Set malicious flags where necessary
  264. if (maliciousServers)
  265. PrsonaBase::set_server_malicious();
  266. if (maliciousClients)
  267. PrsonaBase::set_client_malicious();
  268. // Entities we operate with
  269. PrsonaServerEntity servers(numServers);
  270. vector<Proof> elGamalBlindGeneratorProof;
  271. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  272. Curvepoint elGamalBlindGenerator =
  273. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  274. cout << "Initialization: adding users to system" << endl << endl;
  275. vector<PrsonaClient> users;
  276. for (size_t i = 0; i < numUsers; i++)
  277. {
  278. PrsonaClient currUser(
  279. elGamalBlindGeneratorProof,
  280. elGamalBlindGenerator,
  281. bgnPublicKey,
  282. &servers);
  283. servers.add_new_client(currUser);
  284. users.push_back(currUser);
  285. }
  286. // Seeded randomness for random votes used in epoch
  287. seed_seq seed(seedStr.begin(), seedStr.end());
  288. default_random_engine generator(seed);
  289. // Do the epoch operations
  290. for (size_t i = 0; i < numRounds; i++)
  291. {
  292. vector<double> timings;
  293. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  294. vector<vector<CurveBipoint>> newEncryptedVotes;
  295. vector<vector<Proof>> validVoteProofs;
  296. timings = make_votes(
  297. generator,
  298. newEncryptedVotes,
  299. validVoteProofs,
  300. users,
  301. servers,
  302. numVotesPerRound);
  303. cout << "Vote generation (with proofs): " << mean(timings)
  304. << " seconds per user" << endl;
  305. timings.clear();
  306. timings = transmit_votes_to_servers(
  307. newEncryptedVotes, validVoteProofs, users, servers);
  308. cout << "Vote validation: " << mean(timings)
  309. << " seconds per vote vector/server" << endl;
  310. timings.clear();
  311. timings.push_back(epoch(servers));
  312. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  313. timings.clear();
  314. timings = transmit_epoch_updates(users, servers);
  315. cout << "Transmit epoch updates: " << mean(timings)
  316. << " seconds per user" << endl << endl;
  317. }
  318. // Pick random users for our tests
  319. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  320. size_t user_a = userDistribution(generator);
  321. size_t user_b = user_a;
  322. while (user_b == user_a)
  323. user_b = userDistribution(generator);
  324. test_reputation_proof(generator, users[user_a], users[user_b]);
  325. test_vote_proof(generator, users[user_a], servers);
  326. return 0;
  327. }