Browse Source

changes to make things a little more sane

tristangurtler 3 years ago
parent
commit
a47148d30a

+ 13 - 17
bgn2/inc/Bipoint.hpp

@@ -1,17 +1,9 @@
 #ifndef __BIPOINT_HPP
 #define __BIPOINT_HPP
 
-#include "PublicKey.hpp"
-#include "PrivateKey.hpp"
+#include "Scalar.hpp"
 
-#include "mydouble.h" 
-extern "C" {
-#include "fpe.h"
-}
 #include "curvepoint_fp.h"
-extern "C" {
-#include "fp2e.h"	
-}
 #include "twistpoint_fp2.h"
 
 /* It doesn't actually make sense to instantiate this generally;
@@ -28,17 +20,19 @@ class Bipoint<curvepoint_fp_t>
 		Bipoint();
 		Bipoint(curvepoint_fp_t p1, curvepoint_fp_t p2);
 
-		void receive_encryption(const scalar_t& cleartext, const PublicKey& public_key);
-
 		curvepoint_fp_t& operator[](int n);
 		const curvepoint_fp_t& operator[](int n) const;
 
 		Bipoint<curvepoint_fp_t> operator+(const Bipoint<curvepoint_fp_t>& b) const;
-		Bipoint<curvepoint_fp_t> operator*(const scalar_t& mult) const;
+		Bipoint<curvepoint_fp_t> operator*(const Scalar& mult) const;
 
 		bool operator==(const Bipoint<curvepoint_fp_t>& b) const;
+		bool operator!=(const Bipoint<curvepoint_fp_t>& b) const;
+
+		// "double" is a type, so you can't just name the function that I don't think, but that's all this is
+		Bipoint<curvepoint_fp_t> mult_by_2() const;	
 
-		Bipoint<curvepoint_fp_t> multBy2() const;
+		void make_affine();
 	
 	private:
 		curvepoint_fp_t point[2];	
@@ -51,17 +45,19 @@ class Bipoint<twistpoint_fp2_t>
 		Bipoint(); 
 		Bipoint(twistpoint_fp2_t p1, twistpoint_fp2_t p2);
 
-		receive_encryption(const scalar_t& cleartext, const PublicKey& public_key);
-
 		twistpoint_fp2_t& operator[](int n);
 		const twistpoint_fp2_t& operator[](int n) const;
 
 		Bipoint<twistpoint_fp2_t> operator+(const Bipoint<twistpoint_fp2_t>& b) const;
-		Bipoint<twistpoint_fp2_t> operator*(const scalar_t& mult) const;
+		Bipoint<twistpoint_fp2_t> operator*(const Scalar& mult) const;
 
 		bool operator==(const Bipoint<twistpoint_fp2_t>& b) const;
+		bool operator!=(const Bipoint<twistpoint_fp2_t>& b) const;
+
+		// "double" is a type, so you can't just name the function that I don't think, but that's all this is
+		Bipoint<twistpoint_fp2_t> mult_by_2() const;
 
-		Bipoint<twistpoint_fp2_t> multBy2() const;
+		void make_affine();
 	
 	private:
 		twistpoint_fp2_t point[2];	

+ 19 - 11
bgn2/inc/PrivateKey.hpp

@@ -1,27 +1,35 @@
 #ifndef __PRIVATEKEY_HPP
 #define __PRIVATEKEY_HPP
 
-#include "Fp.hpp"
+#include <unordered_map>
+
+#include "Scalar.hpp"
 #include "Bipoint.hpp"
 #include "Quadripoint.hpp"
-#include "fp12e.h"
+#include "pairing.hpp"
 
 class PrivateKey
 {
     public:
         PrivateKey() = default; 
-        PrivateKey(const Fp& a, const Fp& b, const Fp& c, const Fp& d, const Fp& e, const Fp& f, const Fp& g, const Fp& h);
+        PrivateKey(const Scalar& a1, const Scalar& b1, const Scalar& c1, const Scalar& d1, const Scalar& a2, const Scalar& b2, const Scalar& c2, const Scalar& d2);
 
-        void set(const Fp& a, const Fp& b, const Fp& c, const Fp& d, const Fp& e, const Fp& f, const Fp& g, const Fp& h);
-        void set(string which, const Fp& input);
-        Fp get(string which) const;
-        
-        Bipoint<curvepoint_fp_t>  pi_1(const Bipoint<curvepoint_fp_t>& input) const;
-        Bipoint<twistpoint_fp2_t> pi_2(const Bipoint<twistpoint_fp2_t>& input) const;
-        Quadripoint pi_T(const Quadripoint& input) const;
+        void set(const Scalar& a1, const Scalar& b1, const Scalar& c1, const Scalar& d1, const Scalar& a2, const Scalar& b2, const Scalar& c2, const Scalar& d2);
+
+        Scalar decrypt(const Bipoint<curvepoint_fp_t>& ciphertext) const;
+        Scalar decrypt(const Bipoint<twistpoint_fp2_t>& ciphertext) const;
+        Scalar decrypt(const Quadripoint& ciphertext) const;
     
     private:
-        Fp i1, j1, k1, l1, i2, j2, k2, l2;
+        Scalar a1, b1, c1, d1, a2, b2, c2, d2;
+
+        Bipoint<curvepoint_fp_t>  pi_1_curvegen;
+        Bipoint<twistpoint_fp2_t> pi_2_curvegen;
+        Quadripoint               pi_T_curvegen;
+
+        Bipoint<curvepoint_fp_t>  pi_1(const Bipoint<curvepoint_fp_t> & input) const;
+        Bipoint<twistpoint_fp2_t> pi_2(const Bipoint<twistpoint_fp2_t>& input) const;
+        Quadripoint               pi_T(const Quadripoint              & input) const;
 };
 
 #endif

+ 13 - 8
bgn2/inc/PublicKey.hpp

@@ -2,25 +2,30 @@
 #define __PUBLICKEY_HPP
 
 #include "Bipoint.hpp"
+#include "Scalar.hpp"
 
 class PublicKey
 {
 	public:
 		PublicKey() = default;
-		PublicKey(const Bipoint<curvepoint_fp_t>& a, const Bipoint<twistpoint_fp2_t>& b, const Bipoint<curvepoint_fp_t>& c, const Bipoint<twistpoint_fp2_t>& d);
+		PublicKey(const Bipoint<curvepoint_fp_t>& g, const Bipoint<twistpoint_fp2_t>& h, const Bipoint<curvepoint_fp_t>& g1, const Bipoint<twistpoint_fp2_t>& h1);
 		
-		void set(const Bipoint<curvepoint_fp_t>& a, const Bipoint<twistpoint_fp2_t>& b, const Bipoint<curvepoint_fp_t>& c, const Bipoint<twistpoint_fp2_t>& d);
+		void set(const Bipoint<curvepoint_fp_t>& g, const Bipoint<twistpoint_fp2_t>& h, const Bipoint<curvepoint_fp_t>& g1, const Bipoint<twistpoint_fp2_t>& h1);
+		
+		void encrypt(Bipoint<curvepoint_fp_t>& G_element, const Scalar& cleartext) const;
+		void encrypt(Bipoint<twistpoint_fp2_t>& H_element, const Scalar& cleartext) const;
+		void encrypt(Bipoint<curvepoint_fp_t>& G_element, Bipoint<twistpoint_fp2_t>& H_element, const Scalar& cleartext) const;
 		
 		Bipoint<curvepoint_fp_t> get_bipoint_curvegen() const;
 		Bipoint<twistpoint_fp2_t> get_bipoint_twistgen() const;	
-		Bipoint<curvepoint_fp_t> get_bipoint_curve_groupelt() const;
-		Bipoint<twistpoint_fp2_t> get_bipoint_twist_groupelt() const;	
+		Bipoint<curvepoint_fp_t> get_bipoint_curve_subgroup_gen() const;
+		Bipoint<twistpoint_fp2_t> get_bipoint_twist_subgroup_gen() const;	
 		
 	private:
-		Bipoint<curvepoint_fp_t> bipoint_curvegen; // subgroup_gen(i1g, j1g)
-		Bipoint<twistpoint_fp2_t> bipoint_twistgen; // subgroup_gen(i2h, j2h)
-		Bipoint<curvepoint_fp_t> bipoint_curve_groupelt; // u
-		Bipoint<twistpoint_fp2_t> bipoint_twist_groupelt; // v
+		Bipoint<curvepoint_fp_t> bipoint_curvegen; // g
+		Bipoint<twistpoint_fp2_t> bipoint_twistgen; // h
+		Bipoint<curvepoint_fp_t> bipoint_curve_subgroup_gen; // (g^(a1), g^(b1))
+		Bipoint<twistpoint_fp2_t> bipoint_twist_subgroup_gen; // (h^(a2), h^(b2))
 };
 
 #endif

+ 7 - 10
bgn2/inc/Quadripoint.hpp

@@ -1,13 +1,9 @@
 #ifndef __QUADRIPOINT_HPP
 #define __QUADRIPOINT_HPP
 
-#include "mydouble.h" 
-extern "C" {
-#include "fpe.h"
-#include "fp2e.h"	
-#include "fp6e.h"	
-#include "fp12e.h"	
-}
+#include "Scalar.hpp"
+
+#include "fp12e.h"
 
 class Quadripoint
 {
@@ -18,10 +14,11 @@ class Quadripoint
 		fp12e_t& operator[](int n);
 		const fp12e_t& operator[](int n) const;
 
-		Quadripoint operator*(const Quadripoint& b) const;
-		Quadripoint operator^(const scalar_t& exp) const;
-		Quadripoint operator++(int);
+		Quadripoint operator+(const Quadripoint& b) const;
+		Quadripoint operator*(const Scalar& exp) const;
+		
 		bool operator==(const Quadripoint& b) const;
+		bool operator!=(const Quadripoint& b) const;
 
 		Quadripoint square() const;
 	

+ 48 - 34
bgn2/inc/Scalar.hpp

@@ -2,53 +2,67 @@
 #define __FP_HPP
 
 #include <ostream>
+#include <stdlib.h>
 #include <sstream>
-#include <random>
-#include <gmp.h> 
 #include <gmpxx.h>
 
-#include "mydouble.h"
-extern "C" {
-#include "fpe.h"
-#include "scalar.h"
-}
+#include "Bipoint.hpp"
+#include "Quadripoint.hpp"
 
-class Fp
+class Scalar
 {
-    // Necessary for private and public keys to have access to private members
-    friend class PrivateKey;
-    friend class PublicKey;
-
     public:
-        Fp();
-        Fp(const fpe_t& input);
-        Fp(int input);
+        Scalar();
+        Scalar(const scalar_t& input);
+        Scalar(mpz_class input);
 
-        void set(const fpe_t& input);
-        void set(int input);
+        void set(const scalar_t& input);
+        void set(mpz_class input);
         void set_random();
 
-        Fp operator-() const;
-        Fp operator+(const Fp& b) const;
-        Fp operator-(const Fp& b) const;
-        Fp operator*(const Fp& b) const;
-        Fp operator/(const Fp& b) const;
+        Scalar operator-() const;
+        Scalar operator+(const Scalar& b) const;
+        Scalar operator-(const Scalar& b) const;
+        Scalar operator*(const Scalar& b) const;
+        Scalar operator/(const Scalar& b) const;
+        Scalar& operator++();
+        Scalar operator++(int);
+        Scalar& operator--();
+        Scalar operator--(int);
 
-        bool operator==(const Fp& b) const;
-        bool operator!=(const Fp & b) const;
-        bool Fp::is_zero() const;
-        
-        // Problem: thanks to the magic of weird typedefs, scalar_t is actually an array, which complicates returning it
-        // Solution: make the return value a reference
-        const scalar_t& to_scalar() const;
+        curvepoint_fp_t operator*(const curvepoint_fp_t& b) const;
+        twistpoint_fp2_t operator*(const twistpoint_fp2_t& b) const;
+        fp12e_t operator*(const fp12e_t& b) const;
+        Bipoint<curvepoint_fp_t> operator*(const Bipoint<curvepoint_fp_t>& b) const;
+        Bipoint<twistpoint_fp2_t> operator*(const Bipoint<twistpoint_fp2_t>& b) const;
+        Quadripoint operator*(const Quadripoint& b) const;
+
+        bool operator==(const Scalar& b) const;
+        bool operator!=(const Scalar& b) const;
         
-        friend std::ostream& operator<<(std::ostream& os, const Fp& output);
+        friend std::ostream& operator<<(std::ostream& os, const Scalar& output);
     
     private:
-        unsigned long long Fp::mpz2ull(const mpz_class& n) const;
-        fpe_t element;
-        scalar_t scalar;
-        bool no_change;
+        class SecretScalar
+        {
+            SecretScalar();
+            SecretScalar(const Scalar& input);
+            SecretScalar(mpz_class input);
+
+            // Problem: thanks to the magic of weird typedefs, scalar_t is actually an array, which complicates returning it
+            // Solution: make the return value a reference
+            // This feels bad, I know, but it will only be used in places where the variable remains in scope for the duration of usage
+            const scalar_t& expose() const;
+
+            private:
+                void set(mpz_class input);
+                scalar_t element;
+        };
+        SecretScalar to_scalar_t() const;
+
+        static const mpz_class mpz_bn_n("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16);
+
+        mpz_class element;
 };
 
 #endif

+ 0 - 13
bgn2/inc/dechiffrementL2.hpp

@@ -1,13 +0,0 @@
-#ifndef __DECHIFFREMENTL2_HPP
-
-#define __DECHIFFREMENTL2_HPP
-
-#include "BitEvalL2.hpp"
-#include "keygen.hpp"
-#include "pairing.hpp"
-
-void dechiffrementL2(F2& bit_dechiffre, BitEvalL2 bit_chiffre, PrivateKey private_key);
-void dechiffrementL2(F2& bit_dechiffre, Quadripoint quadripoint, PrivateKey private_key); // routine pour les évalués de niveau 3 et 4
-
-
-#endif /* __DECHIFFREMENTL2_HPP */

