Scalar.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. #include "Scalar.hpp"
  2. #include <iostream>
  3. extern const scalar_t bn_n;
  4. extern const scalar_t bn_p;
  5. mpz_class Scalar::mpz_bn_p = 0;
  6. mpz_class Scalar::mpz_bn_n = 0;
  7. Scalar::Scalar()
  8. {
  9. element = 0;
  10. }
  11. Scalar::Scalar(const scalar_t& input)
  12. {
  13. set(input);
  14. }
  15. Scalar::Scalar(mpz_class input)
  16. {
  17. set(input);
  18. }
  19. void Scalar::init()
  20. {
  21. mpz_bn_p = mpz_class("8FB501E34AA387F9AA6FECB86184DC21EE5B88D120B5B59E185CAC6C5E089667", 16);
  22. mpz_bn_n = mpz_class("8FB501E34AA387F9AA6FECB86184DC212E8D8E12F82B39241A2EF45B57AC7261", 16);
  23. }
  24. void Scalar::set(const scalar_t& input)
  25. {
  26. std::stringstream bufferstream;
  27. std::string buffer;
  28. mpz_class temp;
  29. bufferstream << std::hex << input[3] << input[2] << input[1] << input[0];
  30. bufferstream >> buffer;
  31. temp.set_str(buffer, 16);
  32. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  33. element = temp;
  34. }
  35. void Scalar::set(mpz_class input)
  36. {
  37. mpz_class temp = input;
  38. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  39. element = temp;
  40. }
  41. void Scalar::set_random()
  42. {
  43. scalar_t temp;
  44. /* When we ask for a random number,
  45. * we really mean a seed to find a random element of a group
  46. * and the order of the curve either is bn_n or is divided by it
  47. * (not bn_p) */
  48. scalar_setrandom(temp, bn_n);
  49. set(temp);
  50. }
  51. void Scalar::set_field_random()
  52. {
  53. scalar_t temp;
  54. /* There's only one occasion we actually want a random integer
  55. * in the field, and that's when we generate a private BGN key. */
  56. scalar_setrandom(temp, bn_p);
  57. set(temp);
  58. }
  59. mpz_class Scalar::toInt() const
  60. {
  61. return element;
  62. }
  63. Scalar Scalar::operator+(const Scalar& b) const
  64. {
  65. return this->curveAdd(b);
  66. }
  67. Scalar Scalar::operator-(const Scalar& b) const
  68. {
  69. return this->curveSub(b);
  70. }
  71. Scalar Scalar::operator*(const Scalar& b) const
  72. {
  73. return this->curveMult(b);
  74. }
  75. Scalar Scalar::operator/(const Scalar& b) const
  76. {
  77. return this->curveMult(b.curveMultInverse());
  78. }
  79. Scalar Scalar::operator-() const
  80. {
  81. return Scalar(0).curveSub(*this);
  82. }
  83. Scalar& Scalar::operator++()
  84. {
  85. *this = this->curveAdd(Scalar(1));
  86. return *this;
  87. }
  88. Scalar Scalar::operator++(int)
  89. {
  90. Scalar retval = *this;
  91. *this = this->curveAdd(Scalar(1));
  92. return retval;
  93. }
  94. Scalar& Scalar::operator--()
  95. {
  96. *this = this->curveSub(Scalar(1));
  97. return *this;
  98. }
  99. Scalar Scalar::operator--(int)
  100. {
  101. Scalar retval = *this;
  102. *this = this->curveSub(Scalar(1));
  103. return retval;
  104. }
  105. Scalar Scalar::fieldAdd(const Scalar& b) const
  106. {
  107. mpz_class temp = element + b.element;
  108. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  109. return Scalar(temp);
  110. }
  111. Scalar Scalar::fieldSub(const Scalar& b) const
  112. {
  113. mpz_class temp = element - b.element;
  114. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  115. return Scalar(temp);
  116. }
  117. Scalar Scalar::fieldMult(const Scalar& b) const
  118. {
  119. mpz_class temp = element * b.element;
  120. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_p.get_mpz_t());
  121. return Scalar(temp);
  122. }
  123. Scalar Scalar::fieldMultInverse() const
  124. {
  125. mpz_class temp;
  126. mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_p.get_mpz_t());
  127. return Scalar(temp);
  128. }
  129. Scalar Scalar::curveAdd(const Scalar& b) const
  130. {
  131. mpz_class temp = element + b.element;
  132. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  133. return Scalar(temp);
  134. }
  135. Scalar Scalar::curveSub(const Scalar& b) const
  136. {
  137. mpz_class temp = element - b.element;
  138. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  139. return Scalar(temp);
  140. }
  141. Scalar Scalar::curveMult(const Scalar& b) const
  142. {
  143. mpz_class temp = element * b.element;
  144. mpz_mod(temp.get_mpz_t(), temp.get_mpz_t(), mpz_bn_n.get_mpz_t());
  145. return Scalar(temp);
  146. }
  147. Scalar Scalar::curveMultInverse() const
  148. {
  149. mpz_class temp;
  150. mpz_invert(temp.get_mpz_t(), element.get_mpz_t(), mpz_bn_n.get_mpz_t());
  151. return Scalar(temp);
  152. }
  153. void Scalar::mult(curvepoint_fp_t rop, const curvepoint_fp_t& op1) const
  154. {
  155. SecretScalar secret_element = to_scalar_t();
  156. curvepoint_fp_scalarmult_vartime(rop, op1, secret_element.expose());
  157. }
  158. void Scalar::mult(twistpoint_fp2_t rop, const twistpoint_fp2_t& op1) const
  159. {
  160. SecretScalar secret_element = to_scalar_t();
  161. twistpoint_fp2_scalarmult_vartime(rop, op1, secret_element.expose());
  162. }
  163. void Scalar::mult(fp12e_t rop, const fp12e_t& op1) const
  164. {
  165. SecretScalar secret_element = to_scalar_t();
  166. fp12e_pow_vartime(rop, op1, secret_element.expose());
  167. }
  168. bool Scalar::operator==(const Scalar& b) const
  169. {
  170. return element == b.element;
  171. }
  172. bool Scalar::operator<(const Scalar& b) const
  173. {
  174. return element < b.element;
  175. }
  176. bool Scalar::operator<=(const Scalar& b) const
  177. {
  178. return element <= b.element;
  179. }
  180. bool Scalar::operator>(const Scalar& b) const
  181. {
  182. return element > b.element;
  183. }
  184. bool Scalar::operator>=(const Scalar& b) const
  185. {
  186. return element >= b.element;
  187. }
  188. bool Scalar::operator!=(const Scalar& b) const
  189. {
  190. return element != b.element;
  191. }
  192. Scalar::SecretScalar::SecretScalar()
  193. { }
  194. Scalar::SecretScalar::SecretScalar(const Scalar& input)
  195. {
  196. set(input.element);
  197. }
  198. Scalar::SecretScalar::SecretScalar(mpz_class input)
  199. {
  200. set(input);
  201. }
  202. const scalar_t& Scalar::SecretScalar::expose() const
  203. {
  204. return element;
  205. }
  206. void Scalar::SecretScalar::set(mpz_class input)
  207. {
  208. std::stringstream buffer;
  209. char temp[17];
  210. buffer << std::setfill('0') << std::setw(64) << input.get_str(16);
  211. for (int i = 3; i >= 0; i--)
  212. {
  213. buffer.get(temp, 17);
  214. element[i] = strtoull(temp, NULL, 16);
  215. }
  216. }
  217. Scalar::SecretScalar Scalar::to_scalar_t() const
  218. {
  219. return SecretScalar(element);
  220. }
  221. std::ostream& operator<<(std::ostream& os, const Scalar& output)
  222. {
  223. os << output.element;
  224. return os;
  225. }
  226. std::istream& operator>>(std::istream& is, Scalar& input)
  227. {
  228. is >> input.element;
  229. return is;
  230. }