Samir Menon 2 years ago
parent
commit
23417bc2e1
12 changed files with 330 additions and 65 deletions
  1. 1 2
      benches/ntt.rs
  2. 3 8
      benches/poly.rs
  3. 8 0
      src/arith.rs
  4. 47 10
      src/client.rs
  5. 30 0
      src/gadget.rs
  6. 7 4
      src/lib.rs
  7. 2 2
      src/main.rs
  8. 1 1
      src/ntt.rs
  9. 1 1
      src/number_theory.rs
  10. 24 5
      src/params.rs
  11. 190 32
      src/poly.rs
  12. 16 0
      src/util.rs

+ 1 - 2
benches/ntt.rs

@@ -1,11 +1,10 @@
 use spiral_rs::ntt::*;
-use spiral_rs::params::*;
 use spiral_rs::util::*;
 use rand::Rng;
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 
 fn criterion_benchmark(c: &mut Criterion) {
-    let params = Params::init(2048, &vec![268369921u64, 249561089u64]);
+    let params = get_test_params();
     let mut v1 = vec![0; params.crt_count * params.poly_len];
     let mut rng = rand::thread_rng();
     for i in 0..params.crt_count {

+ 3 - 8
benches/poly.rs

@@ -1,16 +1,11 @@
 use spiral_rs::poly::*;
-use spiral_rs::params::*;
 use spiral_rs::util::*;
-use rand::Rng;
-use rand::distributions::Standard;
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 
 fn criterion_benchmark(c: &mut Criterion) {
-    let params = Params::init(2048, &vec![268369921u64, 249561089u64]);
-    let mut rng = rand::thread_rng();
-    let mut iter = rng.sample_iter(&Standard);
-    let m1 = PolyMatrixNTT::random(&params, 2, 1, &mut iter);
-    let m2 = PolyMatrixNTT::random(&params, 3, 2, &mut iter);
+    let params = get_test_params();
+    let m1 = PolyMatrixNTT::random(&params, 2, 1);
+    let m2 = PolyMatrixNTT::random(&params, 3, 2);
     let mut m3 = PolyMatrixNTT::zero(&params, 2, 2);
     c.bench_function("nttf 2048", |b| b.iter(|| multiply(black_box(&mut m3), black_box(&m1), black_box(&m2))));
 }

+ 8 - 0
src/arith.rs

@@ -17,6 +17,14 @@ pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -
     (a * b + x) % params.moduli[c]
 }
 
+pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
+    (a + b) % params.moduli[c]
+}
+
+pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
+    params.moduli[c] - a
+}
+
 pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
     (x) % params.moduli[c]
 }

+ 47 - 10
src/client.rs

@@ -1,6 +1,6 @@
 use std::collections::HashMap;
 
-use crate::{poly::*, params::*, discrete_gaussian::*};
+use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*};
 
 pub struct PublicParameters<'a> {
     v_packing: Vec<PolyMatrixNTT<'a>>,            // Ws
@@ -26,6 +26,7 @@ pub struct Client<'a> {
     sk_reg: PolyMatrixRaw<'a>,
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
+    dg: DiscreteGaussian,
 }
 
 fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
