Browse Source

Adding functionality that wasn't already present, which was necessary for PRSONA

tristangurtler 3 years ago
parent
commit
ec1ba5c9d6
4 changed files with 39 additions and 3 deletions
  1. 1 0
      bgn2/inc/Curvepoint.hpp
  2. 6 1
      bgn2/inc/Scalar.hpp
  3. 15 0
      bgn2/src/Curvepoint.cpp
  4. 17 2
      bgn2/src/Scalar.cpp

+ 1 - 0
bgn2/inc/Curvepoint.hpp

@@ -18,6 +18,7 @@ class Curvepoint
         Curvepoint(const curvepoint_fp_t input);
 
         Curvepoint operator+(const Curvepoint& b) const;
+        Curvepoint operator-(const Curvepoint& b) const;
         Curvepoint operator*(const Scalar& mult) const;
 
         bool operator==(const Curvepoint& b) const;

+ 6 - 1
bgn2/inc/Scalar.hpp

@@ -21,6 +21,8 @@ class Scalar
         Scalar(const scalar_t& input);
         Scalar(mpz_class input);
 
+        static void init();
+
         void set(const scalar_t& input);
         void set(mpz_class input);
         void set_random();
@@ -34,6 +36,8 @@ class Scalar
         Scalar& operator--();
         Scalar operator--(int);
 
+        Scalar curveInverse() const;
+
         void mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const;
         void mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const;
         void mult(fp12e_t rop, const fp12e_t& op1) const;
@@ -73,7 +77,8 @@ class Scalar
          *  have arithmetic done on them prior to interacting with curvepoints,
          *  if you're calculating something like an exponentiation of products 
          *  of Scalars (or similar). */
-        static const mpz_class mpz_bn_p;
+        static mpz_class mpz_bn_p;
+        static mpz_class mpz_bn_n;
 
         mpz_class element;
 };

+ 15 - 0
bgn2/src/Curvepoint.cpp

@@ -12,6 +12,7 @@ Curvepoint::Curvepoint(const curvepoint_fp_t input)
 
 Curvepoint Curvepoint::operator+(const Curvepoint& b) const
 {
+    
     Curvepoint retval;
 
     if (*this == b)
@@ -22,6 +23,20 @@ Curvepoint Curvepoint::operator+(const Curvepoint& b) const
     return retval;
 }
 
+Curvepoint Curvepoint::operator-(const Curvepoint& b) const
+{
+    Curvepoint retval;
+
+    if (!(*this == b))
+    {
+        Curvepoint inverseB;
+        curvepoint_fp_neg(inverseB.point, b.point);
+        curvepoint_fp_add_vartime(retval.point, point, inverseB.point);
+    }
+
+    return retval;
+}
+
 Curvepoint Curvepoint::operator*(const Scalar& exp) const
 {
     Curvepoint retval;

+ 17 - 2
bgn2/src/Scalar.cpp

@@ -1,8 +1,9 @@
 #include "Scalar.hpp"
+#include <iostream>
 
 extern const scalar_t bn_n;
-
-const mpz_class Scalar::mpz_bn_p("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16);
+mpz_class Scalar::mpz_bn_p = 0;
+mpz_class Scalar::mpz_bn_n = 0;
 
 Scalar::Scalar()
 {
@@ -19,6 +20,12 @@ Scalar::Scalar(mpz_class input)
     set(input);
 }
 
+void Scalar::init()
+{
+    mpz_bn_p = mpz_class("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16);
+    mpz_bn_n = mpz_class("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16);
+}
+
 void Scalar::set(const scalar_t& input)
 {
     std::stringstream bufferstream;
@@ -131,6 +138,14 @@ Scalar Scalar::operator--(int)
     return retval;
 }
 
+Scalar Scalar::curveInverse() const
+{
+    mpz_class temp;
+    mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
+
+    return Scalar(temp);
+}
+
 void Scalar::mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const
 {
     SecretScalar secret_element = to_scalar_t();