Browse Source

fixing underlying issue with affine points

tristangurtler 3 years ago
parent
commit
f03bf85281
3 changed files with 293 additions and 55 deletions
  1. 38 0
      bgn2/inc/Curvepoint.hpp
  2. 29 25
      bgn2/src/Bipoint.cpp
  3. 226 30
      bgn2/src/Curvepoint.cpp

+ 38 - 0
bgn2/inc/Curvepoint.hpp

@@ -34,6 +34,7 @@ class Curvepoint
         void make_affine();
 
         friend class CurvepointHash;
+        friend class CurveBipoint;
         friend std::ostream& operator<<(std::ostream& os, const Curvepoint& output);
         friend std::istream& operator>>(std::istream& is, Curvepoint& input);
         
@@ -47,4 +48,41 @@ class CurvepointHash
         size_t operator()(const Curvepoint& x) const;
 };
 
+class Twistpoint
+{
+    public:
+        Twistpoint();
+        Twistpoint(const twistpoint_fp2_t input);
+
+        twistpoint_fp2_t& toTwistpointFp2T();
+        const twistpoint_fp2_t& toTwistpointFp2T() const;
+
+        Twistpoint operator+(const Twistpoint& b) const;
+        Twistpoint operator-(const Twistpoint& b) const;
+        Twistpoint operator*(const Scalar& mult) const;
+
+        bool operator==(const Twistpoint& b) const;
+        bool operator<(const Twistpoint& b) const;
+        bool operator>(const Twistpoint& b) const;
+        bool operator<=(const Twistpoint& b) const;
+        bool operator>=(const Twistpoint& b) const;
+        bool operator!=(const Twistpoint& b) const;
+
+        void make_affine();
+
+        friend class TwistpointHash;
+        friend class TwistBipoint;
+        friend std::ostream& operator<<(std::ostream& os, const Twistpoint& output);
+        friend std::istream& operator>>(std::istream& is, Twistpoint& input);
+        
+    private:
+        twistpoint_fp2_t point;
+};
+
+class TwistpointHash
+{
+    public:
+        size_t operator()(const Twistpoint& x) const;
+};
+
 #endif

+ 29 - 25
bgn2/src/Bipoint.cpp

@@ -143,20 +143,14 @@ TwistBipoint TwistBipoint::operator*(const Scalar& exp) const
 bool CurveBipoint::equal(const curvepoint_fp_t& op1, const curvepoint_fp_t& op2) const
 {
 	bool retval;
-	curvepoint_fp_t affine_op1, affine_op2;
+	Curvepoint affine_op1(op1), affine_op2(op2);
 
-	curvepoint_fp_set(affine_op1, op1);
-	curvepoint_fp_set(affine_op2, op2);
+	affine_op1.make_affine();
+	affine_op2.make_affine();
 
-	if (!(fpe_isone(affine_op1->m_z) || fpe_iszero(affine_op1->m_z)))
-		curvepoint_fp_makeaffine(affine_op1);
-
-	if (!(fpe_isone(affine_op2->m_z) || fpe_iszero(affine_op2->m_z)))
-		curvepoint_fp_makeaffine(affine_op2);
-
-	retval =           fpe_iseq(affine_op1->m_x, affine_op2->m_x);
-	retval = retval && fpe_iseq(affine_op1->m_y, affine_op2->m_y);
-	retval = retval || (fpe_iszero(affine_op1->m_z) && fpe_iszero(affine_op2->m_z));
+	retval =           fpe_iseq(affine_op1.point->m_x, affine_op2.point->m_x);
+	retval = retval && fpe_iseq(affine_op1.point->m_y, affine_op2.point->m_y);
+	retval = retval || (fpe_iszero(affine_op1.point->m_z) && fpe_iszero(affine_op2.point->m_z));
 
 	return retval;
 }
