proof.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #include "proof.hpp"
  2. /* Altered from answer at
  3. * https://stackoverflow.com/questions/51144505/generate-sha-3-hash-in-c-using-openssl-library
  4. */
  5. // Convert the bytes to a single integer, then make that a Scalar
  6. Scalar bytes_to_scalar(
  7. const std::vector<uint8_t>& bytes,
  8. size_t lambda)
  9. {
  10. std::stringstream stream;
  11. for (uint8_t b : bytes)
  12. stream << std::setw(2) << std::setfill('0') << std::hex << static_cast<int>(b);
  13. mpz_class value;
  14. value.set_str(stream.str(), 16);
  15. if (lambda < 256)
  16. {
  17. mpz_class copy = value;
  18. mpz_class modVal = 1;
  19. modVal = modVal << lambda;
  20. mpz_mod(value.get_mpz_t(), copy.get_mpz_t(), modVal.get_mpz_t());
  21. }
  22. return Scalar(value);
  23. }
  24. // Random Oracle (i.e. SHA256)
  25. Scalar oracle(
  26. const std::string& input,
  27. size_t lambda)
  28. {
  29. uint32_t digest_length = SHA256_DIGEST_LENGTH;
  30. const EVP_MD* algorithm = EVP_sha256();
  31. uint8_t* digest = static_cast<uint8_t*>(OPENSSL_malloc(digest_length));
  32. EVP_MD_CTX* context = EVP_MD_CTX_create();
  33. EVP_DigestInit_ex(context, algorithm, NULL);
  34. EVP_DigestUpdate(context, input.c_str(), input.size());
  35. EVP_DigestFinal_ex(context, digest, &digest_length);
  36. EVP_MD_CTX_destroy(context);
  37. std::vector<uint8_t> digestBytes(digest, digest + digest_length);
  38. Scalar output = bytes_to_scalar(digestBytes, lambda);
  39. OPENSSL_free(digest);
  40. return output;
  41. }
  42. void Proof::clear()
  43. {
  44. hbc.clear();
  45. curvepointUniversals.clear();
  46. curveBipointUniversals.clear();
  47. challengeParts.clear();
  48. responseParts.clear();
  49. }
  50. bool Proof::operator==(
  51. const Proof& b)
  52. {
  53. bool retval = this->hbc == b.hbc;
  54. retval = retval && this->curvepointUniversals.size() == b.curvepointUniversals.size();
  55. retval = retval && this->curveBipointUniversals.size() == b.curveBipointUniversals.size();
  56. retval = retval && this->challengeParts.size() == b.challengeParts.size();
  57. retval = retval && this->responseParts.size() == b.responseParts.size();
  58. for (size_t i = 0; retval && i < this->curvepointUniversals.size(); i++)
  59. retval = retval && this->curvepointUniversals[i] == b.curvepointUniversals[i];
  60. for (size_t i = 0; retval && i < this->curveBipointUniversals.size(); i++)
  61. retval = retval && this->curveBipointUniversals[i] == b.curveBipointUniversals[i];
  62. for (size_t i = 0; retval && i < this->challengeParts.size(); i++)
  63. retval = retval && this->challengeParts[i] == b.challengeParts[i];
  64. for (size_t i = 0; retval && i < this->responseParts.size(); i++)
  65. retval = retval && this->responseParts[i] == b.responseParts[i];
  66. return retval;
  67. }
  68. std::ostream& operator<<(
  69. std::ostream& os,
  70. const Proof& output)
  71. {
  72. BinaryBool hbc(!output.hbc.empty());
  73. os << hbc;
  74. BinarySizeT numElements;
  75. if (hbc.val())
  76. {
  77. numElements.set(output.hbc.length());
  78. os << numElements;
  79. os << output.hbc;
  80. return os;
  81. }
  82. numElements.set(output.curvepointUniversals.size());
  83. os << numElements;
  84. for (size_t i = 0; i < numElements.val(); i++)
  85. os << output.curvepointUniversals[i];
  86. numElements.set(output.curveBipointUniversals.size());
  87. os << numElements;
  88. for (size_t i = 0; i < numElements.val(); i++)
  89. os << output.curveBipointUniversals[i];
  90. numElements.set(output.challengeParts.size());
  91. os << numElements;
  92. for (size_t i = 0; i < numElements.val(); i++)
  93. os << output.challengeParts[i];
  94. numElements.set(output.responseParts.size());
  95. os << numElements;
  96. for (size_t i = 0; i < numElements.val(); i++)
  97. os << output.responseParts[i];
  98. return os;
  99. }
  100. std::istream& operator>>(
  101. std::istream& is,
  102. Proof& input)
  103. {
  104. BinaryBool hbc;
  105. is >> hbc;
  106. if (hbc.val())
  107. {
  108. BinarySizeT numBytes;
  109. is >> numBytes;
  110. char* buffer = new char[numBytes.val() + 1];
  111. is.read(buffer, numBytes.val());
  112. input.hbc = buffer;
  113. delete buffer;
  114. return is;
  115. }
  116. BinarySizeT numElements;
  117. is >> numElements;
  118. for (size_t i = 0; i < numElements.val(); i++)
  119. {
  120. Twistpoint x;
  121. is >> x;
  122. input.curvepointUniversals.push_back(x);
  123. }
  124. is >> numElements;
  125. for (size_t i = 0; i < numElements.val(); i++)
  126. {
  127. TwistBipoint x;
  128. is >> x;
  129. input.curveBipointUniversals.push_back(x);
  130. }
  131. is >> numElements;
  132. for (size_t i = 0; i < numElements.val(); i++)
  133. {
  134. Scalar x;
  135. is >> x;
  136. input.challengeParts.push_back(x);
  137. }
  138. is >> numElements;
  139. for (size_t i = 0; i < numElements.val(); i++)
  140. {
  141. Scalar x;
  142. is >> x;
  143. input.responseParts.push_back(x);
  144. }
  145. return is;
  146. }