Browse Source

making scalars more natural to work with for prsona, as well as fixing a bug in serialization

tristangurtler 3 years ago
parent
commit
3f174bb533
7 changed files with 101 additions and 183 deletions
  1. 7 1
      bgn2/inc/Scalar.hpp
  2. 2 2
      bgn2/src/BGN.cpp
  3. 3 0
      bgn2/src/Bipoint.cpp
  4. 37 148
      bgn2/src/PrivateKey.cpp
  5. 48 30
      bgn2/src/Scalar.cpp
  6. 2 0
      bgn2/src/main.cpp
  7. 2 2
      bgn2/src/print_helpers.cpp

+ 7 - 1
bgn2/inc/Scalar.hpp

@@ -33,15 +33,21 @@ class Scalar
         Scalar operator-(const Scalar& b) const;
         Scalar operator*(const Scalar& b) const;
         Scalar operator/(const Scalar& b) const;
+        Scalar operator-() const;
         Scalar& operator++();
         Scalar operator++(int);
         Scalar& operator--();
         Scalar operator--(int);
 
+        Scalar fieldAdd(const Scalar& b) const;
+        Scalar fieldSub(const Scalar& b) const;
+        Scalar fieldMult(const Scalar& b) const;
+        Scalar fieldMultInverse() const;
+
         Scalar curveAdd(const Scalar& b) const;
         Scalar curveSub(const Scalar& b) const;
         Scalar curveMult(const Scalar& b) const;
-        Scalar curveInverse() const;
+        Scalar curveMultInverse() const;
 
         void mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const;
         void mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const;

+ 2 - 2
bgn2/src/BGN.cpp

@@ -15,7 +15,7 @@ BGN::BGN()
 
         if (a1 != Scalar(0))
         {
-            d1 = (b1 * c1 + Scalar(1)) / a1;
+            d1 = (b1.fieldMult(c1).fieldAdd(Scalar(1))).fieldMult(a1.fieldMultInverse());
             break;
         }
     }
@@ -28,7 +28,7 @@ BGN::BGN()
 
         if (a2 != Scalar(0))
         {
-            d2 = (b2 * c2 + Scalar(1)) / a2;
+            d2 = (b2.fieldMult(c2).fieldAdd(Scalar(1))).fieldMult(a2.fieldMultInverse());
             break;
         }   
     }

+ 3 - 0
bgn2/src/Bipoint.cpp

@@ -240,6 +240,9 @@ std::istream& operator>>(std::istream& is, CurveBipoint& input)
 	for (int i = 0; i < 2; i++)
 	{
 		is >> x >> y >> z;
+		// std::cout << "x: " << std::hex << x << std::dec << std::endl;
+		// std::cout << "y: " << std::hex << y << std::dec << std::endl;
+		// std::cout << "z: " << std::hex << z << std::dec << std::endl;
 		fpe_set(input[i]->m_x, x.data);
 		fpe_set(input[i]->m_y, y.data);
 		fpe_set(input[i]->m_z, z.data);

+ 37 - 148
bgn2/src/PrivateKey.cpp

@@ -153,28 +153,13 @@ CurveBipoint BGNPrivateKey::pi_1(const CurveBipoint& input) const
     CurveBipoint retval;
     curvepoint_fp_t temp0, temp1;
     
+    (-b1 * c1).mult(temp0, input[0]);
+    (a1 * c1).mult(temp1, input[1]);
+    curvepoint_fp_add_vartime(retval[0], temp0, temp1);
 
-    b1.mult(temp0, input[0]);
-    c1.mult(temp0, temp0);
-    curvepoint_fp_neg(temp0, temp0);
-
-    a1.mult(temp1, input[1]);
-    c1.mult(temp1, temp1);
-
-    curvepoint_fp_add_vartime(temp0, temp0, temp1);
-    curvepoint_fp_set(retval[0], temp0);
-
-
-    b1.mult(temp0, input[0]);
-    d1.mult(temp0, temp0);
-    curvepoint_fp_neg(temp0, temp0);
-
-    a1.mult(temp1, input[1]);
-    d1.mult(temp1, temp1);
-
-    curvepoint_fp_add_vartime(temp0, temp0, temp1);
-    curvepoint_fp_set(retval[1], temp0);
-
+    (-b1 * d1).mult(temp0, input[0]);
+    (a1 * d1).mult(temp1, input[1]);
+    curvepoint_fp_add_vartime(retval[1], temp0, temp1);
 
     return retval;
 }
