123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- #include "proof.hpp"
- /* Altered from answer at
- * https://stackoverflow.com/questions/51144505/generate-sha-3-hash-in-c-using-openssl-library
- */
- // Convert the bytes to a single integer, then make that a Scalar
- Scalar bytes_to_scalar(
- const std::vector<uint8_t>& bytes,
- size_t lambda)
- {
- std::stringstream stream;
- for (uint8_t b : bytes)
- stream << std::setw(2) << std::setfill('0') << std::hex << static_cast<int>(b);
- mpz_class value;
- value.set_str(stream.str(), 16);
- if (lambda < 256)
- {
- mpz_class copy = value;
- mpz_class modVal = 1;
- modVal = modVal << lambda;
- mpz_mod(value.get_mpz_t(), copy.get_mpz_t(), modVal.get_mpz_t());
- }
- return Scalar(value);
- }
- // Random Oracle (i.e. SHA256)
- Scalar oracle(
- const std::string& input,
- size_t lambda)
- {
- uint32_t digest_length = SHA256_DIGEST_LENGTH;
- const EVP_MD* algorithm = EVP_sha256();
- uint8_t* digest = static_cast<uint8_t*>(OPENSSL_malloc(digest_length));
-
- EVP_MD_CTX* context = EVP_MD_CTX_create();
- EVP_DigestInit_ex(context, algorithm, NULL);
- EVP_DigestUpdate(context, input.c_str(), input.size());
- EVP_DigestFinal_ex(context, digest, &digest_length);
- EVP_MD_CTX_destroy(context);
- std::vector<uint8_t> digestBytes(digest, digest + digest_length);
- Scalar output = bytes_to_scalar(digestBytes, lambda);
- OPENSSL_free(digest);
- return output;
- }
- void Proof::clear()
- {
- hbc.clear();
- curvepointUniversals.clear();
- curveBipointUniversals.clear();
- challengeParts.clear();
- responseParts.clear();
- }
- bool Proof::operator==(
- const Proof& b)
- {
- bool retval = this->hbc == b.hbc;
-
- retval = retval && this->curvepointUniversals.size() == b.curvepointUniversals.size();
- retval = retval && this->curveBipointUniversals.size() == b.curveBipointUniversals.size();
- retval = retval && this->challengeParts.size() == b.challengeParts.size();
- retval = retval && this->responseParts.size() == b.responseParts.size();
- for (size_t i = 0; retval && i < this->curvepointUniversals.size(); i++)
- retval = retval && this->curvepointUniversals[i] == b.curvepointUniversals[i];
- for (size_t i = 0; retval && i < this->curveBipointUniversals.size(); i++)
- retval = retval && this->curveBipointUniversals[i] == b.curveBipointUniversals[i];
- for (size_t i = 0; retval && i < this->challengeParts.size(); i++)
- retval = retval && this->challengeParts[i] == b.challengeParts[i];
- for (size_t i = 0; retval && i < this->responseParts.size(); i++)
- retval = retval && this->responseParts[i] == b.responseParts[i];
- return retval;
- }
- std::ostream& operator<<(
- std::ostream& os,
- const Proof& output)
- {
- BinaryBool hbc(!output.hbc.empty());
- os << hbc;
- BinarySizeT numElements;
- if (hbc.val())
- {
- numElements.set(output.hbc.length());
-
- os << numElements;
- os << output.hbc;
- return os;
- }
- numElements.set(output.curvepointUniversals.size());
- os << numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- os << output.curvepointUniversals[i];
- numElements.set(output.curveBipointUniversals.size());
- os << numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- os << output.curveBipointUniversals[i];
- numElements.set(output.challengeParts.size());
- os << numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- os << output.challengeParts[i];
- numElements.set(output.responseParts.size());
- os << numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- os << output.responseParts[i];
- return os;
- }
- std::istream& operator>>(
- std::istream& is,
- Proof& input)
- {
- BinaryBool hbc;
- is >> hbc;
- if (hbc.val())
- {
- BinarySizeT numBytes;
- is >> numBytes;
-
- char* buffer = new char[numBytes.val() + 1];
- is.read(buffer, numBytes.val());
- input.hbc = buffer;
- delete buffer;
- return is;
- }
- BinarySizeT numElements;
- is >> numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- {
- Twistpoint x;
- is >> x;
- input.curvepointUniversals.push_back(x);
- }
- is >> numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- {
- TwistBipoint x;
- is >> x;
- input.curveBipointUniversals.push_back(x);
- }
- is >> numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- {
- Scalar x;
- is >> x;
- input.challengeParts.push_back(x);
- }
- is >> numElements;
- for (size_t i = 0; i < numElements.val(); i++)
- {
- Scalar x;
- is >> x;
- input.responseParts.push_back(x);
- }
- return is;
- }
|