main.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <random>
  4. #include <chrono>
  5. #include "BGN.hpp"
  6. #include "client.hpp"
  7. #include "server.hpp"
  8. #include "serverEntity.hpp"
  9. using namespace std;
  10. // Initialize the classes we use
  11. void initialize_prsona_classes()
  12. {
  13. Scalar::init();
  14. PrsonaBase::init();
  15. }
  16. // Quick and dirty mean calculation (used for averaging timings)
  17. double mean(vector<double> xx)
  18. {
  19. return accumulate(xx.begin(), xx.end(), 0.0) / xx.size();
  20. }
  21. void print_user_scores(const vector<PrsonaClient>& users)
  22. {
  23. std::cout << "<";
  24. for (size_t i = 0; i < users.size(); i++)
  25. {
  26. std::cout << users[i].get_score()
  27. << (i == users.size() - 1 ? ">" : " ");
  28. }
  29. std::cout << std::endl;
  30. }
  31. // Time how long it takes to make a proof of valid votes
  32. vector<double> make_votes(
  33. default_random_engine& generator,
  34. vector<vector<CurveBipoint>>& newEncryptedVotes,
  35. vector<vector<Proof>>& validVoteProofs,
  36. const vector<PrsonaClient>& users,
  37. const PrsonaServerEntity& servers,
  38. size_t numVotes)
  39. {
  40. vector<double> retval;
  41. uniform_int_distribution<int> voteDistribution(
  42. 0, PrsonaBase::get_max_allowed_vote());
  43. size_t numUsers = users.size();
  44. newEncryptedVotes.clear();
  45. for (size_t i = 0; i < numUsers; i++)
  46. {
  47. // Make the correct number of new votes, but shuffle where they go
  48. vector<Scalar> votes;
  49. vector<bool> replaces;
  50. for (size_t j = 0; j < numUsers; j++)
  51. {
  52. votes.push_back(Scalar(voteDistribution(generator)));
  53. replaces.push_back(j < numVotes);
  54. }
  55. shuffle(replaces.begin(), replaces.end(), generator);
  56. Proof ownerProof;
  57. Curvepoint shortTermPublicKey =
  58. users[i].get_short_term_public_key(ownerProof);
  59. vector<CurveBipoint> currEncryptedVotes =
  60. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  61. vector<Proof> currVoteProof;
  62. chrono::high_resolution_clock::time_point t0 =
  63. chrono::high_resolution_clock::now();
  64. currEncryptedVotes = users[i].make_votes(
  65. currVoteProof,
  66. ownerProof,
  67. currEncryptedVotes,
  68. votes,
  69. replaces);
  70. chrono::high_resolution_clock::time_point t1 =
  71. chrono::high_resolution_clock::now();
  72. newEncryptedVotes.push_back(currEncryptedVotes);
  73. validVoteProofs.push_back(currVoteProof);
  74. chrono::duration<double> time_span =
  75. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  76. retval.push_back(time_span.count());
  77. }
  78. return retval;
  79. }
  80. // Time how long it takes to validate a proof of valid votes
  81. vector<double> transmit_votes_to_servers(
  82. const vector<vector<CurveBipoint>>& newEncryptedVotes,
  83. const vector<vector<Proof>>& validVoteProofs,
  84. const vector<PrsonaClient>& users,
  85. PrsonaServerEntity& servers)
  86. {
  87. vector<double> retval;
  88. size_t numUsers = users.size();
  89. size_t numServers = servers.get_num_servers();
  90. for (size_t i = 0; i < numUsers; i++)
  91. {
  92. Proof ownerProof;
  93. Curvepoint shortTermPublicKey =
  94. users[i].get_short_term_public_key(ownerProof);
  95. for (size_t j = 0; j < numServers; j++)
  96. {
  97. chrono::high_resolution_clock::time_point t0 =
  98. chrono::high_resolution_clock::now();
  99. servers.receive_vote(
  100. validVoteProofs[i],
  101. newEncryptedVotes[i],
  102. shortTermPublicKey,
  103. j);
  104. chrono::high_resolution_clock::time_point t1 =
  105. chrono::high_resolution_clock::now();
  106. chrono::duration<double> time_span =
  107. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  108. retval.push_back(time_span.count());
  109. }
  110. }
  111. return retval;
  112. }
  113. // Time how long it takes to do the operations associated with an epoch
  114. double epoch(PrsonaServerEntity& servers)
  115. {
  116. // Do the epoch server calculations
  117. chrono::high_resolution_clock::time_point t0 =
  118. chrono::high_resolution_clock::now();
  119. servers.epoch();
  120. chrono::high_resolution_clock::time_point t1 =
  121. chrono::high_resolution_clock::now();
  122. // Return the timing of the epoch server calculations
  123. chrono::duration<double> time_span =
  124. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  125. return time_span.count();
  126. }
  127. // Time how long it takes each user to decrypt their new scores
  128. vector<double> transmit_epoch_updates(
  129. vector<PrsonaClient>& users, const PrsonaServerEntity& servers)
  130. {
  131. vector<double> retval;
  132. size_t numUsers = users.size();
  133. for (size_t i = 0; i < numUsers; i++)
  134. {
  135. chrono::high_resolution_clock::time_point t0 =
  136. chrono::high_resolution_clock::now();
  137. servers.transmit_updates(users[i]);
  138. chrono::high_resolution_clock::time_point t1 =
  139. chrono::high_resolution_clock::now();
  140. chrono::duration<double> time_span =
  141. chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  142. retval.push_back(time_span.count());
  143. }
  144. return retval;
  145. }
  146. // Test if the proof of reputation level is working as expected
  147. void test_reputation_proof(
  148. default_random_engine& generator,
  149. const PrsonaClient& a,
  150. const PrsonaClient& b)
  151. {
  152. bool flag;
  153. mpz_class aScore = a.get_score().toInt();
  154. int i = 0;
  155. while (i < aScore)
  156. i++;
  157. uniform_int_distribution<int> thresholdDistribution(0, i);
  158. Scalar goodThreshold(thresholdDistribution(generator));
  159. Scalar badThreshold(aScore + 1);
  160. Proof pi;
  161. Curvepoint shortTermPublicKey = a.get_short_term_public_key(pi);
  162. vector<Proof> goodRepProof = a.generate_reputation_proof(goodThreshold);
  163. flag = b.verify_reputation_proof(
  164. goodRepProof, shortTermPublicKey, goodThreshold);
  165. cout << "TEST VALID REPUTATION PROOF: "
  166. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  167. << endl;
  168. vector<Proof> badRepProof = a.generate_reputation_proof(badThreshold);
  169. flag = b.verify_reputation_proof(
  170. badRepProof, shortTermPublicKey, badThreshold);
  171. cout << "TEST INVALID REPUTATION PROOF: "
  172. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  173. << endl << endl;
  174. }
  175. // Test if the proof of valid votes is working as expected
  176. void test_vote_proof(
  177. default_random_engine& generator,
  178. const PrsonaClient& user,
  179. PrsonaServerEntity& servers)
  180. {
  181. size_t numUsers = servers.get_num_clients();
  182. vector<Scalar> votes;
  183. vector<bool> replaces;
  184. bool flag;
  185. for (size_t i = 0; i < numUsers; i++)
  186. {
  187. votes.push_back(Scalar(1));
  188. replaces.push_back(true);
  189. }
  190. vector<Proof> validVoteProof;
  191. Proof ownerProof;
  192. Curvepoint shortTermPublicKey =
  193. user.get_short_term_public_key(ownerProof);
  194. vector<CurveBipoint> encryptedVotes =
  195. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  196. encryptedVotes =
  197. user.make_votes(
  198. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  199. flag = servers.receive_vote(
  200. validVoteProof, encryptedVotes, shortTermPublicKey);
  201. cout << "TEST REPLACE VOTE PROOF: "
  202. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  203. << endl;
  204. for (size_t i = 0; i < numUsers; i++)
  205. {
  206. replaces[i] = false;
  207. }
  208. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  209. encryptedVotes =
  210. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  211. encryptedVotes =
  212. user.make_votes(
  213. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  214. flag = servers.receive_vote(
  215. validVoteProof, encryptedVotes, shortTermPublicKey);
  216. cout << "TEST RERANDOMIZE VOTE PROOF: "
  217. << (flag ? "PASSED (Proof verified)" : "FAILED (Proof not verified)" )
  218. << endl;
  219. for (size_t i = 0; i < numUsers; i++)
  220. {
  221. votes[i] = Scalar(3);
  222. replaces[i] = true;
  223. }
  224. shortTermPublicKey = user.get_short_term_public_key(ownerProof);
  225. encryptedVotes =
  226. servers.get_current_votes_by(ownerProof, shortTermPublicKey);
  227. encryptedVotes =
  228. user.make_votes(
  229. validVoteProof, ownerProof, encryptedVotes, votes, replaces);
  230. flag = servers.receive_vote(
  231. validVoteProof, encryptedVotes, shortTermPublicKey);
  232. cout << "TEST INVALID REPLACE VOTE PROOF: "
  233. << (flag ? "FAILED (Proof verified)" : "PASSED (Proof not verified)" )
  234. << endl << endl;
  235. }
  236. int main(int argc, char *argv[])
  237. {
  238. initialize_prsona_classes();
  239. // Defaults
  240. size_t numServers = 2;
  241. size_t numUsers = 3;
  242. size_t numRounds = 1;
  243. size_t numVotesPerRound = 1;
  244. bool maliciousServers = true;
  245. bool maliciousClients = true;
  246. string seedStr = "seed";
  247. // Potentially accept command line inputs
  248. if (argc > 1)
  249. numServers = atoi(argv[1]);
  250. if (argc > 2)
  251. numUsers = atoi(argv[2]);
  252. if (argc > 3)
  253. numRounds = atoi(argv[3]);
  254. if (argc > 4)
  255. numVotesPerRound = atoi(argv[4]);
  256. cout << "Running the protocol with the following parameters: " << endl;
  257. cout << numServers << " PRSONA servers" << endl;
  258. cout << numUsers << " participants (voters/votees)" << endl;
  259. cout << numRounds << " epochs" << endl;
  260. cout << numVotesPerRound << " new (random) votes by each user per epoch"
  261. << endl << endl;
  262. // Set malicious flags where necessary
  263. if (maliciousServers)
  264. PrsonaBase::set_server_malicious();
  265. if (maliciousClients)
  266. PrsonaBase::set_client_malicious();
  267. // Entities we operate with
  268. PrsonaServerEntity servers(numServers);
  269. vector<Proof> elGamalBlindGeneratorProof;
  270. BGNPublicKey bgnPublicKey = servers.get_bgn_public_key();
  271. Curvepoint elGamalBlindGenerator =
  272. servers.get_blinding_generator(elGamalBlindGeneratorProof);
  273. cout << "Initialization: adding users to system" << endl << endl;
  274. vector<PrsonaClient> users;
  275. for (size_t i = 0; i < numUsers; i++)
  276. {
  277. PrsonaClient currUser(
  278. elGamalBlindGeneratorProof,
  279. elGamalBlindGenerator,
  280. bgnPublicKey,
  281. &servers);
  282. users.push_back(currUser);
  283. servers.add_new_client(users[i]);
  284. }
  285. // Seeded randomness for random votes used in epoch
  286. seed_seq seed(seedStr.begin(), seedStr.end());
  287. default_random_engine generator(seed);
  288. // Do the epoch operations
  289. for (size_t i = 0; i < numRounds; i++)
  290. {
  291. vector<double> timings;
  292. cout << "Round " << i+1 << " of " << numRounds << ": " << endl;
  293. vector<vector<CurveBipoint>> newEncryptedVotes;
  294. vector<vector<Proof>> validVoteProofs;
  295. timings = make_votes(
  296. generator,
  297. newEncryptedVotes,
  298. validVoteProofs,
  299. users,
  300. servers,
  301. numVotesPerRound);
  302. cout << "Vote generation (with proofs): " << mean(timings)
  303. << " seconds per user" << endl;
  304. timings.clear();
  305. timings = transmit_votes_to_servers(
  306. newEncryptedVotes, validVoteProofs, users, servers);
  307. cout << "Vote validation: " << mean(timings)
  308. << " seconds per vote vector/server" << endl;
  309. timings.clear();
  310. timings.push_back(epoch(servers));
  311. cout << "Epoch computation: " << mean(timings) << " seconds" << endl;
  312. timings.clear();
  313. timings = transmit_epoch_updates(users, servers);
  314. cout << "Transmit epoch updates: " << mean(timings)
  315. << " seconds per user" << endl << endl;
  316. }
  317. // // Pick random users for our tests
  318. // uniform_int_distribution<size_t> userDistribution(0, numUsers - 1);
  319. // size_t user_a = userDistribution(generator);
  320. // size_t user_b = user_a;
  321. // while (user_b == user_a)
  322. // user_b = userDistribution(generator);
  323. // test_reputation_proof(generator, users[user_a], users[user_b]);
  324. // test_vote_proof(generator, users[user_a], servers);
  325. return 0;
  326. }