client.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. #include <iostream>
  2. #include "client.hpp"
  3. #include "serverEntity.hpp"
  4. extern const curvepoint_fp_t bn_curvegen;
  5. const int MAX_ALLOWED_VOTE = 2;
  6. /* These lines need to be here so these static variables are defined,
  7. * but in C++ putting code here doesn't actually execute
  8. * (or at least, with g++, whenever it would execute is not at a useful time)
  9. * so we have an init() function to actually put the correct values in them. */
  10. Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
  11. Curvepoint PrsonaClient::EL_GAMAL_BLIND_GENERATOR = Curvepoint();
  12. bool PrsonaClient::SERVER_IS_MALICIOUS = false;
  13. bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
  14. // Quick and dirty function to calculate ceil(log base 2) with mpz_class
  15. mpz_class log2(mpz_class x)
  16. {
  17. mpz_class retval = 0;
  18. while (x > 0)
  19. {
  20. retval++;
  21. x = x >> 1;
  22. }
  23. return retval;
  24. }
  25. /********************
  26. * PUBLIC FUNCTIONS *
  27. ********************/
  28. /*
  29. * CONSTRUCTORS
  30. */
  31. PrsonaClient::PrsonaClient(
  32. const BGNPublicKey& serverPublicKey,
  33. const PrsonaServerEntity* servers)
  34. : serverPublicKey(serverPublicKey),
  35. servers(servers),
  36. max_checked(0)
  37. {
  38. longTermPrivateKey.set_random();
  39. inversePrivateKey = longTermPrivateKey.curveInverse();
  40. decryption_memoizer[EL_GAMAL_BLIND_GENERATOR * max_checked] = max_checked;
  41. }
  42. /*
  43. * SETUP FUNCTIONS
  44. */
  45. // Must be called once before any usage of this class
  46. void PrsonaClient::init(const Curvepoint& elGamalBlindGenerator)
  47. {
  48. EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
  49. EL_GAMAL_BLIND_GENERATOR = elGamalBlindGenerator;
  50. }
  51. void PrsonaClient::set_server_malicious()
  52. {
  53. SERVER_IS_MALICIOUS = true;
  54. }
  55. void PrsonaClient::set_client_malicious()
  56. {
  57. CLIENT_IS_MALICIOUS = true;
  58. }
  59. /*
  60. * BASIC PUBLIC SYSTEM INFO GETTERS
  61. */
  62. Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
  63. {
  64. pi = generate_ownership_proof();
  65. return currentFreshGenerator * longTermPrivateKey;
  66. }
  67. /*
  68. * SERVER INTERACTIONS
  69. */
  70. // Generate a new vote vector to give to the servers
  71. // (@replace controls which votes are actually being updated and which are not)
  72. std::vector<CurveBipoint> PrsonaClient::make_votes(
  73. Proof& pi,
  74. const std::vector<Scalar>& vote,
  75. const std::vector<bool>& replace)
  76. {
  77. std::vector<CurveBipoint> retval;
  78. for (size_t i = 0; i < vote.size(); i++)
  79. {
  80. CurveBipoint currScore;
  81. if (replace[i])
  82. serverPublicKey.encrypt(currScore, vote[i]);
  83. else
  84. currScore = serverPublicKey.rerandomize(currentEncryptedVotes[i]);
  85. retval.push_back(currScore);
  86. }
  87. currentEncryptedVotes = retval;
  88. pi = generate_vote_proof(retval, vote);
  89. return retval;
  90. }
  91. // Get a new fresh generator (happens at initialization and during each epoch)
  92. void PrsonaClient::receive_fresh_generator(const Curvepoint& freshGenerator)
  93. {
  94. currentFreshGenerator = freshGenerator;
  95. }
  96. // Receive a new encrypted score from the servers (each epoch)
  97. void PrsonaClient::receive_vote_tally(
  98. const Proof& pi, const EGCiphertext& score)
  99. {
  100. if (!verify_valid_tally_proof(pi, score))
  101. {
  102. std::cerr << "Could not verify proof of valid tally." << std::endl;
  103. return;
  104. }
  105. currentEncryptedScore = score;
  106. decrypt_score(score);
  107. }
  108. // Receive a new encrypted vote vector from the servers (each epoch)
  109. void PrsonaClient::receive_encrypted_votes(
  110. const Proof& pi, const std::vector<CurveBipoint>& votes)
  111. {
  112. if (!verify_valid_votes_proof(pi, votes))
  113. {
  114. std::cerr << "Could not verify proof of valid votes." << std::endl;
  115. return;
  116. }
  117. currentEncryptedVotes = votes;
  118. }
  119. /*
  120. * REPUTATION PROOFS
  121. */
  122. // TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
  123. std::vector<Proof> PrsonaClient::generate_reputation_proof(
  124. const Scalar& threshold) const
  125. {
  126. std::vector<Proof> retval;
  127. if (threshold > currentScore)
  128. return retval;
  129. if (!CLIENT_IS_MALICIOUS)
  130. {
  131. Proof currProof;
  132. currProof.basic = "PROOF";
  133. retval.push_back(currProof);
  134. return retval;
  135. }
  136. retval.push_back(generate_ownership_proof());
  137. mpz_class proofVal = currentScore.curveSub(threshold).toInt();
  138. mpz_class proofBits = log2(currentEncryptedVotes.size() * MAX_ALLOWED_VOTE - threshold.toInt());
  139. std::vector<Scalar> masksPerBit;
  140. masksPerBit.push_back(Scalar());
  141. for (size_t i = 1; i < proofBits; i++)
  142. {
  143. Scalar currMask;
  144. currMask.set_random();
  145. masksPerBit.push_back(currMask);
  146. masksPerBit[0] = masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
  147. }
  148. for (size_t i = 0; i < proofBits; i++)
  149. {
  150. Proof currProof;
  151. std::stringstream oracleInput;
  152. oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR;
  153. mpz_class currBit = proofVal & (1 << i);
  154. Curvepoint currentCommitment = currentFreshGenerator * masksPerBit[i] + EL_GAMAL_BLIND_GENERATOR * Scalar(currBit);
  155. currProof.partialUniversals.push_back(currentCommitment);
  156. oracleInput << currentCommitment;
  157. if (currBit)
  158. {
  159. Scalar u_0, c, c_0, c_1, z_0, z_1;
  160. u_0.set_random();
  161. c_1.set_random();
  162. z_1.set_random();
  163. Curvepoint U_0 = currentFreshGenerator * u_0;
  164. Curvepoint U_1 = currentFreshGenerator * z_1 - currentCommitment * c_1 + EL_GAMAL_BLIND_GENERATOR;
  165. currProof.initParts.push_back(U_0);
  166. currProof.initParts.push_back(U_1);
  167. oracleInput << U_0 << U_1;
  168. c = oracle(oracleInput.str());
  169. c_0 = c.curveSub(c_1);
  170. z_0 = c_0.curveMult(masksPerBit[i]).curveAdd(u_0);
  171. currProof.challengeParts.push_back(c_0);
  172. currProof.challengeParts.push_back(c_1);
  173. currProof.responseParts.push_back(z_0);
  174. currProof.responseParts.push_back(z_1);
  175. }
  176. else
  177. {
  178. Scalar u_1, c, c_0, c_1, z_0, z_1;
  179. u_1.set_random();
  180. c_0.set_random();
  181. z_0.set_random();
  182. Curvepoint U_0 = currentFreshGenerator * z_0 - currentCommitment * c_0;
  183. Curvepoint U_1 = currentFreshGenerator * u_1;
  184. currProof.initParts.push_back(U_0);
  185. currProof.initParts.push_back(U_1);
  186. oracleInput << U_0 << U_1;
  187. c = oracle(oracleInput.str());
  188. c_1 = c.curveSub(c_0);
  189. z_1 = c_1.curveMult(masksPerBit[i]).curveAdd(u_1);
  190. currProof.challengeParts.push_back(c_0);
  191. currProof.challengeParts.push_back(c_1);
  192. currProof.responseParts.push_back(z_0);
  193. currProof.responseParts.push_back(z_1);
  194. }
  195. retval.push_back(currProof);
  196. }
  197. return retval;
  198. }
  199. // TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
  200. bool PrsonaClient::verify_reputation_proof(
  201. const std::vector<Proof>& pi,
  202. const Curvepoint& shortTermPublicKey,
  203. const Scalar& threshold) const
  204. {
  205. if (pi.empty())
  206. return false;
  207. if (!CLIENT_IS_MALICIOUS)
  208. return pi[0].basic == "PROOF";
  209. if (!verify_ownership_proof(pi[0], shortTermPublicKey))
  210. return false;
  211. Curvepoint X;
  212. for (size_t i = 1; i < pi.size(); i++)
  213. {
  214. X = X + pi[i].partialUniversals[0] * Scalar(1 << (i - 1));
  215. std::stringstream oracleInput;
  216. oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR << pi[i].partialUniversals[0];
  217. oracleInput << pi[i].initParts[0] << pi[i].initParts[1];
  218. Scalar c = oracle(oracleInput.str());
  219. if (c != pi[i].challengeParts[0] + pi[i].challengeParts[1])
  220. return false;
  221. if (currentFreshGenerator * pi[i].responseParts[0] != pi[i].initParts[0] + pi[i].partialUniversals[0] * pi[i].challengeParts[0])
  222. return false;
  223. if (currentFreshGenerator * pi[i].responseParts[1] != pi[i].initParts[1] + pi[i].partialUniversals[0] * pi[i].challengeParts[1] - EL_GAMAL_BLIND_GENERATOR)
  224. return false;
  225. }
  226. Proof serverProof;
  227. EGCiphertext encryptedScore = servers->get_current_tally(serverProof, shortTermPublicKey);
  228. if (!verify_valid_tally_proof(serverProof, encryptedScore))
  229. return false;
  230. Scalar negThreshold;
  231. negThreshold = Scalar(0).curveSub(threshold);
  232. Curvepoint scoreCommitment = encryptedScore.encryptedMessage + EL_GAMAL_BLIND_GENERATOR * negThreshold;
  233. if (X != scoreCommitment)
  234. return false;
  235. return true;
  236. }
  237. /*********************
  238. * PRIVATE FUNCTIONS *
  239. *********************/
  240. /*
  241. * SCORE DECRYPTION
  242. */
  243. // Basic memoized score decryption
  244. void PrsonaClient::decrypt_score(const EGCiphertext& score)
  245. {
  246. Curvepoint s, hashedDecrypted;
  247. // Remove the mask portion of the ciphertext
  248. s = score.mask * inversePrivateKey;
  249. hashedDecrypted = score.encryptedMessage - s;
  250. // Check if it's a value we've already seen
  251. auto lookup = decryption_memoizer.find(hashedDecrypted);
  252. if (lookup != decryption_memoizer.end())
  253. {
  254. currentScore = lookup->second;
  255. return;
  256. }
  257. // If not, iterate until we find it (adding everything to the memoization)
  258. max_checked++;
  259. Curvepoint decryptionCandidate = EL_GAMAL_BLIND_GENERATOR * max_checked;
  260. while (decryptionCandidate != hashedDecrypted)
  261. {
  262. decryption_memoizer[decryptionCandidate] = max_checked;
  263. decryptionCandidate = decryptionCandidate + EL_GAMAL_BLIND_GENERATOR;
  264. max_checked++;
  265. }
  266. decryption_memoizer[decryptionCandidate] = max_checked;
  267. // Set the value we found
  268. currentScore = max_checked;
  269. }
  270. /*
  271. * OWNERSHIP PROOFS
  272. */
  273. // Very basic Schnorr proof (generation)
  274. Proof PrsonaClient::generate_ownership_proof() const
  275. {
  276. Proof retval;
  277. if (!CLIENT_IS_MALICIOUS)
  278. {
  279. retval.basic = "PROOF";
  280. return retval;
  281. }
  282. std::stringstream oracleInput;
  283. Scalar r;
  284. r.set_random();
  285. Curvepoint shortTermPublicKey = currentFreshGenerator * longTermPrivateKey;
  286. Curvepoint u = currentFreshGenerator * r;
  287. oracleInput << currentFreshGenerator << shortTermPublicKey << u;
  288. Scalar c = oracle(oracleInput.str());
  289. Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
  290. retval.basic = "PROOF";
  291. retval.initParts.push_back(u);
  292. retval.responseParts.push_back(z);
  293. return retval;
  294. }
  295. // Very basic Schnorr proof (verification)
  296. bool PrsonaClient::verify_ownership_proof(
  297. const Proof& pi, const Curvepoint& shortTermPublicKey) const
  298. {
  299. if (!CLIENT_IS_MALICIOUS)
  300. return pi.basic == "PROOF";
  301. Curvepoint u = pi.initParts[0];
  302. std::stringstream oracleInput;
  303. oracleInput << currentFreshGenerator << shortTermPublicKey << u;
  304. Scalar c = oracle(oracleInput.str());
  305. Scalar z = pi.responseParts[0];
  306. return (currentFreshGenerator * z) == (shortTermPublicKey * c + u);
  307. }
  308. /*
  309. * PROOF VERIFICATION
  310. */
  311. bool PrsonaClient::verify_score_proof(const Proof& pi) const
  312. {
  313. if (!SERVER_IS_MALICIOUS)
  314. return pi.basic == "PROOF";
  315. return pi.basic == "PROOF";
  316. }
  317. bool PrsonaClient::verify_generator_proof(
  318. const Proof& pi, const Curvepoint& generator) const
  319. {
  320. if (!SERVER_IS_MALICIOUS)
  321. return pi.basic == "PROOF";
  322. return pi.basic == "PROOF";
  323. }
  324. bool PrsonaClient::verify_default_tally_proof(
  325. const Proof& pi, const EGCiphertext& score) const
  326. {
  327. if (!SERVER_IS_MALICIOUS)
  328. return pi.basic == "PROOF";
  329. return pi.basic == "PROOF";
  330. }
  331. bool PrsonaClient::verify_valid_tally_proof(
  332. const Proof& pi, const EGCiphertext& score) const
  333. {
  334. if (!SERVER_IS_MALICIOUS)
  335. return pi.basic == "PROOF";
  336. return pi.basic == "PROOF";
  337. }
  338. bool PrsonaClient::verify_default_votes_proof(
  339. const Proof& pi, const std::vector<CurveBipoint>& votes) const
  340. {
  341. if (!SERVER_IS_MALICIOUS)
  342. return pi.basic == "PROOF";
  343. return pi.basic == "PROOF";
  344. }
  345. bool PrsonaClient::verify_valid_votes_proof(
  346. const Proof& pi, const std::vector<CurveBipoint>& votes) const
  347. {
  348. if (!SERVER_IS_MALICIOUS)
  349. return pi.basic == "PROOF";
  350. return pi.basic == "PROOF";
  351. }
  352. /*
  353. * PROOF GENERATION
  354. */
  355. Proof PrsonaClient::generate_vote_proof(
  356. const std::vector<CurveBipoint>& encryptedVotes,
  357. const std::vector<Scalar>& vote) const
  358. {
  359. Proof retval;
  360. if (!CLIENT_IS_MALICIOUS)
  361. {
  362. retval.basic = "PROOF";
  363. return retval;
  364. }
  365. retval.basic = "PROOF";
  366. return retval;
  367. }