瀏覽代碼

fixing serialization of proofs

tristangurtler 3 年之前
父節點
當前提交
dfd57cea95
共有 2 個文件被更改,包括 38 次插入32 次删除
  1. 6 2
      prsona/inc/proof.hpp
  2. 32 30
      prsona/src/proof.cpp

+ 6 - 2
prsona/inc/proof.hpp

@@ -14,8 +14,12 @@
 
 class Proof {
     public:
-        Proof();
-        Proof(std::string hbc);
+        Proof()
+        { /* */ }
+
+        Proof(std::string hbc)
+            : hbc(hbc)
+        { /* */ }
 
         void clear();
     

+ 32 - 30
prsona/src/proof.cpp

@@ -41,13 +41,6 @@ Scalar oracle(const std::string& input)
     return output;
 }
 
-Proof::Proof()
-{ /* Do nothing */ }
-
-Proof::Proof(std::string hbc)
-: hbc(hbc)
-{ /* Do nothing */ }
-
 void Proof::clear()
 {
     hbc.clear();
@@ -59,30 +52,39 @@ void Proof::clear()
 
 std::ostream& operator<<(std::ostream& os, const Proof& output)
 {
-    if (!output.hbc.empty())
+    BinaryBool hbc(!output.hbc.empty());
+    os << hbc;
+
+    BinarySizeT numElements;
+
+    if (hbc.val())
     {
-        os << true;
-        os << output.hbc.size();
+        numElements.set(output.hbc.length());
+        
+        os << numElements;
         os << output.hbc;
+
         return os;
     }
 
-    os << false;
-
-    os << output.curvepointUniversals.size();
-    for (size_t i = 0; i < output.curvepointUniversals.size(); i++)
+    numElements.set(output.curvepointUniversals.size());
+    os << numElements;
+    for (size_t i = 0; i < numElements.val(); i++)
         os << output.curvepointUniversals[i];
 
-    os << output.curveBipointUniversals.size();
-    for (size_t i = 0; i < output.curveBipointUniversals.size(); i++)
+    numElements.set(output.curveBipointUniversals.size());
+    os << numElements;
+    for (size_t i = 0; i < numElements.val(); i++)
         os << output.curveBipointUniversals[i];
 
-    os << output.challengeParts.size();
-    for (size_t i = 0; i < output.challengeParts.size(); i++)
+    numElements.set(output.challengeParts.size());
+    os << numElements;
+    for (size_t i = 0; i < numElements.val(); i++)
         os << output.challengeParts[i];
 
-    os << output.responseParts.size();
-    for (size_t i = 0; i < output.responseParts.size(); i++)
+    numElements.set(output.responseParts.size());
+    os << numElements;
+    for (size_t i = 0; i < numElements.val(); i++)
         os << output.responseParts[i];
 
     return os;
@@ -90,24 +92,24 @@ std::ostream& operator<<(std::ostream& os, const Proof& output)
 
 std::istream& operator>>(std::istream& is, Proof& input)
 {
-    bool hbc;
+    BinaryBool hbc;
     is >> hbc;
-    if (hbc)
+    if (hbc.val())
     {
-        size_t numBytes;
+        BinarySizeT numBytes;
         is >> numBytes;
         
-        char* buffer = new char[numBytes + 1];
-        is.read(buffer, numBytes);
+        char* buffer = new char[numBytes.val() + 1];
+        is.read(buffer, numBytes.val());
         input.hbc = buffer;
         delete buffer;
 
         return is;
     }
 
-    size_t numElements;
+    BinarySizeT numElements;
     is >> numElements;
-    for (size_t i = 0; i < numElements; i++)
+    for (size_t i = 0; i < numElements.val(); i++)
     {
         Twistpoint x;
         is >> x;
@@ -115,7 +117,7 @@ std::istream& operator>>(std::istream& is, Proof& input)
     }
 
     is >> numElements;
-    for (size_t i = 0; i < numElements; i++)
+    for (size_t i = 0; i < numElements.val(); i++)
     {
         TwistBipoint x;
         is >> x;
@@ -123,7 +125,7 @@ std::istream& operator>>(std::istream& is, Proof& input)
     }
 
     is >> numElements;
-    for (size_t i = 0; i < numElements; i++)
+    for (size_t i = 0; i < numElements.val(); i++)
     {
         Scalar x;
         is >> x;
@@ -131,7 +133,7 @@ std::istream& operator>>(std::istream& is, Proof& input)
     }
 
     is >> numElements;
-    for (size_t i = 0; i < numElements; i++)
+    for (size_t i = 0; i < numElements.val(); i++)
     {
         Scalar x;
         is >> x;