+ 0 - 15
bgn2/inc/decryption.hpp

@@ -1,15 +0,0 @@
-#ifndef __DECRYPTION_HPP
-#define __DECRYPTION_HPP
-
-#include <unordered_map>
-
-#include "Bipoint.hpp"
-#include "Quadripoint.hpp"
-#include "PrivateKey.hpp"
-#include "pairing.hpp"
-
-int decrypt(const Bipoint<curvepoint_fp_t>& ciphertext, const PrivateKey& private_key);
-int decrypt(const Bipoint<twistpoint_fp2_t>& ciphertext, const PrivateKey& private_key);
-int decrypt(const Quadripoint& ciphertext, const PrivateKey& private_key);
-
-#endif /* __DECRYPTION_HPP */

+ 1 - 3
bgn2/inc/keygen.hpp

@@ -1,12 +1,10 @@
 #ifndef __KEYGEN_HPP
 #define __KEYGEN_HPP
 
-#include "Fp.hpp"
+#include "Scalar.hpp"
 #include "PrivateKey.hpp"
 #include "PublicKey.hpp"
 
-using namespace std;
-
 void keygen(PublicKey& public_key, PrivateKey& private_key);
 
 #endif

+ 62 - 62
bgn2/src/Bipoint.cpp

@@ -1,7 +1,5 @@
 #include "Bipoint.hpp"
 
-extern const scalar_t bn_n;
-
 Bipoint<curvepoint_fp_t>::Bipoint()
 {
 	curvepoint_fp_setneutral(point[0]);
@@ -26,46 +24,6 @@ Bipoint<twistpoint_fp2_t>::Bipoint(twistpoint_fp2_t p1, twistpoint_fp2_t p2)
 	twistpoint_fp2_set(point[1], p2);
 }
 