@@ -37,36 +38,72 @@ fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
 }
 
 impl<'a> Client<'a> {
-    fn init(params: &'a Params) -> Self {
+    pub fn init(params: &'a Params) -> Self {
         let sk_gsw_dims = params.get_sk_gsw();
         let sk_reg_dims = params.get_sk_reg();
         let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
+        let dg = DiscreteGaussian::init(params);
         Self {
             params,
             sk_gsw,
             sk_reg,
             sk_gsw_full,
             sk_reg_full,
+            dg,
         }
     }
-    fn generate_keys(&mut self) -> PublicParameters {
+
+    fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
         let params = self.params;
-        let mut dg = DiscreteGaussian::init(params);
-        dg.sample_matrix(&mut self.sk_gsw);
-        dg.sample_matrix(&mut self.sk_reg);
+        let n = params.n;
+
+        let a = PolyMatrixRaw::random(params, 1, m);
+        let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
+        let a_inv = -&a;
+        let b_p = &self.sk_gsw.ntt() * &a.ntt();
+        let b = &e.ntt() + &b_p;
+        let p = stack(&a_inv, &b.raw());
+        p
+    }
+
+    fn encrypt_matrix_gsw(&mut self, ag: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+        let params = self.params;
+        let mx = ag.cols;
+        let p = self.get_fresh_gsw_public_key(mx);
+        let res = &(p.ntt()) + &(ag.pad_top(1));
+        res
+    }
+
+    pub fn generate_keys(&mut self) -> PublicParameters {
+        let params = self.params;
+        self.dg.sample_matrix(&mut self.sk_gsw);
+        self.dg.sample_matrix(&mut self.sk_reg);
         self.sk_gsw_full = matrix_with_identity(&self.sk_gsw);
         self.sk_reg_full = matrix_with_identity(&self.sk_reg);
-        let sk_reg_ntt = to_ntt()
-        let pp = PublicParameters::init(params);
+        let sk_reg_ntt = to_ntt_alloc(&self.sk_reg);
+        let m_conv = params.m_conv();
+
+        let mut pp = PublicParameters::init(params);
         
-        // For packing
+        // Params for packing
+        let gadget_conv = build_gadget(params, 1, m_conv);
+        let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
         for i in 0..params.n {
-            MatPoly scaled = from_scalar_multiply(sk_reg_ntt, )
+            let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt);
+            let mut ag = PolyMatrixNTT::zero(params, params.n, m_conv);
+            ag.copy_into(&scaled, i, 0);
+            let w = self.encrypt_matrix_gsw(ag);
+            pp.v_packing.push(w);
         }
 
+        // Params for expansion
+
+        // Params for converison
+
+
         pp
     }
     // fn generate_query(&self) -> Query<'a, Params>;

+ 30 - 0
src/gadget.rs

@@ -0,0 +1,30 @@
+use crate::{poly::*, params::*};
+
+pub fn get_bits_per(params: &Params, dim: usize) -> usize{
+    let modulus_log2 = params.modulus_log2;
+    if dim as u64 == modulus_log2 { return 1; }
+    ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1
+}
+
+pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw {
+    let mut g = PolyMatrixRaw::zero(params, rows, cols);
+    let nx = g.rows;
+    let m = g.cols;
+
+    assert_eq!(m % nx, 0);
+
+    let num_elems = m / nx;
+    let params = g.params;
+    let bits_per = get_bits_per(params, num_elems);
+
+    for i in 0..nx {
+        for j in 0..num_elems {
+            if bits_per * j >= 64 {
+                continue;
+            }
+            let poly = g.get_poly_mut(i, i + j * nx);
+            poly[0] = 1u64 << (bits_per * j);
+        }
+    }
+    g
+}

+ 7 - 4
src/lib.rs

@@ -1,8 +1,11 @@
+pub mod util;
 pub mod arith;
-pub mod ntt;
 pub mod number_theory;
+pub mod discrete_gaussian;
+
+pub mod ntt;
+pub mod gadget;
 pub mod params;
 pub mod poly;
-pub mod util;
-pub mod client;
-pub mod discrete_gaussian;
+
+pub mod client;

+ 2 - 2
src/main.rs

@@ -1,10 +1,10 @@
 use spiral_rs::poly::*;
-use spiral_rs::params::*;
+use spiral_rs::util::*;
 use spiral_rs::*;
 
 fn main() {
     println!("Hello, world!");
-    let params = Params::init(2048, &vec![7, 31]);
+    let params = get_test_params();
     let m1 = poly::PolyMatrixNTT::zero(&params, 2, 1);
     println!("{}", m1.is_ntt());
     let m2 = poly::PolyMatrixNTT::zero(&params, 3, 2);

+ 1 - 1
src/ntt.rs

@@ -351,7 +351,7 @@ mod test {
     use crate::util::*;
 
     fn get_params() -> Params {
-        Params::init(2048, &vec![268369921u64, 249561089u64])
+        get_test_params()
     }
 
     const REF_VAL: u64 = 519370102;

+ 1 - 1
src/number_theory.rs

@@ -93,4 +93,4 @@ pub fn invert_uint_mod(value: u64, modulus: u64) -> Option<u64> {
     } else {
         return Some(gcd_tuple.1 as u64);
     }
-}
+}

+ 24 - 5
src/params.rs

@@ -1,13 +1,15 @@
-use crate::{arith::*, ntt::*};
+use crate::{arith::*, ntt::*, number_theory::*};
 
 pub struct Params {    
     pub poly_len: usize,
     pub poly_len_log2: usize,
     pub ntt_tables: Vec<Vec<Vec<u64>>>,    
+    pub scratch: Vec<u64>,
     
     pub crt_count: usize,
     pub moduli: Vec<u64>,
     pub modulus: u64,
+    pub modulus_log2: u64,
 
     pub noise_width: f64,
 
@@ -17,12 +19,11 @@ pub struct Params {
     pub t_exp_left: usize,
     pub t_exp_right: usize,
     pub t_gsw: usize,
+
+    pub expand_queries: bool,
 }
 
 impl Params {
-    pub fn num_words(&self) -> usize {
-        self.poly_len * self.crt_count
-    }
     pub fn get_ntt_forward_table(&self, i: usize) -> &[u64] {
         self.ntt_tables[i][0].as_slice()
     }
@@ -44,7 +45,19 @@ impl Params {
     }
 
     pub fn m_conv(&self) -> usize {
-        2 * self.t_conv;
+        2 * self.t_conv
+    }
+
+    pub fn crt_compose_2(&self, x: u64, y: u64) -> u64 {
+        assert_eq!(self.crt_count, 2);
+
+        let a = self.moduli[0];
+        let b = self.moduli[1];
+        let a_inv_mod_b = invert_uint_mod(a, b).unwrap();
+        let b_inv_mod_a = invert_uint_mod(b, a).unwrap();
+        let mut val = (x as u128) * (b_inv_mod_a as u128) * (b as u128);
+        val += (y as u128) * (a_inv_mod_b as u128) * (a as u128);
+        (val % (self.modulus as u128)) as u64 // FIXME: use barrett
     }
 
     pub fn init(
@@ -56,27 +69,33 @@ impl Params {
         t_exp_left: usize,
         t_exp_right: usize,
         t_gsw: usize,
+        expand_queries: bool,
     ) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
         let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
+        let scratch = vec![0u64; crt_count * poly_len];
         let mut modulus = 1;
         for m in moduli {
             modulus *= m;
         }
+        let modulus_log2 = log2(modulus);
         Self {
             poly_len,
             poly_len_log2,
             ntt_tables,
+            scratch,
             crt_count,
             moduli: moduli.clone(),
             modulus,
+            modulus_log2,
             noise_width,
             n,
             t_conv,
             t_exp_left,
             t_exp_right,
             t_gsw,
+            expand_queries,
         }
     }
 }

+ 190 - 32
src/poly.rs

@@ -1,15 +1,22 @@
 use std::arch::x86_64::*;
-use std::ops::Mul;
+use std::ops::{Add, Mul, Neg};
+use std::cell::RefCell;
+use rand::Rng;
+use rand::distributions::Standard;
 
-use crate::{arith::*, params::*, ntt::*, util::calc_index};
+use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
+
+const SCRATCH_SPACE: usize = 8192;
+thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
 
 pub trait PolyMatrix<'a> {
     fn is_ntt(&self) -> bool;
     fn get_rows(&self) -> usize;
     fn get_cols(&self) -> usize;
     fn get_params(&self) -> &Params;
+    fn num_words(&self) -> usize;
     fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
-    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self;
+    fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
     fn as_slice(&self) -> &[u64];
     fn as_mut_slice(&mut self) -> &mut [u64];
     fn zero_out(&mut self) {
@@ -18,12 +25,12 @@ pub trait PolyMatrix<'a> {
         }
     }
     fn get_poly(&self, row: usize, col: usize) -> &[u64] {
-        let num_words = self.get_params().num_words();
+        let num_words = self.num_words();
         let start = (row * self.get_cols() + col) * num_words;
         &self.as_slice()[start..start + num_words]
     }
     fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
-        let num_words = self.get_params().num_words();
+        let num_words = self.num_words();
         let start = (row * self.get_cols() + col) * num_words;
         &mut self.as_mut_slice()[start..start + num_words]
     }
