ntt.rs 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. use crate::{
  2. arith::*,
  3. number_theory::*,
  4. params::*,
  5. poly::*,
  6. util::*,
  7. };
  8. pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
  9. let poly_len = 1usize << poly_len_log2;
  10. let mut root_powers = vec![0u64; poly_len];
  11. let mut power = root;
  12. for i in 1..poly_len {
  13. let idx = reverse_bits(i as u64, poly_len_log2) as usize;
  14. root_powers[idx] = power;
  15. power = multiply_uint_mod(power, root, modulus);
  16. }
  17. root_powers[0] = 1;
  18. root_powers
  19. }
  20. pub fn scale_powers_u64(modulus: u64, poly_len: usize, inp: &[u64]) -> Vec<u64> {
  21. let mut scaled_powers = vec![0; poly_len];
  22. for i in 0..poly_len {
  23. let wide_val = (inp[i] as u128) << 64u128;
  24. let quotient = wide_val / (modulus as u128);
  25. scaled_powers[i] = quotient as u64;
  26. }
  27. scaled_powers
  28. }
  29. pub fn scale_powers_u32(modulus: u32, poly_len: usize, inp: &[u64]) -> Vec<u64> {
  30. let mut scaled_powers = vec![0; poly_len];
  31. for i in 0..poly_len {
  32. let wide_val = inp[i] << 32;
  33. let quotient = wide_val / (modulus as u64);
  34. scaled_powers[i] = (quotient as u32) as u64;
  35. }
  36. scaled_powers
  37. }
  38. pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
  39. let poly_len_log2 = log2(poly_len as u64) as usize;
  40. let mut output: Vec<Vec<Vec<u64>>> = vec![Vec::new(); moduli.len()];
  41. for coeff_mod in 0..moduli.len() {
  42. let modulus = moduli[coeff_mod];
  43. let modulus_as_u32 = modulus.try_into().unwrap();
  44. let root = get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap();
  45. let inv_root = invert_uint_mod(root, modulus).unwrap();
  46. let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2);
  47. let scaled_root_powers = scale_powers_u32(modulus_as_u32, poly_len, root_powers.as_slice());
  48. let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2);
  49. for i in 0..poly_len {
  50. inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
  51. }
  52. let mut scaled_inv_root_powers =
  53. scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice());
  54. output[coeff_mod] = vec![
  55. root_powers,
  56. scaled_root_powers,
  57. inv_root_powers,
  58. scaled_inv_root_powers,
  59. ];
  60. }
  61. output
  62. }
  63. pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
  64. let log_n = params.poly_len_log2;
  65. let n = 1 << log_n;
  66. for coeff_mod in 0..params.crt_count {
  67. let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
  68. let forward_table = params.get_ntt_forward_table(coeff_mod);
  69. let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
  70. let modulus_small = params.moduli[coeff_mod] as u32;
  71. let two_times_modulus_small: u32 = 2 * modulus_small;
  72. for mm in 0..log_n {
  73. let m = 1 << mm;
  74. let t = n >> (mm + 1);
  75. let mut it = operand.chunks_exact_mut(2 * t);
  76. for i in 0..m {
  77. let w = forward_table[m+i];
  78. let w_prime = forward_table_prime[m+i];
  79. let op = it.next().unwrap();
  80. for j in 0..t {
  81. let x: u32 = op[j] as u32;
  82. let y: u32 = op[t + j] as u32;
  83. let currX: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
  84. let Q: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
  85. let new_Q = w * (y as u64) - Q * (modulus_small as u64);
  86. op[j] = currX as u64 + new_Q;
  87. op[t + j] = currX as u64 + ((two_times_modulus_small as u64) - new_Q);
  88. }
  89. }
  90. }
  91. for i in 0..n {
  92. operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) * two_times_modulus_small as u64;
  93. operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64;
  94. }
  95. }
  96. }
  97. pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
  98. for coeff_mod in 0..params.crt_count {
  99. let mut n = params.poly_len;
  100. let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
  101. let inverse_table = params.get_ntt_inverse_table(coeff_mod);
  102. let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
  103. let modulus = params.moduli[coeff_mod];
  104. let two_times_modulus: u64 = 2 * modulus;
  105. for mm in (0..params.poly_len_log2).rev() {
  106. let h = 1 << mm;
  107. let t = n >> (mm + 1);
  108. for i in 0..h {
  109. let w = inverse_table[h+i];
  110. let w_prime = inverse_table_prime[h+i];
  111. for j in 0..t {
  112. let x = operand[2 * i * t + j];
  113. let y = operand[2 * i * t + t + j];
  114. let T = two_times_modulus - y + x;
  115. let currU = x + y - (two_times_modulus * (((x << 1) >= T) as u64));
  116. let resX= (currU + (modulus * ((T & 1) as u64))) >> 1;
  117. let H = (T * w_prime) >> 32;
  118. operand[2 * i * t + j] = resX;
  119. operand[2 * i * t + t + j] = w * T - H * modulus;
  120. }
  121. }
  122. }
  123. for i in 0..n {
  124. operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
  125. operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
  126. }
  127. }
  128. }
  129. #[cfg(test)]
  130. mod test {
  131. use super::*;
  132. use rand::Rng;
  133. fn get_params() -> Params {
  134. Params::init(2048, vec![268369921u64, 249561089u64])
  135. }
  136. const REF_VAL: u64 = 519370102;
  137. #[test]
  138. fn build_ntt_tables_correct() {
  139. let moduli = [268369921u64, 249561089u64];
  140. let poly_len = 2048usize;
  141. let res = build_ntt_tables(poly_len, moduli.as_slice());
  142. assert_eq!(res.len(), 2);
  143. assert_eq!(res[0].len(), 4);
  144. assert_eq!(res[0][0].len(), poly_len);
  145. assert_eq!(res[0][2][0], 134184961u64);
  146. assert_eq!(res[0][2][1], 96647580u64);
  147. let mut x1 = 0u64;
  148. for i in 0..res.len() {
  149. for j in 0..res[0].len() {
  150. for k in 0..res[0][0].len() {
  151. x1 ^= res[i][j][k];
  152. }
  153. }
  154. }
  155. assert_eq!(x1, REF_VAL);
  156. }
  157. #[test]
  158. fn ntt_forward_correct() {
  159. let params = get_params();
  160. let mut v1 = vec![0; 2*2048];
  161. v1[0] = 100;
  162. v1[2048] = 100;
  163. ntt_forward(&params, v1.as_mut_slice());
  164. assert_eq!(v1[50], 100);
  165. assert_eq!(v1[2048 + 50], 100);
  166. }
  167. #[test]
  168. fn ntt_correct() {
  169. let params = get_params();
  170. let mut v1 = vec![0; params.crt_count * params.poly_len];
  171. let mut rng = rand::thread_rng();
  172. for i in 0..params.crt_count {
  173. for j in 0..params.poly_len {
  174. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  175. let val: u64 = rng.gen();
  176. v1[idx] = val % params.moduli[i];
  177. }
  178. }
  179. let mut v2 = v1.clone();
  180. ntt_forward(&params, v2.as_mut_slice());
  181. ntt_inverse(&params, v2.as_mut_slice());
  182. for i in 0..params.crt_count*params.poly_len {
  183. assert_eq!(v1[i], v2[i]);
  184. }
  185. }
  186. #[test]
  187. fn calc_index_correct() {
  188. assert_eq!(calc_index(&[2, 3, 4], &[10, 10, 100]), 2304);
  189. assert_eq!(calc_index(&[2, 3, 4], &[3, 5, 7]), 95);
  190. }
  191. }