|
@@ -1,7 +1,7 @@
|
|
|
use std::arch::x86_64::*;
|
|
|
use std::ops::Mul;
|
|
|
|
|
|
-use crate::{arith::*, params::*, util::calc_index};
|
|
|
+use crate::{arith::*, params::*, ntt::*, util::calc_index};
|
|
|
|
|
|
pub trait PolyMatrix<'a> {
|
|
|
fn is_ntt(&self) -> bool;
|
|
@@ -27,20 +27,33 @@ pub trait PolyMatrix<'a> {
|
|
|
let start = (row * self.get_cols() + col) * num_words;
|
|
|
&mut self.as_mut_slice()[start..start + num_words]
|
|
|
}
|
|
|
+ fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
|
|
|
+ assert!(target_row < self.get_rows());
|
|
|
+ assert!(target_col < self.get_cols());
|
|
|
+ assert!(target_row + p.get_rows() < self.get_rows());
|
|
|
+ assert!(target_col + p.get_cols() < self.get_cols());
|
|
|
+ for r in 0..p.get_rows() {
|
|
|
+ for c in 0..p.get_cols() {
|
|
|
+ let pol_src = p.get_poly(r, c);
|
|
|
+ let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
|
|
|
+ pol_dst.copy_from_slice(pol_src);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
pub struct PolyMatrixRaw<'a> {
|
|
|
- params: &'a Params,
|
|
|
- rows: usize,
|
|
|
- cols: usize,
|
|
|
- data: Vec<u64>,
|
|
|
+ pub params: &'a Params,
|
|
|
+ pub rows: usize,
|
|
|
+ pub cols: usize,
|
|
|
+ pub data: Vec<u64>,
|
|
|
}
|
|
|
|
|
|
pub struct PolyMatrixNTT<'a> {
|
|
|
- params: &'a Params,
|
|
|
- rows: usize,
|
|
|
- cols: usize,
|
|
|
- data: Vec<u64>,
|
|
|
+ pub params: &'a Params,
|
|
|
+ pub rows: usize,
|
|
|
+ pub cols: usize,
|
|
|
+ pub data: Vec<u64>,
|
|
|
}
|
|
|
|
|
|
impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
|
|
@@ -86,6 +99,24 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+impl<'a> PolyMatrixRaw<'a> {
|
|
|
+ pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
|
|
|
+ let num_coeffs = rows * cols * params.poly_len;
|
|
|
+ let mut data: Vec<u64> = vec![0; num_coeffs];
|
|
|
+ for r in 0..rows {
|
|
|
+ let c = r;
|
|
|
+ let idx = r * cols * params.poly_len + c * params.poly_len;
|
|
|
+ data[idx] = 1;
|
|
|
+ }
|
|
|
+ PolyMatrixRaw {
|
|
|
+ params,
|
|
|
+ rows,
|
|
|
+ cols,
|
|
|
+ data,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
|
|
|
fn is_ntt(&self) -> bool {
|
|
|
true
|
|
@@ -217,6 +248,52 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
|
|
|
+ assert_eq!(b.rows, 1);
|
|
|
+ assert_eq!(b.cols, 1);
|
|
|
+
|
|
|
+ let params = res.params;
|
|
|
+ let pol2 = b.get_poly(0, 0);
|
|
|
+ for i in 0..a.rows {
|
|
|
+ for j in 0..a.cols {
|
|
|
+ let res_poly = res.get_poly_mut(i, j);
|
|
|
+ let pol1 = a.get_poly(i, j);
|
|
|
+ multiply_poly(params, res_poly, pol1, pol2);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn from_scalar_multiply<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
|
|
|
+ let mut res = PolyMatrixNTT::zero(a.params, a.rows, a.cols);
|
|
|
+ scalar_multiply(res, a, b);
|
|
|
+ res
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
|
|
|
+ for r in 0..a.rows {
|
|
|
+ for c in 0..a.cols {
|
|
|
+ let pol_src = a.get_poly_mut(r, c);
|
|
|
+ let pol_dst = b.get_poly_mut(r, c);
|
|
|
+ for n in 0..a.params.crt_count {
|
|
|
+ for z in 0..a.params.poly_len {
|
|
|
+ pol_dst[n * a.params.poly_len + z] = pol_src[z] % a.params.moduli[n];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ntt_forward(a.params, pol_dst);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn from_ntt(params: &Params, res: &mut [u64]) {
|
|
|
+ for c in 0..params.crt_count {
|
|
|
+ for i in 0..params.poly_len {
|
|
|
+ res[c*params.poly_len + i] %= params.moduli[c];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
impl<'a> Mul for PolyMatrixNTT<'a> {
|
|
|
type Output = Self;
|
|
|
|
|
@@ -232,7 +309,7 @@ mod test {
|
|
|
use super::*;
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
- Params::init(2048, &vec![268369921u64, 249561089u64])
|
|
|
+ Params::init(2048, &vec![268369921u64, 249561089u64], 2, 6.4)
|
|
|
}
|
|
|
|
|
|
fn assert_all_zero(a: &[u64]) {
|