@@ -40,6 +47,7 @@ pub trait PolyMatrix<'a> {
             }
         }
     }
+    fn pad_top(&self, pad_rows: usize) -> Self;
 }
 
 pub struct PolyMatrixRaw<'a> {
@@ -75,6 +83,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
     fn as_mut_slice(&mut self) -> &mut [u64] {
         self.data.as_mut_slice()
     }
+    fn num_words(&self) -> usize {
+        self.params.poly_len
+    }
     fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
         let num_coeffs = rows * cols * params.poly_len;
         let data: Vec<u64> = vec![0; num_coeffs];
@@ -85,18 +96,25 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
             data,
         }
     }
-    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
+    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
+        let rng = rand::thread_rng();
+        let mut iter = rng.sample_iter(&Standard);
         let mut out = PolyMatrixRaw::zero(params, rows, cols);
         for r in 0..rows {
             for c in 0..cols {
                 for i in 0..params.poly_len {
-                    let val: u64 = rng.next().unwrap();
+                    let val: u64 = iter.next().unwrap();
                     out.get_poly_mut(r, c)[i] = val % params.modulus;
                 }
             }
         }
         out
     }
+    fn pad_top(&self, pad_rows: usize) -> Self {
+        let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
+        padded.copy_into(&self, pad_rows, 0);
+        padded
+    }
 }
 
 impl<'a> PolyMatrixRaw<'a> {
@@ -115,6 +133,16 @@ impl<'a> PolyMatrixRaw<'a> {
             data,
         }
     }
