proof.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #include "proof.hpp"
  2. /* Altered from answer at
  3. * https://stackoverflow.com/questions/51144505/generate-sha-3-hash-in-c-using-openssl-library
  4. */
  5. // Convert the bytes to a single integer, then make that a Scalar
  6. Scalar bytes_to_scalar(const std::vector<uint8_t>& bytes)
  7. {
  8. std::stringstream stream;
  9. for (uint8_t b : bytes)
  10. {
  11. stream << std::setw(2)
  12. << std::setfill('0')
  13. << std::hex
  14. << static_cast<int>(b);
  15. }
  16. mpz_class value;
  17. value.set_str(stream.str(), 16);
  18. return Scalar(value);
  19. }
  20. // Random Oracle (i.e. SHA3_256)
  21. Scalar oracle(const std::string& input)
  22. {
  23. uint32_t digest_length = SHA256_DIGEST_LENGTH;
  24. const EVP_MD* algorithm = EVP_sha3_256();
  25. uint8_t* digest = static_cast<uint8_t*>(OPENSSL_malloc(digest_length));
  26. EVP_MD_CTX* context = EVP_MD_CTX_new();
  27. EVP_DigestInit_ex(context, algorithm, NULL);
  28. EVP_DigestUpdate(context, input.c_str(), input.size());
  29. EVP_DigestFinal_ex(context, digest, &digest_length);
  30. EVP_MD_CTX_destroy(context);
  31. std::vector<uint8_t> digestBytes(digest, digest + digest_length);
  32. Scalar output = bytes_to_scalar(digestBytes);
  33. OPENSSL_free(digest);
  34. return output;
  35. }
  36. Proof::Proof()
  37. { /* Do nothing */ }
  38. Proof::Proof(std::string hbc)
  39. : hbc(hbc)
  40. { /* Do nothing */ }
  41. void Proof::clear()
  42. {
  43. hbc.clear();
  44. curvepointUniversals.clear();
  45. curveBipointUniversals.clear();
  46. challengeParts.clear();
  47. responseParts.clear();
  48. }
  49. std::ostream& operator<<(std::ostream& os, const Proof& output)
  50. {
  51. if (!output.hbc.empty())
  52. {
  53. os << true;
  54. os << output.hbc.size();
  55. os << output.hbc;
  56. return os;
  57. }
  58. os << false;
  59. os << output.curvepointUniversals.size();
  60. for (size_t i = 0; i < output.curvepointUniversals.size(); i++)
  61. os << output.curvepointUniversals[i];
  62. os << output.curveBipointUniversals.size();
  63. for (size_t i = 0; i < output.curveBipointUniversals.size(); i++)
  64. os << output.curveBipointUniversals[i];
  65. os << output.challengeParts.size();
  66. for (size_t i = 0; i < output.challengeParts.size(); i++)
  67. os << output.challengeParts[i];
  68. os << output.responseParts.size();
  69. for (size_t i = 0; i < output.responseParts.size(); i++)
  70. os << output.responseParts[i];
  71. return os;
  72. }
  73. std::istream& operator>>(std::istream& is, Proof& input)
  74. {
  75. bool hbc;
  76. is >> hbc;
  77. if (hbc)
  78. {
  79. size_t numBytes;
  80. is >> numBytes;
  81. char* buffer = new char[numBytes + 1];
  82. is.read(buffer, numBytes);
  83. input.hbc = buffer;
  84. delete buffer;
  85. return is;
  86. }
  87. size_t numElements;
  88. is >> numElements;
  89. for (size_t i = 0; i < numElements; i++)
  90. {
  91. Twistpoint x;
  92. is >> x;
  93. input.curvepointUniversals.push_back(x);
  94. }
  95. is >> numElements;
  96. for (size_t i = 0; i < numElements; i++)
  97. {
  98. TwistBipoint x;
  99. is >> x;
  100. input.curveBipointUniversals.push_back(x);
  101. }
  102. is >> numElements;
  103. for (size_t i = 0; i < numElements; i++)
  104. {
  105. Scalar x;
  106. is >> x;
  107. input.challengeParts.push_back(x);
  108. }
  109. is >> numElements;
  110. for (size_t i = 0; i < numElements; i++)
  111. {
  112. Scalar x;
  113. is >> x;
  114. input.responseParts.push_back(x);
  115. }
  116. return is;
  117. }