@@ -184,28 +169,13 @@ TwistBipoint BGNPrivateKey::pi_2(const TwistBipoint& input) const
     TwistBipoint retval;
     twistpoint_fp2_t temp0, temp1;
     
+    (-b2 * c2).mult(temp0, input[0]);
+    (a2 * c2).mult(temp1, input[1]);
+    twistpoint_fp2_add_vartime(retval[0], temp0, temp1);
 
-    b2.mult(temp0, input[0]);
-    c2.mult(temp0, temp0);
-    twistpoint_fp2_neg(temp0, temp0);
-
-    a2.mult(temp1, input[1]);
-    c2.mult(temp1, temp1);
-
-    twistpoint_fp2_add_vartime(temp0, temp0, temp1);
-    twistpoint_fp2_set(retval[0], temp0);
-
-
-    b2.mult(temp0, input[0]);
-    d2.mult(temp0, temp0);
-    twistpoint_fp2_neg(temp0, temp0);
-
-    a2.mult(temp1, input[1]);
-    d2.mult(temp1, temp1);
-
-    twistpoint_fp2_add_vartime(temp0, temp0, temp1);
-    twistpoint_fp2_set(retval[1], temp0);
-
+    (-b2 * d2).mult(temp0, input[0]);
+    (a2 * d2).mult(temp1, input[1]);
+    twistpoint_fp2_add_vartime(retval[1], temp0, temp1);
 
     return retval;     
 }