+
+    pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
+        let mut out = PolyMatrixRaw::zero(params, rows, cols);
+        dg.sample_matrix(&mut out);
+        out
+    }
+
+    pub fn ntt(&self) -> PolyMatrixNTT<'a> {
+        to_ntt_alloc(&self)
+    }
 }
 
 impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
@@ -136,6 +164,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
     fn as_mut_slice(&mut self) -> &mut [u64] {
         self.data.as_mut_slice()
     }
+    fn num_words(&self) -> usize {
+        self.params.poly_len * self.params.crt_count
+    }
     fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
         let num_coeffs = rows * cols * params.poly_len * params.crt_count;
         let data: Vec<u64> = vec![0; num_coeffs];
@@ -146,14 +177,16 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
             data,
         }
     }
-    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
+    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
+        let rng = rand::thread_rng();
+        let mut iter = rng.sample_iter(&Standard);
         let mut out = PolyMatrixNTT::zero(params, rows, cols);
         for r in 0..rows {
             for c in 0..cols {
                 for i in 0..params.crt_count {
                     for j in 0..params.poly_len {
                         let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
-                        let val: u64 = rng.next().unwrap();
+                        let val: u64 = iter.next().unwrap();
                         out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
                     }
                 }
@@ -161,6 +194,17 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
         }
         out
     }
+    fn pad_top(&self, pad_rows: usize) -> Self {
+        let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
+        padded.copy_into(&self, pad_rows, 0);
+        padded
+    }
+}
+
+impl<'a> PolyMatrixNTT<'a> {
+    pub fn raw(&self) -> PolyMatrixRaw<'a> {
+        from_ntt_alloc(&self)
+    }
 }
 
 pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