@@ -164,20 +158,14 @@ bool CurveBipoint::equal(const curvepoint_fp_t& op1, const curvepoint_fp_t& op2)
 bool TwistBipoint::equal(const twistpoint_fp2_t& op1, const twistpoint_fp2_t& op2) const
 {
 	bool retval;
-	twistpoint_fp2_t affine_op1, affine_op2;
-
-	twistpoint_fp2_set(affine_op1, op1);
-	twistpoint_fp2_set(affine_op2, op2);
-
-	if (!(fp2e_isone(affine_op1->m_z) || fp2e_iszero(affine_op1->m_z)))
-		twistpoint_fp2_makeaffine(affine_op1);
-
-	if (!(fp2e_isone(affine_op2->m_z) || fp2e_iszero(affine_op2->m_z)))
-		twistpoint_fp2_makeaffine(affine_op2);
+	Twistpoint affine_op1(op1), affine_op2(op2);
+	
+	affine_op1.make_affine();
+	affine_op2.make_affine();
 
-	retval =           fp2e_iseq(affine_op1->m_x, affine_op2->m_x);
-	retval = retval && fp2e_iseq(affine_op1->m_y, affine_op2->m_y);
-	retval = retval || (fp2e_iszero(affine_op1->m_z) && fp2e_iszero(affine_op2->m_z));
+	retval =           fp2e_iseq(affine_op1.point->m_x, affine_op2.point->m_x);
+	retval = retval && fp2e_iseq(affine_op1.point->m_y, affine_op2.point->m_y);
+	retval = retval || (fp2e_iszero(affine_op1.point->m_z) && fp2e_iszero(affine_op2.point->m_z));
 
 	return retval;
 }
@@ -206,16 +194,32 @@ void CurveBipoint::make_affine()
 {
 	if (!(fpe_isone(point[0]->m_z) || fpe_iszero(point[0]->m_z)))
 		curvepoint_fp_makeaffine(point[0]);
+
 	if (!(fpe_isone(point[1]->m_z) || fpe_iszero(point[1]->m_z)))
 		curvepoint_fp_makeaffine(point[1]);
+
+	fpe_short_coeffred(point[0]->m_x);
+	fpe_short_coeffred(point[0]->m_y);
+	fpe_short_coeffred(point[0]->m_z);
+	fpe_short_coeffred(point[1]->m_x);
+	fpe_short_coeffred(point[1]->m_y);
+	fpe_short_coeffred(point[1]->m_z);
 }
 
 void TwistBipoint::make_affine()
 {
 	if (!(fp2e_isone(point[0]->m_z) || fp2e_iszero(point[0]->m_z)))
 		twistpoint_fp2_makeaffine(point[0]);
+
 	if (!(fp2e_isone(point[1]->m_z) || fp2e_iszero(point[1]->m_z)))
 		twistpoint_fp2_makeaffine(point[1]);
+
+	fp2e_short_coeffred(point[0]->m_x);
+	fp2e_short_coeffred(point[0]->m_y);
+	fp2e_short_coeffred(point[0]->m_z);
+	fp2e_short_coeffred(point[1]->m_x);
+	fp2e_short_coeffred(point[1]->m_y);
+	fp2e_short_coeffred(point[1]->m_z);
 }
 
 std::ostream& operator<<(std::ostream& os, const CurveBipoint& output)

+ 226 - 30
bgn2/src/Curvepoint.cpp

@@ -5,24 +5,43 @@ Curvepoint::Curvepoint()
     curvepoint_fp_setneutral(point);
 }
 
+Twistpoint::Twistpoint()
+{
+    twistpoint_fp2_setneutral(point);
+}
+
 Curvepoint::Curvepoint(const curvepoint_fp_t input)
 {
     curvepoint_fp_set(point, input);
 }
 
+Twistpoint::Twistpoint(const twistpoint_fp2_t input)
+{
+    twistpoint_fp2_set(point, input);
+}
+
 curvepoint_fp_t& Curvepoint::toCurvepointFpT()
 {
     return point;
 }
 
+twistpoint_fp2_t& Twistpoint::toTwistpointFp2T()
+{
+    return point;
+}
+
 const curvepoint_fp_t& Curvepoint::toCurvepointFpT() const
 {
     return point;
 }
 
