client.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. #include <iostream>
  2. #include "client.hpp"
  3. #include "serverEntity.hpp"
  4. extern const curvepoint_fp_t bn_curvegen;
  5. const int MAX_ALLOWED_VOTE = 2;
  6. /* These lines need to be here so these static variables are defined,
  7. * but in C++ putting code here doesn't actually execute
  8. * (or at least, with g++, whenever it would execute is not at a useful time)
  9. * so we have an init() function to actually put the correct values in them. */
  10. Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
  11. Curvepoint PrsonaClient::EL_GAMAL_BLIND_GENERATOR = Curvepoint();
  12. bool PrsonaClient::SERVER_IS_MALICIOUS = false;
  13. bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
  14. // Quick and dirty function to calculate ceil(log base 2) with mpz_class
  15. mpz_class log2(mpz_class x)
  16. {
  17. mpz_class retval = 0;
  18. while (x > 0)
  19. {
  20. retval++;
  21. x = x >> 1;
  22. }
  23. return retval;
  24. }
  25. mpz_class bit(mpz_class x)
  26. {
  27. return x > 0 ? 1 : 0;
  28. }
  29. /********************
  30. * PUBLIC FUNCTIONS *
  31. ********************/
  32. /*
  33. * CONSTRUCTORS
  34. */
  35. PrsonaClient::PrsonaClient(
  36. const BGNPublicKey& serverPublicKey,
  37. const PrsonaServerEntity* servers)
  38. : serverPublicKey(serverPublicKey),
  39. servers(servers),
  40. max_checked(0)
  41. {
  42. longTermPrivateKey.set_random();
  43. inversePrivateKey = longTermPrivateKey.curveInverse();
  44. decryption_memoizer[EL_GAMAL_BLIND_GENERATOR * max_checked] = max_checked;
  45. }
  46. /*
  47. * SETUP FUNCTIONS
  48. */
  49. // Must be called once before any usage of this class
  50. void PrsonaClient::init(const Curvepoint& elGamalBlindGenerator)
  51. {
  52. EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
  53. EL_GAMAL_BLIND_GENERATOR = elGamalBlindGenerator;
  54. }
  55. void PrsonaClient::set_server_malicious()
  56. {
  57. SERVER_IS_MALICIOUS = true;
  58. }
  59. void PrsonaClient::set_client_malicious()
  60. {
  61. CLIENT_IS_MALICIOUS = true;
  62. }
  63. /*
  64. * BASIC PUBLIC SYSTEM INFO GETTERS
  65. */
  66. Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
  67. {
  68. pi = generate_ownership_proof();
  69. return currentFreshGenerator * longTermPrivateKey;
  70. }
  71. /*
  72. * SERVER INTERACTIONS
  73. */
  74. /* Generate a new vote vector to give to the servers
  75. * @replaces controls which votes are actually being updated and which are not
  76. *
  77. * You may really want to make currentEncryptedVotes a member variable,
  78. * but it doesn't behave correctly when adding new clients after this one. */
  79. std::vector<CurveBipoint> PrsonaClient::make_votes(
  80. Proof& pi,
  81. const std::vector<CurveBipoint>& currentEncryptedVotes,
  82. const std::vector<Scalar>& votes,
  83. const std::vector<bool>& replaces)
  84. {
  85. std::vector<CurveBipoint> retval;
  86. if (!verify_valid_votes_proof(pi, currentEncryptedVotes))
  87. {
  88. std::cerr << "Could not verify proof of valid votes." << std::endl;
  89. return retval;
  90. }
  91. for (size_t i = 0; i < votes.size(); i++)
  92. {
  93. CurveBipoint currScore;
  94. if (replaces[i])
  95. serverPublicKey.encrypt(currScore, votes[i]);
  96. else
  97. currScore = serverPublicKey.rerandomize(currentEncryptedVotes[i]);
  98. retval.push_back(currScore);
  99. }
  100. pi = generate_vote_proof(retval, votes);
  101. return retval;
  102. }
  103. // Get a new fresh generator (happens at initialization and during each epoch)
  104. void PrsonaClient::receive_fresh_generator(const Curvepoint& freshGenerator)
  105. {
  106. currentFreshGenerator = freshGenerator;
  107. }
  108. // Receive a new encrypted score from the servers (each epoch)
  109. void PrsonaClient::receive_vote_tally(
  110. const Proof& pi, const EGCiphertext& score)
  111. {
  112. if (!verify_valid_tally_proof(pi, score))
  113. {
  114. std::cerr << "Could not verify proof of valid tally." << std::endl;
  115. return;
  116. }
  117. currentEncryptedScore = score;
  118. decrypt_score(score);
  119. }
  120. /*
  121. * REPUTATION PROOFS
  122. */
  123. // A pretty straightforward range proof (generation)
  124. std::vector<Proof> PrsonaClient::generate_reputation_proof(
  125. const Scalar& threshold) const
  126. {
  127. std::vector<Proof> retval;
  128. // Don't even try if the user asks to make an illegitimate proof
  129. if (threshold > currentScore)
  130. return retval;
  131. // Base case
  132. if (!CLIENT_IS_MALICIOUS)
  133. {
  134. Proof currProof;
  135. currProof.basic = "PROOF";
  136. retval.push_back(currProof);
  137. return retval;
  138. }
  139. // We really have two consecutive proofs in a junction.
  140. // The first is to prove that we are the stpk we claim we are
  141. retval.push_back(generate_ownership_proof());
  142. // The value we're actually using in our proof
  143. mpz_class proofVal = currentScore.curveSub(threshold).toInt();
  144. // Top of the range in our proof determined by what scores are even possible
  145. mpz_class proofBits =
  146. log2(
  147. servers->get_num_clients() * MAX_ALLOWED_VOTE -
  148. threshold.toInt());
  149. // Don't risk a situation that would divulge our private key
  150. if (proofBits <= 1)
  151. proofBits = 2;
  152. // This seems weird, but remember our base is A_t^r, not g^t
  153. std::vector<Scalar> masksPerBit;
  154. masksPerBit.push_back(inversePrivateKey);
  155. for (size_t i = 1; i < proofBits; i++)
  156. {
  157. Scalar currMask;
  158. currMask.set_random();
  159. masksPerBit.push_back(currMask);
  160. masksPerBit[0] =
  161. masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
  162. }
  163. // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
  164. for (size_t i = 0; i < proofBits; i++)
  165. {
  166. Proof currProof;
  167. Curvepoint g, h, c, c_a, c_b;
  168. g = currentEncryptedScore.mask;
  169. h = EL_GAMAL_BLIND_GENERATOR;
  170. mpz_class currBit = bit(proofVal & (1 << i));
  171. Scalar a, s, t, m, r;
  172. a.set_random();
  173. s.set_random();
  174. t.set_random();
  175. m = Scalar(currBit);
  176. r = masksPerBit[i];
  177. c = g * r + h * m;
  178. currProof.partialUniversals.push_back(c);
  179. c_a = g * s + h * a;
  180. Scalar am = a.curveMult(m);
  181. c_b = g * t + h * am;
  182. std::stringstream oracleInput;
  183. oracleInput << g << h << c << c_a << c_b;
  184. Scalar x = oracle(oracleInput.str());
  185. currProof.challengeParts.push_back(x);
  186. Scalar f, z_a, z_b;
  187. Scalar mx = m.curveMult(x);
  188. f = mx.curveAdd(a);
  189. Scalar rx = r.curveMult(x);
  190. z_a = rx.curveAdd(s);
  191. Scalar x_f = x.curveSub(f);
  192. Scalar r_x_f = r.curveMult(x_f);
  193. z_b = r_x_f.curveAdd(t);
  194. currProof.responseParts.push_back(f);
  195. currProof.responseParts.push_back(z_a);
  196. currProof.responseParts.push_back(z_b);
  197. retval.push_back(currProof);
  198. }
  199. return retval;
  200. }
  201. // A pretty straightforward range proof (verification)
  202. bool PrsonaClient::verify_reputation_proof(
  203. const std::vector<Proof>& pi,
  204. const Curvepoint& shortTermPublicKey,
  205. const Scalar& threshold) const
  206. {
  207. // Reject outright if there's no proof to check
  208. if (pi.empty())
  209. {
  210. std::cerr << "Proof was empty, aborting." << std::endl;
  211. return false;
  212. }
  213. // Base case
  214. if (!CLIENT_IS_MALICIOUS)
  215. return pi[0].basic == "PROOF";
  216. // User should be able to prove they are who they say they are
  217. if (!verify_ownership_proof(pi[0], shortTermPublicKey))
  218. {
  219. std::cerr << "Schnorr proof failed, aborting." << std::endl;
  220. return false;
  221. }
  222. // Get the encrypted score in question from the servers
  223. Proof serverProof;
  224. EGCiphertext encryptedScore =
  225. servers->get_current_tally(serverProof, shortTermPublicKey);
  226. // Rough for the prover but if the server messes up,
  227. // no way to prove the thing anyways
  228. if (!verify_valid_tally_proof(serverProof, encryptedScore))
  229. {
  230. std::cerr << "Server error prevented proof from working, aborting." << std::endl;
  231. return false;
  232. }
  233. // X is the thing we're going to be checking in on throughout
  234. // to try to get our score commitment back in the end.
  235. Curvepoint X;
  236. for (size_t i = 1; i < pi.size(); i++)
  237. {
  238. Curvepoint c, g, h;
  239. c = pi[i].partialUniversals[0];
  240. g = encryptedScore.mask;
  241. h = EL_GAMAL_BLIND_GENERATOR;
  242. X = X + c * Scalar(1 << (i - 1));
  243. Scalar x, f, z_a, z_b;
  244. x = pi[i].challengeParts[0];
  245. f = pi[i].responseParts[0];
  246. z_a = pi[i].responseParts[1];
  247. z_b = pi[i].responseParts[2];
  248. // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
  249. Curvepoint c_a, c_b;
  250. c_a = g * z_a + h * f - c * x;
  251. Scalar x_f = x.curveSub(f);
  252. c_b = g * z_b - c * x_f;
  253. std::stringstream oracleInput;
  254. oracleInput << g << h << c << c_a << c_b;
  255. if (oracle(oracleInput.str()) != pi[i].challengeParts[0])
  256. {
  257. std::cerr << "0 or 1 proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
  258. return false;
  259. }
  260. }
  261. Scalar negThreshold;
  262. negThreshold = Scalar(0).curveSub(threshold);
  263. Curvepoint scoreCommitment =
  264. encryptedScore.encryptedMessage +
  265. EL_GAMAL_BLIND_GENERATOR * negThreshold;
  266. return X == scoreCommitment;
  267. }
  268. Scalar PrsonaClient::get_score() const
  269. {
  270. return currentScore;
  271. }
  272. /*********************
  273. * PRIVATE FUNCTIONS *
  274. *********************/
  275. /*
  276. * SCORE DECRYPTION
  277. */
  278. // Basic memoized score decryption
  279. void PrsonaClient::decrypt_score(const EGCiphertext& score)
  280. {
  281. Curvepoint s, hashedDecrypted;
  282. // Remove the mask portion of the ciphertext
  283. s = score.mask * inversePrivateKey;
  284. hashedDecrypted = score.encryptedMessage - s;
  285. // Check if it's a value we've already seen
  286. auto lookup = decryption_memoizer.find(hashedDecrypted);
  287. if (lookup != decryption_memoizer.end())
  288. {
  289. currentScore = lookup->second;
  290. return;
  291. }
  292. // If not, iterate until we find it (adding everything to the memoization)
  293. max_checked++;
  294. Curvepoint decryptionCandidate = EL_GAMAL_BLIND_GENERATOR * max_checked;
  295. while (decryptionCandidate != hashedDecrypted)
  296. {
  297. decryption_memoizer[decryptionCandidate] = max_checked;
  298. decryptionCandidate = decryptionCandidate + EL_GAMAL_BLIND_GENERATOR;
  299. max_checked++;
  300. }
  301. decryption_memoizer[decryptionCandidate] = max_checked;
  302. // Set the value we found
  303. currentScore = max_checked;
  304. }
  305. /*
  306. * OWNERSHIP PROOFS
  307. */
  308. // Very basic Schnorr proof (generation)
  309. Proof PrsonaClient::generate_ownership_proof() const
  310. {
  311. Proof retval;
  312. if (!CLIENT_IS_MALICIOUS)
  313. {
  314. retval.basic = "PROOF";
  315. return retval;
  316. }
  317. std::stringstream oracleInput;
  318. Scalar r;
  319. r.set_random();
  320. Curvepoint shortTermPublicKey = currentFreshGenerator * longTermPrivateKey;
  321. Curvepoint u = currentFreshGenerator * r;
  322. oracleInput << currentFreshGenerator << shortTermPublicKey << u;
  323. Scalar c = oracle(oracleInput.str());
  324. Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
  325. retval.basic = "PROOF";
  326. retval.challengeParts.push_back(c);
  327. retval.responseParts.push_back(z);
  328. return retval;
  329. }
  330. // Very basic Schnorr proof (verification)
  331. bool PrsonaClient::verify_ownership_proof(
  332. const Proof& pi, const Curvepoint& shortTermPublicKey) const
  333. {
  334. if (!CLIENT_IS_MALICIOUS)
  335. return pi.basic == "PROOF";
  336. Scalar c = pi.challengeParts[0];
  337. Scalar z = pi.responseParts[0];
  338. Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
  339. std::stringstream oracleInput;
  340. oracleInput << currentFreshGenerator << shortTermPublicKey << u;
  341. return c == oracle(oracleInput.str());
  342. }
  343. /*
  344. * PROOF VERIFICATION
  345. */
  346. bool PrsonaClient::verify_score_proof(const Proof& pi) const
  347. {
  348. if (!SERVER_IS_MALICIOUS)
  349. return pi.basic == "PROOF";
  350. return pi.basic == "PROOF";
  351. }
  352. bool PrsonaClient::verify_generator_proof(
  353. const Proof& pi, const Curvepoint& generator) const
  354. {
  355. if (!SERVER_IS_MALICIOUS)
  356. return pi.basic == "PROOF";
  357. return pi.basic == "PROOF";
  358. }
  359. bool PrsonaClient::verify_default_tally_proof(
  360. const Proof& pi, const EGCiphertext& score) const
  361. {
  362. if (!SERVER_IS_MALICIOUS)
  363. return pi.basic == "PROOF";
  364. return pi.basic == "PROOF";
  365. }
  366. bool PrsonaClient::verify_valid_tally_proof(
  367. const Proof& pi, const EGCiphertext& score) const
  368. {
  369. if (!SERVER_IS_MALICIOUS)
  370. return pi.basic == "PROOF";
  371. return pi.basic == "PROOF";
  372. }
  373. bool PrsonaClient::verify_default_votes_proof(
  374. const Proof& pi, const std::vector<CurveBipoint>& votes) const
  375. {
  376. if (!SERVER_IS_MALICIOUS)
  377. return pi.basic == "PROOF";
  378. return pi.basic == "PROOF";
  379. }
  380. bool PrsonaClient::verify_valid_votes_proof(
  381. const Proof& pi, const std::vector<CurveBipoint>& votes) const
  382. {
  383. if (!SERVER_IS_MALICIOUS)
  384. return pi.basic == "PROOF";
  385. return pi.basic == "PROOF";
  386. }
  387. /*
  388. * PROOF GENERATION
  389. */
  390. Proof PrsonaClient::generate_vote_proof(
  391. const std::vector<CurveBipoint>& encryptedVotes,
  392. const std::vector<Scalar>& vote) const
  393. {
  394. Proof retval;
  395. if (!CLIENT_IS_MALICIOUS)
  396. {
  397. retval.basic = "PROOF";
  398. return retval;
  399. }
  400. retval.basic = "PROOF";
  401. return retval;
  402. }