|
@@ -1,4 +1,6 @@
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
use std::arch::x86_64::*;
|
|
|
+
|
|
|
use std::ops::{Add, Mul, Neg};
|
|
|
use std::cell::RefCell;
|
|
|
use rand::Rng;
|
|
@@ -210,7 +212,8 @@ impl<'a> PolyMatrixNTT<'a> {
|
|
|
pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in 0..params.poly_len {
|
|
|
- res[i] = multiply_modular(params, a[i], b[i], c);
|
|
|
+ let idx = c * params.poly_len + i;
|
|
|
+ res[idx] = multiply_modular(params, a[idx], b[idx], c);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -218,7 +221,8 @@ pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in 0..params.poly_len {
|
|
|
- res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
|
|
|
+ let idx = c * params.poly_len + i;
|
|
|
+ res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -226,7 +230,8 @@ pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64])
|
|
|
pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in 0..params.poly_len {
|
|
|
- res[i] = add_modular(params, a[i], b[i], c);
|
|
|
+ let idx = c * params.poly_len + i;
|
|
|
+ res[idx] = add_modular(params, a[idx], b[idx], c);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -234,7 +239,8 @@ pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in 0..params.poly_len {
|
|
|
- res[i] = invert_modular(params, a[i], c);
|
|
|
+ let idx = c * params.poly_len + i;
|
|
|
+ res[idx] = invert_modular(params, a[idx], c);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -253,6 +259,7 @@ pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
|
|
|
for c in 0..params.crt_count {
|
|
|
for i in (0..params.poly_len).step_by(4) {
|
|
@@ -283,11 +290,14 @@ pub fn modular_reduce(params: &Params, res: &mut [u64]) {
|
|
|
|
|
|
#[cfg(not(target_feature = "avx2"))]
|
|
|
pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
+ assert!(res.rows == a.rows);
|
|
|
+ assert!(res.cols == b.cols);
|
|
|
assert!(a.cols == b.rows);
|
|
|
|
|
|
+ let params = res.params;
|
|
|
for i in 0..a.rows {
|
|
|
for j in 0..b.cols {
|
|
|
- for z in 0..res.params.poly_len {
|
|
|
+ for z in 0..params.poly_len*params.crt_count {
|
|
|
res.get_poly_mut(i, j)[z] = 0;
|
|
|
}
|
|
|
for k in 0..a.cols {
|
|
@@ -310,7 +320,7 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
let params = res.params;
|
|
|
for i in 0..a.rows {
|
|
|
for j in 0..b.cols {
|
|
|
- for z in 0..res.params.poly_len {
|
|
|
+ for z in 0..params.poly_len*params.crt_count {
|
|
|
res.get_poly_mut(i, j)[z] = 0;
|
|
|
}
|
|
|
let res_poly = res.get_poly_mut(i, j);
|