Browse Source

speed up inverses

Samir Menon 2 years ago
parent
commit
285f21b924
2 changed files with 198 additions and 4 deletions
  1. 3 0
      .cargo/config.toml
  2. 195 4
      src/ntt.rs

+ 3 - 0
.cargo/config.toml

@@ -0,0 +1,3 @@
+[build]
+target = "x86_64-unknown-linux-gnu"
+rustflags = ["-C", "target-feature=+avx2"]

+ 195 - 4
src/ntt.rs

@@ -1,3 +1,5 @@
+use std::arch::x86_64::*;
+
 use crate::{
     arith::*,
     number_theory::*,
@@ -65,6 +67,7 @@ pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
     output
 }
 
+#[cfg(not(target_feature = "avx2"))]
 pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
     let log_n = params.poly_len_log2;
     let n = 1 << log_n;
@@ -110,6 +113,103 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
     }
 }
 
+#[cfg(target_feature = "avx2")]
+pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
+    let log_n = params.poly_len_log2;
+    let n = 1 << log_n;
+
+    for coeff_mod in 0..params.crt_count {
+        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+
+        let forward_table = params.get_ntt_forward_table(coeff_mod);
+        let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
+        let modulus_small = params.moduli[coeff_mod] as u32;
+        let two_times_modulus_small: u32 = 2 * modulus_small;
+
+        for mm in 0..log_n {
+            let m = 1 << mm;
+            let t = n >> (mm + 1);
+
+            let mut it = operand.chunks_exact_mut(2 * t);
+
+            for i in 0..m {
+                let w = forward_table[m+i];
+                let w_prime = forward_table_prime[m+i];
+                
+                let op = it.next().unwrap();
+                
+                if t < 4 {
+                    for j in 0..t {
+                        let x: u32 = op[j] as u32;
+                        let y: u32 = op[t + j] as u32;
+
+                        let curr_x: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+                        let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
+                        let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
+
+                        op[j] = curr_x as u64 + q_new;
+                        op[t + j] = curr_x as u64 +  ((two_times_modulus_small as u64) - q_new);
+                    }
+                } else {
+                    unsafe {
+                        for j in (0..t).step_by(4) {
+                            // Use AVX2 here
+                            let p_x = &mut op[j] as *mut u64;
+                            let p_y = &mut op[j + t] as *mut u64;
+                            let x = _mm256_loadu_si256(p_x as *const __m256i);
+                            let y = _mm256_loadu_si256(p_y as *const __m256i);
+
+                            let cmp_val = _mm256_set1_epi64x(two_times_modulus_small as i64);
+                            let gt_mask = _mm256_cmpgt_epi64(x, cmp_val);
+
+                            let to_subtract = _mm256_and_si256(gt_mask, cmp_val);
+                            let curr_x = _mm256_sub_epi64(x, to_subtract);
+
+                            // uint32_t q_val = ((y) * (uint64_t)(Wprime)) >> 32;
+                            let w_prime_vec = _mm256_set1_epi64x(w_prime as i64);
+                            let product = _mm256_mul_epu32(y, w_prime_vec);
+                            let q_val = _mm256_srli_epi64(product, 32);
+
+                            // q_val = W * y - q_val * modulus_small;
+                            let w_vec = _mm256_set1_epi64x(w as i64);
+                            let w_times_y = _mm256_mul_epu32(y, w_vec);
+                            let modulus_small_vec = _mm256_set1_epi64x(modulus_small as i64);
+                            let q_scaled = _mm256_mul_epu32(q_val, modulus_small_vec);
+                            let q_final = _mm256_sub_epi64(w_times_y, q_scaled);
+
+                            let new_x = _mm256_add_epi64(curr_x, q_final);
+                            let q_final_inverted = _mm256_sub_epi64(cmp_val, q_final);
+                            let new_y = _mm256_add_epi64(curr_x, q_final_inverted);
+
+                            _mm256_storeu_si256(p_x as *mut __m256i, new_x);
+                            _mm256_storeu_si256(p_y as *mut __m256i, new_y);
+                        }
+                    }
+                }
+            }
+        }
+
+        for i in (0..n).step_by(4) {
+            unsafe {
+                let p_x = &mut operand[i] as *mut u64;
+
+                let cmp_val1 = _mm256_set1_epi64x(two_times_modulus_small as i64);
+                let mut x = _mm256_loadu_si256(p_x as *const __m256i);
+                let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
+                let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
+                x = _mm256_sub_epi64(x, to_subtract);
+
+                let cmp_val2 = _mm256_set1_epi64x(modulus_small as i64);
+                gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
+                to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
+                x = _mm256_sub_epi64(x, to_subtract);
+                _mm256_storeu_si256(p_x as *mut __m256i, x);
+            }
+        }
+    }
+}
+
+#[cfg(not(target_feature = "avx2"))]
 pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
     for coeff_mod in 0..params.crt_count {
         let n = params.poly_len;
@@ -125,13 +225,17 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
             let h = 1 << mm;
             let t = n >> (mm + 1);
 
+            let mut it = operand.chunks_exact_mut(2 * t);
+
             for i in 0..h {
                 let w = inverse_table[h+i];
                 let w_prime = inverse_table_prime[h+i];
 
+                let op = it.next().unwrap();
+
                 for j in 0..t {
-                    let x = operand[2 * i * t + j];
-                    let y = operand[2 * i * t + t + j];
+                    let x = op[j];
+                    let y = op[t + j];
                     
                     let t_tmp = two_times_modulus - y + x;
                     let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
@@ -140,8 +244,95 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                     let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
                     let res_y= w * t_tmp - h_tmp * modulus;
 
-                    operand[2 * i * t + j] = res_x;
-                    operand[2 * i * t + t + j] = res_y;
+                    op[j] = res_x;
+                    op[t + j] = res_y;
+                }
+            }
+        }
+
+        for i in 0..n {
+            operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
+            operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
+        }
+    }
+}
+
+#[cfg(target_feature = "avx2")]
+pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
+    for coeff_mod in 0..params.crt_count {
+        let n = params.poly_len;
+
+        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+
+        let inverse_table = params.get_ntt_inverse_table(coeff_mod);
+        let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
+        let modulus = params.moduli[coeff_mod];
+        let two_times_modulus: u64 = 2 * modulus;
+        for mm in (0..params.poly_len_log2).rev() {
+            let h = 1 << mm;
+            let t = n >> (mm + 1);
+
+            let mut it = operand.chunks_exact_mut(2 * t);
+
+            for i in 0..h {
+                let w = inverse_table[h+i];
+                let w_prime = inverse_table_prime[h+i];
+
+                let op = it.next().unwrap();
+                
+                if t < 4 {
+                    for j in 0..t {
+                        let x = op[j];
+                        let y = op[t + j];
+                        
+                        let t_tmp = two_times_modulus - y + x;
+                        let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
+                        let h_tmp = (t_tmp * w_prime) >> 32;
+
+                        let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
+                        let res_y= w * t_tmp - h_tmp * modulus;
+
+                        op[j] = res_x;
+                        op[t + j] = res_y;
+                    }
+                } else {
+                    unsafe {
+                        for j in (0..t).step_by(4) {
+                            // Use AVX2 here
+                            let p_x = &mut op[j] as *mut u64;
+                            let p_y = &mut op[j + t] as *mut u64;
+                            let x = _mm256_loadu_si256(p_x as *const __m256i);
+                            let y = _mm256_loadu_si256(p_y as *const __m256i);
+                            
+                            let modulus_vec = _mm256_set1_epi64x(modulus as i64);
+                            let two_times_modulus_vec = _mm256_set1_epi64x(two_times_modulus as i64);
+                            let mut t_tmp = _mm256_set1_epi64x(two_times_modulus as i64);
+                            t_tmp = _mm256_sub_epi64(t_tmp, y);
+                            t_tmp = _mm256_add_epi64(t_tmp, x);
+                            let gt_mask = _mm256_cmpgt_epi64(_mm256_slli_epi64(x, 1), t_tmp);
+                            let to_subtract = _mm256_and_si256(gt_mask, two_times_modulus_vec);
+                            let mut curr_x = _mm256_add_epi64(x, y);
+                            curr_x = _mm256_sub_epi64(curr_x, to_subtract);
+
+                            let w_prime_vec = _mm256_set1_epi64x(w_prime as i64);
+                            let mut h_tmp = _mm256_mul_epu32(t_tmp, w_prime_vec);
+                            h_tmp = _mm256_srli_epi64(h_tmp, 32);
+
+                            let and_mask = _mm256_set_epi64x(1, 1, 1, 1);
+                            let eq_mask = _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask);
+                            let to_add = _mm256_and_si256(eq_mask, modulus_vec);
+
+                            let new_x = _mm256_srli_epi64(_mm256_add_epi64(curr_x, to_add), 1);
+
+                            let w_vec = _mm256_set1_epi64x(w as i64);
+                            let w_times_t_tmp = _mm256_mul_epu32(t_tmp, w_vec);
+                            let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec);
+                            let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus);
+
+                            _mm256_storeu_si256(p_x as *mut __m256i, new_x);
+                            _mm256_storeu_si256(p_y as *mut __m256i, new_y);
+                        }
+                    }
                 }
             }
         }