#include "Scalar.hpp" #include #include "proof.hpp" extern const scalar_t bn_n; extern const scalar_t bn_p; mpz_class Scalar::mpz_bn_p = 0; mpz_class Scalar::mpz_bn_n = 0; Scalar::Scalar() { element = 0; } Scalar::Scalar(const scalar_t& input) { set(input); } Scalar::Scalar(mpz_class input) { set(input); } void Scalar::init() { mpz_bn_p = mpz_class("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16); mpz_bn_n = mpz_class("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16); } void Scalar::set(const scalar_t& input) { std::stringstream bufferstream; std::string buffer; mpz_class temp; bufferstream << std::hex << input[3] << input[2] << input[1] << input[0]; bufferstream >> buffer; temp.set_str(buffer, 16); mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t()); element = temp; } void Scalar::set(mpz_class input) { mpz_class temp = input; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t()); element = temp; } void Scalar::set_random() { scalar_t temp; /* When we ask for a random number, * we really mean a seed to find a random element of a group * and the order of the curve either is bn_n or is divided by it * (not bn_p) */ scalar_setrandom(temp, bn_n); set(temp); } void Scalar::set_field_random() { scalar_t temp; /* There's only one occasion we actually want a random integer * in the field, and that's when we generate a private BGN key. */ scalar_setrandom(temp, bn_p); set(temp); } mpz_class Scalar::toInt() const { return element; } Scalar Scalar::operator+(const Scalar& b) const { return this->curveAdd(b); } Scalar Scalar::operator-(const Scalar& b) const { return this->curveSub(b); } Scalar Scalar::operator*(const Scalar& b) const { return this->curveMult(b); } Scalar Scalar::operator/(const Scalar& b) const { return this->curveMult(b.curveMultInverse()); } Scalar Scalar::operator-() const { return Scalar(0).curveSub(*this); } Scalar& Scalar::operator++() { *this = this->curveAdd(Scalar(1)); return *this; } Scalar Scalar::operator++(int) { Scalar retval = *this; *this = this->curveAdd(Scalar(1)); return retval; } Scalar& Scalar::operator--() { *this = this->curveSub(Scalar(1)); return *this; } Scalar Scalar::operator--(int) { Scalar retval = *this; *this = this->curveSub(Scalar(1)); return retval; } Scalar Scalar::fieldAdd(const Scalar& b) const { mpz_class temp = element + b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t()); return Scalar(temp); } Scalar Scalar::fieldSub(const Scalar& b) const { mpz_class temp = element - b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t()); return Scalar(temp); } Scalar Scalar::fieldMult(const Scalar& b) const { mpz_class temp = element * b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t()); return Scalar(temp); } Scalar Scalar::fieldMultInverse() const { mpz_class temp; mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t()); return Scalar(temp); } Scalar Scalar::curveAdd(const Scalar& b) const { mpz_class temp = element + b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t()); return Scalar(temp); } Scalar Scalar::curveSub(const Scalar& b) const { mpz_class temp = element - b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t()); return Scalar(temp); } Scalar Scalar::curveMult(const Scalar& b) const { mpz_class temp = element * b.element; mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t()); return Scalar(temp); } Scalar Scalar::curveMultInverse() const { mpz_class temp; mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t()); return Scalar(temp); } void Scalar::mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const { SecretScalar secret_element = to_scalar_t(); curvepoint_fp_scalarmult_vartime(rop, op1, secret_element.expose()); } void Scalar::mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const { SecretScalar secret_element = to_scalar_t(); twistpoint_fp2_scalarmult_vartime(rop, op1, secret_element.expose()); } void Scalar::mult(fp12e_t rop, const fp12e_t& op1) const { SecretScalar secret_element = to_scalar_t(); fp12e_pow_vartime(rop, op1, secret_element.expose()); } bool Scalar::operator==(const Scalar& b) const { return element == b.element; } bool Scalar::operator<(const Scalar& b) const { return element < b.element; } bool Scalar::operator<=(const Scalar& b) const { return element <= b.element; } bool Scalar::operator>(const Scalar& b) const { return element > b.element; } bool Scalar::operator>=(const Scalar& b) const { return element >= b.element; } bool Scalar::operator!=(const Scalar& b) const { return element != b.element; } Scalar::SecretScalar::SecretScalar() { } Scalar::SecretScalar::SecretScalar(const Scalar& input) { set(input.element); } Scalar::SecretScalar::SecretScalar(mpz_class input) { set(input); } const scalar_t& Scalar::SecretScalar::expose() const { return element; } void Scalar::SecretScalar::set(mpz_class input) { std::stringstream buffer; char temp[17]; buffer << std::setfill('0') << std::setw(64) << input.get_str(16); for (int i = 3; i >= 0; i--) { buffer.get(temp, 17); element[i] = strtoull(temp, NULL, 16); } } Scalar::SecretScalar Scalar::to_scalar_t() const { return SecretScalar(element); } char byteToHexByte(char in) { if (in < 0xA) return in + '0'; else return (in - 0xA) + 'a'; } char hexByteToByte(char in) { if (in >= '0' && in <= '9') return in - '0'; else if (in >= 'A' && in <= 'F') return (in - 'A') + 0xA; else return (in - 'a') + 0xA; } std::vector hexToBytes(const std::string& hex) { std::vector bytes; for (size_t i = 0; i < hex.length(); i += 2) { char partA, partB, currByte; partA = hex[i]; partB = hex[i + 1]; partA = hexByteToByte(partA); partB = hexByteToByte(partB); currByte = (partA << 4) | partB; bytes.push_back(currByte); } return bytes; } std::string bytesToHex(const std::vector& bytes) { std::string hex; for (size_t i = 0; i < bytes.size(); i++) { char partA, partB; partA = (0xF0 & bytes[i]) >> 4; partB = (0xF & bytes[i]); partA = byteToHexByte(partA); partB = byteToHexByte(partB); hex += partA; hex += partB; } return hex; } std::ostream& operator<<(std::ostream& os, const Scalar& output) { std::string outString = output.element.get_str(16); std::vector bytes = hexToBytes(outString); BinarySizeT sizeOfVector(bytes.size()); os << sizeOfVector; for (size_t i = 0; i < sizeOfVector.val(); i++) os.write(&(bytes[i]), sizeof(bytes[i])); return os; } std::istream& operator>>(std::istream& is, Scalar& input) { std::vector bytes; BinarySizeT sizeOfVector; is >> sizeOfVector; for (size_t i = 0; i < sizeOfVector.val(); i++) { char currByte; is.read(&currByte, sizeof(currByte)); bytes.push_back(currByte); } std::string hex = bytesToHex(bytes); input.element.set_str(hex, 16); return is; }