main.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. // #include <boost/program_options.hpp>
  2. // namespace po = boost::program_options;
  3. #include <iostream>
  4. #include <algorithm>
  5. #include <random>
  6. #include <chrono>
  7. #include "BGN.hpp"
  8. #include "client.hpp"
  9. #include "server.hpp"
  10. #include "serverEntity.hpp"
  11. using namespace std;
  12. const int MAX_ALLOWED_VOTE = 2;
  13. // Initialize the classes we use
  14. void initialize_prsona_classes()
  15. {
  16. Scalar::init();
  17. PrsonaServer::init();
  18. PrsonaClient::init();
  19. }
  20. // Quick and dirty mean calculation (used for averaging timings)
  21. double mean(vector<double> xx)
  22. {
  23. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  24. }
  25. // Time how long it takes to make a proof of valid votes
  26. vector<double> make_votes(
  27. default_random_engine& generator,
  28. vector<vector<CurveBipoint>>& newEncryptedVotes,
  29. vector<vector<Proof>>& validVoteProofs,
  30. const vector<PrsonaClient>& users,
  31. const PrsonaServerEntity& servers,
  32. size_t numVotes)
  33. {
  34. vector<double> retval;
  35. uniform_int_distribution<int> voteDistribution(0, MAX_ALLOWED_VOTE);
  36. size_t numUsers = users.size();
  37. newEncryptedVotes.clear();
  38. for (size_t i = 0; i < numUsers; i++)
  39. {
  40. // Make the correct number of new votes, but shuffle where they go
  41. vector<Scalar> votes;
  42. vector<bool> replaces;
  43. for (size_t j = 0; j < numUsers; j++)
  44. {
  45. votes.push_back(Scalar(voteDistribution(generator)));
  46. replaces.push_back(j < numVotes);
  47. }
  48. shuffle(replaces.begin(), replaces.end(), generator);
  49. Proof ownerProof;
  50. Curvepoint shortTermPublicKey =
  51. users[i].get_short_term_public_key(ownerProof);
  52. vector<CurveBipoint> currEncryptedVotes =
  53. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  54. vector<Proof> currVoteProof;
  55. chrono::high_resolution_clock::time_point t0 =
  56. chrono::high_resolution_clock::now();
  57. currEncryptedVotes = users[i].make_votes(
  58. currVoteProof,
  59. ownerProof,
  60. currEncryptedVotes,
  61. votes,
  62. replaces);
  63. chrono::high_resolution_clock::time_point t1 =
  64. chrono::high_resolution_clock::now();
  65. newEncryptedVotes.push_back(currEncryptedVotes);
  66. validVoteProofs.push_back(currVoteProof);
  67. chrono::duration<double> time_span =
  68. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  69. retval.push_back(time_span.count());
  70. }
  71. return retval;
  72. }
  73. // Time how long it takes to validate a proof of valid votes
  74. vector<double> transmit_votes_to_servers(
  75. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  76. const vector<vector<Proof>>& validVoteProofs,
  77. const vector<PrsonaClient>& users,
  78. PrsonaServerEntity& servers)
  79. {
  80. vector<double> retval;
  81. size_t numUsers = users.size();
  82. size_t numServers = servers.get_num_servers();
  83. for (size_t i = 0; i < numUsers; i++)
  84. {
  85. Proof ownerProof;
  86. Curvepoint shortTermPublicKey =
  87. users[i].get_short_term_public_key(ownerProof);
  88. for (size_t j = 0; j < numServers; j++)
  89. {
  90. chrono::high_resolution_clock::time_point t0 =
  91. chrono::high_resolution_clock::now();
  92. servers.receive_vote(
  93. validVoteProofs[i],
  94. newEncryptedVotes[i],
  95. shortTermPublicKey,
  96. j);
  97. chrono::high_resolution_clock::time_point t1 =
  98. chrono::high_resolution_clock::now();
  99. chrono::duration<double> time_span =
  100. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  101. retval.push_back(time_span.count());
  102. }
  103. }
  104. return retval;
  105. }
  106. // Time how long it takes to do the operations associated with an epoch
  107. double epoch(PrsonaServerEntity& servers)
  108. {
  109. Proof unused;
  110. // Do the epoch server calculations
  111. chrono::high_resolution_clock::time_point t0 =
  112. chrono::high_resolution_clock::now();
  113. servers.epoch(unused);
  114. chrono::high_resolution_clock::time_point t1 =
  115. chrono::high_resolution_clock::now();
  116. // Return the timing of the epoch server calculations
  117. chrono::duration<double> time_span =
  118. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  119. return time_span.count();
  120. }
  121. // Time how long it takes each user to decrypt their new scores
  122. vector<double> transmit_epoch_updates(
  123. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  124. {
  125. vector<double> retval;
  126. size_t numUsers = users.size();
  127. for (size_t i = 0; i < numUsers; i++)
  128. {
  129. chrono::high_resolution_clock::time_point t0 =
  130. chrono::high_resolution_clock::now();
  131. servers.transmit_updates(users[i]);
  132. chrono::high_resolution_clock::time_point t1 =
  133. chrono::high_resolution_clock::now();
  134. chrono::duration<double> time_span =
  135. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  136. retval.push_back(time_span.count());
  137. }
  138. return retval;
  139. }
  140. // Test if the proof of reputation level is working as expected
  141. void test_reputation_proof(
  142. default_random_engine& generator,
  143. const PrsonaClient& a,
  144. const PrsonaClient& b)
  145. {
  146. bool flag;
  147. mpz_class aScore = a.get_score().toInt();
  148. int i = 0;
  149. while (i < aScore)
  150. i++;
  151. uniform_int_distribution<int> thresholdDistribution(0, i);
  152. Scalar goodThreshold(thresholdDistribution(generator));
  153. Scalar badThreshold(aScore + 1);
  154. Proof pi;
  155. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  156. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  157. flag = b.verify_reputation_proof(
  158. goodRepProof, shortTermPublicKey, goodThreshold);
  159. cout << "TEST VALID REPUTATION PROOF: "
  160. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  161. << endl;
  162. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  163. flag = b.verify_reputation_proof(
  164. badRepProof, shortTermPublicKey, badThreshold);
  165. cout << "TEST INVALID REPUTATION PROOF: "
  166. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  167. << endl << endl;
  168. }
  169. // Test if the proof of valid votes is working as expected
  170. void test_vote_proof(
  171. default_random_engine& generator,
  172. const PrsonaClient& user,
  173. PrsonaServerEntity& servers)
  174. {
  175. size_t numUsers = servers.get_num_clients();
  176. vector<Scalar> votes;
  177. vector<bool> replaces;
  178. bool flag;
  179. for (size_t i = 0; i < numUsers; i++)
  180. {
  181. votes.push_back(Scalar(1));
  182. replaces.push_back(true);
  183. }
  184. vector<Proof> validVoteProof;
  185. Proof ownerProof;
  186. Curvepoint shortTermPublicKey =
  187. user.get_short_term_public_key(ownerProof);
  188. vector<CurveBipoint> encryptedVotes =
  189. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  190. encryptedVotes =
  191. user.make_votes(
  192. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  193. flag = servers.receive_vote(
  194. validVoteProof, encryptedVotes, shortTermPublicKey);
  195. cout << "TEST REPLACE VOTE PROOF: "
  196. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  197. << endl;
  198. for (size_t i = 0; i < numUsers; i++)
  199. {
  200. replaces[i] = false;
  201. }
  202. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  203. encryptedVotes =
  204. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  205. encryptedVotes =
  206. user.make_votes(
  207. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  208. flag = servers.receive_vote(
  209. validVoteProof, encryptedVotes, shortTermPublicKey);
  210. cout << "TEST RERANDOMIZE VOTE PROOF: "
  211. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  212. << endl;
  213. for (size_t i = 0; i < numUsers; i++)
  214. {
  215. votes[i] = Scalar(3);
  216. replaces[i] = true;
  217. }
  218. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  219. encryptedVotes =
  220. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  221. encryptedVotes =
  222. user.make_votes(
  223. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  224. flag = servers.receive_vote(
  225. validVoteProof, encryptedVotes, shortTermPublicKey);
  226. cout << "TEST INVALID REPLACE VOTE PROOF: "
  227. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  228. << endl << endl;
  229. }
  230. int main(int argc, char *argv[])
  231. {
  232. initialize_prsona_classes();
  233. // Defaults
  234. size_t numServers = 2;
  235. size_t numUsers = 5;
  236. size_t numRounds = 3;
  237. size_t numVotesPerRound = 3;
  238. bool maliciousServers = false;
  239. bool maliciousClients = true;
  240. string seedStr = "seed";
  241. // Potentially accept command line inputs
  242. if (argc > 1)
  243. numServers = atoi(argv[1]);
  244. if (argc > 2)
  245. numUsers = atoi(argv[2]);
  246. if (argc > 3)
  247. numRounds = atoi(argv[3]);
  248. if (argc > 4)
  249. numVotesPerRound = atoi(argv[4]);
  250. cout << "Running the protocol with the following parameters: " << endl;
  251. cout << numServers << " PRSONA servers" << endl;
  252. cout << numUsers << " participants (voters/votees)" << endl;
  253. cout << numRounds << " epochs" << endl;
  254. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  255. << endl << endl;
  256. // Set malicious flags where necessary
  257. if (maliciousServers)
  258. {
  259. PrsonaServer::set_server_malicious();
  260. PrsonaClient::set_server_malicious();
  261. }
  262. if (maliciousClients)
  263. {
  264. PrsonaServer::set_client_malicious();
  265. PrsonaClient::set_client_malicious();
  266. }
  267. // Entities we operate with
  268. PrsonaServerEntity servers(numServers);
  269. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  270. Curvepoint elGamalBlindGenerator = servers.get_blinding_generator();
  271. cout << "Initialization: adding users to system" << endl << endl;
  272. vector<PrsonaClient> users;
  273. for (size_t i = 0; i < numUsers; i++)
  274. {
  275. PrsonaClient currUser(bgnPublicKey, elGamalBlindGenerator, &servers);
  276. servers.add_new_client(currUser);
  277. users.push_back(currUser);
  278. }
  279. // Seeded randomness for random votes used in epoch
  280. seed_seq seed(seedStr.begin(), seedStr.end());
  281. default_random_engine generator(seed);
  282. // Do the epoch operations
  283. for (size_t i = 0; i < numRounds; i++)
  284. {
  285. vector<double> timings;
  286. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  287. vector<vector<CurveBipoint>> newEncryptedVotes;
  288. vector<vector<Proof>> validVoteProofs;
  289. timings = make_votes(
  290. generator,
  291. newEncryptedVotes,
  292. validVoteProofs,
  293. users,
  294. servers,
  295. numVotesPerRound);
  296. cout << "Vote generation (with proofs): " << mean(timings)
  297. << " seconds per user" << endl;
  298. timings.clear();
  299. timings = transmit_votes_to_servers(
  300. newEncryptedVotes, validVoteProofs, users, servers);
  301. cout << "Vote validation: " << mean(timings)
  302. << " seconds per vote vector/server" << endl;
  303. timings.clear();
  304. timings.push_back(epoch(servers));
  305. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  306. timings.clear();
  307. timings = transmit_epoch_updates(users, servers);
  308. cout << "Transmit epoch updates: " << mean(timings)
  309. << " seconds per user" << endl << endl;
  310. }
  311. // Pick random users for our tests
  312. uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  313. size_t user_a = userDistribution(generator);
  314. size_t user_b = user_a;
  315. while (user_b == user_a)
  316. user_b = userDistribution(generator);
  317. test_reputation_proof(generator, users[user_a], users[user_b]);
  318. test_vote_proof(generator, users[user_a], servers);
  319. return 0;
  320. }