+const twistpoint_fp2_t& Twistpoint::toTwistpointFp2T() const
+{
+    return point;
+}
+
 Curvepoint Curvepoint::operator+(const Curvepoint& b) const
 {
-    
     Curvepoint retval;
 
     if (*this == b)
@@ -33,6 +52,18 @@ Curvepoint Curvepoint::operator+(const Curvepoint& b) const
     return retval;
 }
 
+Twistpoint Twistpoint::operator+(const Twistpoint& b) const
+{
+    Twistpoint retval;
+
+    if (*this == b)
+        twistpoint_fp2_double(retval.point, point);
+    else
+        twistpoint_fp2_add_vartime(retval.point, point, b.point);
+
+    return retval;
+}
+
 Curvepoint Curvepoint::operator-(const Curvepoint& b) const
 {
     Curvepoint retval;
@@ -47,6 +78,20 @@ Curvepoint Curvepoint::operator-(const Curvepoint& b) const
     return retval;
 }
 
+Twistpoint Twistpoint::operator-(const Twistpoint& b) const
+{
+    Twistpoint retval;
+
+    if (!(*this == b))
+    {
+        Twistpoint inverseB;
+        twistpoint_fp2_neg(inverseB.point, b.point);
+        twistpoint_fp2_add_vartime(retval.point, point, inverseB.point);
+    }
+
+    return retval;
+}
+
 Curvepoint Curvepoint::operator*(const Scalar& exp) const
 {
     Curvepoint retval;
@@ -56,23 +101,41 @@ Curvepoint Curvepoint::operator*(const Scalar& exp) const
     return retval;
 }
 
+Twistpoint Twistpoint::operator*(const Scalar& exp) const
+{
+    Twistpoint retval;
+
+    exp.mult(retval.point, point);
+
+    return retval;
+}
+
 bool Curvepoint::operator==(const Curvepoint& b) const
 {
     bool retval;
-    curvepoint_fp_t affine_this_point, affine_b_point;
+    Curvepoint affine_this(point), affine_b(b.point);
 
-    curvepoint_fp_set(affine_this_point, point);
-    curvepoint_fp_set(affine_b_point, b.point);
+    affine_this.make_affine();
+    affine_b.make_affine();
 
-    if (!(fpe_isone(affine_this_point->m_z) || fpe_iszero(affine_this_point->m_z)))
-        curvepoint_fp_makeaffine(affine_this_point);
+    retval =           fpe_iseq(affine_this.point->m_x, affine_b.point->m_x);
+    retval = retval && fpe_iseq(affine_this.point->m_y, affine_b.point->m_y);
+    retval = retval || (fpe_iszero(affine_this.point->m_z) && fpe_iszero(affine_b.point->m_z));
 
-    if (!(fpe_isone(affine_b_point->m_z) || fpe_iszero(affine_b_point->m_z)))
-        curvepoint_fp_makeaffine(affine_b_point);
+    return retval;
+}
 
-    retval =           fpe_iseq(affine_this_point->m_x, affine_b_point->m_x);
-    retval = retval && fpe_iseq(affine_this_point->m_y, affine_b_point->m_y);
-    retval = retval || (fpe_iszero(affine_this_point->m_z) && fpe_iszero(affine_b_point->m_z));
+bool Twistpoint::operator==(const Twistpoint& b) const
+{
+    bool retval;
+    Twistpoint affine_this(point), affine_b(b.point);
+
+    affine_this.make_affine();
+    affine_b.make_affine();
+
+    retval =           fp2e_iseq(affine_this.point->m_x, affine_b.point->m_x);
+    retval = retval && fp2e_iseq(affine_this.point->m_y, affine_b.point->m_y);
+    retval = retval || (fp2e_iszero(affine_this.point->m_z) && fp2e_iszero(affine_b.point->m_z));
 
     return retval;
 }
