#include "proof.hpp" /* Altered from answer at * https://stackoverflow.com/questions/51144505/generate-sha-3-hash-in-c-using-openssl-library */ // Convert the bytes to a single integer, then make that a Scalar Scalar bytes_to_scalar( const std::vector& bytes, size_t lambda) { std::stringstream stream; for (uint8_t b : bytes) stream << std::setw(2) << std::setfill('0') << std::hex << static_cast(b); mpz_class value; value.set_str(stream.str(), 16); if (lambda < 256) { mpz_class copy = value; mpz_class modVal = 1; modVal = modVal << lambda; mpz_mod(value.get_mpz_t(), copy.get_mpz_t(), modVal.get_mpz_t()); } return Scalar(value); } // Random Oracle (i.e. SHA256) Scalar oracle( const std::string& input, size_t lambda) { uint32_t digest_length = SHA256_DIGEST_LENGTH; const EVP_MD* algorithm = EVP_sha256(); uint8_t* digest = static_cast(OPENSSL_malloc(digest_length)); EVP_MD_CTX* context = EVP_MD_CTX_create(); EVP_DigestInit_ex(context, algorithm, NULL); EVP_DigestUpdate(context, input.c_str(), input.size()); EVP_DigestFinal_ex(context, digest, &digest_length); EVP_MD_CTX_destroy(context); std::vector digestBytes(digest, digest + digest_length); Scalar output = bytes_to_scalar(digestBytes, lambda); OPENSSL_free(digest); return output; } void Proof::clear() { hbc.clear(); curvepointUniversals.clear(); curveBipointUniversals.clear(); challengeParts.clear(); responseParts.clear(); } bool Proof::operator==( const Proof& b) { bool retval = this->hbc == b.hbc; retval = retval && this->curvepointUniversals.size() == b.curvepointUniversals.size(); retval = retval && this->curveBipointUniversals.size() == b.curveBipointUniversals.size(); retval = retval && this->challengeParts.size() == b.challengeParts.size(); retval = retval && this->responseParts.size() == b.responseParts.size(); for (size_t i = 0; retval && i < this->curvepointUniversals.size(); i++) retval = retval && this->curvepointUniversals[i] == b.curvepointUniversals[i]; for (size_t i = 0; retval && i < this->curveBipointUniversals.size(); i++) retval = retval && this->curveBipointUniversals[i] == b.curveBipointUniversals[i]; for (size_t i = 0; retval && i < this->challengeParts.size(); i++) retval = retval && this->challengeParts[i] == b.challengeParts[i]; for (size_t i = 0; retval && i < this->responseParts.size(); i++) retval = retval && this->responseParts[i] == b.responseParts[i]; return retval; } std::ostream& operator<<( std::ostream& os, const Proof& output) { BinaryBool hbc(!output.hbc.empty()); os << hbc; BinarySizeT numElements; if (hbc.val()) { numElements.set(output.hbc.length()); os << numElements; os << output.hbc; return os; } numElements.set(output.curvepointUniversals.size()); os << numElements; for (size_t i = 0; i < numElements.val(); i++) os << output.curvepointUniversals[i]; numElements.set(output.curveBipointUniversals.size()); os << numElements; for (size_t i = 0; i < numElements.val(); i++) os << output.curveBipointUniversals[i]; numElements.set(output.challengeParts.size()); os << numElements; for (size_t i = 0; i < numElements.val(); i++) os << output.challengeParts[i]; numElements.set(output.responseParts.size()); os << numElements; for (size_t i = 0; i < numElements.val(); i++) os << output.responseParts[i]; return os; } std::istream& operator>>( std::istream& is, Proof& input) { BinaryBool hbc; is >> hbc; if (hbc.val()) { BinarySizeT numBytes; is >> numBytes; char* buffer = new char[numBytes.val() + 1]; is.read(buffer, numBytes.val()); input.hbc = buffer; delete buffer; return is; } BinarySizeT numElements; is >> numElements; for (size_t i = 0; i < numElements.val(); i++) { Twistpoint x; is >> x; input.curvepointUniversals.push_back(x); } is >> numElements; for (size_t i = 0; i < numElements.val(); i++) { TwistBipoint x; is >> x; input.curveBipointUniversals.push_back(x); } is >> numElements; for (size_t i = 0; i < numElements.val(); i++) { Scalar x; is >> x; input.challengeParts.push_back(x); } is >> numElements; for (size_t i = 0; i < numElements.val(); i++) { Scalar x; is >> x; input.responseParts.push_back(x); } return is; }