proof.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #include "proof.hpp"
  2. /* Altered from answer at
  3. * https://stackoverflow.com/questions/51144505/generate-sha-3-hash-in-c-using-openssl-library
  4. */
  5. // Convert the bytes to a single integer, then make that a Scalar
  6. Scalar bytes_to_scalar(
  7. const std::vector<uint8_t>& bytes)
  8. {
  9. std::stringstream stream;
  10. for (uint8_t b : bytes)
  11. stream << std::setw(2) << std::setfill('0') << std::hex << static_cast<int>(b);
  12. mpz_class value;
  13. value.set_str(stream.str(), 16);
  14. return Scalar(value);
  15. }
  16. // Random Oracle (i.e. SHA256)
  17. Scalar oracle(
  18. const std::string& input)
  19. {
  20. uint32_t digest_length = SHA256_DIGEST_LENGTH;
  21. const EVP_MD* algorithm = EVP_sha256();
  22. uint8_t* digest = static_cast<uint8_t*>(OPENSSL_malloc(digest_length));
  23. EVP_MD_CTX* context = EVP_MD_CTX_create();
  24. EVP_DigestInit_ex(context, algorithm, NULL);
  25. EVP_DigestUpdate(context, input.c_str(), input.size());
  26. EVP_DigestFinal_ex(context, digest, &digest_length);
  27. EVP_MD_CTX_destroy(context);
  28. std::vector<uint8_t> digestBytes(digest, digest + digest_length);
  29. Scalar output = bytes_to_scalar(digestBytes);
  30. OPENSSL_free(digest);
  31. return output;
  32. }
  33. void Proof::clear()
  34. {
  35. hbc.clear();
  36. curvepointUniversals.clear();
  37. curveBipointUniversals.clear();
  38. challengeParts.clear();
  39. responseParts.clear();
  40. }
  41. bool Proof::operator==(
  42. const Proof& b)
  43. {
  44. bool retval = this->hbc == b.hbc;
  45. retval = retval && this->curvepointUniversals.size() == b.curvepointUniversals.size();
  46. retval = retval && this->curveBipointUniversals.size() == b.curveBipointUniversals.size();
  47. retval = retval && this->challengeParts.size() == b.challengeParts.size();
  48. retval = retval && this->responseParts.size() == b.responseParts.size();
  49. for (size_t i = 0; retval && i < this->curvepointUniversals.size(); i++)
  50. retval = retval && this->curvepointUniversals[i] == b.curvepointUniversals[i];
  51. for (size_t i = 0; retval && i < this->curveBipointUniversals.size(); i++)
  52. retval = retval && this->curveBipointUniversals[i] == b.curveBipointUniversals[i];
  53. for (size_t i = 0; retval && i < this->challengeParts.size(); i++)
  54. retval = retval && this->challengeParts[i] == b.challengeParts[i];
  55. for (size_t i = 0; retval && i < this->responseParts.size(); i++)
  56. retval = retval && this->responseParts[i] == b.responseParts[i];
  57. return retval;
  58. }
  59. std::ostream& operator<<(
  60. std::ostream& os,
  61. const Proof& output)
  62. {
  63. BinaryBool hbc(!output.hbc.empty());
  64. os << hbc;
  65. BinarySizeT numElements;
  66. if (hbc.val())
  67. {
  68. numElements.set(output.hbc.length());
  69. os << numElements;
  70. os << output.hbc;
  71. return os;
  72. }
  73. numElements.set(output.curvepointUniversals.size());
  74. os << numElements;
  75. for (size_t i = 0; i < numElements.val(); i++)
  76. os << output.curvepointUniversals[i];
  77. numElements.set(output.curveBipointUniversals.size());
  78. os << numElements;
  79. for (size_t i = 0; i < numElements.val(); i++)
  80. os << output.curveBipointUniversals[i];
  81. numElements.set(output.challengeParts.size());
  82. os << numElements;
  83. for (size_t i = 0; i < numElements.val(); i++)
  84. os << output.challengeParts[i];
  85. numElements.set(output.responseParts.size());
  86. os << numElements;
  87. for (size_t i = 0; i < numElements.val(); i++)
  88. os << output.responseParts[i];
  89. return os;
  90. }
  91. std::istream& operator>>(
  92. std::istream& is,
  93. Proof& input)
  94. {
  95. BinaryBool hbc;
  96. is >> hbc;
  97. if (hbc.val())
  98. {
  99. BinarySizeT numBytes;
  100. is >> numBytes;
  101. char* buffer = new char[numBytes.val() + 1];
  102. is.read(buffer, numBytes.val());
  103. input.hbc = buffer;
  104. delete buffer;
  105. return is;
  106. }
  107. BinarySizeT numElements;
  108. is >> numElements;
  109. for (size_t i = 0; i < numElements.val(); i++)
  110. {
  111. Twistpoint x;
  112. is >> x;
  113. input.curvepointUniversals.push_back(x);
  114. }
  115. is >> numElements;
  116. for (size_t i = 0; i < numElements.val(); i++)
  117. {
  118. TwistBipoint x;
  119. is >> x;
  120. input.curveBipointUniversals.push_back(x);
  121. }
  122. is >> numElements;
  123. for (size_t i = 0; i < numElements.val(); i++)
  124. {
  125. Scalar x;
  126. is >> x;
  127. input.challengeParts.push_back(x);
  128. }
  129. is >> numElements;
  130. for (size_t i = 0; i < numElements.val(); i++)
  131. {
  132. Scalar x;
  133. is >> x;
  134. input.responseParts.push_back(x);
  135. }
  136. return is;
  137. }