main.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(vector<double> xx)
  18. {
  19. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  20. }
  21. // Time how long it takes to make a proof of valid votes
  22. vector<double> make_votes(
  23. default_random_engine& generator,
  24. vector<vector<CurveBipoint>>& newEncryptedVotes,
  25. vector<vector<Proof>>& validVoteProofs,
  26. const vector<PrsonaClient>& users,
  27. const PrsonaServerEntity& servers,
  28. size_t numVotes)
  29. {
  30. vector<double> retval;
  31. uniform_int_distribution<int> voteDistribution(
  32. 0, PrsonaBase::get_max_allowed_vote());
  33. size_t numUsers = users.size();
  34. newEncryptedVotes.clear();
  35. for (size_t i = 0; i < numUsers; i++)
  36. {
  37. // Make the correct number of new votes, but shuffle where they go
  38. vector<Scalar> votes;
  39. vector<bool> replaces;
  40. for (size_t j = 0; j < numUsers; j++)
  41. {
  42. votes.push_back(Scalar(voteDistribution(generator)));
  43. replaces.push_back(j < numVotes);
  44. }
  45. shuffle(replaces.begin(), replaces.end(), generator);
  46. Proof ownerProof;
  47. Curvepoint shortTermPublicKey =
  48. users[i].get_short_term_public_key(ownerProof);
  49. vector<CurveBipoint> currEncryptedVotes =
  50. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  51. vector<Proof> currVoteProof;
  52. chrono::high_resolution_clock::time_point t0 =
  53. chrono::high_resolution_clock::now();
  54. currEncryptedVotes = users[i].make_votes(
  55. currVoteProof,
  56. ownerProof,
  57. currEncryptedVotes,
  58. votes,
  59. replaces);
  60. chrono::high_resolution_clock::time_point t1 =
  61. chrono::high_resolution_clock::now();
  62. newEncryptedVotes.push_back(currEncryptedVotes);
  63. validVoteProofs.push_back(currVoteProof);
  64. chrono::duration<double> time_span =
  65. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  66. retval.push_back(time_span.count());
  67. }
  68. return retval;
  69. }
  70. // Time how long it takes to validate a proof of valid votes
  71. vector<double> transmit_votes_to_servers(
  72. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  73. const vector<vector<Proof>>& validVoteProofs,
  74. const vector<PrsonaClient>& users,
  75. PrsonaServerEntity& servers)
  76. {
  77. vector<double> retval;
  78. size_t numUsers = users.size();
  79. size_t numServers = servers.get_num_servers();
  80. for (size_t i = 0; i < numUsers; i++)
  81. {
  82. Proof ownerProof;
  83. Curvepoint shortTermPublicKey =
  84. users[i].get_short_term_public_key(ownerProof);
  85. for (size_t j = 0; j < numServers; j++)
  86. {
  87. chrono::high_resolution_clock::time_point t0 =
  88. chrono::high_resolution_clock::now();
  89. servers.receive_vote(
  90. validVoteProofs[i],
  91. newEncryptedVotes[i],
  92. shortTermPublicKey,
  93. j);
  94. chrono::high_resolution_clock::time_point t1 =
  95. chrono::high_resolution_clock::now();
  96. chrono::duration<double> time_span =
  97. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  98. retval.push_back(time_span.count());
  99. }
  100. }
  101. return retval;
  102. }
  103. // Time how long it takes to do the operations associated with an epoch
  104. double epoch(PrsonaServerEntity& servers)
  105. {
  106. Proof unused;
  107. // Do the epoch server calculations
  108. chrono::high_resolution_clock::time_point t0 =
  109. chrono::high_resolution_clock::now();
  110. servers.epoch(unused);
  111. chrono::high_resolution_clock::time_point t1 =
  112. chrono::high_resolution_clock::now();
  113. // Return the timing of the epoch server calculations
  114. chrono::duration<double> time_span =
  115. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  116. return time_span.count();
  117. }
  118. // Time how long it takes each user to decrypt their new scores
  119. vector<double> transmit_epoch_updates(
  120. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  121. {
  122. vector<double> retval;
  123. size_t numUsers = users.size();
  124. for (size_t i = 0; i < numUsers; i++)
  125. {
  126. chrono::high_resolution_clock::time_point t0 =
  127. chrono::high_resolution_clock::now();
  128. servers.transmit_updates(users[i]);
  129. chrono::high_resolution_clock::time_point t1 =
  130. chrono::high_resolution_clock::now();
  131. chrono::duration<double> time_span =
  132. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  133. retval.push_back(time_span.count());
  134. }
  135. return retval;
  136. }
  137. // Test if the proof of reputation level is working as expected
  138. void test_reputation_proof(
  139. default_random_engine& generator,
  140. const PrsonaClient& a,
  141. const PrsonaClient& b)
  142. {
  143. bool flag;
  144. mpz_class aScore = a.get_score().toInt();
  145. int i = 0;
  146. while (i < aScore)
  147. i++;
  148. uniform_int_distribution<int> thresholdDistribution(0, i);
  149. Scalar goodThreshold(thresholdDistribution(generator));
  150. Scalar badThreshold(aScore + 1);
  151. Proof pi;
  152. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  153. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  154. flag = b.verify_reputation_proof(
  155. goodRepProof, shortTermPublicKey, goodThreshold);
  156. cout << "TEST VALID REPUTATION PROOF: "
  157. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  158. << endl;
  159. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  160. flag = b.verify_reputation_proof(
  161. badRepProof, shortTermPublicKey, badThreshold);
  162. cout << "TEST INVALID REPUTATION PROOF: "
  163. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  164. << endl << endl;
  165. }
  166. // Test if the proof of valid votes is working as expected
  167. void test_vote_proof(
  168. default_random_engine& generator,
  169. const PrsonaClient& user,
  170. PrsonaServerEntity& servers)
  171. {
  172. size_t numUsers = servers.get_num_clients();
  173. vector<Scalar> votes;
  174. vector<bool> replaces;
  175. bool flag;
  176. for (size_t i = 0; i < numUsers; i++)
  177. {
  178. votes.push_back(Scalar(1));
  179. replaces.push_back(true);
  180. }
  181. vector<Proof> validVoteProof;
  182. Proof ownerProof;
  183. Curvepoint shortTermPublicKey =
  184. user.get_short_term_public_key(ownerProof);
  185. vector<CurveBipoint> encryptedVotes =
  186. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  187. encryptedVotes =
  188. user.make_votes(
  189. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  190. flag = servers.receive_vote(
  191. validVoteProof, encryptedVotes, shortTermPublicKey);
  192. cout << "TEST REPLACE VOTE PROOF: "
  193. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  194. << endl;
  195. for (size_t i = 0; i < numUsers; i++)
  196. {
  197. replaces[i] = false;
  198. }
  199. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  200. encryptedVotes =
  201. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  202. encryptedVotes =
  203. user.make_votes(
  204. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  205. flag = servers.receive_vote(
  206. validVoteProof, encryptedVotes, shortTermPublicKey);
  207. cout << "TEST RERANDOMIZE VOTE PROOF: "
  208. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  209. << endl;
  210. for (size_t i = 0; i < numUsers; i++)
  211. {
  212. votes[i] = Scalar(3);
  213. replaces[i] = true;
  214. }
  215. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  216. encryptedVotes =
  217. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  218. encryptedVotes =
  219. user.make_votes(
  220. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  221. flag = servers.receive_vote(
  222. validVoteProof, encryptedVotes, shortTermPublicKey);
  223. cout << "TEST INVALID REPLACE VOTE PROOF: "
  224. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  225. << endl << endl;
  226. }
  227. int main(int argc, char *argv[])
  228. {
  229. initialize_prsona_classes();
  230. // Defaults
  231. size_t numServers = 2;
  232. size_t numUsers = 5;
  233. size_t numRounds = 3;
  234. size_t numVotesPerRound = 3;
  235. bool maliciousServers = true;
  236. bool maliciousClients = true;
  237. string seedStr = "seed";
  238. // Potentially accept command line inputs
  239. if (argc > 1)
  240. numServers = atoi(argv[1]);
  241. if (argc > 2)
  242. numUsers = atoi(argv[2]);
  243. if (argc > 3)
  244. numRounds = atoi(argv[3]);
  245. if (argc > 4)
  246. numVotesPerRound = atoi(argv[4]);
  247. cout << "Running the protocol with the following parameters: " << endl;
  248. cout << numServers << " PRSONA servers" << endl;
  249. cout << numUsers << " participants (voters/votees)" << endl;
  250. cout << numRounds << " epochs" << endl;
  251. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  252. << endl << endl;
  253. // Set malicious flags where necessary
  254. if (maliciousServers)
  255. PrsonaBase::set_server_malicious();
  256. if (maliciousClients)
  257. PrsonaBase::set_client_malicious();
  258. // Entities we operate with
  259. PrsonaServerEntity servers(numServers);
  260. vector<Proof> elGamalBlindGeneratorProof;
  261. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  262. Curvepoint elGamalBlindGenerator =
  263. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  264. cout << "Initialization: adding users to system" << endl << endl;
  265. vector<PrsonaClient> users;
  266. for (size_t i = 0; i < numUsers; i++)
  267. {
  268. PrsonaClient currUser(
  269. bgnPublicKey,
  270. elGamalBlindGeneratorProof,
  271. elGamalBlindGenerator,
  272. &servers);
  273. servers.add_new_client(currUser);
  274. users.push_back(currUser);
  275. }
  276. // Seeded randomness for random votes used in epoch
  277. seed_seq seed(seedStr.begin(), seedStr.end());
  278. default_random_engine generator(seed);
  279. // Do the epoch operations
  280. for (size_t i = 0; i < numRounds; i++)
  281. {
  282. vector<double> timings;
  283. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  284. vector<vector<CurveBipoint>> newEncryptedVotes;
  285. vector<vector<Proof>> validVoteProofs;
  286. timings = make_votes(
  287. generator,
  288. newEncryptedVotes,
  289. validVoteProofs,
  290. users,
  291. servers,
  292. numVotesPerRound);
  293. cout << "Vote generation (with proofs): " << mean(timings)
  294. << " seconds per user" << endl;
  295. timings.clear();
  296. timings = transmit_votes_to_servers(
  297. newEncryptedVotes, validVoteProofs, users, servers);
  298. cout << "Vote validation: " << mean(timings)
  299. << " seconds per vote vector/server" << endl;
  300. timings.clear();
  301. timings.push_back(epoch(servers));
  302. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  303. timings.clear();
  304. timings = transmit_epoch_updates(users, servers);
  305. cout << "Transmit epoch updates: " << mean(timings)
  306. << " seconds per user" << endl << endl;
  307. }
  308. // Pick random users for our tests
  309. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  310. size_t user_a = userDistribution(generator);
  311. size_t user_b = user_a;
  312. while (user_b == user_a)
  313. user_b = userDistribution(generator);
  314. test_reputation_proof(generator, users[user_a], users[user_b]);
  315. test_vote_proof(generator, users[user_a], servers);
  316. return 0;
  317. }