|
@@ -2,8 +2,6 @@ use crate::{
|
|
|
arith::*,
|
|
|
number_theory::*,
|
|
|
params::*,
|
|
|
- poly::*,
|
|
|
- util::*,
|
|
|
};
|
|
|
|
|
|
pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
|
|
@@ -54,7 +52,7 @@ pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
|
|
|
for i in 0..poly_len {
|
|
|
inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
|
|
|
}
|
|
|
- let mut scaled_inv_root_powers =
|
|
|
+ let scaled_inv_root_powers =
|
|
|
scale_powers_u32(modulus_as_u32, poly_len, inv_root_powers.as_slice());
|
|
|
|
|
|
output[coeff_mod] = vec![
|
|
@@ -95,14 +93,12 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
|
|
|
let x: u32 = op[j] as u32;
|
|
|
let y: u32 = op[t + j] as u32;
|
|
|
|
|
|
- let currX: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
|
|
|
+ let curr_x: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
|
|
|
+ let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
|
|
|
+ let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
|
|
|
|
|
|
- let Q: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
|
|
|
-
|
|
|
- let new_Q = w * (y as u64) - Q * (modulus_small as u64);
|
|
|
-
|
|
|
- op[j] = currX as u64 + new_Q;
|
|
|
- op[t + j] = currX as u64 + ((two_times_modulus_small as u64) - new_Q);
|
|
|
+ op[j] = curr_x as u64 + q_new;
|
|
|
+ op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -116,7 +112,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
|
|
|
|
|
|
pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
|
|
|
for coeff_mod in 0..params.crt_count {
|
|
|
- let mut n = params.poly_len;
|
|
|
+ let n = params.poly_len;
|
|
|
|
|
|
let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
|
|
|
|
|
@@ -137,14 +133,15 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
|
|
|
let x = operand[2 * i * t + j];
|
|
|
let y = operand[2 * i * t + t + j];
|
|
|
|
|
|
- let T = two_times_modulus - y + x;
|
|
|
- let currU = x + y - (two_times_modulus * (((x << 1) >= T) as u64));
|
|
|
+ let t_tmp = two_times_modulus - y + x;
|
|
|
+ let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
|
|
|
+ let h_tmp = (t_tmp * w_prime) >> 32;
|
|
|
|
|
|
- let resX= (currU + (modulus * ((T & 1) as u64))) >> 1;
|
|
|
- let H = (T * w_prime) >> 32;
|
|
|
+ let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
|
|
|
+ let res_y= w * t_tmp - h_tmp * modulus;
|
|
|
|
|
|
- operand[2 * i * t + j] = resX;
|
|
|
- operand[2 * i * t + t + j] = w * T - H * modulus;
|
|
|
+ operand[2 * i * t + j] = res_x;
|
|
|
+ operand[2 * i * t + t + j] = res_y;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -160,6 +157,7 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
|
|
|
mod test {
|
|
|
use super::*;
|
|
|
use rand::Rng;
|
|
|
+ use crate::util::*;
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
Params::init(2048, vec![268369921u64, 249561089u64])
|