@@ -81,41 +144,34 @@ bool Curvepoint::operator<(const Curvepoint& b) const
 {
     bool lessThan[2];
     bool equal[2];
-    curvepoint_fp_t affine_this_point, affine_b_point;
+    Curvepoint affine_this(point), affine_b(b.point);
 
     lessThan[0] = lessThan[1] = false;
 
-    curvepoint_fp_set(affine_this_point, point);
-    curvepoint_fp_set(affine_b_point, b.point);
-
-    if (fpe_iszero(affine_this_point->m_z))
+    if (fpe_iszero(affine_this.point->m_z))
     {
         // this case would be equal
-        if (fpe_iszero(affine_b_point->m_z))
+        if (fpe_iszero(affine_b.point->m_z))
             return false;
 
         // point at infinity is less than all other points
         return true;
     }
 
-    if (fpe_iszero(affine_b_point->m_z))
+    if (fpe_iszero(affine_b.point->m_z))
         return false;
 
-    // already checked for the point at infinity, so we don't have to redo that here
-    if (!fpe_isone(affine_this_point->m_z))
-        curvepoint_fp_makeaffine(affine_this_point);
-
-    if (!fpe_isone(affine_b_point->m_z))
-        curvepoint_fp_makeaffine(affine_b_point);
+    affine_this.make_affine();
+    affine_b.make_affine();
 
     for (int i = 11; i >= 0; i--)
     {
-        if (affine_this_point->m_x->v[i] > affine_b_point->m_x->v[i])
+        if (affine_this.point->m_x->v[i] > affine_b.point->m_x->v[i])
         {
             lessThan[0] = false;
             break;
         }
-        if (affine_this_point->m_x->v[i] < affine_b_point->m_x->v[i])
+        if (affine_this.point->m_x->v[i] < affine_b.point->m_x->v[i])
         {
             lessThan[0] = true;
             break;
@@ -124,20 +180,79 @@ bool Curvepoint::operator<(const Curvepoint& b) const
 
     for (int i = 11; i >= 0; i--)
     {
-        if (affine_this_point->m_y->v[i] > affine_b_point->m_y->v[i])
+        if (affine_this.point->m_y->v[i] > affine_b.point->m_y->v[i])
+        {
+            lessThan[1] = false;
+            break;
+        }
+        if (affine_this.point->m_y->v[i] < affine_b.point->m_y->v[i])
+        {
+            lessThan[1] = true;
+            break;
+        }
+    }
+
+    equal[0] = fpe_iseq(affine_this.point->m_x, affine_b.point->m_x);
+    equal[1] = fpe_iseq(affine_this.point->m_y, affine_b.point->m_y);
+
+    // sort is lesser x value first, and then lesser y value second if x's are equal
+    return equal[0] ? (equal[1] ? false : lessThan[1]) : lessThan[0];
+}
+
+bool Twistpoint::operator<(const Twistpoint& b) const
+{
+    bool lessThan[2];
+    bool equal[2];
+    Twistpoint affine_this(point), affine_b(b.point);
+
+    lessThan[0] = lessThan[1] = false;
+
+    if (fp2e_iszero(affine_this.point->m_z))
+    {
+        // this case would be equal
+        if (fp2e_iszero(affine_b.point->m_z))
+            return false;
+
+        // point at infinity is less than all other points
+        return true;
+    }
+
+    if (fp2e_iszero(affine_b.point->m_z))
+        return false;
+
+    affine_this.make_affine();
+    affine_b.make_affine();
+
+    for (int i = 23; i >= 0; i--)
+    {
+        if (affine_this.point->m_x->v[i] > affine_b.point->m_x->v[i])
+        {
+            lessThan[0] = false;
+            break;
+        }
+        if (affine_this.point->m_x->v[i] < affine_b.point->m_x->v[i])
+        {
+            lessThan[0] = true;
+            break;
+        }
+    }
+
+    for (int i = 23; i >= 0; i--)
+    {
+        if (affine_this.point->m_y->v[i] > affine_b.point->m_y->v[i])
         {
             lessThan[1] = false;
             break;
         }
-        if (affine_this_point->m_y->v[i] < affine_b_point->m_y->v[i])
+        if (affine_this.point->m_y->v[i] < affine_b.point->m_y->v[i])
         {
             lessThan[1] = true;
             break;
         }
     }
 
-    equal[0] = fpe_iseq(affine_this_point->m_x, affine_b_point->m_x);
-    equal[1] = fpe_iseq(affine_this_point->m_y, affine_b_point->m_y);
+    equal[0] = fp2e_iseq(affine_this.point->m_x, affine_b.point->m_x);
+    equal[1] = fp2e_iseq(affine_this.point->m_y, affine_b.point->m_y);
 
     // sort is lesser x value first, and then lesser y value second if x's are equal
     return equal[0] ? (equal[1] ? false : lessThan[1]) : lessThan[0];
@@ -148,25 +263,59 @@ bool Curvepoint::operator>(const Curvepoint& b) const
     return !(*this < b);
 }
 
+bool Twistpoint::operator>(const Twistpoint& b) const
+{
+    return !(*this < b);
+}
+
 bool Curvepoint::operator<=(const Curvepoint& b) const
 {
     return (*this == b) || (*this < b);
 }
 
+bool Twistpoint::operator<=(const Twistpoint& b) const
+{
+    return (*this == b) || (*this < b);
+}
+
 bool Curvepoint::operator>=(const Curvepoint& b) const
 {
     return (*this == b) || !(*this < b);
 }
 
+bool Twistpoint::operator>=(const Twistpoint& b) const
+{
+    return (*this == b) || !(*this < b);
+}
+
 bool Curvepoint::operator!=(const Curvepoint& b) const
 {
     return !(*this == b);
 }
 
+bool Twistpoint::operator!=(const Twistpoint& b) const
+{
+    return !(*this == b);
+}
+
 void Curvepoint::make_affine()
 {
     if (!(fpe_isone(point->m_z) || fpe_iszero(point->m_z)))
         curvepoint_fp_makeaffine(point);
+
+    fpe_short_coeffred(point->m_x);
+    fpe_short_coeffred(point->m_y);
+    fpe_short_coeffred(point->m_z);
+}
+
+void Twistpoint::make_affine()
+{
+    if (!(fp2e_isone(point->m_z) || fp2e_iszero(point->m_z)))
+        twistpoint_fp2_makeaffine(point);
+
+    fp2e_short_coeffred(point->m_x);
+    fp2e_short_coeffred(point->m_y);
+    fp2e_short_coeffred(point->m_z);
 }
 
 std::ostream& operator<<(std::ostream& os, const Curvepoint& output)
@@ -182,6 +331,19 @@ std::ostream& operator<<(std::ostream& os, const Curvepoint& output)
     return os;
 }
 
+std::ostream& operator<<(std::ostream& os, const Twistpoint& output)
+{
+    Twistpoint affine_out = output;
+    affine_out.make_affine();
+    
+    if ((os.flags() & std::ios::hex) && fp2e_iszero(affine_out.point->m_z))
+        os << "Infinity";
+    else
+        os << Fp2e(affine_out.point->m_x) << Fp2e(affine_out.point->m_y) << Fp2e(affine_out.point->m_z);
+
+    return os;
+}
+
 std::istream& operator>>(std::istream& is, Curvepoint& input)
 {
     Fpe x, y, z;
@@ -194,6 +356,18 @@ std::istream& operator>>(std::istream& is, Curvepoint& input)
     return is;
 }
 
+std::istream& operator>>(std::istream& is, Twistpoint& input)
+{
+    Fp2e x, y, z;
+    is >> x >> y >> z;
+
+    fp2e_set(input.point->m_x, x.data);
+    fp2e_set(input.point->m_y, y.data);
+    fp2e_set(input.point->m_z, z.data);
+
+    return is;
+}
+
 size_t CurvepointHash::operator()(const Curvepoint& x) const
 {
     if (fpe_iszero(x.point->m_z))
@@ -215,3 +389,25 @@ size_t CurvepointHash::operator()(const Curvepoint& x) const
 
     return retval;
 }
+
+size_t TwistpointHash::operator()(const Twistpoint& x) const
+{
+    if (fp2e_iszero(x.point->m_z))
+    {
+        return 0;
+    }
+
+    size_t retval;
+    std::hash<double> hasher;
+
+    Twistpoint affine_x = x;
+    affine_x.make_affine();
+
+    for (int j = 0; j < 24; j++)
+    {
+        retval ^= hasher(affine_x.point->m_x->v[j]);
+        retval ^= hasher(affine_x.point->m_y->v[j]);
+    }
+
+    return retval;
+}