-void Bipoint<curvepoint_fp_t>::receive_encryption(const scalar_t& cleartext, const PublicKey& public_key)
-{
-	scalar_t lambda;
-	scalar_setrandom(lambda, bn_n);
-
-	Bipoint<curvepoint_fp_t> cleartext_as_element, random_mask;
-
-	cleartext_as_element = public_key.get_bipoint_curvegen().scalarmult_vartime(cleartext);
-	cleartext_as_element.makeaffine();
-
-	random_mask = public_key.get_bipoint_curve_groupelt().scalarmult_vartime(lambda);
-	random_mask.makeaffine();
-
-	ciphertext = cleartext_as_element + random_mask;
-	ciphertext.makeaffine();
-
-	point[0] = ciphertext.point[0];
-	point[1] = ciphertext.point[1];
-}
-
-void Bipoint<twistpoint_fp2_t>::receive_encryption(const scalar_t& cleartext, const PublicKey& public_key)
-{
-	scalar_t lambda;
-	scalar_setrandom(lambda, bn_n);
-
-	Bipoint<twistpoint_fp2_t> cleartext_as_element, random_mask;
-
-	cleartext_as_element = public_key.get_bipoint_twistgen().scalarmult_vartime(cleartext);
-	cleartext_as_element.makeaffine();
-
-	random_mask = public_key.get_bipoint_twist_groupelt().scalarmult_vartime(lambda);
-	random_mask.makeaffine();
-
-	ciphertext = cleartext_as_element + random_mask;
-	ciphertext.makeaffine();
-
-	point[0] = ciphertext.point[0];
-	point[1] = ciphertext.point[1];
-}
-
 curvepoint_fp_t& Bipoint<curvepoint_fp_t>::operator[](int n)
 {
 	return point[n];
@@ -106,51 +64,81 @@ Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::operator+(const Bipoint<twi
 	return retval;
 }
 
-Bipoint<curvepoint_fp_t> Bipoint<curvepoint_fp_t>::operator*(const scalar_t& mult) const
+Bipoint<curvepoint_fp_t> Bipoint<curvepoint_fp_t>::operator*(const Scalar& mult) const
 {
 	Bipoint<curvepoint_fp_t> retval;
 
-	curvepoint_fp_scalarmult_vartime(retval[0], point[0], mult);
-	curvepoint_fp_scalarmult_vartime(retval[1], point[1], mult);
+	retval[0] = mult * point[0];
+	retval[1] = mult * point[1];
 
 	return retval;
 }
 
-Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::operator*(const scalar_t& mult) const
+Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::operator*(const Scalar& mult) const
 {
 	Bipoint<twistpoint_fp2_t> retval;
 
-	twistpoint_fp2_scalarmult_vartime(retval[0], point[0], mult);
-	twistpoint_fp2_scalarmult_vartime(retval[1], point[1], mult);
+	retval[0] = mult * point[0];
+	retval[1] = mult * point[1];
 
 	return retval;
 }
 
 bool Bipoint<curvepoint_fp_t>::operator==(const Bipoint<curvepoint_fp_t>& b) const
 {
-	bool retval = fpe_iseq(point[0]->m_x, b[0]->m_x);
-	retval &&= fpe_iseq(point[0]->m_y, b[0]->m_y); 
-	retval &&= fpe_iseq(point[0]->m_z, b[0]->m_z);
-	retval &&= fpe_iseq(point[1]->m_x, b[1]->m_x); 
-	retval &&= fpe_iseq(point[1]->m_y, b[1]->m_y);
-	retval &&= fpe_iseq(point[1]->m_z, b[1]->m_z); 
+	bool retval;
+
+	fpe_t point0_x1z2, point0_y1z2, point0_x2z1, point0_y2z1, point1_x1z2, point1_y1z2, point1_x2z1, point1_y2z1;
+	fpe_mul(point0_x1z2, point[0]->m_x, b[0]->m_z);
+	fpe_mul(point0_y1z2, point[0]->m_y, b[0]->m_z);
+	fpe_mul(point0_x2z1, point[0]->m_z, b[0]->m_x);
+	fpe_mul(point0_y2z1, point[0]->m_z, b[0]->m_y);
+	fpe_mul(point1_x1z2, point[1]->m_x, b[1]->m_z);
+	fpe_mul(point1_y1z2, point[1]->m_y, b[1]->m_z);
+	fpe_mul(point1_x2z1, point[1]->m_z, b[1]->m_x);
+	fpe_mul(point1_y2z1, point[1]->m_z, b[1]->m_y);
+
+	retval   = fpe_iseq(point0_x1z2, point0_x2z1);
+	retval &&= fpe_iseq(point0_y1z2, point0_y2z1); 
+	retval &&= fpe_iseq(point1_x1z2, point1_x2z1);
+	retval &&= fpe_iseq(point1_y1z2, point1_y2z1); 
 
 	return retval;
 }
 
 bool Bipoint<twistpoint_fp2_t>::operator==(const Bipoint<twistpoint_fp2_t>& b) const
 {
-	bool retval = fp2e_iseq(point[0]->m_x, b[0]->m_x);
-	retval &&= fp2e_iseq(point[0]->m_y, b[0]->m_y); 
-	retval &&= fp2e_iseq(point[0]->m_z, b[0]->m_z);
-	retval &&= fp2e_iseq(point[1]->m_x, b[1]->m_x); 
-	retval &&= fp2e_iseq(point[1]->m_y, b[1]->m_y);
-	retval &&= fp2e_iseq(point[1]->m_z, b[1]->m_z); 
+	bool retval;
+
+	fp2e_t point0_x1z2, point0_y1z2, point0_x2z1, point0_y2z1, point1_x1z2, point1_y1z2, point1_x2z1, point1_y2z1;
+	fp2e_mul(point0_x1z2, point[0]->m_x, b[0]->m_z);
+	fp2e_mul(point0_y1z2, point[0]->m_y, b[0]->m_z);
+	fp2e_mul(point0_x2z1, point[0]->m_z, b[0]->m_x);
+	fp2e_mul(point0_y2z1, point[0]->m_z, b[0]->m_y);
+	fp2e_mul(point1_x1z2, point[1]->m_x, b[1]->m_z);
+	fp2e_mul(point1_y1z2, point[1]->m_y, b[1]->m_z);
+	fp2e_mul(point1_x2z1, point[1]->m_z, b[1]->m_x);
+	fp2e_mul(point1_y2z1, point[1]->m_z, b[1]->m_y);
+
+	retval   = fp2e_iseq(point0_x1z2, point0_x2z1);
+	retval &&= fp2e_iseq(point0_y1z2, point0_y2z1); 
+	retval &&= fp2e_iseq(point1_x1z2, point1_x2z1);
+	retval &&= fp2e_iseq(point1_y1z2, point1_y2z1);
 
 	return retval;
 }
 
-Bipoint<curvepoint_fp_t> Bipoint<curvepoint_fp_t>::multBy2() const
+bool Bipoint<curvepoint_fp_t>::operator!=(const Bipoint<curvepoint_fp_t>& b) const
+{
+	return !(*this == b);
+}
+
+bool Bipoint<twistpoint_fp2_t>::operator!=(const Bipoint<twistpoint_fp2_t>& b) const
+{
+	return !(*this == b);
+}
+
+Bipoint<curvepoint_fp_t> Bipoint<curvepoint_fp_t>::mult_by_2() const
 {
 	Bipoint<curvepoint_fp_t> retval;
 
@@ -160,7 +148,7 @@ Bipoint<curvepoint_fp_t> Bipoint<curvepoint_fp_t>::multBy2() const
 	return retval;
 }
 
-Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::multBy2() const
+Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::mult_by_2() const
 {
 	Bipoint<twistpoint_fp2_t> retval;
 
@@ -169,3 +157,15 @@ Bipoint<twistpoint_fp2_t> Bipoint<twistpoint_fp2_t>::multBy2() const
 
 	return retval;
 }
+
+void Bipoint<curvepoint_fp_t>::make_affine()
+{
+	curvepoint_fp_makeaffine(point[0]);
+	curvepoint_fp_makeaffine(point[1]);
+}
+
+void Bipoint<twistpoint_fp2_t>::make_affine()
+{
+	twistpoint_fp2_makeaffine(point[0]);
+	twistpoint_fp2_makeaffine(point[1]);
+}

+ 155 - 246
bgn2/src/PrivateKey.cpp

@@ -2,149 +2,122 @@
 
 extern const curvepoint_fp_t bn_curvegen;
 
-PrivateKey::PrivateKey(const Fp& a, const Fp& b, const Fp& c, const Fp& d, const Fp& e, const Fp& f, const Fp& g, const Fp& h)
+PrivateKey::PrivateKey(const PublicKey& pub_key, const Scalar& a1, const Scalar& b1, const Scalar& c1, const Scalar& d1, const Scalar& a2, const Scalar& b2, const Scalar& c2, const Scalar& d2)
 {
-    set(a, b, c, d, e, f, g, h);
+    set(pub_key, a1, b1, c1, d1, a2, b2, c2, d2);
 }
 
-void PrivateKey::set(const Fp& a, const Fp& b, const Fp& c, const Fp& d, const Fp& e, const Fp& f, const Fp& g, const Fp& h)
+void PrivateKey::set(const PublicKey& pub_key, const Scalar& a1, const Scalar& b1, const Scalar& c1, const Scalar& d1, const Scalar& a2, const Scalar& b2, const Scalar& c2, const Scalar& d2)
 {
-    i1 = a;
-    j1 = b;
-    k1 = c;
-    l1 = d;
-
-    i2 = e;
-    j2 = f;
-    k2 = g;
-    l2 = h;
+    this->a1 = a1;
+    this->b1 = b1;
+    this->c1 = c1;
+    this->d1 = d1;
+
+    this->a2 = a2;
+    this->b2 = b2;
+    this->c2 = c2;
+    this->d2 = d2;
+
+    this->pi_1_curvegen = pi_1(pub_key.get_bipoint_curvegen());
+    this->pi_2_curvegen = pi_2(pub_key.get_bipoint_twistgen());
+    this->pi_T_curvegen = pi_T(pairing(pub_key.get_bipoint_curvegen(), pub_key.get_bipoint_twistgen()));
 }
 
-void PrivateKey::set(string which, const Fp& input);
+Scalar PrivateKey::decrypt(const Bipoint<curvepoint_fp_t>& ciphertext)
 {
-    if (which.length() != 2)
-        return;
+    static std::unordered_map<Bipoint<curvepoint_fp_t>, Scalar> memoizer;
+    static Scalar max_checked = Scalar(0);
 
-    bool var_is_1 = false;
-    switch (which[1])
-    {
-        case '1':
-            var_is_1 = true;
-            break;
-
-        case '2':
-            break;
+    Bipoint<curvepoint_fp_t> pi_1_ciphertext = pi_1(ciphertext); 
 
-        default:
-            return;
+    auto lookup = memoizer.find(pi_1_ciphertext);
+    if (lookup != memoizer.end())
+    {
+        return lookup->second;
     }
 
-    switch (which[0])
+    Bipoint<curvepoint_fp_t> i = pi_1_curvegen * max_checked;
+    do
     {
-        case 'i':
-            if (var_is_1)
-                i1 = input;
-            else i2 = input;
-            break;
-
-        case 'j':
-            if (var_is_1)
-                j1 = input;
-            else j2 = input;
-            break;
-
-        case 'k':
-            if (var_is_1)
-                k1 = input;
-            else k2 = input;
-            break;
-
-        case 'l':
-            if (var_is_1)
-                l1 = input;
-            else l2 = input;
-            break;
-
-        default:
-            break;
-    } 
+        memoizer[pi_1_ciphertext] = max_checked++;
+        i = i + pi_1_curvegen;
+    } while (i != pi_1_ciphertext);
+
+    return max_checked - Scalar(1);
 }
 
-Fp PrivateKey::get(string which) const
+Scalar PrivateKey::decrypt(const Bipoint<twistpoint_fp2_t>& ciphertext)
 {
-    if (which.length() != 2)
-        return Fp();
-
-    bool var_is_1 = false;
-    switch (which[1])
-    {
-        case '1':
-            var_is_1 = true;
-            break;
+    static std::unordered_map<Bipoint<twistpoint_fp2_t>, Scalar> memoizer;
+    static Scalar max_checked = Scalar(0);
 
-        case '2':
-            break;
+    Bipoint<twistpoint_fp2_t> pi_2_ciphertext = pi_2(ciphertext);
 
-        default:
-            return Fp();
+    auto lookup = memoizer.find(pi_2_ciphertext);
+    if (lookup != memoizer.end())
+    {
+        return lookup->second;
     }
 
-    switch (which[0])
+    Bipoint<twistpoint_fp2_t> i = pi_2_twistgen * max_checked;
+    do
     {
-        case 'i':
-            return var_is_1 ? i1 : i2;
+        memoizer[pi_2_ciphertext] = max_checked++;
+        i = i + pi_2_twistgen;
+    } while (i != pi_2_ciphertext);
+
+    return max_checked - Scalar(1);
+}
+
 
-        case 'j':
-            return var_is_1 ? j1 : j2;
+void PrivateKey::decrypt(const Quadripoint& ciphertext)
+{
+    static std::unordered_map<Quadripoint, Scalar> memoizer;
+    static Scalar max_checked = Scalar(0);
 
-        case 'k':
-            return var_is_1 ? k1 : k2;
+    Quadripoint pi_T_ciphertext = pi_T(ciphertext); 
 
-        case 'l':
-            return var_is_1 ? l1 : l2;
+    auto lookup = memoizer.find(pi_T_ciphertext);
+    if (lookup != memoizer.end())
+    {
+        return lookup->second;
+    }
+
+    Quadripoint i = pi_T_pairgen * max_checked;
+    do
+    {
+        memoizer[pi_2_ciphertext] = max_checked++;
+        i = i + pi_T_pairgen;
+    } while (i != pi_T_ciphertext);
 
-        default:
-            return Fp();
-    } 
+    return max_checked - Scalar(1);
 }
 
 Bipoint<curvepoint_fp_t> PrivateKey::pi_1(const Bipoint<curvepoint_fp_t>& input) const
 {
     Bipoint<curvepoint_fp_t> retval;
-    curvepoint_fp_t temp1, temp2;
+    curvepoint_fp_t temp0, temp1;
     
-    const scalar_t i1_s = i1.to_scalar();
-    const scalar_t j1_s = j1.to_scalar();
-    const scalar_t k1_s = k1.to_scalar();
-    const scalar_t l1_s = l1.to_scalar();
 
+    temp0 = b1 * (c1 * input[0]);
+    curvepoint_fp_neg(temp0, temp0);
 
-    curvepoint_fp_scalarmult_vartime(temp1, input[0], j1_s);
-    curvepoint_fp_scalarmult_vartime(temp1, temp1, k1_s);
-    curvepoint_fp_neg(temp1, temp1);
-    curvepoint_fp_makeaffine(temp1);
+    temp1 = a1 * (c1 * input[1]);
+    curvepoint_fp_add_vartime(temp0, temp0, temp1);
+    
+    curvepoint_fp_makeaffine(temp0);
+    curvepoint_fp_set(retval[0], temp0);
 
-    curvepoint_fp_scalarmult_vartime(temp2, input[1], i1_s);
-    curvepoint_fp_scalarmult_vartime(temp2, temp2, k1_s);
-    curvepoint_fp_makeaffine(temp2);
 
-    curvepoint_fp_add_vartime(temp1, temp1, temp2);
-    curvepoint_fp_makeaffine(temp1);
-    curvepoint_fp_set(retval[0], temp1);
-    
+    temp0 = b1 * (d1 * input[0]);
+    curvepoint_fp_neg(temp0, temp0);
 
-    curvepoint_fp_scalarmult_vartime(temp1, input[0], j1_s);
-    curvepoint_fp_scalarmult_vartime(temp1, temp1, l1_s);
-    curvepoint_fp_neg(temp1, temp1);
-    curvepoint_fp_makeaffine(temp1);
-    
-    curvepoint_fp_scalarmult_vartime(temp2, input[1], i1_s);
-    curvepoint_fp_scalarmult_vartime(temp2, temp2, l1_s);
-    curvepoint_fp_makeaffine(temp2);
+    temp1 = a1 * (d1 * input[1]);
+    curvepoint_fp_add_vartime(temp0, temp0, temp1);
     
-    curvepoint_fp_add_vartime(temp1, temp1, temp2);
-    curvepoint_fp_makeaffine(temp1);
-    curvepoint_fp_set(retval[1], temp1); 
+    curvepoint_fp_makeaffine(temp0);
+    curvepoint_fp_set(retval[1], temp0);
 
 
     return retval;
@@ -153,40 +126,27 @@ Bipoint<curvepoint_fp_t> PrivateKey::pi_1(const Bipoint<curvepoint_fp_t>& input)
 Bipoint<twistpoint_fp2_t> PrivateKey::pi_2(const Bipoint<twistpoint_fp2_t>& input) const
 {
     Bipoint<twistpoint_fp2_t> retval;
-    twistpoint_fp2_t temp1, temp2;
-
-    const scalar_t i2_s = i2.to_scalar();
-    const scalar_t j2_s = j2.to_scalar();
-    const scalar_t k2_s = k2.to_scalar();
-    const scalar_t l2_s = l2.to_scalar();
+    twistpoint_fp2_t temp0, temp1;
     
 
-    twistpoint_fp2_scalarmult_vartime(temp1, input[0], j2_s);
-    twistpoint_fp2_scalarmult_vartime(temp1, temp1, k2_s);
-    twistpoint_fp2_neg(temp1,temp1);
-    twistpoint_fp2_makeaffine(temp1);
-    
-    twistpoint_fp2_scalarmult_vartime(temp2, input[1], i2_s);
-    twistpoint_fp2_scalarmult_vartime(temp2, temp2, k2_s);
-    twistpoint_fp2_makeaffine(temp2);
+    temp0 = b2 * (c2 * input[0]);
+    twistpoint_fp2_neg(temp0, temp0);
 
-    twistpoint_fp2_add_vartime(temp1, temp1, temp2);
-    twistpoint_fp2_makeaffine(temp1);
-    twistpoint_fp2_set(retval[0], temp1);
+    temp1 = a2 * (c2 * input[1]);
+    twistpoint_fp2_add_vartime(temp0, temp0, temp1);
     
+    twistpoint_fp2_makeaffine(temp0);
+    twistpoint_fp2_set(retval[0], temp0);
 
-    twistpoint_fp2_scalarmult_vartime(temp1, input[0], j2_s);
-    twistpoint_fp2_scalarmult_vartime(temp1, temp1, l2_s);
-    twistpoint_fp2_neg(temp1, temp1);      
-    twistpoint_fp2_makeaffine(temp1);
-    
-    twistpoint_fp2_scalarmult_vartime(temp2, input[1], i2_s);
-    twistpoint_fp2_scalarmult_vartime(temp2, temp2, l2_s);
-    twistpoint_fp2_makeaffine(temp2);
 
-    twistpoint_fp2_add_vartime(temp1, temp1, temp2);
-    twistpoint_fp2_makeaffine(temp1);
-    twistpoint_fp2_set(retval[1], temp1);
+    temp0 = b2 * (d2 * input[0]);
+    twistpoint_fp2_neg(temp0, temp0);
+
+    temp1 = a2 * (d2 * input[1]);
+    twistpoint_fp2_add_vartime(temp0, temp0, temp1);
+    
+    twistpoint_fp2_makeaffine(temp0);
+    twistpoint_fp2_set(retval[1], temp0);
 
 
     return retval;     
@@ -195,128 +155,77 @@ Bipoint<twistpoint_fp2_t> PrivateKey::pi_2(const Bipoint<twistpoint_fp2_t>& inpu
 Quadripoint PrivateKey::pi_T(const Quadripoint& input) const
 {
     Quadripoint retval;
-    fp12e_t temp1, temp2, temp3, temp4;
-
-    const scalar_t i1_s = i1.to_scalar();
-    const scalar_t j1_s = j1.to_scalar();
-    const scalar_t k1_s = k1.to_scalar();
-    const scalar_t l1_s = l1.to_scalar();
-    const scalar_t i2_s = i2.to_scalar();
-    const scalar_t j2_s = j2.to_scalar();
-    const scalar_t k2_s = k2.to_scalar();
-    const scalar_t l2_s = l2.to_scalar();
-
-
-    fp12e_pow_vartime(temp1, input[0], j1_s);
-    fp12e_pow_vartime(temp1, temp1, k1_s);
-    fp12e_pow_vartime(temp1, temp1, j2_s); 
-    fp12e_pow_vartime(temp1, temp1, k2_s);
+    fp12e_t temp0, temp1, temp2, temp3;
+
+
+    temp0 = c2 * (b2 * (c1 * (b1 * input[0])));
     
-    fp12e_invert(temp2, input[1]);
-    fp12e_pow_vartime(temp2, temp2, j1_s);
-    fp12e_pow_vartime(temp2, temp2, k1_s);
-    fp12e_pow_vartime(temp2, temp2, i2_s);
-    fp12e_pow_vartime(temp2, temp2, k2_s);
-
-    fp12e_invert(temp3, input[2]);
-    fp12e_pow_vartime(temp3, temp3, i1_s);
-    fp12e_pow_vartime(temp3, temp3, k1_s);
-    fp12e_pow_vartime(temp3, temp3, j2_s);
-    fp12e_pow_vartime(temp3, temp3, k2_s);
-
-    fp12e_pow_vartime(temp4, input[3], i1_s);
-    fp12e_pow_vartime(temp4, temp4, k1_s);
-    fp12e_pow_vartime(temp4, temp4, i2_s);
-    fp12e_pow_vartime(temp4, temp4, k2_s);
-
-    fp12e_mul(temp1, temp1, temp2);
-    fp12e_mul(temp2, temp3, temp4);
-    fp12e_mul(temp1, temp1, temp2);
-    retval.set(temp1, 0);
-
-
-    fp12e_pow_vartime(temp1, input[0], j1_s);
-    fp12e_pow_vartime(temp1, temp1, k1_s);
-    fp12e_pow_vartime(temp1, temp1, j2_s);
-    fp12e_pow_vartime(temp1, temp1, l2_s);
-
-    fp12e_pow_vartime(temp2, input[1], j1_s);
+    fp12e_invert(temp1, input[1]);
+    temp1 = c2 * (a2 * (c1 * (b1 * temp1)));
+    
+    fp12e_invert(temp2, input[2]);
+    temp2 = c2 * (b2 * (c1 * (a1 * temp2)));
+    
+    temp3 = c2 * (a2 * (c1 * (a1 * input[3])));
+
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_mul(temp1, temp2, temp3);
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_set(retval[0], temp0);
+
+
+    temp0 = d2 * (b2 * (c1 * (b1 * input[0])));
+    
+    temp1 = b1 * input[0];
+    fp12e_invert(temp1, temp1);
+    temp1 = d2 * (a2 * (c1 * temp1));
+
+    temp2 = b2 * (c1 * (a1 * input[2]));
     fp12e_invert(temp2, temp2);
-    fp12e_pow_vartime(temp2, temp2, k1_s);
-    fp12e_pow_vartime(temp2, temp2, i2_s);
-    fp12e_pow_vartime(temp2, temp2, l2_s);
-
-    fp12e_pow_vartime(temp3, input[2], i1_s);
-    fp12e_pow_vartime(temp3, temp3, k1_s);
-    fp12e_pow_vartime(temp3, temp3, j2_s);
-    fp12e_invert(temp3, temp3);
-    fp12e_pow_vartime(temp3, temp3, l2_s);
-
-    fp12e_pow_vartime(temp4, input[3], i1_s);
-    fp12e_pow_vartime(temp4, temp4, k1_s);
-    fp12e_pow_vartime(temp4, temp4, i2_s);
-    fp12e_pow_vartime(temp4, temp4, l2_s);
-
-    fp12e_mul(temp1, temp1, temp2);
-    fp12e_mul(temp2, temp3, temp4);
-    fp12e_mul(temp1, temp1, temp2);
-    retval.set(temp1, 1);
+    temp2 = d2 * temp2;
+
+    temp3 = d2 * (a2 * (c1 * (a1 * input[3])));
+
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_mul(temp1, temp2, temp3);
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_set(retval[1], temp0);
     
 
-    fp12e_pow_vartime(temp1, input[0], j1_s);
-    fp12e_pow_vartime(temp1, temp1, l1_s);
-    fp12e_pow_vartime(temp1, temp1, j2_s);
-    fp12e_pow_vartime(temp1, temp1, k2_s);
+    temp0 = c2 * (b2 * (d1 * (b1 * input[0])));
+    
+    temp1 = b1 * input[0];
+    fp12e_invert(temp1, temp1);
+    temp1 = c2 * (a2 * (d1 * temp1));
 
-    fp12e_pow_vartime(temp2, input[1], j1_s);
+    temp2 = b2 * (d1 * (a1 * input[2]));
     fp12e_invert(temp2, temp2);
-    fp12e_pow_vartime(temp2, temp2, l1_s);
-    fp12e_pow_vartime(temp2, temp2, i2_s);
-    fp12e_pow_vartime(temp2, temp2, k2_s);
-
-    fp12e_pow_vartime(temp3, input[2], i1_s);
-    fp12e_pow_vartime(temp3, temp3, l1_s);
-    fp12e_pow_vartime(temp3, temp3, j2_s);
-    fp12e_invert(temp3, temp3);
-    fp12e_pow_vartime(temp3, temp3, k2_s);
-
-    fp12e_pow_vartime(temp4, input[3], i1_s);
-    fp12e_pow_vartime(temp4, temp4, l1_s);
-    fp12e_pow_vartime(temp4, temp4, i2_s);
-    fp12e_pow_vartime(temp4, temp4, k2_s);
+    temp2 = c2 * temp2;
+
+    temp3 = c2 * (a2 * (d1 * (a1 * input[3])));
     
-    fp12e_mul(temp1, temp1, temp2);
-    fp12e_mul(temp2, temp3, temp4);
-    fp12e_mul(temp1, temp1, temp2);
-    retval.set(temp1, 2);
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_mul(temp1, temp2, temp3);
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_set(retval[2], temp0);
     
 
-    fp12e_pow_vartime(temp1, input[0], j1_s);
-    fp12e_pow_vartime(temp1, temp1, l1_s);
-    fp12e_pow_vartime(temp1, temp1, j2_s);
-    fp12e_pow_vartime(temp1, temp1, l2_s);
+    temp0 = d2 * (b2 * (d1 * (b1 * input[0])));
+    
+    temp1 = b1 * input[0];
+    fp12e_invert(temp1, temp1);
+    temp1 = d2 * (a2 * (d1 * temp1));
 
-    fp12e_pow_vartime(temp2, input[1], j1_s);
+    temp2 = b2 * (d1 * (a1 * input[2]));
     fp12e_invert(temp2, temp2);
-    fp12e_pow_vartime(temp2, temp2, l1_s);
-    fp12e_pow_vartime(temp2, temp2, i2_s);
-    fp12e_pow_vartime(temp2, temp2, l2_s);
-
-    fp12e_pow_vartime(temp3, input[2], i1_s);
-    fp12e_pow_vartime(temp3, temp3, l1_s);
-    fp12e_pow_vartime(temp3, temp3, j2_s);
-    fp12e_invert(temp3, temp3);
-    fp12e_pow_vartime(temp3, temp3, l2_s);
-
-    fp12e_pow_vartime(temp4, input[3], i1_s);
-    fp12e_pow_vartime(temp4, temp4, l1_s);
-    fp12e_pow_vartime(temp4, temp4, i2_s);
-    fp12e_pow_vartime(temp4, temp4, l2_s);
-
-    fp12e_mul(temp1, temp1, temp2);
-    fp12e_mul(temp2, temp3, temp4);
-    fp12e_mul(temp1, temp1, temp2);
-    retval.set(temp1, 3);
+    temp2 = d2 * temp2;
+
+    temp3 = d2 * (a2 * (d1 * (a1 * input[3])));
+
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_mul(temp1, temp2, temp3);
+    fp12e_mul(temp0, temp0, temp1);
+    fp12e_set(retval[3], temp0);
     
 
     return retval;

+ 41 - 16
bgn2/src/PublicKey.cpp

@@ -1,22 +1,47 @@
 #include "PublicKey.hpp"
 
-PublicKey(const Bipoint<curvepoint_fp_t>& a, const Bipoint<twistpoint_fp2_t>& b, const Bipoint<curvepoint_fp_t>& c, const Bipoint<twistpoint_fp2_t>& d)
+PublicKey(const Bipoint<curvepoint_fp_t>& g, const Bipoint<twistpoint_fp2_t>& h, const Bipoint<curvepoint_fp_t>& g1, const Bipoint<twistpoint_fp2_t>& h1)
 {
-    set(a, b, c, d);
+    set(g, h, g1, h1);
 }
 
-void set(const Bipoint<curvepoint_fp_t>& a, const Bipoint<twistpoint_fp2_t>& b, const Bipoint<curvepoint_fp_t>& c, const Bipoint<twistpoint_fp2_t>& d)
+void set(const Bipoint<curvepoint_fp_t>& g, const Bipoint<twistpoint_fp2_t>& h, const Bipoint<curvepoint_fp_t>& g1, const Bipoint<twistpoint_fp2_t>& h1)
 {
-    a.makeaffine();
-    b.makeaffine();
-    c.makeaffine();
-    d.makeaffine();
-
-    bipoint_curvegen = a;
-    bipoint_twistgen = b;
+    bipoint_curvegen = g;
+    bipoint_twistgen = h;
     
-    bipoint_curve_groupelt = c;
-    bipoint_twist_groupelt = d;
+    bipoint_curve_subgroup_gen = g1;
+    bipoint_twist_subgroup_gen = h1;
+}
+
+void PublicKey::encrypt(Bipoint<curvepoint_fp_t>& G_element, const Scalar& cleartext) const
+{
+    Scalar lambda;
+    lambda.set_random();
+
+    Bipoint<curvepoint_fp_t> cleartext_as_element, random_mask;
+    cleartext_as_element = get_bipoint_curvegen() * cleartext;
+    random_mask = get_bipoint_curve_subgroup_gen() * lambda;
+
+    G_element = cleartext_as_element + random_mask;
+}
+
+void PublicKey::encrypt(Bipoint<twistpoint_fp2_t>& H_element, const Scalar& cleartext) const
+{
+    Scalar lambda;
+    lambda.set_random();
+
+    Bipoint<twistpoint_fp2_t> cleartext_as_element, random_mask;
+    cleartext_as_element = get_bipoint_twistgen() * cleartext;
+    random_mask = get_bipoint_twist_subgroup_gen() * lambda;
+
+    H_element = cleartext_as_element + random_mask;
+}
+
+void PublicKey::encrypt(Bipoint<curvepoint_fp_t>& G_element, Bipoint<twistpoint_fp2_t>& H_element, const Scalar& cleartext) const
+{
+    encrypt(G_element, cleartext);
+    encrypt(H_element, cleartext);
 }
 
 Bipoint<curvepoint_fp_t> PublicKey::get_bipoint_curvegen() const
@@ -29,12 +54,12 @@ Bipoint<twistpoint_fp2_t> PublicKey::get_bipoint_twistgen() const
     return bipoint_twistgen;
 }
 
-Bipoint<curvepoint_fp_t> PublicKey::get_bipoint_curve_groupelt() const
+Bipoint<curvepoint_fp_t> PublicKey::get_bipoint_curve_subgroup_gen() const
 {
-    return bipoint_curve_groupelt;
+    return bipoint_curve_subgroup_gen;
 }
 
-Bipoint<twistpoint_fp2_t> PublicKey::get_bipoint_twist_groupelt() const
+Bipoint<twistpoint_fp2_t> PublicKey::get_bipoint_twist_subgroup_gen() const
 {
-    return bipoint_twist_groupelt;
+    return bipoint_twist_subgroup_gen;
 }

+ 11 - 6
bgn2/src/Quadripoint.cpp

@@ -26,7 +26,7 @@ const fp12e_t& Quadripoint::operator[](int n) const
 	return point[n];
 }
 
-Quadripoint Quadripoint::operator*(const Quadripoint& b) const
+Quadripoint Quadripoint::operator+(const Quadripoint& b) const
 {
 	Quadripoint retval;
 
@@ -38,14 +38,14 @@ Quadripoint Quadripoint::operator*(const Quadripoint& b) const
 	return retval;
 }
 
-Quadripoint Quadripoint::operator^(const scalar_t& exp) const
+Quadripoint Quadripoint::operator*(const Scalar& exp) const
 {
 	Quadripoint retval;
 	
-	fp12e_pow_vartime(retval[0], point[0], exp);
-	fp12e_pow_vartime(retval[1], point[1], exp);
-	fp12e_pow_vartime(retval[2], point[2], exp);
-	fp12e_pow_vartime(retval[3], point[3], exp);
+	retval[0] = exp * point[0];
+	retval[1] = exp * point[1];
+	retval[2] = exp * point[2];
+	retval[3] = exp * point[3];
 
 	return retval;	
 }
@@ -60,6 +60,11 @@ bool Quadripoint::operator==(const Quadripoint& b) const
 	return retval;
 }
 
+bool Quadripoint::operator!=(const Quadripoint& b) const
+{
+	return !(*this == b);
+}
+
 Quadripoint Quadripoint::square() const
 {
 	Quadripoint retval;

+ 142 - 146
bgn2/src/Scalar.cpp

@@ -1,216 +1,212 @@
-#include "Fp.hpp"
+#include "Scalar.hpp"
 
-extern const double bn_v;
+extern const scalar_t bn_n;
 
-Fp::Fp()
+Scalar::Scalar()
 {
-    fpe_setzero(element);
-    no_change = false;
+    element = 0;
 }
 
-Fp::Fp(const fpe_t & input)
+Scalar::Scalar(const scalar_t& input)
 {
     set(input);
-    no_change = false;
 }
 
-Fp::Fp(int input)
+Scalar::Scalar(mpz_class input)
 {
-    set(input);
-    no_change = false;
+    element = input;
 }
 
-void Fp::set(const fpe_t & fpe)
+void Scalar::set(const scalar_t& input)
 {
-    fpe_set(element, fpe);
-    no_change = false;
+    std::stringstream buffer;
+    std::string temp;
+    buffer << std::hex << input[3] << input[2] << input[1] << input[0];
+    buffer >> temp;
+    
+    element.set_str(temp, 16);
 }
 
-void Fp::set(int input)
+void Scalar::set(mpz_class input)
 {
-    mydouble[12] coefficient_matrix;
-    for (int i = 0; i < 12; i++)
-    {
-        switch (i)
-        {
-            case 0:
-                coefficient_matrix[i] = ((mydouble) input);
-                break;
-
-            default:
-                coefficient_matrix[i] = 0.;             
-        }
-    }
+    element = input;
+}
 
-    fpe_set_doublearray(element, coefficient_matrix);
-    if (input > ((int) bn_v) * 3)
-    {
-        fpe_short_coeffred(element)
-    }
+void Scalar::set_random()
+{
+    scalar_t temp;
+    
+    scalar_setrandom(temp, bn_n);
 
-    no_change = false;
+    set(temp);
 }
 
-void Fp::set_random()
+Scalar Scalar::operator+(const Scalar& b) const
 {
-    // c.f. https://www.cryptojedi.org/papers/dclxvi-20100714.pdf for these maxes
-    const int MAX_A = ((int) bn_v) * 3;
-    const int MAX_B = ((int) bn_v) / 2 + 1;
-    
-    std::random_device generator;
-    std::uniform_int_distribution<int> distribution_a(-MAX_A, MAX_A);
-    std::uniform_int_distribution<int> distribution_b(-MAX_B, MAX_B);
+    mpz_class temp = element + b.element;
 
-    mydouble[12] coefficient_matrix;
-    for (int i = 0; i < 12; i++)
-    {
-        switch (i)
-        {
-            case 0:
-            case 6:
-                coefficient_matrix[i] = ((mydouble) distribution_a(generator));
-                break;
-
-            default:
-                coefficient_matrix[i] = ((mydouble) distribution_b(generator));             
-        }
-    }
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
+
+    return Scalar(temp);
+}
+
+Scalar Scalar::operator-(const Scalar& b) const
+{
+    mpz_class temp = element - b.element;
 
-    fpe_set_doublearray(element, coefficient_matrix);
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    no_change = false;
+    return Scalar(temp);
 }
 
-Fp Fp::operator-() const
+Scalar Scalar::operator*(const Scalar& b) const
 {
-    fpe_t temp;
+    mpz_class temp = element * b.element;
+
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    fpe_neg(temp, element);
-    return Fp(temp);
+    return Scalar(temp);
 }
 
-Fp Fp::operator+(const Fp & b) const
+Scalar Scalar::operator/(const Scalar& b) const
 {
-    fpe_t temp;
-    fpe_add(temp, element, b.element);
+    mpz_class temp;
+    mpz_invert(temp.get_mpz_t(), b.element.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    return Fp(temp);
+    temp *= element;
+    mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
+
+    return Scalar(temp);
 }
 
-Fp Fp::operator-(const Fp & b) const
+Scalar& Scalar::operator++()
 {
-    fpe_t temp;
-    fpe_sub(temp, element, b.element);
+    element++;
+    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    return Fp(temp);
+    return *this;
 }
 
-Fp Fp::operator*(const Fp & b) const
+Scalar Scalar::operator++(int)
 {
-    fpe_t temp;
-    fpe_mul(temp, element, b.element);
+    Scalar retval = *this;
+    
+    element++;
+    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    return Fp(temp);
+    return retval;
 }
 
-Fp Fp::operator/(const Fp & b) const
+Scalar& Scalar::operator--()
 {
-    fpe_t temp;
-    fpe_invert(temp, b.element);
-    fpe_mul(temp, element, temp);
+    element--;
+    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
 
-    return Fp(temp);
+    return *this;
 }
 
-bool Fp::operator==(const Fp & b) const
+Scalar Scalar::operator--(int)
 {
-    return fpe_iseq(element, b.element) == 1;
+    Scalar retval = *this;
+    
+    element--;
+    mpz_mod(element.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
+
+    return retval;
 }
 
-bool Fp::operator!=(const Fp & b) const
+curvepoint_fp_t Scalar::operator*(const curvepoint_fp_t& b) const
 {
-    return fpe_iseq(element, b.element) == 0;
+    curvepoint_fp_t retval;
+
+    curvepoint_fp_scalarmult_vartime(retval, b, element.to_scalar_t().expose());
+
+    return retval;
 }
 
-bool Fp::is_zero() const
+twistpoint_fp2_t Scalar::operator*(const twistpoint_fp2_t& b) const
 {
-    return fpe_iszero(element) == 1;
+    twistpoint_fp2_t retval;
+
+    twistpoint_fp2_scalarmult_vartime(retval, b, element.to_scalar_t().expose());
+
+    return retval;
 }
-    
-scalar_t& Fp::to_scalar() const
+
+fp12e_t Scalar::operator*(const fp12e_t& b) const
 {
-    if (no_change)
-        return scalar;
+    fp12e_t retval;
 
-    mpz_class poly_at_one = 1.;
-    mpz_class increment_factor = bn_v * 6;
-    for (int i = 0; i < 12; i++)
-    {
-        switch (i)
-        {
-            case 0:
-                poly_at_one = todouble(element->v[0]);
-                break;
-
-            case 1:
-            case 2:
-            case 3:
-            case 4:
-            case 5:
-            case 6:
-                poly_at_one += increment_factor * todouble(element->v[i]);
-                increment_factor *= bn_v;
-                break;
-
-            case 7:
-                increment_factor *= 6.;
-                poly_at_one += increment_factor * todouble(element->v[i]);
-                increment_factor *= bn_v;
-                break;
-
-            default:
-                poly_at_one += increment_factor * todouble(element->v[i]);
-                increment_factor *= bn_v;
-                break;
-        }
-    }
-    
-    mpz_class bn_u = 1;
-    for (int i = 0; i < 3; i++)
-        bn_u *= bn_v;
-        
-    mpz_class bn_p;
-    bn_p = 36*bn_u*bn_u*bn_u*bn_u + 36*bn_u*bn_u*bn_u + 24*bn_u*bn_u + 6*bn_u + 1;
+    fp12e_pow_vartime(retval, b, element.to_scalar_t().expose());
 
-    mpz_class field_element = poly_at_one % bn_p;
+    return retval;
+}
 
-    mpz_class mask = 0xffffffffffffffff; // 8 octets 64 bits
-    scalar[0] = mpz2ull( field_element         & mask);
-    scalar[1] = mpz2ull((field_element >> 64)  & mask);
-    scalar[2] = mpz2ull((field_element >> 128) & mask);
-    scalar[3] = mpz2ull((field_element >> 192) & mask);
+Bipoint<curvepoint_fp_t> Scalar::operator*(const Bipoint<curvepoint_fp_t>& b) const
+{
+    return b * *this;
+}
 
-    no_change = true;
+Bipoint<twistpoint_fp2_t> Scalar::operator*(const Bipoint<twistpoint_fp2_t>& b) const
+{
+    return b * *this;
+}
 
-    return scalar;
+Quadripoint Scalar::operator*(const Quadripoint& b) const
+{
+    return b * *this;
 }
 
-std::ostream& operator<<(std::ostream& os, const Fp& output)
+bool Scalar::operator==(const Scalar& b) const
 {
-    if (!output.no_change)
-        output.to_scalar();
+    return element == b.element;
+}
 
-    os << output.scalar[3] << output.scalar[2] << output.scalar[1] << output.scalar[0];
-    return os;
+bool Scalar::operator!=(const Scalar& b) const
+{
+    return element != b.element;
+}
+
+Scalar::SecretScalar::SecretScalar()
+{
+    element = {0,0,0,0};
+}
+
+Scalar::SecretScalar::SecretScalar(const Scalar& input)
+{
+    set(input.element);
 }
 
-unsigned long long Fp::mpz2ull(const mpz_class& n) const
+Scalar::SecretScalar::SecretScalar(mpz_class input)
 {
-    stringstream str;
-    unsigned long long retval;
+    set(input);
+}
+
+const scalar_t& Scalar::SecretScalar::expose() const
+{
+    return element;
+}
+
+void Scalar::SecretScalar::set(mpz_class input)
+{
+    std::stringstream buffer;
+    char temp[17];
+    buffer << std::setfill('0') << std::setw(64) << input.get_string(16);
+
+    for (int i = 3; i >= 0; i--)
+    {
+        buffer.get(temp, 17);
+        element[i] = strtoull(temp, NULL, 16);
+    }
+}
     
-    str << n;
-    str >> retval;
+Scalar::SecretScalar Scalar::to_scalar_t() const
+{
+    return SecretScalar(element);
+}
 
-    return retval;
+std::ostream& operator<<(std::ostream& os, const Scalar& output)
+{
+    os << output.element;
+    return os;
 }

+ 0 - 129
bgn2/src/dechiffrementL2.cpp

@@ -1,129 +0,0 @@
-#include "dechiffrementL2.hpp"
-
-void dechiffrementL2(F2& bit_dechiffre, BitEvalL2 bit_chiffre, PrivateKey private_key) //valeur du log b1b2+a1b2+a2b1+s
-{
-	signature;
-	//ecris(log=b1b2+a1b2+a2b1+s);
-	Quadripoint quadripoint_pi_T_chiffre;
-	quadripoint_pi_T_chiffre = private_key.pi_T(bit_chiffre.get_quadripoint());  //pi_T(beta)
-	//ecris(affichage de beta);
-	//bit_chiffre.get_quadripoint().print();
-	//ecris(affichage de pi_T(beta));
-	//quadripoint_pi_T_chiffre.print(0);
-	Quadripoint base_log, pow_log; 
-	base_log = private_key.pi_T(pairing(public_key.get_bipoint_curve_groupelt(),public_key.get_bipoint_twist_groupelt()));  //pi_T(e(u,v))
-	//ecris(affichage de pi_T(e(u,v)));
-	//base_log.print(0);
-	//ecris(affichage de private_key);
-	//private_key.print();
-	
-	if (fp12e_isone(quadripoint_pi_T_chiffre[0]) && fp12e_isone(quadripoint_pi_T_chiffre[1]) && fp12e_isone(quadripoint_pi_T_chiffre[2]) && fp12e_isone(quadripoint_pi_T_chiffre[3]))
-	{
-		//cout << "cas log=0" << endl;
-		bit_dechiffre = bit_chiffre.get_bit_masque();
-		return;
-	}
-	if (quadripoint_pi_T_chiffre == base_log)
-	{
-		//cout << "cas log=1" << endl;
-		bit_dechiffre = (bit_chiffre.get_bit_masque()+1)%2;
-		return;
-	}
-	pow_log=base_log*base_log;
-	//ecris(affichage de pi_T(e(u,v))^2);
-	//pow_log.print();	
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=2" << endl;
-		bit_dechiffre = bit_chiffre.get_bit_masque();
-		return;
-	}
-	pow_log=pow_log*base_log;
-	//ecris(affichage de pi_T(e(u,v))^3);
-	//pow_log.print();	
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=3" << endl;
-		bit_dechiffre = (bit_chiffre.get_bit_masque()+1)%2;
-		return;
-	}	
-	pow_log=pow_log*base_log;
-	//ecris(affichage de pi_T(e(u,v))^4);
-	//pow_log.print();
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=4" << endl;
-		bit_dechiffre = bit_chiffre.get_bit_masque();
-		return;
-	}
-	int log=4;
-	abc;
-	while (!(quadripoint_pi_T_chiffre == pow_log))
-	{
-		pow_log=pow_log*base_log;
-		log++;
-		//zout(log);
-	}
-	xyz;
-	//zout(log);
-	bit_dechiffre = (bit_chiffre.get_bit_masque()+ log)%2;
-	
-
-}
-
-void dechiffrementL2(F2& bit_dechiffre, Quadripoint quadripoint, PrivateKey private_key)
-//routine pour les évalués de niveau 3 et 4, déchiffrement sans Catalano Fiore, calcul d'un log seulement, prend en entrée un quadripoint et non pas un évalué de niveau 2
-{
-	Quadripoint quadripoint_pi_T_chiffre;
-	quadripoint_pi_T_chiffre = private_key.pi_T(quadripoint);  //pi_T(beta)
-	Quadripoint base_log, pow_log; 
-	base_log = private_key.pi_T(pairing(public_key.get_bipoint_curve_groupelt(),public_key.get_bipoint_twist_groupelt()));  //pi_T(e(u,v))
-	if (fp12e_isone(quadripoint_pi_T_chiffre[0]) && fp12e_isone(quadripoint_pi_T_chiffre[1]) && fp12e_isone(quadripoint_pi_T_chiffre[2]) && fp12e_isone(quadripoint_pi_T_chiffre[3]))
-	{
-		//cout << "cas log=0" << endl;
-		bit_dechiffre = 0;
-		return;
-	}
-	if (quadripoint_pi_T_chiffre == base_log)
-	{
-		//cout << "cas log=1" << endl;
-		bit_dechiffre = 1;
-		return;
-	}	
-	pow_log=base_log*base_log;
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=2" << endl;
-		bit_dechiffre = 0;
-		return;
-	}
-	pow_log=pow_log*base_log;
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=3" << endl;
-		bit_dechiffre = 1;
-		return;
-	}	
-	pow_log=pow_log*base_log;
-	if (quadripoint_pi_T_chiffre == pow_log)
-	{
-		//cout << "cas log=4" << endl;
-		bit_dechiffre = 0;
-		return;
-	}
-	int log=4;
-	
-	//quadripoint_pi_T_chiffre.print(0);
-	//quadripoint_pi_T_chiffre.print(1);
-	//quadripoint_pi_T_chiffre.print(2);
-	//quadripoint_pi_T_chiffre.print(3);
-	//abc;
-	while (!(quadripoint_pi_T_chiffre == pow_log))
-	{
-		pow_log=pow_log*base_log;
-		log++;
-	}
-	//xyz;
-	//zout(log);
-	bit_dechiffre = log%2;
-}

