Scalar.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #include "Scalar.hpp"
  2. #include <iostream>
  3. #include "proof.hpp"
  4. extern const scalar_t bn_n;
  5. extern const scalar_t bn_p;
  6. mpz_class Scalar::mpz_bn_p = 0;
  7. mpz_class Scalar::mpz_bn_n = 0;
  8. Scalar::Scalar()
  9. {
  10. element = 0;
  11. }
  12. Scalar::Scalar(const scalar_t& input)
  13. {
  14. set(input);
  15. }
  16. Scalar::Scalar(mpz_class input)
  17. {
  18. set(input);
  19. }
  20. void Scalar::init()
  21. {
  22. mpz_bn_p = mpz_class("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16);
  23. mpz_bn_n = mpz_class("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16);
  24. }
  25. void Scalar::set(const scalar_t& input)
  26. {
  27. std::stringstream bufferstream;
  28. std::string buffer;
  29. mpz_class temp;
  30. bufferstream << std::hex << input[3] << input[2] << input[1] << input[0];
  31. bufferstream >> buffer;
  32. temp.set_str(buffer, 16);
  33. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  34. element = temp;
  35. }
  36. void Scalar::set(mpz_class input)
  37. {
  38. mpz_class temp = input;
  39. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  40. element = temp;
  41. }
  42. void Scalar::set_random()
  43. {
  44. scalar_t temp;
  45. /* When we ask for a random number,
  46. * we really mean a seed to find a random element of a group
  47. * and the order of the curve either is bn_n or is divided by it
  48. * (not bn_p) */
  49. scalar_setrandom(temp, bn_n);
  50. set(temp);
  51. }
  52. void Scalar::set_field_random()
  53. {
  54. scalar_t temp;
  55. /* There's only one occasion we actually want a random integer
  56. * in the field, and that's when we generate a private BGN key. */
  57. scalar_setrandom(temp, bn_p);
  58. set(temp);
  59. }
  60. mpz_class Scalar::toInt() const
  61. {
  62. return element;
  63. }
  64. Scalar Scalar::operator+(const Scalar& b) const
  65. {
  66. return this->curveAdd(b);
  67. }
  68. Scalar Scalar::operator-(const Scalar& b) const
  69. {
  70. return this->curveSub(b);
  71. }
  72. Scalar Scalar::operator*(const Scalar& b) const
  73. {
  74. return this->curveMult(b);
  75. }
  76. Scalar Scalar::operator/(const Scalar& b) const
  77. {
  78. return this->curveMult(b.curveMultInverse());
  79. }
  80. Scalar Scalar::operator-() const
  81. {
  82. return Scalar(0).curveSub(*this);
  83. }
  84. Scalar& Scalar::operator++()
  85. {
  86. *this = this->curveAdd(Scalar(1));
  87. return *this;
  88. }
  89. Scalar Scalar::operator++(int)
  90. {
  91. Scalar retval = *this;
  92. *this = this->curveAdd(Scalar(1));
  93. return retval;
  94. }
  95. Scalar& Scalar::operator--()
  96. {
  97. *this = this->curveSub(Scalar(1));
  98. return *this;
  99. }
  100. Scalar Scalar::operator--(int)
  101. {
  102. Scalar retval = *this;
  103. *this = this->curveSub(Scalar(1));
  104. return retval;
  105. }
  106. Scalar Scalar::fieldAdd(const Scalar& b) const
  107. {
  108. mpz_class temp = element + b.element;
  109. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  110. return Scalar(temp);
  111. }
  112. Scalar Scalar::fieldSub(const Scalar& b) const
  113. {
  114. mpz_class temp = element - b.element;
  115. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  116. return Scalar(temp);
  117. }
  118. Scalar Scalar::fieldMult(const Scalar& b) const
  119. {
  120. mpz_class temp = element * b.element;
  121. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  122. return Scalar(temp);
  123. }
  124. Scalar Scalar::fieldMultInverse() const
  125. {
  126. mpz_class temp;
  127. mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
  128. return Scalar(temp);
  129. }
  130. Scalar Scalar::curveAdd(const Scalar& b) const
  131. {
  132. mpz_class temp = element + b.element;
  133. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  134. return Scalar(temp);
  135. }
  136. Scalar Scalar::curveSub(const Scalar& b) const
  137. {
  138. mpz_class temp = element - b.element;
  139. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  140. return Scalar(temp);
  141. }
  142. Scalar Scalar::curveMult(const Scalar& b) const
  143. {
  144. mpz_class temp = element * b.element;
  145. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  146. return Scalar(temp);
  147. }
  148. Scalar Scalar::curveMultInverse() const
  149. {
  150. mpz_class temp;
  151. mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
  152. return Scalar(temp);
  153. }
  154. void Scalar::mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const
  155. {
  156. SecretScalar secret_element = to_scalar_t();
  157. curvepoint_fp_scalarmult_vartime(rop, op1, secret_element.expose());
  158. }
  159. void Scalar::mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const
  160. {
  161. SecretScalar secret_element = to_scalar_t();
  162. twistpoint_fp2_scalarmult_vartime(rop, op1, secret_element.expose());
  163. }
  164. void Scalar::mult(fp12e_t rop, const fp12e_t& op1) const
  165. {
  166. SecretScalar secret_element = to_scalar_t();
  167. fp12e_pow_vartime(rop, op1, secret_element.expose());
  168. }
  169. bool Scalar::operator==(const Scalar& b) const
  170. {
  171. return element == b.element;
  172. }
  173. bool Scalar::operator<(const Scalar& b) const
  174. {
  175. return element < b.element;
  176. }
  177. bool Scalar::operator<=(const Scalar& b) const
  178. {
  179. return element <= b.element;
  180. }
  181. bool Scalar::operator>(const Scalar& b) const
  182. {
  183. return element > b.element;
  184. }
  185. bool Scalar::operator>=(const Scalar& b) const
  186. {
  187. return element >= b.element;
  188. }
  189. bool Scalar::operator!=(const Scalar& b) const
  190. {
  191. return element != b.element;
  192. }
  193. Scalar::SecretScalar::SecretScalar()
  194. { }
  195. Scalar::SecretScalar::SecretScalar(const Scalar& input)
  196. {
  197. set(input.element);
  198. }
  199. Scalar::SecretScalar::SecretScalar(mpz_class input)
  200. {
  201. set(input);
  202. }
  203. const scalar_t& Scalar::SecretScalar::expose() const
  204. {
  205. return element;
  206. }
  207. void Scalar::SecretScalar::set(mpz_class input)
  208. {
  209. std::stringstream buffer;
  210. char temp[17];
  211. buffer << std::setfill('0') << std::setw(64) << input.get_str(16);
  212. for (int i = 3; i >= 0; i--)
  213. {
  214. buffer.get(temp, 17);
  215. element[i] = strtoull(temp, NULL, 16);
  216. }
  217. }
  218. Scalar::SecretScalar Scalar::to_scalar_t() const
  219. {
  220. return SecretScalar(element);
  221. }
  222. char byteToHexByte(char in)
  223. {
  224. if (in < 0xA)
  225. return in + '0';
  226. else
  227. return (in - 0xA) + 'a';
  228. }
  229. char hexByteToByte(char in)
  230. {
  231. if (in >= '0' && in <= '9')
  232. return in - '0';
  233. else if (in >= 'A' && in <= 'F')
  234. return (in - 'A') + 0xA;
  235. else
  236. return (in - 'a') + 0xA;
  237. }
  238. std::vector<char> hexToBytes(const std::string& hex)
  239. {
  240. std::vector<char> bytes;
  241. for (size_t i = 0; i < hex.length(); i += 2)
  242. {
  243. char partA, partB, currByte;
  244. partA = hex[i];
  245. partB = hex[i + 1];
  246. partA = hexByteToByte(partA);
  247. partB = hexByteToByte(partB);
  248. currByte = (partA << 4) | partB;
  249. bytes.push_back(currByte);
  250. }
  251. return bytes;
  252. }
  253. std::string bytesToHex(const std::vector<char>& bytes)
  254. {
  255. std::string hex;
  256. for (size_t i = 0; i < bytes.size(); i++)
  257. {
  258. char partA, partB;
  259. partA = (0xF0 & bytes[i]) >> 4;
  260. partB = (0xF & bytes[i]);
  261. partA = byteToHexByte(partA);
  262. partB = byteToHexByte(partB);
  263. hex += partA;
  264. hex += partB;
  265. }
  266. return hex;
  267. }
  268. std::ostream& operator<<(std::ostream& os, const Scalar& output)
  269. {
  270. if (os.flags() & std::ios::hex)
  271. {
  272. os << output.toInt();
  273. return os;
  274. }
  275. std::string outString = output.element.get_str(16);
  276. if (outString.size() % 2 == 1)
  277. outString = "0" + outString;
  278. std::vector<char> bytes = hexToBytes(outString);
  279. BinarySizeT sizeOfVector(bytes.size());
  280. os << sizeOfVector;
  281. for (size_t i = 0; i < sizeOfVector.val(); i++)
  282. os.write(&(bytes[i]), sizeof(bytes[i]));
  283. return os;
  284. }
  285. std::istream& operator>>(std::istream& is, Scalar& input)
  286. {
  287. std::vector<char> bytes;
  288. BinarySizeT sizeOfVector;
  289. is >> sizeOfVector;
  290. for (size_t i = 0; i < sizeOfVector.val(); i++)
  291. {
  292. char currByte;
  293. is.read(&currByte, sizeof(currByte));
  294. bytes.push_back(currByte);
  295. }
  296. std::string hex = bytesToHex(bytes);
  297. input.element.set_str(hex, 16);
  298. return is;
  299. }