djn.cpp 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /**
  2. \file djn.cpp
  3. \author Daniel Demmler
  4. \copyright Copyright (C) 2019 ENCRYPTO Group, TU Darmstadt
  5. This program is free software: you can redistribute it and/or modify
  6. it under the terms of the GNU Lesser General Public License as published
  7. by the Free Software Foundation, either version 3 of the License, or
  8. (at your option) any later version.
  9. ABY is distributed in the hope that it will be useful,
  10. but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. GNU Lesser General Public License for more details.
  13. You should have received a copy of the GNU Lesser General Public License
  14. along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. \brief
  16. libdjn - v0.9
  17. A library implementing the Damgaard Jurik Nielsen cryptosystem with s=1 (~Paillier).
  18. based on:
  19. libpaillier - A library implementing the Paillier cryptosystem.
  20. (http://hms.isi.jhu.edu/acsc/libpaillier/)
  21. */
  22. #include "djn.h"
  23. #include "../powmod.h"
  24. #include "../utils.h"
  25. #include <cstdlib>
  26. #define DJN_DEBUG 0
  27. #define DJN_CHECKSIZE 0
  28. void djn_complete_pubkey(unsigned int modulusbits, djn_pubkey_t** pub, mpz_t n, mpz_t h) {
  29. *pub = (djn_pubkey_t*) malloc(sizeof(djn_pubkey_t));
  30. /* initialize our integers */
  31. mpz_init((*pub)->n);
  32. mpz_init((*pub)->n_squared);
  33. mpz_init((*pub)->h);
  34. mpz_init((*pub)->h_s);
  35. mpz_set((*pub)->n, n);
  36. mpz_set((*pub)->h, h);
  37. mpz_mul((*pub)->n_squared, n, n);
  38. mpz_powm((*pub)->h_s, h, n, (*pub)->n_squared);
  39. (*pub)->bits = modulusbits;
  40. (*pub)->rbits = modulusbits % 2 ? modulusbits / 2 + 1 : modulusbits / 2; // rbits = ceil(bits/2)
  41. }
  42. void djn_keygen(unsigned int modulusbits, djn_pubkey_t** pub, djn_prvkey_t** prv) {
  43. mpz_t test, x;
  44. /* allocate the new key structures */
  45. *pub = (djn_pubkey_t*) malloc(sizeof(djn_pubkey_t));
  46. *prv = (djn_prvkey_t*) malloc(sizeof(djn_prvkey_t));
  47. /* initialize our integers */
  48. mpz_init((*pub)->n);
  49. mpz_init((*pub)->n_squared);
  50. mpz_init((*pub)->h);
  51. mpz_init((*pub)->h_s);
  52. mpz_init((*prv)->lambda);
  53. mpz_init((*prv)->lambda_inverse);
  54. mpz_init((*prv)->p);
  55. mpz_init((*prv)->q);
  56. mpz_init((*prv)->p_squared);
  57. mpz_init((*prv)->q_squared);
  58. mpz_init((*prv)->q_inverse);
  59. mpz_init((*prv)->q_squared_inverse);
  60. mpz_init((*prv)->p_minusone);
  61. mpz_init((*prv)->q_minusone);
  62. mpz_init((*prv)->ordpsq);
  63. mpz_init((*prv)->ordqsq);
  64. mpz_init(test);
  65. mpz_init(x);
  66. do {
  67. // choose bits of p and q randomly
  68. aby_prng((*prv)->p, modulusbits / 2);
  69. aby_prng((*prv)->q, modulusbits / 2);
  70. // set highest bit to 1 to ensure high length
  71. mpz_setbit((*prv)->p, modulusbits / 2);
  72. mpz_setbit((*prv)->q, modulusbits / 2);
  73. //find next primes
  74. do {
  75. mpz_nextprime((*prv)->p, (*prv)->p);
  76. } while (!mpz_tstbit((*prv)->p, 1)); //make sure p mod 4 = 3
  77. do {
  78. mpz_nextprime((*prv)->q, (*prv)->q);
  79. } while (!mpz_cmp((*prv)->p, (*prv)->q) || !mpz_tstbit((*prv)->q, 1)); //make sure p!=q and q mod 4 = 3
  80. /* p-1 and q-1 */
  81. mpz_sub_ui((*prv)->p_minusone, (*prv)->p, 1);
  82. mpz_sub_ui((*prv)->q_minusone, (*prv)->q, 1);
  83. mpz_gcd(test, (*prv)->p_minusone, (*prv)->q_minusone);
  84. } while (mpz_cmp_ui(test, 2)); // make sure gcd(p-1,q-1)=2
  85. //} while((mpz_cmp_ui(test,2) || !mpz_tstbit((*pub)->n, modulusbits - 1) ); // make sure gcd(p-1,q-1)=2 and first bit of n is set
  86. //complete_pubkey(*pub);
  87. /* compute the public modulus n = p q */
  88. mpz_mul((*pub)->n, (*prv)->p, (*prv)->q);
  89. mpz_mul((*pub)->n_squared, (*pub)->n, (*pub)->n);
  90. #if DJN_DEBUG
  91. if (!mpz_tstbit((*pub)->n, modulusbits - 1)) {
  92. printf("DJN n too small!?\n");
  93. }
  94. #endif
  95. /* p^2 and q^2 */
  96. mpz_mul((*prv)->p_squared, (*prv)->p, (*prv)->p);
  97. mpz_mul((*prv)->q_squared, (*prv)->q, (*prv)->q);
  98. mpz_sub((*prv)->ordpsq, (*prv)->p_squared, (*prv)->p);
  99. mpz_sub((*prv)->ordqsq, (*prv)->q_squared, (*prv)->q);
  100. /* computer multiplicative inverse of q mod p and q^2 mod p^2 for CRT*/
  101. mpz_invert((*prv)->q_inverse, (*prv)->q, (*prv)->p);
  102. mpz_invert((*prv)->q_squared_inverse, (*prv)->q_squared, (*prv)->p_squared);
  103. /* save one multiplication for CRT */
  104. mpz_mul((*prv)->q_squared_inverse, (*prv)->q_squared_inverse, (*prv)->q_squared);
  105. mpz_mul((*prv)->q_inverse, (*prv)->q_inverse, (*prv)->q);
  106. #if DJN_DEBUG
  107. gmp_printf("p = %Zd\nq = %Zd\nn = %Zd\nn^2 = %Zd\n", (*prv)->p, (*prv)->q, (*pub)->n, (*pub)->n_squared);
  108. #endif
  109. /* pick random x in Z_n^* */
  110. do {
  111. aby_prng(x, mpz_sizeinbase((*pub)->n, 2) + 128);
  112. mpz_mod(x, x, (*pub)->n);
  113. mpz_gcd(test, x, (*pub)->n);
  114. } while (mpz_cmp_ui(test, 1));
  115. // gmp_printf("x = %Zd\n", x);
  116. mpz_mul(x, x, x);
  117. // gmp_printf("x^2 = %Zd\n", x);
  118. mpz_neg(x, x);
  119. // gmp_printf("-x^2 = %Zd\n", x);
  120. mpz_mod((*pub)->h, x, (*pub)->n);
  121. // mpz_powm((*pub)->h_s, (*pub)->h, (*pub)->n, (*pub)->n_squared);
  122. djn_pow_mod_n_squared_crt((*pub)->h_s, (*pub)->h, (*pub)->n, *pub, *prv);
  123. (*pub)->bits = modulusbits;
  124. (*pub)->rbits = modulusbits % 2 ? modulusbits / 2 + 1 : modulusbits / 2; // rbits = ceil(bits/2)
  125. /* compute the private key lambda = lcm(p-1,q-1) = (p-1)(q-1)/2 */
  126. //mpz_lcm((*prv)->lambda, (*prv)->p_minusone, (*prv)->q_minusone);
  127. mpz_mul((*prv)->lambda, (*prv)->p_minusone, (*prv)->q_minusone);
  128. mpz_fdiv_q_2exp((*prv)->lambda, (*prv)->lambda, 1); // division by two
  129. /* compute multiplicative inverse of lambda */
  130. mpz_invert((*prv)->lambda_inverse, (*prv)->lambda, (*pub)->n);
  131. #if DJN_DEBUG
  132. gmp_printf("h = %Zd\nh_s = %Zd\n", (*pub)->h, (*pub)->h_s);
  133. printf("rbits = %d, bits = %d\n", (*pub)->rbits, (*pub)->bits);
  134. gmp_printf("lambda = %Zd\nlambda_inverse = %Zd\n", (*prv)->lambda, (*prv)->lambda_inverse);
  135. #endif
  136. /* clear temporary integers */
  137. mpz_clears(x, test, NULL);
  138. }
  139. /**
  140. * encrypt plaintext to res
  141. */
  142. void djn_encrypt(mpz_t res, djn_pubkey_t* pub, mpz_t plaintext) {
  143. mpz_t r;
  144. mpz_init(r);
  145. #if DJN_CHECKSIZE
  146. if (mpz_cmp(plaintext, pub->n) >= 0) {
  147. printf("WARNING: m>=N!\n");
  148. }
  149. #endif
  150. /* pick random blinding factor r */
  151. aby_prng(r, pub->rbits);
  152. #if DJN_DEBUG
  153. gmp_printf("r = %Zd\n", r);
  154. #endif
  155. mpz_mul(res, plaintext, pub->n);
  156. mpz_add_ui(res, res, 1);
  157. mpz_mod(res, res, pub->n_squared);
  158. mpz_powm(r, pub->h_s, r, pub->n_squared);
  159. mpz_mul(res, res, r);
  160. mpz_mod(res, res, pub->n_squared);
  161. mpz_clear(r);
  162. }
  163. /**
  164. * encrypt plaintext using crt if private key is known
  165. */
  166. void djn_encrypt_crt(mpz_t res, djn_pubkey_t* pub, djn_prvkey_t* prv, mpz_t plaintext) {
  167. mpz_t r;
  168. mpz_init(r);
  169. #if DJN_CHECKSIZE
  170. if (mpz_cmp(plaintext, pub->n) >= 0) {
  171. printf("WARNING: m>=N!\n");
  172. }
  173. #endif
  174. /* pick random blinding factor r */
  175. aby_prng(r, pub->rbits);
  176. #if DJN_DEBUG
  177. gmp_printf("r = %Zd\n", r);
  178. #endif
  179. mpz_mul(res, plaintext, pub->n);
  180. mpz_add_ui(res, res, 1);
  181. mpz_mod(res, res, pub->n_squared);
  182. djn_pow_mod_n_squared_crt(r, pub->h_s, r, pub, prv);
  183. mpz_mul(res, res, r);
  184. mpz_mod(res, res, pub->n_squared);
  185. mpz_clear(r);
  186. }
  187. /**
  188. * mpz_t version of encrypt_crt
  189. */
  190. void djn_encrypt_fb(mpz_t res, djn_pubkey_t* pub, mpz_t plaintext) {
  191. mpz_t r;
  192. mpz_init(r);
  193. #if DJN_CHECKSIZE
  194. if (mpz_cmp(plaintext, pub->n) >= 0) {
  195. printf("WARNING: m>=N!\n");
  196. }
  197. #endif
  198. /* pick random blinding factor r */
  199. aby_prng(r, pub->rbits);
  200. #if DJN_DEBUG
  201. gmp_printf("r = %Zd\n", r);
  202. #endif
  203. mpz_mul(res, plaintext, pub->n);
  204. mpz_add_ui(res, res, 1);
  205. mpz_mod(res, res, pub->n_squared);
  206. // r = h_s ^ r
  207. fbpowmod_g(r, r);
  208. mpz_mul(res, res, r);
  209. mpz_mod(res, res, pub->n_squared);
  210. mpz_clear(r);
  211. }
  212. /**
  213. * decrypt, using CRT, assumes res to be initialized
  214. */
  215. void djn_decrypt(mpz_t res, djn_pubkey_t* pub, djn_prvkey_t* prv, mpz_t ciphertext) {
  216. /* powmod using CRT */
  217. djn_pow_mod_n_squared_crt(res, ciphertext, prv->lambda, pub, prv);
  218. mpz_sub_ui(res, res, 1);
  219. mpz_divexact(res, res, pub->n);
  220. mpz_mul(res, res, prv->lambda_inverse);
  221. mpz_mod(res, res, pub->n);
  222. }
  223. /**
  224. * plain decrypt version without crt (= much slower), assumes res to be initialized
  225. */
  226. void djn_decrypt_plain(mpz_t res, djn_pubkey_t* pub, djn_prvkey_t* prv, mpz_t ciphertext) {
  227. mpz_powm(res, ciphertext, prv->lambda, pub->n_squared);
  228. mpz_sub_ui(res, res, 1);
  229. mpz_divexact(res, res, pub->n);
  230. mpz_mul(res, res, prv->lambda_inverse);
  231. mpz_mod(res, res, pub->n);
  232. }
  233. void djn_freepubkey(djn_pubkey_t* pub) {
  234. mpz_clear(pub->n);
  235. mpz_clear(pub->h);
  236. mpz_clear(pub->n_squared);
  237. mpz_clear(pub->h_s);
  238. free(pub);
  239. }
  240. void djn_freeprvkey(djn_prvkey_t* prv) {
  241. mpz_clears(prv->lambda, prv->lambda_inverse, prv->ordpsq, prv->ordqsq, prv->p, prv->p_minusone, prv->p_squared, prv->q,
  242. prv->q_minusone, prv->q_squared, prv->q_inverse, prv->q_squared_inverse,
  243. NULL);
  244. free(prv);
  245. }
  246. /* calculate base^exp mod n using fermats little theorem and CRT */
  247. void djn_pow_mod_n_crt(mpz_t res, const mpz_t base, const mpz_t exp, const djn_pubkey_t* pub, const djn_prvkey_t* prv) {
  248. mpz_t temp, cp, cq;
  249. mpz_inits(cp, cq, temp, NULL);
  250. /* smaller exponents due to fermat: e mod (p-1), e mod (q-1) */
  251. mpz_mod(cp, exp, prv->p_minusone);
  252. mpz_mod(cq, exp, prv->q_minusone);
  253. /* smaller exponentiations of base mod p, q */
  254. mpz_mod(temp, base, prv->p);
  255. mpz_powm(cp, temp, cp, prv->p);
  256. mpz_mod(temp, base, prv->q);
  257. mpz_powm(cq, temp, cq, prv->q);
  258. /* CRT to calculate base^exp mod (pq) */
  259. mpz_sub(cp, cp, cq);
  260. mpz_addmul(cq, cp, prv->q_inverse);
  261. mpz_mod(res, cq, pub->n);
  262. mpz_clears(cp, cq, temp, NULL);
  263. }
  264. /* calculate base^exp mod n^2 using fermats little theorem and CRT */
  265. void djn_pow_mod_n_squared_crt(mpz_t res, const mpz_t base, const mpz_t exp, const djn_pubkey_t* pub, const djn_prvkey_t* prv) {
  266. mpz_t temp, cp, cq;
  267. mpz_inits(cp, cq, temp, NULL);
  268. /* smaller exponents due to fermat: e mod (p-1), e mod (q-1) */
  269. mpz_mod(cp, exp, prv->ordpsq);
  270. mpz_mod(cq, exp, prv->ordqsq);
  271. /* smaller exponentiations of base mod p^2, q^2 */
  272. mpz_mod(temp, base, prv->p_squared);
  273. mpz_powm(cp, temp, cp, prv->p_squared);
  274. mpz_mod(temp, base, prv->q_squared);
  275. mpz_powm(cq, temp, cq, prv->q_squared);
  276. /* CRT to calculate base^exp mod n^2 */
  277. mpz_sub(cp, cp, cq);
  278. mpz_addmul(cq, cp, prv->q_squared_inverse);
  279. mpz_mod(res, cq, pub->n_squared);
  280. mpz_clears(cp, cq, temp, NULL);
  281. }