+ 0 - 74
bgn2/src/decryption.cpp

@@ -1,74 +0,0 @@
-#include "decryption.hpp"
-
-int decrypt(const Bipoint<curvepoint_fp_t>& ciphertext, const PublicKey& public_key, const PrivateKey& private_key)
-{
-	static std::unordered_map<Bipoint<curvepoint_fp_t>, int> memoizer;
-	static int max_checked = 0;
-	static Bipoint<curvepoint_fp_t> pi_1_curvegen = private_key.pi_1(public_key.get_bipoint_curvegen());
-
-	Bipoint<curvepoint_fp_t> pi_1_ciphertext = private_key.pi_1(ciphertext); 
-
-	auto lookup = memoizer.find(pi_1_ciphertext);
-	if (lookup != memoizer.end())
-	{
-		return lookup->second;
-	}
-
-	Bipoint<curvepoint_fp_t> i = pi_1_curvegen * max_checked;
-	do
-	{
-		memoizer[pi_1_ciphertext] = max_checked++;
-		i = i + pi_1_curvegen;
-	} while (i != pi_1_ciphertext);
-
-	return max_checked - 1;
-}
-
-int decrypt(const Bipoint<twistpoint_fp2_t>& ciphertext, const PrivateKey& private_key) // pour les chiffrés de niveau 1
-{
-	static std::unordered_map<Bipoint<twistpoint_fp2_t>, int> memoizer;
-	static int max_checked = 0;
-	static Bipoint<twistpoint_fp2_t> pi_2_twistgen = private_key.pi_2(public_key.get_bipoint_twistgen());
-
-	Bipoint<twistpoint_fp2_t> pi_2_ciphertext = private_key.pi_2(ciphertext); 
-
-	auto lookup = memoizer.find(pi_2_ciphertext);
-	if (lookup != memoizer.end())
-	{
-		return lookup->second;
-	}
-
-	Bipoint<twistpoint_fp2_t> i = pi_2_twistgen * max_checked;
-	do
-	{
-		memoizer[pi_2_ciphertext] = max_checked++;
-		i = i + pi_2_twistgen;
-	} while (i != pi_2_ciphertext);
-
-	return max_checked - 1;
-}
-
-
-void decrypt(const Quadripoint& ciphertext, const PrivateKey& private_key)
-{
-	static std::unordered_map<Quadripoint, int> memoizer;
-	static int max_checked = 0;
-	static Quadripoint pi_T_pairgen = private_key.pi_T(pairing(public_key.get_bipoint_curvegen(), public_key.get_bipoint_twistgen()));
-
-	Quadripoint pi_T_ciphertext = private_key.pi_T(ciphertext); 
-
-	auto lookup = memoizer.find(pi_T_ciphertext);
-	if (lookup != memoizer.end())
-	{
-		return lookup->second;
-	}
-
-	Quadripoint i = pi_T_pairgen ^ max_checked;
-	do
-	{
-		memoizer[pi_2_ciphertext] = max_checked++;
-		i = i * pi_T_pairgen;
-	} while (i != pi_T_ciphertext);
-
-	return max_checked - 1;
-}

