ntt.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. #[cfg(target_feature = "avx2")]
  2. use std::arch::x86_64::*;
  3. use crate::{arith::*, number_theory::*, params::*};
  4. pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
  5. let poly_len = 1usize << poly_len_log2;
  6. let mut root_powers = vec![0u64; poly_len];
  7. let mut power = root;
  8. for i in 1..poly_len {
  9. let idx = reverse_bits(i as u64, poly_len_log2) as usize;
  10. root_powers[idx] = power;
  11. power = multiply_uint_mod(power, root, modulus);
  12. }
  13. root_powers[0] = 1;
  14. root_powers
  15. }
  16. pub fn scale_powers_u64(modulus: u64, poly_len: usize, inp: &[u64]) -> Vec<u64> {
  17. let mut scaled_powers = vec![0; poly_len];
  18. for i in 0..poly_len {
  19. let wide_val = (inp[i] as u128) << 64u128;
  20. let quotient = wide_val / (modulus as u128);
  21. scaled_powers[i] = quotient as u64;
  22. }
  23. scaled_powers
  24. }
  25. pub fn scale_powers_u32(modulus: u32, poly_len: usize, inp: &[u64]) -> Vec<u64> {
  26. let mut scaled_powers = vec![0; poly_len];
  27. for i in 0..poly_len {
  28. let wide_val = inp[i] << 32;
  29. let quotient = wide_val / (modulus as u64);
  30. scaled_powers[i] = (quotient as u32) as u64;
  31. }
  32. scaled_powers
  33. }
  34. pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
  35. let poly_len_log2 = log2(poly_len as u64) as usize;
  36. let mut output: Vec<Vec<Vec<u64>>> = vec![Vec::new(); moduli.len()];
  37. for coeff_mod in 0..moduli.len() {
  38. let modulus = moduli[coeff_mod];
  39. let modulus_as_u32 = modulus.try_into().unwrap();
  40. let root = get_minimal_primitive_root(2 * poly_len as u64, modulus).unwrap();
  41. let inv_root = invert_uint_mod(root, modulus).unwrap();
  42. let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2);
  43. let scaled_root_powers = scale_powers_u32(modulus_as_u32, poly_len, root_powers.as_slice());
  44. let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2);
  45. for i in 0..poly_len {
  46. inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
  47. }
  48. let scaled_inv_root_powers =
  49. scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice());
  50. output[coeff_mod] = vec![
  51. root_powers,
  52. scaled_root_powers,
  53. inv_root_powers,
  54. scaled_inv_root_powers,
  55. ];
  56. }
  57. output
  58. }
  59. #[cfg(not(target_feature = "avx2"))]
  60. pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
  61. let log_n = params.poly_len_log2;
  62. let n = 1 << log_n;
  63. for coeff_mod in 0..params.crt_count {
  64. let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
  65. let forward_table = params.get_ntt_forward_table(coeff_mod);
  66. let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
  67. let modulus_small = params.moduli[coeff_mod] as u32;
  68. let two_times_modulus_small: u32 = 2 * modulus_small;
  69. for mm in 0..log_n {
  70. let m = 1 << mm;
  71. let t = n >> (mm + 1);
  72. let mut it = operand.chunks_exact_mut(2 * t);
  73. for i in 0..m {
  74. let w = forward_table[m + i];
  75. let w_prime = forward_table_prime[m + i];
  76. let op = it.next().unwrap();
  77. for j in 0..t {
  78. let x: u32 = op[j] as u32;
  79. let y: u32 = op[t + j] as u32;
  80. let curr_x: u32 =
  81. x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
  82. let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
  83. let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
  84. op[j] = curr_x as u64 + q_new;
  85. op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new);
  86. }
  87. }
  88. }
  89. for i in 0..n {
  90. operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64)
  91. * two_times_modulus_small as u64;
  92. operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64;
  93. }
  94. }
  95. }
  96. #[cfg(target_feature = "avx2")]
  97. pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
  98. let log_n = params.poly_len_log2;
  99. let n = 1 << log_n;
  100. for coeff_mod in 0..params.crt_count {
  101. let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
  102. let forward_table = params.get_ntt_forward_table(coeff_mod);
  103. let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
  104. let modulus_small = params.moduli[coeff_mod] as u32;
  105. let two_times_modulus_small: u32 = 2 * modulus_small;
  106. for mm in 0..log_n {
  107. let m = 1 << mm;
  108. let t = n >> (mm + 1);
  109. let mut it = operand.chunks_exact_mut(2 * t);
  110. for i in 0..m {
  111. let w = forward_table[m + i];
  112. let w_prime = forward_table_prime[m + i];
  113. let op = it.next().unwrap();
  114. if t < 4 {
  115. for j in 0..t {
  116. let x: u32 = op[j] as u32;
  117. let y: u32 = op[t + j] as u32;
  118. let curr_x: u32 =
  119. x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
  120. let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
  121. let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
  122. op[j] = curr_x as u64 + q_new;
  123. op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new);
  124. }
  125. } else {
  126. unsafe {
  127. for j in (0..t).step_by(4) {
  128. // Use AVX2 here
  129. let p_x = &mut op[j] as *mut u64;
  130. let p_y = &mut op[j + t] as *mut u64;
  131. let x = _mm256_load_si256(p_x as *const __m256i);
  132. let y = _mm256_load_si256(p_y as *const __m256i);
  133. let cmp_val = _mm256_set1_epi64x(two_times_modulus_small as i64);
  134. let gt_mask = _mm256_cmpgt_epi64(x, cmp_val);
  135. let to_subtract = _mm256_and_si256(gt_mask, cmp_val);
  136. let curr_x = _mm256_sub_epi64(x, to_subtract);
  137. // uint32_t q_val = ((y) * (uint64_t)(Wprime)) >> 32;
  138. let w_prime_vec = _mm256_set1_epi64x(w_prime as i64);
  139. let product = _mm256_mul_epu32(y, w_prime_vec);
  140. let q_val = _mm256_srli_epi64(product, 32);
  141. // q_val = W * y - q_val * modulus_small;
  142. let w_vec = _mm256_set1_epi64x(w as i64);
  143. let w_times_y = _mm256_mul_epu32(y, w_vec);
  144. let modulus_small_vec = _mm256_set1_epi64x(modulus_small as i64);
  145. let q_scaled = _mm256_mul_epu32(q_val, modulus_small_vec);
  146. let q_final = _mm256_sub_epi64(w_times_y, q_scaled);
  147. let new_x = _mm256_add_epi64(curr_x, q_final);
  148. let q_final_inverted = _mm256_sub_epi64(cmp_val, q_final);
  149. let new_y = _mm256_add_epi64(curr_x, q_final_inverted);
  150. _mm256_store_si256(p_x as *mut __m256i, new_x);
  151. _mm256_store_si256(p_y as *mut __m256i, new_y);
  152. }
  153. }
  154. }
  155. }
  156. }
  157. for i in (0..n).step_by(4) {
  158. unsafe {
  159. let p_x = &mut operand[i] as *mut u64;
  160. let cmp_val1 = _mm256_set1_epi64x(two_times_modulus_small as i64);
  161. let mut x = _mm256_load_si256(p_x as *const __m256i);
  162. let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
  163. let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
  164. x = _mm256_sub_epi64(x, to_subtract);
  165. let cmp_val2 = _mm256_set1_epi64x(modulus_small as i64);
  166. gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
  167. to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
  168. x = _mm256_sub_epi64(x, to_subtract);
  169. _mm256_store_si256(p_x as *mut __m256i, x);
  170. }
  171. }
  172. }
  173. }
  174. #[cfg(not(target_feature = "avx2"))]
  175. pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
  176. for coeff_mod in 0..params.crt_count {
  177. let n = params.poly_len;
  178. let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
  179. let inverse_table = params.get_ntt_inverse_table(coeff_mod);
  180. let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
  181. let modulus = params.moduli[coeff_mod];
  182. let two_times_modulus: u64 = 2 * modulus;
  183. for mm in (0..params.poly_len_log2).rev() {
  184. let h = 1 << mm;
  185. let t = n >> (mm + 1);
  186. let mut it = operand.chunks_exact_mut(2 * t);
  187. for i in 0..h {
  188. let w = inverse_table[h + i];
  189. let w_prime = inverse_table_prime[h + i];
  190. let op = it.next().unwrap();
  191. for j in 0..t {
  192. let x = op[j];
  193. let y = op[t + j];
  194. let t_tmp = two_times_modulus - y + x;
  195. let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
  196. let h_tmp = (t_tmp * w_prime) >> 32;
  197. let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
  198. let res_y = w * t_tmp - h_tmp * modulus;
  199. op[j] = res_x;
  200. op[t + j] = res_y;
  201. }
  202. }
  203. }
  204. for i in 0..n {
  205. operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
  206. operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
  207. }
  208. }
  209. }
  210. #[cfg(target_feature = "avx2")]
  211. pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
  212. for coeff_mod in 0..params.crt_count {
  213. let n = params.poly_len;
  214. let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
  215. let inverse_table = params.get_ntt_inverse_table(coeff_mod);
  216. let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
  217. let modulus = params.moduli[coeff_mod];
  218. let two_times_modulus: u64 = 2 * modulus;
  219. for mm in (0..params.poly_len_log2).rev() {
  220. let h = 1 << mm;
  221. let t = n >> (mm + 1);
  222. let mut it = operand.chunks_exact_mut(2 * t);
  223. for i in 0..h {
  224. let w = inverse_table[h + i];
  225. let w_prime = inverse_table_prime[h + i];
  226. let op = it.next().unwrap();
  227. if t < 4 {
  228. for j in 0..t {
  229. let x = op[j];
  230. let y = op[t + j];
  231. let t_tmp = two_times_modulus - y + x;
  232. let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
  233. let h_tmp = (t_tmp * w_prime) >> 32;
  234. let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
  235. let res_y = w * t_tmp - h_tmp * modulus;
  236. op[j] = res_x;
  237. op[t + j] = res_y;
  238. }
  239. } else {
  240. unsafe {
  241. for j in (0..t).step_by(4) {
  242. // Use AVX2 here
  243. let p_x = &mut op[j] as *mut u64;
  244. let p_y = &mut op[j + t] as *mut u64;
  245. let x = _mm256_load_si256(p_x as *const __m256i);
  246. let y = _mm256_load_si256(p_y as *const __m256i);
  247. let modulus_vec = _mm256_set1_epi64x(modulus as i64);
  248. let two_times_modulus_vec =
  249. _mm256_set1_epi64x(two_times_modulus as i64);
  250. let mut t_tmp = _mm256_set1_epi64x(two_times_modulus as i64);
  251. t_tmp = _mm256_sub_epi64(t_tmp, y);
  252. t_tmp = _mm256_add_epi64(t_tmp, x);
  253. let gt_mask = _mm256_cmpgt_epi64(_mm256_slli_epi64(x, 1), t_tmp);
  254. let to_subtract = _mm256_and_si256(gt_mask, two_times_modulus_vec);
  255. let mut curr_x = _mm256_add_epi64(x, y);
  256. curr_x = _mm256_sub_epi64(curr_x, to_subtract);
  257. let w_prime_vec = _mm256_set1_epi64x(w_prime as i64);
  258. let mut h_tmp = _mm256_mul_epu32(t_tmp, w_prime_vec);
  259. h_tmp = _mm256_srli_epi64(h_tmp, 32);
  260. let and_mask = _mm256_set_epi64x(1, 1, 1, 1);
  261. let eq_mask =
  262. _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask);
  263. let to_add = _mm256_and_si256(eq_mask, modulus_vec);
  264. let new_x = _mm256_srli_epi64(_mm256_add_epi64(curr_x, to_add), 1);
  265. let w_vec = _mm256_set1_epi64x(w as i64);
  266. let w_times_t_tmp = _mm256_mul_epu32(t_tmp, w_vec);
  267. let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec);
  268. let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus);
  269. _mm256_store_si256(p_x as *mut __m256i, new_x);
  270. _mm256_store_si256(p_y as *mut __m256i, new_y);
  271. }
  272. }
  273. }
  274. }
  275. }
  276. for i in 0..n {
  277. operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
  278. operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
  279. }
  280. // for i in (0..n).step_by(4) {
  281. // unsafe {
  282. // let p_x = &mut operand[i] as *mut u64;
  283. // let cmp_val1 = _mm256_set1_epi64x(two_times_modulus as i64);
  284. // let mut x = _mm256_load_si256(p_x as *const __m256i);
  285. // let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
  286. // let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
  287. // x = _mm256_sub_epi64(x, to_subtract);
  288. // let cmp_val2 = _mm256_set1_epi64x(modulus as i64);
  289. // gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
  290. // to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
  291. // x = _mm256_sub_epi64(x, to_subtract);
  292. // _mm256_store_si256(p_x as *mut __m256i, x);
  293. // }
  294. // }
  295. }
  296. }
  297. #[cfg(test)]
  298. mod test {
  299. use super::*;
  300. use crate::{aligned_memory::AlignedMemory64, util::*};
  301. use rand::Rng;
  302. fn get_params() -> Params {
  303. get_test_params()
  304. }
  305. const REF_VAL: u64 = 519370102;
  306. #[test]
  307. fn build_ntt_tables_correct() {
  308. let moduli = [268369921u64, 249561089u64];
  309. let poly_len = 2048usize;
  310. let res = build_ntt_tables(poly_len, moduli.as_slice());
  311. assert_eq!(res.len(), 2);
  312. assert_eq!(res[0].len(), 4);
  313. assert_eq!(res[0][0].len(), poly_len);
  314. assert_eq!(res[0][2][0], 134184961u64);
  315. assert_eq!(res[0][2][1], 96647580u64);
  316. let mut x1 = 0u64;
  317. for i in 0..res.len() {
  318. for j in 0..res[0].len() {
  319. for k in 0..res[0][0].len() {
  320. x1 ^= res[i][j][k];
  321. }
  322. }
  323. }
  324. assert_eq!(x1, REF_VAL);
  325. }
  326. #[test]
  327. fn ntt_forward_correct() {
  328. let params = get_params();
  329. let mut v1 = AlignedMemory64::new(2 * 2048);
  330. v1[0] = 100;
  331. v1[2048] = 100;
  332. ntt_forward(&params, v1.as_mut_slice());
  333. assert_eq!(v1[50], 100);
  334. assert_eq!(v1[2048 + 50], 100);
  335. }
  336. #[test]
  337. fn ntt_inverse_correct() {
  338. let params = get_params();
  339. let mut v1 = AlignedMemory64::new(2 * 2048);
  340. for i in 0..v1.len() {
  341. v1[i] = 100;
  342. }
  343. ntt_inverse(&params, v1.as_mut_slice());
  344. assert_eq!(v1[0], 100);
  345. assert_eq!(v1[2048], 100);
  346. assert_eq!(v1[50], 0);
  347. assert_eq!(v1[2048 + 50], 0);
  348. }
  349. #[test]
  350. fn ntt_correct() {
  351. let params = get_params();
  352. let mut v1 = AlignedMemory64::new(params.crt_count * params.poly_len);
  353. let mut rng = rand::thread_rng();
  354. for i in 0..params.crt_count {
  355. for j in 0..params.poly_len {
  356. let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
  357. let val: u64 = rng.gen();
  358. v1[idx] = val % params.moduli[i];
  359. }
  360. }
  361. let mut v2 = v1.clone();
  362. ntt_forward(&params, v2.as_mut_slice());
  363. ntt_inverse(&params, v2.as_mut_slice());
  364. for i in 0..params.crt_count * params.poly_len {
  365. assert_eq!(v1[i], v2[i]);
  366. }
  367. }
  368. #[test]
  369. fn calc_index_correct() {
  370. assert_eq!(calc_index(&[2, 3, 4], &[10, 10, 100]), 2304);
  371. assert_eq!(calc_index(&[2, 3, 4], &[3, 5, 7]), 95);
  372. }
  373. }