ed25519-donna-batchverify.h 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. /*
  2. Ed25519 batch verification
  3. */
  4. #define max_batch_size 64
  5. #define heap_batch_size ((max_batch_size * 2) + 1)
  6. /* which limb is the 128th bit in? */
  7. static const size_t limb128bits = (128 + bignum256modm_bits_per_limb - 1) / bignum256modm_bits_per_limb;
  8. typedef size_t heap_index_t;
  9. typedef struct batch_heap_t {
  10. unsigned char r[heap_batch_size][16]; /* 128 bit random values */
  11. ge25519 points[heap_batch_size];
  12. bignum256modm scalars[heap_batch_size];
  13. heap_index_t heap[heap_batch_size];
  14. size_t size;
  15. } batch_heap;
  16. /* swap two values in the heap */
  17. static void
  18. heap_swap(heap_index_t *heap, size_t a, size_t b) {
  19. heap_index_t temp;
  20. temp = heap[a];
  21. heap[a] = heap[b];
  22. heap[b] = temp;
  23. }
  24. /* add the scalar at the end of the list to the heap */
  25. static void
  26. heap_insert_next(batch_heap *heap) {
  27. size_t node = heap->size, parent;
  28. heap_index_t *pheap = heap->heap;
  29. bignum256modm *scalars = heap->scalars;
  30. /* insert at the bottom */
  31. pheap[node] = (heap_index_t)node;
  32. /* sift node up to its sorted spot */
  33. parent = (node - 1) / 2;
  34. while (node && lt256_modm_batch(scalars[pheap[parent]], scalars[pheap[node]], bignum256modm_limb_size - 1)) {
  35. heap_swap(pheap, parent, node);
  36. node = parent;
  37. parent = (node - 1) / 2;
  38. }
  39. heap->size++;
  40. }
  41. /* update the heap when the root element is updated */
  42. static void
  43. heap_updated_root(batch_heap *heap, size_t limbsize) {
  44. size_t node, parent, childr, childl;
  45. heap_index_t *pheap = heap->heap;
  46. bignum256modm *scalars = heap->scalars;
  47. /* sift root to the bottom */
  48. parent = 0;
  49. node = 1;
  50. childl = 1;
  51. childr = 2;
  52. while ((childr < heap->size)) {
  53. node = lt256_modm_batch(scalars[pheap[childl]], scalars[pheap[childr]], limbsize) ? childr : childl;
  54. heap_swap(pheap, parent, node);
  55. parent = node;
  56. childl = (parent * 2) + 1;
  57. childr = childl + 1;
  58. }
  59. /* sift root back up to its sorted spot */
  60. parent = (node - 1) / 2;
  61. while (node && lte256_modm_batch(scalars[pheap[parent]], scalars[pheap[node]], limbsize)) {
  62. heap_swap(pheap, parent, node);
  63. node = parent;
  64. parent = (node - 1) / 2;
  65. }
  66. }
  67. /* build the heap with count elements, count must be >= 3 */
  68. static void
  69. heap_build(batch_heap *heap, size_t count) {
  70. heap->heap[0] = 0;
  71. heap->size = 0;
  72. while (heap->size < count)
  73. heap_insert_next(heap);
  74. }
  75. /* extend the heap to contain new_count elements */
  76. static void
  77. heap_extend(batch_heap *heap, size_t new_count) {
  78. while (heap->size < new_count)
  79. heap_insert_next(heap);
  80. }
  81. /* get the top 2 elements of the heap */
  82. static void
  83. heap_get_top2(batch_heap *heap, heap_index_t *max1, heap_index_t *max2, size_t limbsize) {
  84. heap_index_t h0 = heap->heap[0], h1 = heap->heap[1], h2 = heap->heap[2];
  85. if (lt256_modm_batch(heap->scalars[h1], heap->scalars[h2], limbsize))
  86. h1 = h2;
  87. *max1 = h0;
  88. *max2 = h1;
  89. }
  90. /* */
  91. static void
  92. ge25519_multi_scalarmult_vartime_final(ge25519 *r, ge25519 *point, bignum256modm scalar) {
  93. const bignum256modm_element_t topbit = ((bignum256modm_element_t)1 << (bignum256modm_bits_per_limb - 1));
  94. size_t limb = limb128bits;
  95. bignum256modm_element_t flag;
  96. if (isone256_modm_batch(scalar)) {
  97. /* this will happen most of the time after bos-carter */
  98. *r = *point;
  99. return;
  100. } else if (iszero256_modm_batch(scalar)) {
  101. /* this will only happen if all scalars == 0 */
  102. memset(r, 0, sizeof(*r));
  103. r->y[0] = 1;
  104. r->z[0] = 1;
  105. return;
  106. }
  107. *r = *point;
  108. /* find the limb where first bit is set */
  109. while (!scalar[limb])
  110. limb--;
  111. /* find the first bit */
  112. flag = topbit;
  113. while ((scalar[limb] & flag) == 0)
  114. flag >>= 1;
  115. /* exponentiate */
  116. for (;;) {
  117. ge25519_double(r, r);
  118. if (scalar[limb] & flag)
  119. ge25519_add(r, r, point);
  120. flag >>= 1;
  121. if (!flag) {
  122. if (!limb--)
  123. break;
  124. flag = topbit;
  125. }
  126. }
  127. }
  128. /* count must be >= 5 */
  129. static void
  130. ge25519_multi_scalarmult_vartime(ge25519 *r, batch_heap *heap, size_t count) {
  131. heap_index_t max1, max2;
  132. /* start with the full limb size */
  133. size_t limbsize = bignum256modm_limb_size - 1;
  134. /* whether the heap has been extended to include the 128 bit scalars */
  135. int extended = 0;
  136. /* grab an odd number of scalars to build the heap, unknown limb sizes */
  137. heap_build(heap, ((count + 1) / 2) | 1);
  138. for (;;) {
  139. heap_get_top2(heap, &max1, &max2, limbsize);
  140. /* only one scalar remaining, we're done */
  141. if (iszero256_modm_batch(heap->scalars[max2]))
  142. break;
  143. /* exhausted another limb? */
  144. if (!heap->scalars[max1][limbsize])
  145. limbsize -= 1;
  146. /* can we extend to the 128 bit scalars? */
  147. if (!extended && isatmost128bits256_modm_batch(heap->scalars[max1])) {
  148. heap_extend(heap, count);
  149. heap_get_top2(heap, &max1, &max2, limbsize);
  150. extended = 1;
  151. }
  152. sub256_modm_batch(heap->scalars[max1], heap->scalars[max1], heap->scalars[max2], limbsize);
  153. ge25519_add(&heap->points[max2], &heap->points[max2], &heap->points[max1]);
  154. heap_updated_root(heap, limbsize);
  155. }
  156. ge25519_multi_scalarmult_vartime_final(r, &heap->points[max1], heap->scalars[max1]);
  157. }
  158. /* not actually used for anything other than testing */
  159. static unsigned char batch_point_buffer[3][32];
  160. static int
  161. ge25519_is_neutral_vartime(const ge25519 *p) {
  162. static const unsigned char zero[32] = {0};
  163. unsigned char point_buffer[3][32];
  164. curve25519_contract(point_buffer[0], p->x);
  165. curve25519_contract(point_buffer[1], p->y);
  166. curve25519_contract(point_buffer[2], p->z);
  167. memcpy(batch_point_buffer[1], point_buffer[1], 32);
  168. return (memcmp(point_buffer[0], zero, 32) == 0) && (memcmp(point_buffer[1], point_buffer[2], 32) == 0);
  169. }
  170. int
  171. ED25519_FN(ed25519_sign_open_batch) (const unsigned char **m, size_t *mlen, const unsigned char **pk, const unsigned char **RS, size_t num, int *valid) {
  172. batch_heap ALIGN(16) batch;
  173. ge25519 ALIGN(16) p;
  174. bignum256modm *r_scalars;
  175. size_t i, batchsize;
  176. unsigned char hram[64];
  177. int ret = 0;
  178. for (i = 0; i < num; i++)
  179. valid[i] = 1;
  180. while (num > 3) {
  181. batchsize = (num > max_batch_size) ? max_batch_size : num;
  182. /* generate r (scalars[batchsize+1]..scalars[2*batchsize] */
  183. ED25519_FN(ed25519_randombytes_unsafe) (batch.r, batchsize * 16);
  184. r_scalars = &batch.scalars[batchsize + 1];
  185. for (i = 0; i < batchsize; i++)
  186. expand256_modm(r_scalars[i], batch.r[i], 16);
  187. /* compute scalars[0] = ((r1s1 + r2s2 + ...)) */
  188. for (i = 0; i < batchsize; i++) {
  189. expand256_modm(batch.scalars[i], RS[i] + 32, 32);
  190. mul256_modm(batch.scalars[i], batch.scalars[i], r_scalars[i]);
  191. }
  192. for (i = 1; i < batchsize; i++)
  193. add256_modm(batch.scalars[0], batch.scalars[0], batch.scalars[i]);
  194. /* compute scalars[1]..scalars[batchsize] as r[i]*H(R[i],A[i],m[i]) */
  195. for (i = 0; i < batchsize; i++) {
  196. ed25519_hram(hram, RS[i], pk[i], m[i], mlen[i]);
  197. expand256_modm(batch.scalars[i+1], hram, 64);
  198. mul256_modm(batch.scalars[i+1], batch.scalars[i+1], r_scalars[i]);
  199. }
  200. /* compute points */
  201. batch.points[0] = ge25519_basepoint;
  202. for (i = 0; i < batchsize; i++)
  203. if (!ge25519_unpack_negative_vartime(&batch.points[i+1], pk[i]))
  204. goto fallback;
  205. for (i = 0; i < batchsize; i++)
  206. if (!ge25519_unpack_negative_vartime(&batch.points[batchsize+i+1], RS[i]))
  207. goto fallback;
  208. ge25519_multi_scalarmult_vartime(&p, &batch, (batchsize * 2) + 1);
  209. if (!ge25519_is_neutral_vartime(&p)) {
  210. ret |= 2;
  211. fallback:
  212. for (i = 0; i < batchsize; i++) {
  213. valid[i] = ED25519_FN(ed25519_sign_open) (m[i], mlen[i], pk[i], RS[i]) ? 0 : 1;
  214. ret |= (valid[i] ^ 1);
  215. }
  216. }
  217. m += batchsize;
  218. mlen += batchsize;
  219. pk += batchsize;
  220. RS += batchsize;
  221. num -= batchsize;
  222. valid += batchsize;
  223. }
  224. for (i = 0; i < num; i++) {
  225. valid[i] = ED25519_FN(ed25519_sign_open) (m[i], mlen[i], pk[i], RS[i]) ? 0 : 1;
  226. ret |= (valid[i] ^ 1);
  227. }
  228. return ret;
  229. }