123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658 |
- #include <iostream>
- #include "client.hpp"
- #include "serverEntity.hpp"
- extern const curvepoint_fp_t bn_curvegen;
- const int MAX_ALLOWED_VOTE = 2;
- /* These lines need to be here so these static variables are defined,
- * but in C++ putting code here doesn't actually execute
- * (or at least, with g++, whenever it would execute is not at a useful time)
- * so we have an init() function to actually put the correct values in them. */
- Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
- bool PrsonaClient::SERVER_IS_MALICIOUS = false;
- bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
- // Quick and dirty function to calculate ceil(log base 2) with mpz_class
- mpz_class log2(mpz_class x)
- {
- mpz_class retval = 0;
- while (x > 0)
- {
- retval++;
- x = x >> 1;
- }
- return retval;
- }
- mpz_class bit(mpz_class x)
- {
- return x > 0 ? 1 : 0;
- }
- /********************
- * PUBLIC FUNCTIONS *
- ********************/
- /*
- * CONSTRUCTORS
- */
- PrsonaClient::PrsonaClient(
- const BGNPublicKey& serverPublicKey,
- const Curvepoint& elGamalBlindGenerator,
- const PrsonaServerEntity* servers)
- : serverPublicKey(serverPublicKey),
- elGamalBlindGenerator(elGamalBlindGenerator),
- servers(servers),
- max_checked(0)
- {
- longTermPrivateKey.set_random();
- inversePrivateKey = longTermPrivateKey.curveInverse();
- decryption_memoizer[elGamalBlindGenerator * max_checked] = max_checked;
- }
- /*
- * SETUP FUNCTIONS
- */
- // Must be called once before any usage of this class
- void PrsonaClient::init()
- {
- EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
- }
- void PrsonaClient::set_server_malicious()
- {
- SERVER_IS_MALICIOUS = true;
- }
- void PrsonaClient::set_client_malicious()
- {
- CLIENT_IS_MALICIOUS = true;
- }
- /*
- * BASIC PUBLIC SYSTEM INFO GETTERS
- */
- Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
- {
- pi = generate_ownership_proof();
- return currentFreshGenerator * longTermPrivateKey;
- }
- /*
- * SERVER INTERACTIONS
- */
- /* Generate a new vote vector to give to the servers
- * @replaces controls which votes are actually being updated and which are not
- *
- * You may really want to make currentEncryptedVotes a member variable,
- * but it doesn't behave correctly when adding new clients after this one. */
- std::vector<CurveBipoint> PrsonaClient::make_votes(
- std::vector<Proof>& validVoteProof,
- const Proof& serverProof,
- const std::vector<CurveBipoint>& oldEncryptedVotes,
- const std::vector<Scalar>& votes,
- const std::vector<bool>& replaces) const
- {
- std::vector<Scalar> seeds(oldEncryptedVotes.size());
- std::vector<CurveBipoint> newEncryptedVotes(oldEncryptedVotes.size());
- if (!verify_valid_votes_proof(serverProof, oldEncryptedVotes))
- {
- std::cerr << "Could not verify proof of valid votes." << std::endl;
- return newEncryptedVotes;
- }
- for (size_t i = 0; i < votes.size(); i++)
- {
- if (replaces[i])
- {
- newEncryptedVotes[i] = serverPublicKey.encrypt(seeds[i], votes[i]);
- }
- else
- {
- newEncryptedVotes[i] =
- serverPublicKey.rerandomize(seeds[i], oldEncryptedVotes[i]);
- }
- }
- validVoteProof = generate_vote_proof(
- replaces, oldEncryptedVotes, newEncryptedVotes, seeds, votes);
- return newEncryptedVotes;
- }
- // Get a new fresh generator (happens at initialization and during each epoch)
- void PrsonaClient::receive_fresh_generator(const Curvepoint& freshGenerator)
- {
- currentFreshGenerator = freshGenerator;
- }
- // Receive a new encrypted score from the servers (each epoch)
- void PrsonaClient::receive_vote_tally(
- const Proof& pi, const EGCiphertext& score)
- {
- if (!verify_valid_tally_proof(pi, score))
- {
- std::cerr << "Could not verify proof of valid tally." << std::endl;
- return;
- }
- currentEncryptedScore = score;
- decrypt_score(score);
- }
- /*
- * REPUTATION PROOFS
- */
- // A pretty straightforward range proof (generation)
- std::vector<Proof> PrsonaClient::generate_reputation_proof(
- const Scalar& threshold) const
- {
- std::vector<Proof> retval;
- // Don't even try if the user asks to make an illegitimate proof
- if (threshold.toInt() > (servers->get_num_clients() * MAX_ALLOWED_VOTE))
- return retval;
- // Base case
- if (!CLIENT_IS_MALICIOUS)
- {
- Proof currProof;
- currProof.basic = "PROOF";
- retval.push_back(currProof);
- return retval;
- }
- // We really have two consecutive proofs in a junction.
- // The first is to prove that we are the stpk we claim we are
- retval.push_back(generate_ownership_proof());
- // The value we're actually using in our proof
- mpz_class proofVal = currentScore.curveSub(threshold).toInt();
- // Top of the range in our proof determined by what scores are even possible
- mpz_class proofBits =
- log2(
- servers->get_num_clients() * MAX_ALLOWED_VOTE -
- threshold.toInt());
-
- // Don't risk a situation that would divulge our private key
- if (proofBits <= 1)
- proofBits = 2;
- // This seems weird, but remember our base is A_t^r, not g^t
- std::vector<Scalar> masksPerBit;
- masksPerBit.push_back(inversePrivateKey);
- for (size_t i = 1; i < proofBits; i++)
- {
- Scalar currMask;
- currMask.set_random();
- masksPerBit.push_back(currMask);
- masksPerBit[0] =
- masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
- }
- // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
- for (size_t i = 0; i < proofBits; i++)
- {
- Proof currProof;
- Curvepoint g, h, c, c_a, c_b;
- g = currentEncryptedScore.mask;
- h = elGamalBlindGenerator;
-
- mpz_class currBit = bit(proofVal & (1 << i));
- Scalar a, s, t, m, r;
- a.set_random();
- s.set_random();
- t.set_random();
- m = Scalar(currBit);
- r = masksPerBit[i];
-
- c = g * r + h * m;
- currProof.partialUniversals.push_back(c);
- c_a = g * s + h * a;
- Scalar am = a.curveMult(m);
- c_b = g * t + h * am;
- std::stringstream oracleInput;
- oracleInput << g << h << c << c_a << c_b;
- Scalar x = oracle(oracleInput.str());
- currProof.challengeParts.push_back(x);
- Scalar f, z_a, z_b;
- Scalar mx = m.curveMult(x);
- f = mx.curveAdd(a);
- Scalar rx = r.curveMult(x);
- z_a = rx.curveAdd(s);
- Scalar x_f = x.curveSub(f);
- Scalar r_x_f = r.curveMult(x_f);
- z_b = r_x_f.curveAdd(t);
- currProof.responseParts.push_back(f);
- currProof.responseParts.push_back(z_a);
- currProof.responseParts.push_back(z_b);
- retval.push_back(currProof);
- }
- return retval;
- }
- // A pretty straightforward range proof (verification)
- bool PrsonaClient::verify_reputation_proof(
- const std::vector<Proof>& pi,
- const Curvepoint& shortTermPublicKey,
- const Scalar& threshold) const
- {
- // Reject outright if there's no proof to check
- if (pi.empty())
- {
- std::cerr << "Proof was empty, aborting." << std::endl;
- return false;
- }
- // Base case
- if (!CLIENT_IS_MALICIOUS)
- return pi[0].basic == "PROOF";
- // User should be able to prove they are who they say they are
- if (!verify_ownership_proof(pi[0], shortTermPublicKey))
- {
- std::cerr << "Schnorr proof failed, aborting." << std::endl;
- return false;
- }
- // Get the encrypted score in question from the servers
- Proof serverProof;
- EGCiphertext encryptedScore =
- servers->get_current_tally(serverProof, shortTermPublicKey);
- // Rough for the prover but if the server messes up,
- // no way to prove the thing anyways
- if (!verify_valid_tally_proof(serverProof, encryptedScore))
- {
- std::cerr << "Server error prevented proof from working, aborting." << std::endl;
- return false;
- }
- // X is the thing we're going to be checking in on throughout
- // to try to get our score commitment back in the end.
- Curvepoint X;
- for (size_t i = 1; i < pi.size(); i++)
- {
- Curvepoint c, g, h;
- c = pi[i].partialUniversals[0];
- g = encryptedScore.mask;
- h = elGamalBlindGenerator;
- X = X + c * Scalar(1 << (i - 1));
- Scalar x, f, z_a, z_b;
- x = pi[i].challengeParts[0];
- f = pi[i].responseParts[0];
- z_a = pi[i].responseParts[1];
- z_b = pi[i].responseParts[2];
- // Taken from Fig. 1 in https://eprint.iacr.org/2014/764.pdf
- Curvepoint c_a, c_b;
- c_a = g * z_a + h * f - c * x;
- Scalar x_f = x.curveSub(f);
- c_b = g * z_b - c * x_f;
- std::stringstream oracleInput;
- oracleInput << g << h << c << c_a << c_b;
- if (oracle(oracleInput.str()) != pi[i].challengeParts[0])
- {
- std::cerr << "0 or 1 proof failed at index " << i << " of " << pi.size() - 1 << ", aborting." << std::endl;
- return false;
- }
- }
- Scalar negThreshold;
- negThreshold = Scalar(0).curveSub(threshold);
- Curvepoint scoreCommitment =
- encryptedScore.encryptedMessage +
- elGamalBlindGenerator * negThreshold;
-
- return X == scoreCommitment;
- }
- Scalar PrsonaClient::get_score() const
- {
- return currentScore;
- }
- /*********************
- * PRIVATE FUNCTIONS *
- *********************/
- /*
- * SCORE DECRYPTION
- */
- // Basic memoized score decryption
- void PrsonaClient::decrypt_score(const EGCiphertext& score)
- {
- Curvepoint s, hashedDecrypted;
- // Remove the mask portion of the ciphertext
- s = score.mask * inversePrivateKey;
- hashedDecrypted = score.encryptedMessage - s;
-
- // Check if it's a value we've already seen
- auto lookup = decryption_memoizer.find(hashedDecrypted);
- if (lookup != decryption_memoizer.end())
- {
- currentScore = lookup->second;
- return;
- }
- // If not, iterate until we find it (adding everything to the memoization)
- max_checked++;
- Curvepoint decryptionCandidate = elGamalBlindGenerator * max_checked;
- while (decryptionCandidate != hashedDecrypted)
- {
- decryption_memoizer[decryptionCandidate] = max_checked;
- decryptionCandidate = decryptionCandidate + elGamalBlindGenerator;
- max_checked++;
- }
- decryption_memoizer[decryptionCandidate] = max_checked;
- // Set the value we found
- currentScore = max_checked;
- }
- /*
- * OWNERSHIP PROOFS
- */
- // Very basic Schnorr proof (generation)
- Proof PrsonaClient::generate_ownership_proof() const
- {
- Proof retval;
- if (!CLIENT_IS_MALICIOUS)
- {
- retval.basic = "PROOF";
- return retval;
- }
- std::stringstream oracleInput;
-
- Scalar r;
- r.set_random();
- Curvepoint shortTermPublicKey = currentFreshGenerator * longTermPrivateKey;
- Curvepoint u = currentFreshGenerator * r;
- oracleInput << currentFreshGenerator << shortTermPublicKey << u;
- Scalar c = oracle(oracleInput.str());
- Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
- retval.basic = "PROOF";
- retval.challengeParts.push_back(c);
- retval.responseParts.push_back(z);
- return retval;
- }
- // Very basic Schnorr proof (verification)
- bool PrsonaClient::verify_ownership_proof(
- const Proof& pi, const Curvepoint& shortTermPublicKey) const
- {
- if (!CLIENT_IS_MALICIOUS)
- return pi.basic == "PROOF";
- Scalar c = pi.challengeParts[0];
- Scalar z = pi.responseParts[0];
- Curvepoint u = currentFreshGenerator * z - shortTermPublicKey * c;
- std::stringstream oracleInput;
- oracleInput << currentFreshGenerator << shortTermPublicKey << u;
-
- return c == oracle(oracleInput.str());
- }
- /*
- * PROOF VERIFICATION
- */
- bool PrsonaClient::verify_score_proof(const Proof& pi) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_generator_proof(
- const Proof& pi, const Curvepoint& generator) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_default_tally_proof(
- const Proof& pi, const EGCiphertext& score) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_valid_tally_proof(
- const Proof& pi, const EGCiphertext& score) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_default_votes_proof(
- const Proof& pi, const std::vector<CurveBipoint>& votes) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_valid_votes_proof(
- const Proof& pi, const std::vector<CurveBipoint>& votes) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- /*
- * PROOF GENERATION
- */
- std::vector<Proof> PrsonaClient::generate_vote_proof(
- const std::vector<bool>& replaces,
- const std::vector<CurveBipoint>& oldEncryptedVotes,
- const std::vector<CurveBipoint>& newEncryptedVotes,
- const std::vector<Scalar>& seeds,
- const std::vector<Scalar>& votes) const
- {
- std::vector<Proof> retval;
- // Base case
- if (!CLIENT_IS_MALICIOUS)
- {
- Proof currProof;
- currProof.basic = "PROOF";
- retval.push_back(currProof);
-
- return retval;
- }
- // The first need is to prove that we are the stpk we claim we are
- retval.push_back(generate_ownership_proof());
- // Then, we iterate over all votes for the proofs that they are correct
- for (size_t i = 0; i < replaces.size(); i++)
- {
- std::stringstream oracleInput;
- oracleInput << serverPublicKey.get_bipoint_curvegen()
- << serverPublicKey.get_bipoint_curve_subgroup_gen()
- << oldEncryptedVotes[i] << newEncryptedVotes[i];
-
- /* This proof structure is documented in my notes.
- * It's inspired by the proof in Fig. 1 at
- * https://eprint.iacr.org/2014/764.pdf, but adapted so that you prove
- * m(m-1)(m-2) = 0 instead of m(m-1) = 0.
- *
- * The rerandomization part is just a slight variation on an
- * ordinary Schnorr proof, so that part's less scary. */
- if (replaces[i]) // CASE: Make new vote
- {
- Proof currProof;
- Scalar c_r, z_r, a, b, s_1, s_2, t_1, t_2;
- c_r.set_random();
- z_r.set_random();
- a.set_random();
- b.set_random();
- s_1.set_random();
- s_2.set_random();
- t_1.set_random();
- t_2.set_random();
- CurveBipoint U =
- serverPublicKey.get_bipoint_curve_subgroup_gen() * z_r +
- oldEncryptedVotes[i] * c_r -
- newEncryptedVotes[i] * c_r;
- CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * a +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * s_1;
- CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * b +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * s_2;
- Scalar power = (a.curveAdd(b)).curveMult(votes[i].curveMult(votes[i]));
- power =
- power.curveSub((a.curveAdd(a).curveAdd(b)).curveMult(votes[i]));
- CurveBipoint C_c = serverPublicKey.get_bipoint_curvegen() * power +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * t_1;
- currProof.partialUniversals.push_back(C_c[0]);
- currProof.partialUniversals.push_back(C_c[1]);
- CurveBipoint C_d =
- serverPublicKey.get_bipoint_curvegen() *
- a.curveMult(b.curveMult(votes[i])) +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * t_2;
- oracleInput << U << C_a << C_b << C_c << C_d;
- Scalar c = oracle(oracleInput.str());
- Scalar c_n = c.curveSub(c_r);
- currProof.challengeParts.push_back(c_r);
- currProof.challengeParts.push_back(c_n);
- Scalar f_1 = (votes[i].curveMult(c_n)).curveAdd(a);
- Scalar f_2 = (votes[i].curveMult(c_n)).curveAdd(b);
- Scalar z_na = (seeds[i].curveMult(c_n)).curveAdd(s_1);
- Scalar z_nb = (seeds[i].curveMult(c_n)).curveAdd(s_2);
- Scalar t_1_c_n_t_2 = (t_1.curveMult(c_n)).curveAdd(t_2);
- Scalar f_1_c_n = f_1.curveSub(c_n);
- Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
- Scalar z_nc =
- (seeds[i].curveMult(f_1_c_n).curveMult(c_n_f_2)).curveAdd(
- t_1_c_n_t_2);
- currProof.responseParts.push_back(z_r);
- currProof.responseParts.push_back(f_1);
- currProof.responseParts.push_back(f_2);
- currProof.responseParts.push_back(z_na);
- currProof.responseParts.push_back(z_nb);
- currProof.responseParts.push_back(z_nc);
- retval.push_back(currProof);
- }
- else // CASE: Rerandomize existing vote
- {
- Proof currProof;
- Scalar u, commitmentLambda_1, commitmentLambda_2,
- c_n, z_na, z_nb, z_nc, f_1, f_2;
- u.set_random();
- commitmentLambda_1.set_random();
- commitmentLambda_2.set_random();
- c_n.set_random();
- z_na.set_random();
- z_nb.set_random();
- z_nc.set_random();
- f_1.set_random();
- f_2.set_random();
- CurveBipoint U =
- serverPublicKey.get_bipoint_curve_subgroup_gen() * u;
- CurveBipoint C_a = serverPublicKey.get_bipoint_curvegen() * f_1 +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * z_na -
- newEncryptedVotes[i] * c_n;
- CurveBipoint C_b = serverPublicKey.get_bipoint_curvegen() * f_2 +
- serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nb -
- newEncryptedVotes[i] * c_n;
- CurveBipoint C_c =
- serverPublicKey.get_bipoint_curvegen() * commitmentLambda_1 +
- serverPublicKey.get_bipoint_curve_subgroup_gen() *
- commitmentLambda_2;
- currProof.partialUniversals.push_back(C_c[0]);
- currProof.partialUniversals.push_back(C_c[1]);
- Scalar f_1_c_n = f_1.curveSub(c_n);
- Scalar c_n_f_2 = c_n.curveAdd(c_n).curveSub(f_2);
- CurveBipoint C_d =
- serverPublicKey.get_bipoint_curve_subgroup_gen() * z_nc -
- newEncryptedVotes[i] * f_1_c_n.curveMult(c_n_f_2) -
- C_c * c_n;
- oracleInput << U << C_a << C_b << C_c << C_d;
- Scalar c = oracle(oracleInput.str());
- Scalar c_r = c.curveSub(c_n);
- currProof.challengeParts.push_back(c_r);
- currProof.challengeParts.push_back(c_n);
- Scalar z_r = u.curveAdd(c_r.curveMult(seeds[i]));
- currProof.responseParts.push_back(z_r);
- currProof.responseParts.push_back(f_1);
- currProof.responseParts.push_back(f_2);
- currProof.responseParts.push_back(z_na);
- currProof.responseParts.push_back(z_nb);
- currProof.responseParts.push_back(z_nc);
- retval.push_back(currProof);
- }
- }
- return retval;
- }
|