123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- #include <iostream>
- #include "client.hpp"
- #include "serverEntity.hpp"
- extern const curvepoint_fp_t bn_curvegen;
- const int MAX_ALLOWED_VOTE = 2;
- /* These lines need to be here so these static variables are defined,
- * but in C++ putting code here doesn't actually execute
- * (or at least, with g++, whenever it would execute is not at a useful time)
- * so we have an init() function to actually put the correct values in them. */
- Curvepoint PrsonaClient::EL_GAMAL_GENERATOR = Curvepoint();
- Curvepoint PrsonaClient::EL_GAMAL_BLIND_GENERATOR = Curvepoint();
- bool PrsonaClient::SERVER_IS_MALICIOUS = false;
- bool PrsonaClient::CLIENT_IS_MALICIOUS = false;
- // Quick and dirty function to calculate ceil(log base 2) with mpz_class
- mpz_class log2(mpz_class x)
- {
- mpz_class retval = 0;
- while (x > 0)
- {
- retval++;
- x = x >> 1;
- }
- return retval;
- }
- /********************
- * PUBLIC FUNCTIONS *
- ********************/
- /*
- * CONSTRUCTORS
- */
- PrsonaClient::PrsonaClient(
- const BGNPublicKey& serverPublicKey,
- const PrsonaServerEntity* servers)
- : serverPublicKey(serverPublicKey),
- servers(servers),
- max_checked(0)
- {
- longTermPrivateKey.set_random();
- inversePrivateKey = longTermPrivateKey.curveInverse();
- decryption_memoizer[EL_GAMAL_BLIND_GENERATOR * max_checked] = max_checked;
- }
- /*
- * SETUP FUNCTIONS
- */
- // Must be called once before any usage of this class
- void PrsonaClient::init(const Curvepoint& elGamalBlindGenerator)
- {
- EL_GAMAL_GENERATOR = Curvepoint(bn_curvegen);
- EL_GAMAL_BLIND_GENERATOR = elGamalBlindGenerator;
- }
- void PrsonaClient::set_server_malicious()
- {
- SERVER_IS_MALICIOUS = true;
- }
- void PrsonaClient::set_client_malicious()
- {
- CLIENT_IS_MALICIOUS = true;
- }
- /*
- * BASIC PUBLIC SYSTEM INFO GETTERS
- */
- Curvepoint PrsonaClient::get_short_term_public_key(Proof &pi) const
- {
- pi = generate_ownership_proof();
- return currentFreshGenerator * longTermPrivateKey;
- }
- /*
- * SERVER INTERACTIONS
- */
- // Generate a new vote vector to give to the servers
- // (@replace controls which votes are actually being updated and which are not)
- std::vector<CurveBipoint> PrsonaClient::make_votes(
- Proof& pi,
- const std::vector<Scalar>& vote,
- const std::vector<bool>& replace)
- {
- std::vector<CurveBipoint> retval;
- for (size_t i = 0; i < vote.size(); i++)
- {
- CurveBipoint currScore;
- if (replace[i])
- serverPublicKey.encrypt(currScore, vote[i]);
- else
- currScore = serverPublicKey.rerandomize(currentEncryptedVotes[i]);
- retval.push_back(currScore);
- }
- currentEncryptedVotes = retval;
- pi = generate_vote_proof(retval, vote);
- return retval;
- }
- // Get a new fresh generator (happens at initialization and during each epoch)
- void PrsonaClient::receive_fresh_generator(const Curvepoint& freshGenerator)
- {
- currentFreshGenerator = freshGenerator;
- }
- // Receive a new encrypted score from the servers (each epoch)
- void PrsonaClient::receive_vote_tally(
- const Proof& pi, const EGCiphertext& score)
- {
- if (!verify_valid_tally_proof(pi, score))
- {
- std::cerr << "Could not verify proof of valid tally." << std::endl;
- return;
- }
- currentEncryptedScore = score;
- decrypt_score(score);
- }
- // Receive a new encrypted vote vector from the servers (each epoch)
- void PrsonaClient::receive_encrypted_votes(
- const Proof& pi, const std::vector<CurveBipoint>& votes)
- {
- if (!verify_valid_votes_proof(pi, votes))
- {
- std::cerr << "Could not verify proof of valid votes." << std::endl;
- return;
- }
- currentEncryptedVotes = votes;
- }
- /*
- * REPUTATION PROOFS
- */
- // TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
- std::vector<Proof> PrsonaClient::generate_reputation_proof(
- const Scalar& threshold) const
- {
- std::vector<Proof> retval;
- if (threshold > currentScore)
- return retval;
- if (!CLIENT_IS_MALICIOUS)
- {
- Proof currProof;
- currProof.basic = "PROOF";
- retval.push_back(currProof);
- return retval;
- }
- retval.push_back(generate_ownership_proof());
- mpz_class proofVal = currentScore.curveSub(threshold).toInt();
- mpz_class proofBits = log2(currentEncryptedVotes.size() * MAX_ALLOWED_VOTE - threshold.toInt());
- std::vector<Scalar> masksPerBit;
- masksPerBit.push_back(Scalar());
- for (size_t i = 1; i < proofBits; i++)
- {
- Scalar currMask;
- currMask.set_random();
- masksPerBit.push_back(currMask);
- masksPerBit[0] = masksPerBit[0].curveSub(currMask.curveMult(Scalar(1 << i)));
- }
- for (size_t i = 0; i < proofBits; i++)
- {
- Proof currProof;
- std::stringstream oracleInput;
- oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR;
-
- mpz_class currBit = proofVal & (1 << i);
-
- Curvepoint currentCommitment = currentFreshGenerator * masksPerBit[i] + EL_GAMAL_BLIND_GENERATOR * Scalar(currBit);
- currProof.partialUniversals.push_back(currentCommitment);
- oracleInput << currentCommitment;
- if (currBit)
- {
- Scalar u_0, c, c_0, c_1, z_0, z_1;
- u_0.set_random();
- c_1.set_random();
- z_1.set_random();
- Curvepoint U_0 = currentFreshGenerator * u_0;
- Curvepoint U_1 = currentFreshGenerator * z_1 - currentCommitment * c_1 + EL_GAMAL_BLIND_GENERATOR;
- currProof.initParts.push_back(U_0);
- currProof.initParts.push_back(U_1);
- oracleInput << U_0 << U_1;
- c = oracle(oracleInput.str());
- c_0 = c.curveSub(c_1);
- z_0 = c_0.curveMult(masksPerBit[i]).curveAdd(u_0);
- currProof.challengeParts.push_back(c_0);
- currProof.challengeParts.push_back(c_1);
- currProof.responseParts.push_back(z_0);
- currProof.responseParts.push_back(z_1);
- }
- else
- {
- Scalar u_1, c, c_0, c_1, z_0, z_1;
- u_1.set_random();
- c_0.set_random();
- z_0.set_random();
- Curvepoint U_0 = currentFreshGenerator * z_0 - currentCommitment * c_0;
- Curvepoint U_1 = currentFreshGenerator * u_1;
- currProof.initParts.push_back(U_0);
- currProof.initParts.push_back(U_1);
- oracleInput << U_0 << U_1;
- c = oracle(oracleInput.str());
- c_1 = c.curveSub(c_0);
- z_1 = c_1.curveMult(masksPerBit[i]).curveAdd(u_1);
- currProof.challengeParts.push_back(c_0);
- currProof.challengeParts.push_back(c_1);
- currProof.responseParts.push_back(z_0);
- currProof.responseParts.push_back(z_1);
- }
- retval.push_back(currProof);
- }
- return retval;
- }
- // TO BE UPDATED WITH THING IAN SHOWED ME IN MEETING FOR DISJUNCTION
- bool PrsonaClient::verify_reputation_proof(
- const std::vector<Proof>& pi,
- const Curvepoint& shortTermPublicKey,
- const Scalar& threshold) const
- {
- if (pi.empty())
- return false;
- if (!CLIENT_IS_MALICIOUS)
- return pi[0].basic == "PROOF";
- if (!verify_ownership_proof(pi[0], shortTermPublicKey))
- return false;
- Curvepoint X;
- for (size_t i = 1; i < pi.size(); i++)
- {
- X = X + pi[i].partialUniversals[0] * Scalar(1 << (i - 1));
- std::stringstream oracleInput;
- oracleInput << currentFreshGenerator << EL_GAMAL_BLIND_GENERATOR << pi[i].partialUniversals[0];
- oracleInput << pi[i].initParts[0] << pi[i].initParts[1];
- Scalar c = oracle(oracleInput.str());
- if (c != pi[i].challengeParts[0] + pi[i].challengeParts[1])
- return false;
- if (currentFreshGenerator * pi[i].responseParts[0] != pi[i].initParts[0] + pi[i].partialUniversals[0] * pi[i].challengeParts[0])
- return false;
- if (currentFreshGenerator * pi[i].responseParts[1] != pi[i].initParts[1] + pi[i].partialUniversals[0] * pi[i].challengeParts[1] - EL_GAMAL_BLIND_GENERATOR)
- return false;
- }
- Proof serverProof;
- EGCiphertext encryptedScore = servers->get_current_tally(serverProof, shortTermPublicKey);
- if (!verify_valid_tally_proof(serverProof, encryptedScore))
- return false;
- Scalar negThreshold;
- negThreshold = Scalar(0).curveSub(threshold);
- Curvepoint scoreCommitment = encryptedScore.encryptedMessage + EL_GAMAL_BLIND_GENERATOR * negThreshold;
- if (X != scoreCommitment)
- return false;
- return true;
- }
- /*********************
- * PRIVATE FUNCTIONS *
- *********************/
- /*
- * SCORE DECRYPTION
- */
- // Basic memoized score decryption
- void PrsonaClient::decrypt_score(const EGCiphertext& score)
- {
- Curvepoint s, hashedDecrypted;
- // Remove the mask portion of the ciphertext
- s = score.mask * inversePrivateKey;
- hashedDecrypted = score.encryptedMessage - s;
-
- // Check if it's a value we've already seen
- auto lookup = decryption_memoizer.find(hashedDecrypted);
- if (lookup != decryption_memoizer.end())
- {
- currentScore = lookup->second;
- return;
- }
- // If not, iterate until we find it (adding everything to the memoization)
- max_checked++;
- Curvepoint decryptionCandidate = EL_GAMAL_BLIND_GENERATOR * max_checked;
- while (decryptionCandidate != hashedDecrypted)
- {
- decryption_memoizer[decryptionCandidate] = max_checked;
- decryptionCandidate = decryptionCandidate + EL_GAMAL_BLIND_GENERATOR;
- max_checked++;
- }
- decryption_memoizer[decryptionCandidate] = max_checked;
- // Set the value we found
- currentScore = max_checked;
- }
- /*
- * OWNERSHIP PROOFS
- */
- // Very basic Schnorr proof (generation)
- Proof PrsonaClient::generate_ownership_proof() const
- {
- Proof retval;
- if (!CLIENT_IS_MALICIOUS)
- {
- retval.basic = "PROOF";
- return retval;
- }
- std::stringstream oracleInput;
-
- Scalar r;
- r.set_random();
- Curvepoint shortTermPublicKey = currentFreshGenerator * longTermPrivateKey;
- Curvepoint u = currentFreshGenerator * r;
- oracleInput << currentFreshGenerator << shortTermPublicKey << u;
- Scalar c = oracle(oracleInput.str());
- Scalar z = r.curveAdd(c.curveMult(longTermPrivateKey));
- retval.basic = "PROOF";
- retval.initParts.push_back(u);
- retval.responseParts.push_back(z);
- return retval;
- }
- // Very basic Schnorr proof (verification)
- bool PrsonaClient::verify_ownership_proof(
- const Proof& pi, const Curvepoint& shortTermPublicKey) const
- {
- if (!CLIENT_IS_MALICIOUS)
- return pi.basic == "PROOF";
- Curvepoint u = pi.initParts[0];
- std::stringstream oracleInput;
- oracleInput << currentFreshGenerator << shortTermPublicKey << u;
- Scalar c = oracle(oracleInput.str());
- Scalar z = pi.responseParts[0];
- return (currentFreshGenerator * z) == (shortTermPublicKey * c + u);
- }
- /*
- * PROOF VERIFICATION
- */
- bool PrsonaClient::verify_score_proof(const Proof& pi) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_generator_proof(
- const Proof& pi, const Curvepoint& generator) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_default_tally_proof(
- const Proof& pi, const EGCiphertext& score) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_valid_tally_proof(
- const Proof& pi, const EGCiphertext& score) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_default_votes_proof(
- const Proof& pi, const std::vector<CurveBipoint>& votes) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- bool PrsonaClient::verify_valid_votes_proof(
- const Proof& pi, const std::vector<CurveBipoint>& votes) const
- {
- if (!SERVER_IS_MALICIOUS)
- return pi.basic == "PROOF";
- return pi.basic == "PROOF";
- }
- /*
- * PROOF GENERATION
- */
- Proof PrsonaClient::generate_vote_proof(
- const std::vector<CurveBipoint>& encryptedVotes,
- const std::vector<Scalar>& vote) const
- {
- Proof retval;
- if (!CLIENT_IS_MALICIOUS)
- {
- retval.basic = "PROOF";
- return retval;
- }
- retval.basic = "PROOF";
- return retval;
- }
|