+ 34 - 54
bgn2/src/keygen.cpp

@@ -1,81 +1,61 @@
 #include "keygen.hpp"
 
-extern const scalar_t bn_n;
 extern const curvepoint_fp_t bn_curvegen;   
 extern const twistpoint_fp2_t bn_twistgen;
 
 void keygen(PublicKey& public_key, PrivateKey& private_key)
 {
-    Fp i1, j1, k1, l1, i2, j2, k2, l2;
+    Scalar a1, b1, c1, d1, a2, b2, c2, d2;
     
     while (true)
     {
-        j1.set_random();
-        k1.set_random();
-        l1.set_random();
+        a1.set_random();
+        b1.set_random();
+        c1.set_random();
 
-        if (!l1.is_zero())
+        if (a1 != 0)
         {
-            i1 = (j1 * k1 + Fp(1)) / l1;
+            d1 = (b1 * c1 + Scalar(1)) / a1;
             break;
         }
     }
 
     while (true)
     {
-        j2.set_random();
-        k2.set_random();
-        l2.set_random();
+        a2.set_random();
+        b2.set_random();
+        c2.set_random();
 
-        if (!l2.is_zero())
+        if (a2 != 0)
         {
-            i2 = (j2 * k2 + Fp(1)) / l2;
+            d2 = (b2 * c2 + Scalar(1)) / a2;
             break;
         }   
     }
 
-    private_key.set(i1, j1, k1, l1, i2, j2, k2, l2);
+    private_key.set(a1, b1, c1, d1, a2, b2, c2, d2);
     
-    curvepoint_fp_t c1, c2, c3, c4; 
-    
-    curvepoint_fp_scalarmult_vartime(c1, bn_curvegen, i1.to_scalar());  
-    curvepoint_fp_makeaffine(c1);
-
-    curvepoint_fp_scalarmult_vartime(c2, bn_curvegen, j1.to_scalar());
-    curvepoint_fp_makeaffine(c2);
-
-    Bipoint<curvepoint_fp_t> b1(c1, c2);
-
-    twistpoint_fp2_t t1, t2, t3, t4;
-
-    twistpoint_fp2_scalarmult_vartime(t1, bn_twistgen,i2.scalar());
-    twistpoint_fp2_makeaffine(t1);
-
-    twistpoint_fp2_scalarmult_vartime(t2, bn_twistgen,j2.scalar());
-    twistpoint_fp2_makeaffine(t2);
-
-    Bipoint<twistpoint_fp2_t> b2(t1, t2);
-    
-    scalar_t s1, s2, s3, s4;
-    scalar_setrandom(s1, bn_n);
-    scalar_setrandom(s2, bn_n);
-    scalar_setrandom(s3, bn_n);
-    scalar_setrandom(s4, bn_n);
-    
-    curvepoint_fp_scalarmult_vartime(c3, bn_curvegen, s1);
-    curvepoint_fp_makeaffine(c3);
-
-    curvepoint_fp_scalarmult_vartime(c4, bn_curvegen, s2);
-    curvepoint_fp_makeaffine(c4);
-
-    Bipoint<curvepoint_fp_t> b3(c3, c4);    
-    
-    twistpoint_fp2_scalarmult_vartime(t3, bn_twistgen, s3);
-    twistpoint_fp2_makeaffine(t3);
-    twistpoint_fp2_scalarmult_vartime(t4, bn_twistgen, s4);
-    twistpoint_fp2_makeaffine(t4);
-    
-    Bipoint<twistpoint_fp2_t> b4(t3, t4);
+    Scalar r1, r2, r3, r4;
+    r1.set_random();
+    r2.set_random();
+    r3.set_random();
+    r4.set_random();
+
+    curvepoint_fp_t g_part1, g_part2, g_a1, g_b1;
+    g_part1 = r1 * bn_curvegen;
+    g_part2 = r2 * bn_curvegen;
+    g_a1 = a1 * bn_curvegen;
+    g_b1 = b1 * bn_curvegen;
+    Bipoint<curvepoint_fp_t> full_g(g_part1, g_part2);
+    Bipoint<curvepoint_fp_t> full_g1(g_a1, g_b1);
+
+    twistpoint_fp2_t h_part1, h_part2, h_a2, h_b2;
+    h_part1 = r3 * bn_twistgen;
+    h_part2 = r4 * bn_twistgen;
+    h_a2 = a2 * bn_twistgen;
+    h_b2 = b2 * bn_twistgen;
+    Bipoint<twistpoint_fp2_t> full_h(h_part1, h_part2);
+    Bipoint<twistpoint_fp2_t> full_h1(h_a2, h_b2);
         
-    public_key.set(b1, b2, b3, b4);
+    public_key.set(full_g, full_h, full_g1, full_h1);
 }