123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- #include "Scalar.hpp"
- #include <iostream>
- #include "proof.hpp"
- extern const scalar_t bn_n;
- extern const scalar_t bn_p;
- mpz_class Scalar::mpz_bn_p = 0;
- mpz_class Scalar::mpz_bn_n = 0;
- Scalar::Scalar()
- {
- element = 0;
- }
- Scalar::Scalar(const scalar_t& input)
- {
- set(input);
- }
- Scalar::Scalar(mpz_class input)
- {
- set(input);
- }
- void Scalar::init()
- {
- mpz_bn_p = mpz_class("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16);
- mpz_bn_n = mpz_class("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16);
- }
- void Scalar::set(const scalar_t& input)
- {
- std::stringstream bufferstream;
- std::string buffer;
- mpz_class temp;
- bufferstream << std::hex << input[3] << input[2] << input[1] << input[0];
- bufferstream >> buffer;
-
- temp.set_str(buffer, 16);
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
- element = temp;
- }
- void Scalar::set(mpz_class input)
- {
- mpz_class temp = input;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
- element = temp;
- }
- void Scalar::set_random()
- {
- scalar_t temp;
-
- /* When we ask for a random number,
- * we really mean a seed to find a random element of a group
- * and the order of the curve either is bn_n or is divided by it
- * (not bn_p) */
- scalar_setrandom(temp, bn_n);
- set(temp);
- }
- void Scalar::set_field_random()
- {
- scalar_t temp;
-
- /* There's only one occasion we actually want a random integer
- * in the field, and that's when we generate a private BGN key. */
- scalar_setrandom(temp, bn_p);
- set(temp);
- }
- mpz_class Scalar::toInt() const
- {
- return element;
- }
- Scalar Scalar::operator+(const Scalar& b) const
- {
- return this->curveAdd(b);
- }
- Scalar Scalar::operator-(const Scalar& b) const
- {
- return this->curveSub(b);
- }
- Scalar Scalar::operator*(const Scalar& b) const
- {
- return this->curveMult(b);
- }
- Scalar Scalar::operator/(const Scalar& b) const
- {
- return this->curveMult(b.curveMultInverse());
- }
- Scalar Scalar::operator-() const
- {
- return Scalar(0).curveSub(*this);
- }
- Scalar& Scalar::operator++()
- {
- *this = this->curveAdd(Scalar(1));
- return *this;
- }
- Scalar Scalar::operator++(int)
- {
- Scalar retval = *this;
-
- *this = this->curveAdd(Scalar(1));
- return retval;
- }
- Scalar& Scalar::operator--()
- {
- *this = this->curveSub(Scalar(1));
- return *this;
- }
- Scalar Scalar::operator--(int)
- {
- Scalar retval = *this;
-
- *this = this->curveSub(Scalar(1));
- return retval;
- }
- Scalar Scalar::fieldAdd(const Scalar& b) const
- {
- mpz_class temp = element + b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::fieldSub(const Scalar& b) const
- {
- mpz_class temp = element - b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::fieldMult(const Scalar& b) const
- {
- mpz_class temp = element * b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::fieldMultInverse() const
- {
- mpz_class temp;
- mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::curveAdd(const Scalar& b) const
- {
- mpz_class temp = element + b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::curveSub(const Scalar& b) const
- {
- mpz_class temp = element - b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::curveMult(const Scalar& b) const
- {
- mpz_class temp = element * b.element;
- mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
- return Scalar(temp);
- }
- Scalar Scalar::curveMultInverse() const
- {
- mpz_class temp;
- mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
- return Scalar(temp);
- }
- void Scalar::mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const
- {
- SecretScalar secret_element = to_scalar_t();
- curvepoint_fp_scalarmult_vartime(rop, op1, secret_element.expose());
- }
- void Scalar::mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const
- {
- SecretScalar secret_element = to_scalar_t();
- twistpoint_fp2_scalarmult_vartime(rop, op1, secret_element.expose());
- }
- void Scalar::mult(fp12e_t rop, const fp12e_t& op1) const
- {
- SecretScalar secret_element = to_scalar_t();
- fp12e_pow_vartime(rop, op1, secret_element.expose());
- }
- bool Scalar::operator==(const Scalar& b) const
- {
- return element == b.element;
- }
- bool Scalar::operator<(const Scalar& b) const
- {
- return element < b.element;
- }
- bool Scalar::operator<=(const Scalar& b) const
- {
- return element <= b.element;
- }
- bool Scalar::operator>(const Scalar& b) const
- {
- return element > b.element;
- }
- bool Scalar::operator>=(const Scalar& b) const
- {
- return element >= b.element;
- }
- bool Scalar::operator!=(const Scalar& b) const
- {
- return element != b.element;
- }
- Scalar::SecretScalar::SecretScalar()
- { }
- Scalar::SecretScalar::SecretScalar(const Scalar& input)
- {
- set(input.element);
- }
- Scalar::SecretScalar::SecretScalar(mpz_class input)
- {
- set(input);
- }
- const scalar_t& Scalar::SecretScalar::expose() const
- {
- return element;
- }
- void Scalar::SecretScalar::set(mpz_class input)
- {
- std::stringstream buffer;
- char temp[17];
- buffer << std::setfill('0') << std::setw(64) << input.get_str(16);
- for (int i = 3; i >= 0; i--)
- {
- buffer.get(temp, 17);
- element[i] = strtoull(temp, NULL, 16);
- }
- }
-
- Scalar::SecretScalar Scalar::to_scalar_t() const
- {
- return SecretScalar(element);
- }
- char byteToHexByte(char in)
- {
- if (in < 0xA)
- return in + '0';
- else
- return (in - 0xA) + 'a';
- }
- char hexByteToByte(char in)
- {
- if (in >= '0' && in <= '9')
- return in - '0';
- else if (in >= 'A' && in <= 'F')
- return (in - 'A') + 0xA;
- else
- return (in - 'a') + 0xA;
- }
- std::vector<char> hexToBytes(const std::string& hex)
- {
- std::vector<char> bytes;
- for (size_t i = 0; i < hex.length(); i += 2)
- {
- char partA, partB, currByte;
- partA = hex[i];
- partB = hex[i + 1];
- partA = hexByteToByte(partA);
- partB = hexByteToByte(partB);
- currByte = (partA << 4) | partB;
- bytes.push_back(currByte);
- }
- return bytes;
- }
- std::string bytesToHex(const std::vector<char>& bytes)
- {
- std::string hex;
- for (size_t i = 0; i < bytes.size(); i++)
- {
- char partA, partB;
- partA = (0xF0 & bytes[i]) >> 4;
- partB = (0xF & bytes[i]);
- partA = byteToHexByte(partA);
- partB = byteToHexByte(partB);
- hex += partA;
- hex += partB;
- }
- return hex;
- }
- std::ostream& operator<<(std::ostream& os, const Scalar& output)
- {
- if (os.flags() & std::ios::hex)
- {
- os << output.toInt();
- return os;
- }
- std::string outString = output.element.get_str(16);
-
- if (outString.size() % 2 == 1)
- outString = "0" + outString;
- std::vector<char> bytes = hexToBytes(outString);
- BinarySizeT sizeOfVector(bytes.size());
- os << sizeOfVector;
- for (size_t i = 0; i < sizeOfVector.val(); i++)
- os.write(&(bytes[i]), sizeof(bytes[i]));
- return os;
- }
- std::istream& operator>>(std::istream& is, Scalar& input)
- {
- std::vector<char> bytes;
- BinarySizeT sizeOfVector;
- is >> sizeOfVector;
- for (size_t i = 0; i < sizeOfVector.val(); i++)
- {
- char currByte;
- is.read(&currByte, sizeof(currByte));
- bytes.push_back(currByte);
- }
- std::string hex = bytesToHex(bytes);
- input.element.set_str(hex, 16);
- return is;
- }
|