@@ -213,120 +183,39 @@ TwistBipoint BGNPrivateKey::pi_2(const TwistBipoint& input) const
 Quadripoint BGNPrivateKey::pi_T(const Quadripoint& input) const
 {
     Quadripoint retval;
-    fp12e_t temp0, temp1, temp2, temp3;
-
-
-    b1.mult(temp0, input[0]);
-    c1.mult(temp0, temp0);
-    b2.mult(temp0, temp0);
-    c2.mult(temp0, temp0);
-    
-    b1.mult(temp1, input[1]);
-    c1.mult(temp1, temp1);
-    a2.mult(temp1, temp1);
-    c2.mult(temp1, temp1);
-    fp12e_invert(temp1, temp1);
-
-    a1.mult(temp2, input[2]);
-    c1.mult(temp2, temp2);
-    b2.mult(temp2, temp2);
-    c2.mult(temp2, temp2);
-    fp12e_invert(temp2, temp2);
-    
-    a1.mult(temp3, input[3]);
-    c1.mult(temp3, temp3);
-    a2.mult(temp3, temp3);
-    c2.mult(temp3, temp3);
+    fp12e_t temp0, temp1, temp2;
 
+    (b1 * c1 * b2 * c2).mult(temp0, input[0]);
+    (-b1 * c1 * a2 * c2).mult(temp1, input[1]);
     fp12e_mul(temp0, temp0, temp1);
-    fp12e_mul(temp1, temp2, temp3);
-    fp12e_mul(temp0, temp0, temp1);
-    fp12e_set(retval[0], temp0);
-
-
-    b1.mult(temp0, input[0]);
-    c1.mult(temp0, temp0);
-    b2.mult(temp0, temp0);
-    d2.mult(temp0, temp0);
-    
-    b1.mult(temp1, input[1]);
-    c1.mult(temp1, temp1);
-    a2.mult(temp1, temp1);
-    d2.mult(temp1, temp1);
-    fp12e_invert(temp1, temp1);
-
-    a1.mult(temp2, input[2]);
-    c1.mult(temp2, temp2);
-    b2.mult(temp2, temp2);
-    d2.mult(temp2, temp2);
-    fp12e_invert(temp2, temp2);
-    
-    a1.mult(temp3, input[3]);
-    c1.mult(temp3, temp3);
-    a2.mult(temp3, temp3);
-    d2.mult(temp3, temp3);
+    (-a1 * c1 * b2 * c2).mult(temp1, input[2]);
+    (a1 * c1 * a2 * c2).mult(temp2, input[3]);
+    fp12e_mul(temp1, temp1, temp2);
+    fp12e_mul(retval[0], temp0, temp1);
 
+    (b1 * c1 * b2 * d2).mult(temp0, input[0]);
+    (-b1 * c1 * a2 * d2).mult(temp1, input[1]);
     fp12e_mul(temp0, temp0, temp1);
-    fp12e_mul(temp1, temp2, temp3);
-    fp12e_mul(temp0, temp0, temp1);
-    fp12e_set(retval[1], temp0);
-    
-
-    b1.mult(temp0, input[0]);
-    d1.mult(temp0, temp0);
-    b2.mult(temp0, temp0);
-    c2.mult(temp0, temp0);
-    
-    b1.mult(temp1, input[1]);
-    d1.mult(temp1, temp1);
-    a2.mult(temp1, temp1);
-    c2.mult(temp1, temp1);
-    fp12e_invert(temp1, temp1);
-
-    a1.mult(temp2, input[2]);
-    d1.mult(temp2, temp2);
-    b2.mult(temp2, temp2);
-    c2.mult(temp2, temp2);
-    fp12e_invert(temp2, temp2);
+    (-a1 * c1 * b2 * d2).mult(temp1, input[2]);
+    (a1 * c1 * a2 * d2).mult(temp2, input[3]);
+    fp12e_mul(temp1, temp1, temp2);
+    fp12e_mul(retval[1], temp0, temp1);
     
-    a1.mult(temp3, input[3]);
-    d1.mult(temp3, temp3);
-    a2.mult(temp3, temp3);
-    c2.mult(temp3, temp3);
-    
-    fp12e_mul(temp0, temp0, temp1);
-    fp12e_mul(temp1, temp2, temp3);
+    (b1 * d1 * b2 * c2).mult(temp0, input[0]);
+    (-b1 * d1 * a2 * c2).mult(temp1, input[1]);
     fp12e_mul(temp0, temp0, temp1);
-    fp12e_set(retval[2], temp0);
-    
-
-    b1.mult(temp0, input[0]);
-    d1.mult(temp0, temp0);
-    b2.mult(temp0, temp0);
-    d2.mult(temp0, temp0);
+    (-a1 * d1 * b2 * c2).mult(temp1, input[2]);
+    (a1 * d1 * a2 * c2).mult(temp2, input[3]);
+    fp12e_mul(temp1, temp1, temp2);
+    fp12e_mul(retval[2], temp0, temp1);
     
-    b1.mult(temp1, input[1]);
-    d1.mult(temp1, temp1);
-    a2.mult(temp1, temp1);
-    d2.mult(temp1, temp1);
-    fp12e_invert(temp1, temp1);
-
-    a1.mult(temp2, input[2]);
-    d1.mult(temp2, temp2);
-    b2.mult(temp2, temp2);
-    d2.mult(temp2, temp2);
-    fp12e_invert(temp2, temp2);
-    
-    a1.mult(temp3, input[3]);
-    d1.mult(temp3, temp3);
-    a2.mult(temp3, temp3);
-    d2.mult(temp3, temp3);
-
+    (b1 * d1 * b2 * d2).mult(temp0, input[0]);
+    (-b1 * d1 * a2 * d2).mult(temp1, input[1]);
     fp12e_mul(temp0, temp0, temp1);
-    fp12e_mul(temp1, temp2, temp3);
-    fp12e_mul(temp0, temp0, temp1);
-    fp12e_set(retval[3], temp0);
-    
+    (-a1 * d1 * b2 * d2).mult(temp1, input[2]);
+    (a1 * d1 * a2 * d2).mult(temp2, input[3]);
+    fp12e_mul(temp1, temp1, temp2);
+    fp12e_mul(retval[3], temp0, temp1);
 
     return retval;
 }

+ 48 - 30
bgn2/src/Scalar.cpp