@@ -179,6 +223,22 @@ pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64])
     }
 }
 
+pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[i] = add_modular(params, a[i], b[i], c);
+        }
+    }
+}
+
+pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[i] = invert_modular(params, a[i], c);
+        }
+    }
+}
+
 pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in (0..params.poly_len).step_by(4) {
@@ -229,6 +289,8 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
 
 #[cfg(target_feature = "avx2")]
 pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == b.cols);
     assert!(a.cols == b.rows);
 
     let params = res.params;
@@ -248,58 +310,140 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     }
 }
 
-pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
-    assert_eq!(b.rows, 1);
-    assert_eq!(b.cols, 1);
+pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == a.cols);
+    assert!(a.rows == b.rows);
+    assert!(a.cols == b.cols);
 
     let params = res.params;
-    let pol2 = b.get_poly(0, 0);
     for i in 0..a.rows {
         for j in 0..a.cols {
             let res_poly = res.get_poly_mut(i, j);
             let pol1 = a.get_poly(i, j);
+            let pol2 = b.get_poly(i, j);
+            add_poly(params, res_poly, pol1, pol2);
+        }
+    }
+}
+
+pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == a.cols);
+
+    let params = res.params;
+    for i in 0..a.rows {
+        for j in 0..a.cols {
+            let res_poly = res.get_poly_mut(i, j);
+            let pol1 = a.get_poly(i, j);
+            invert_poly(params, res_poly, pol1);
+        }
+    }
+}
+
+pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
+    assert_eq!(a.cols, b.cols);
+    let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
+    c.copy_into(a, 0, 0);
+    c.copy_into(b, a.rows, 0);
+    c
+}
+
+pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert_eq!(a.rows, 1);
+    assert_eq!(a.cols, 1);
+
+    let params = res.params;
+    let pol2 = a.get_poly(0, 0);
+    for i in 0..b.rows {
+        for j in 0..b.cols {
+            let res_poly = res.get_poly_mut(i, j);
+            let pol1 = b.get_poly(i, j);
             multiply_poly(params, res_poly, pol1, pol2);
         }
     }
 }
 
-pub fn from_scalar_multiply<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
-    let mut res = PolyMatrixNTT::zero(a.params, a.rows, a.cols);
-    scalar_multiply(res, a, b);
+pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
+    scalar_multiply(&mut res, a, b);
     res
 }
 
 
 pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
+    let params = a.params;
     for r in 0..a.rows {
         for c in 0..a.cols {
-            let pol_src = a.get_poly_mut(r, c);
-            let pol_dst = b.get_poly_mut(r, c);
-            for n in 0..a.params.crt_count {
-                for z in 0..a.params.poly_len {
-                    pol_dst[n * a.params.poly_len + z] = pol_src[z] % a.params.moduli[n];
+            let pol_src = b.get_poly(r, c);
+            let pol_dst = a.get_poly_mut(r, c);
+            for n in 0..params.crt_count {
+                for z in 0..params.poly_len {
+                    pol_dst[n * params.poly_len + z] = pol_src[z] % params.moduli[n];
                 }
             }
-            ntt_forward(a.params, pol_dst);
+            ntt_forward(params, pol_dst);
         }
     }
 }
 
-pub fn from_ntt(params: &Params, res: &mut [u64]) {
-    for c in 0..params.crt_count {
-        for i in 0..params.poly_len {
-            res[c*params.poly_len + i] %= params.moduli[c];
+pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
+    let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
+    to_ntt(&mut a, b);
+    a
+}
+
+pub fn from_ntt(a: &mut PolyMatrixRaw, b: &PolyMatrixNTT) {
+    let params = a.params;
+    SCRATCH.with(|scratch_cell| {
+        let scratch_vec = &mut *scratch_cell.borrow_mut();
+        let scratch = scratch_vec.as_mut_slice();
+        for r in 0..a.rows {
+            for c in 0..a.cols {
+                let pol_src = b.get_poly(r, c);
+                let pol_dst = a.get_poly_mut(r, c);
+                scratch[0..pol_src.len()].copy_from_slice(pol_src);
+                ntt_inverse(params, scratch);
+                for z in 0..params.poly_len {
+                    pol_dst[z] = params.crt_compose_2(scratch[z], scratch[params.poly_len + z]);
+                }
+            }
         }
-    }
+    });
 }
 
+pub fn from_ntt_alloc<'a>(b: &PolyMatrixNTT<'a>) -> PolyMatrixRaw<'a> {
+    let mut a = PolyMatrixRaw::zero(b.params, b.rows, b.cols);
+    from_ntt(&mut a, b);
+    a
+}
 
