Browse Source

fix warnings

Samir Menon 2 years ago
parent
commit
404487502c
3 changed files with 16 additions and 22 deletions
  1. 0 4
      src/arith.rs
  2. 15 17
      src/ntt.rs
  3. 1 1
      src/poly.rs

+ 0 - 4
src/arith.rs

@@ -17,10 +17,6 @@ pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -
     (a * b + x) % params.moduli[c]
 }
 
-fn swap(a: u64, b: u64) -> (u64, u64) {
-    (b, a)
-}
-
 pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
     if exponent == 0 {
         return 1;

+ 15 - 17
src/ntt.rs

@@ -2,8 +2,6 @@ use crate::{
     arith::*,
     number_theory::*,
     params::*,
-    poly::*,
-    util::*,
 };
 
 pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
@@ -54,7 +52,7 @@ pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
         for i in 0..poly_len {
             inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
         }
-        let mut scaled_inv_root_powers =
+        let scaled_inv_root_powers =
             scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice());
 
         output[coeff_mod] = vec![
@@ -95,14 +93,12 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
                     let x: u32 = op[j] as u32;
                     let y: u32 = op[t + j] as u32;
 
-                    let currX: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+                    let curr_x: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+                    let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
+                    let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
 
-                    let Q: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
-
-                    let new_Q = w * (y as u64) - Q * (modulus_small as u64);
-
-                    op[j] = currX as u64 + new_Q;
-                    op[t + j] = currX as u64 +  ((two_times_modulus_small as u64) - new_Q);
+                    op[j] = curr_x as u64 + q_new;
+                    op[t + j] = curr_x as u64 +  ((two_times_modulus_small as u64) - q_new);
                 }
             }
         }
@@ -116,7 +112,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
 
 pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
     for coeff_mod in 0..params.crt_count {
-        let mut n = params.poly_len;
+        let n = params.poly_len;
 
         let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
 
@@ -137,14 +133,15 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                     let x = operand[2 * i * t + j];
                     let y = operand[2 * i * t + t + j];
                     
-                    let T = two_times_modulus - y + x;
-                    let currU = x + y - (two_times_modulus * (((x << 1) >= T) as u64));
+                    let t_tmp = two_times_modulus - y + x;
+                    let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
+                    let h_tmp = (t_tmp * w_prime) >> 32;
 
-                    let resX= (currU + (modulus * ((T & 1) as u64))) >> 1;
-                    let H = (T * w_prime) >> 32;
+                    let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
+                    let res_y= w * t_tmp - h_tmp * modulus;
 
-                    operand[2 * i * t + j] = resX;
-                    operand[2 * i * t + t + j] = w * T - H * modulus;
+                    operand[2 * i * t + j] = res_x;
+                    operand[2 * i * t + t + j] = res_y;
                 }
             }
         }
@@ -160,6 +157,7 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
 mod test {
     use super::*;
     use rand::Rng;
+    use crate::util::*;
 
     fn get_params() -> Params {
         Params::init(2048, vec![268369921u64, 249561089u64])

+ 1 - 1
src/poly.rs

@@ -141,7 +141,7 @@ mod test {
     use super::*;
 
     fn get_params() -> Params {
-        Params::init(2048, vec![7, 31])
+        Params::init(2048, vec![268369921u64, 249561089u64])
     }
 
     fn assert_all_zero(a: &[u64]) {