@@ -71,46 +71,32 @@ mpz_class Scalar::toInt() const
 
 Scalar Scalar::operator+(const Scalar& b) const
 {
-    mpz_class temp = element + b.element;
-
-    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
-
-    return Scalar(temp);
+    return this->curveAdd(b);
 }
 
 Scalar Scalar::operator-(const Scalar& b) const
 {
-    mpz_class temp = element - b.element;
-
-    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
-
-    return Scalar(temp);
+    return this->curveSub(b);
 }
 
 Scalar Scalar::operator*(const Scalar& b) const
 {
-    mpz_class temp = element * b.element;
-
-    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
-
-    return Scalar(temp);
+    return this->curveMult(b);
 }
 
 Scalar Scalar::operator/(const Scalar& b) const
 {
-    mpz_class temp;
-    mpz_invert(temp.get_mpz_t(), b.element.get_mpz_t(), mpz_bn_p.get_mpz_t());
-
-    temp *= element;
-    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
+    return this->curveMult(b.curveMultInverse());
+}
 
-    return Scalar(temp);
+Scalar Scalar::operator-() const
+{
+    return Scalar(0).curveSub(*this);
 }
 
 Scalar& Scalar::operator++()
 {
-    element++;
-    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
+    *this = this->curveAdd(Scalar(1));
 
     return *this;
 }
@@ -119,16 +105,14 @@ Scalar Scalar::operator++(int)
 {
     Scalar retval = *this;
     
-    element++;
-    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
+    *this = this->curveAdd(Scalar(1));
 
     return retval;
 }
 
 Scalar& Scalar::operator--()
 {
-    element--;
-    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
+    *this = this->curveSub(Scalar(1));
 
     return *this;
 }
@@ -137,12 +121,46 @@ Scalar Scalar::operator--(int)
 {
     Scalar retval = *this;
     
-    element--;
-    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
+    *this = this->curveSub(Scalar(1));
 
     return retval;
 }
 
+Scalar Scalar::fieldAdd(const Scalar& b) const
+{
+    mpz_class temp = element + b.element;
+
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
+
+    return Scalar(temp);
+}
+
+Scalar Scalar::fieldSub(const Scalar& b) const
+{
+    mpz_class temp = element - b.element;
+
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
+
+    return Scalar(temp);
+}
+
+Scalar Scalar::fieldMult(const Scalar& b) const
+{
+    mpz_class temp = element * b.element;
+
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
+
+    return Scalar(temp);
+}
+
+Scalar Scalar::fieldMultInverse() const
+{
+    mpz_class temp;
+    mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
+
+    return Scalar(temp);
+}
+
 Scalar Scalar::curveAdd(const Scalar& b) const
 {
     mpz_class temp = element + b.element;
@@ -170,7 +188,7 @@ Scalar Scalar::curveMult(const Scalar& b) const
     return Scalar(temp);
 }
 
-Scalar Scalar::curveInverse() const
+Scalar Scalar::curveMultInverse() const
 {
     mpz_class temp;
     mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());

+ 2 - 0
bgn2/src/main.cpp

@@ -691,6 +691,8 @@ double testQuadRerandomizeSpeed(default_random_engine& generator)
 
 int main(int argc, char *argv[])
 {
+    Scalar::init();
+
     string seedStr("default");
     if (argc > 1)
         seedStr = argv[1];

+ 2 - 2
bgn2/src/print_helpers.cpp

@@ -98,7 +98,7 @@ std::istream& operator>>(std::istream& is, Fp2e& input)
 
 std::ostream& operator<<(std::ostream& os, const Fpe& output)
 {
-    if (os.flags() | std::ios::hex)
+    if (os.flags() & std::ios::hex)
     {
         for (int i = 0; i < 12; i++)
             hex_double(os, output.data->v[i]);
@@ -114,7 +114,7 @@ std::ostream& operator<<(std::ostream& os, const Fpe& output)
 
 std::istream& operator>>(std::istream& is, Fpe& input)
 {
-    if (is.flags() | std::ios::hex)
+    if (is.flags() & std::ios::hex)
     {
         for (int i = 0; i < 12; i++)
             hex_double(is, input.data->v[i]);