-impl<'a> Mul for PolyMatrixNTT<'a> {
-    type Output = Self;
+impl<'a, 'b> Neg for &'b PolyMatrixRaw<'a> {
+    type Output = PolyMatrixRaw<'a>;
+
+    fn neg(self) -> Self::Output {
+        let mut out = PolyMatrixRaw::zero(self.params, self.rows, self.cols);
+        invert(&mut out, self);
+        out
+    }
+}
+
+impl<'a, 'b> Mul for &'b PolyMatrixNTT<'a> {
+    type Output = PolyMatrixNTT<'a>;
 
     fn mul(self, rhs: Self) -> Self::Output {
         let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
-        multiply(&mut out, &self, &rhs);
+        multiply(&mut out, self, rhs);
+        out
+    }
+}
+
+impl<'a, 'b> Add for &'b PolyMatrixNTT<'a> {
+    type Output = PolyMatrixNTT<'a>;
+
+    fn add(self, rhs: Self) -> Self::Output {
+        let mut out = PolyMatrixNTT::zero(self.params, self.rows, self.cols);
+        add(&mut out, self, rhs);
         out
     }
 }
@@ -309,7 +453,7 @@ mod test {
     use super::*;
 
     fn get_params() -> Params {
-        Params::init(2048, &vec![268369921u64, 249561089u64], 2, 6.4)
+        get_test_params()
     }
 
     fn assert_all_zero(a: &[u64]) {
@@ -330,7 +474,21 @@ mod test {
         let params = get_params();
         let m1 = PolyMatrixNTT::zero(&params, 2, 1);
         let m2 = PolyMatrixNTT::zero(&params, 3, 2);
-        let m3 = m2 * m1;
+        let m3 = &m2 * &m1;
         assert_all_zero(m3.as_slice());
     }
+
+    #[test]
+    fn full_multiply_correctness() {
+        let params = get_params();
+        let mut m1 = PolyMatrixRaw::zero(&params, 1, 1);
+        let mut m2 = PolyMatrixRaw::zero(&params, 1, 1);
+        m1.get_poly_mut(0, 0)[1] = 100;
+        m2.get_poly_mut(0, 0)[1] = 7;
+        let m1_ntt = to_ntt_alloc(&m1);
+        let m2_ntt = to_ntt_alloc(&m2);
+        let m3_ntt = &m1_ntt * &m2_ntt;
+        let m3 = from_ntt_alloc(&m3_ntt);
+        assert_eq!(m3.get_poly(0, 0)[2], 700);
+    }
 }

+ 16 - 0
src/util.rs

@@ -1,3 +1,5 @@
+use crate::params::*;
+
 pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
     let mut idx = 0usize;
     let mut prod = 1usize;
@@ -7,3 +9,17 @@ pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
     }
     idx
 }
+
+pub fn get_test_params() -> Params {
+    Params::init(
+        2048, 
+        &vec![268369921u64, 249561089u64], 
+        6.4,
+        2,
+        56,
+        56,
+        56,
+        56,
+        true
+    )
+}