Browse Source

basic operator multiply

Samir Menon 2 years ago
parent
commit
766c662b66
8 changed files with 134 additions and 17 deletions
  1. 0 2
      Cargo.toml
  2. 1 1
      benches/ntt.rs
  3. 19 0
      benches/poly.rs
  4. 4 0
      src/arith.rs
  5. 1 1
      src/main.rs
  6. 1 1
      src/ntt.rs
  7. 8 2
      src/params.rs
  8. 100 10
      src/poly.rs

+ 0 - 2
Cargo.toml

@@ -3,8 +3,6 @@ name = "spiral-rs"
 version = "0.1.0"
 edition = "2021"
 
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
 [dependencies]
 rand = "0.8.5"
 

+ 1 - 1
benches/ntt.rs

@@ -5,7 +5,7 @@ use rand::Rng;
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 
 fn criterion_benchmark(c: &mut Criterion) {
-    let params = Params::init(2048, vec![268369921u64, 249561089u64]);
+    let params = Params::init(2048, &vec![268369921u64, 249561089u64]);
     let mut v1 = vec![0; params.crt_count * params.poly_len];
     let mut rng = rand::thread_rng();
     for i in 0..params.crt_count {

+ 19 - 0
benches/poly.rs

@@ -0,0 +1,19 @@
+use spiral_rs::poly::*;
+use spiral_rs::params::*;
+use spiral_rs::util::*;
+use rand::Rng;
+use rand::distributions::Standard;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let params = Params::init(2048, &vec![268369921u64, 249561089u64]);
+    let mut rng = rand::thread_rng();
+    let mut iter = rng.sample_iter(&Standard);
+    let m1 = PolyMatrixNTT::random(&params, 2, 1, &mut iter);
+    let m2 = PolyMatrixNTT::random(&params, 3, 2, &mut iter);
+    let mut m3 = PolyMatrixNTT::zero(&params, 2, 2);
+    c.bench_function("nttf 2048", |b| b.iter(|| multiply(black_box(&mut m3), black_box(&m1), black_box(&m2))));
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);

+ 4 - 0
src/arith.rs

@@ -17,6 +17,10 @@ pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -
     (a * b + x) % params.moduli[c]
 }
 
+pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
+    (x) % params.moduli[c]
+}
+
 pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
     if exponent == 0 {
         return 1;

+ 1 - 1
src/main.rs

@@ -4,7 +4,7 @@ use spiral_rs::*;
 
 fn main() {
     println!("Hello, world!");
-    let params = Params::init(2048, vec![7, 31]);
+    let params = Params::init(2048, &vec![7, 31]);
     let m1 = poly::PolyMatrixNTT::zero(&params, 2, 1);
     println!("{}", m1.is_ntt());
     let m2 = poly::PolyMatrixNTT::zero(&params, 3, 2);

+ 1 - 1
src/ntt.rs

@@ -351,7 +351,7 @@ mod test {
     use crate::util::*;
 
     fn get_params() -> Params {
-        Params::init(2048, vec![268369921u64, 249561089u64])
+        Params::init(2048, &vec![268369921u64, 249561089u64])
     }
 
     const REF_VAL: u64 = 519370102;

+ 8 - 2
src/params.rs

@@ -6,6 +6,7 @@ pub struct Params {
     pub ntt_tables: Vec<Vec<Vec<u64>>>,
     pub crt_count: usize,
     pub moduli: Vec<u64>,
+    pub modulus: u64
 }
 
 impl Params {
@@ -25,16 +26,21 @@ impl Params {
         self.ntt_tables[i][3].as_slice()
     }
 
-    pub fn init(poly_len: usize, moduli: Vec<u64>) -> Self {
+    pub fn init(poly_len: usize, moduli: &Vec<u64>) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
         let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
+        let mut modulus = 1;
+        for m in moduli {
+            modulus *= m;
+        }
         Self {
             poly_len,
             poly_len_log2,
             ntt_tables,
             crt_count,
-            moduli,
+            moduli: moduli.clone(),
+            modulus
         }
     }
 }

+ 100 - 10
src/poly.rs

@@ -1,4 +1,7 @@
-use crate::{arith::*, params::*};
+use std::arch::x86_64::*;
+use std::ops::Mul;
+
+use crate::{arith::*, params::*, util::calc_index};
 
 pub trait PolyMatrix<'a> {
     fn is_ntt(&self) -> bool;
@@ -6,6 +9,7 @@ pub trait PolyMatrix<'a> {
     fn get_cols(&self) -> usize;
     fn get_params(&self) -> &Params;
     fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
+    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self;
     fn as_slice(&self) -> &[u64];
     fn as_mut_slice(&mut self) -> &mut [u64];
     fn zero_out(&mut self) {
@@ -14,14 +18,14 @@ pub trait PolyMatrix<'a> {
         }
     }
     fn get_poly(&self, row: usize, col: usize) -> &[u64] {
-        let params = self.get_params();
-        let start = (row * self.get_cols() + col) * params.poly_len;
-        &self.as_slice()[start..start + params.poly_len]
+        let num_words = self.get_params().num_words();
+        let start = (row * self.get_cols() + col) * num_words;
+        &self.as_slice()[start..start + num_words]
     }
     fn get_poly_mut(&mut self, row: usize, col: usize) -> &mut [u64] {
-        let poly_len = self.get_params().poly_len;
-        let start = (row * self.get_cols() + col) * poly_len;
-        &mut self.as_mut_slice()[start..start + poly_len]
+        let num_words = self.get_params().num_words();
+        let start = (row * self.get_cols() + col) * num_words;
+        &mut self.as_mut_slice()[start..start + num_words]
     }
 }
 
@@ -68,6 +72,18 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
             data,
         }
     }
+    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
+        let mut out = PolyMatrixRaw::zero(params, rows, cols);
+        for r in 0..rows {
+            for c in 0..cols {
+                for i in 0..params.poly_len {
+                    let val: u64 = rng.next().unwrap();
+                    out.get_poly_mut(r, c)[i] = val % params.modulus;
+                }
+            }
+        }
+        out
+    }
 }
 
 impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
@@ -99,6 +115,21 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
             data,
         }
     }
+    fn random(params: &'a Params, rows: usize, cols: usize, rng: &mut dyn Iterator<Item=u64>) -> Self {
+        let mut out = PolyMatrixNTT::zero(params, rows, cols);
+        for r in 0..rows {
+            for c in 0..cols {
+                for i in 0..params.crt_count {
+                    for j in 0..params.poly_len {
+                        let idx = calc_index(&[i, j], &[params.crt_count, params.poly_len]);
+                        let val: u64 = rng.next().unwrap();
+                        out.get_poly_mut(r, c)[idx] = val % params.moduli[i];
+                    }
+                }
+            }
+        }
+        out
+    }
 }
 
 pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
@@ -117,6 +148,35 @@ pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64])
     }
 }
 
+pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in (0..params.poly_len).step_by(4) {
+            unsafe {
+                let p_x = &a[c*params.poly_len + i] as *const u64;
+                let p_y = &b[c*params.poly_len + i] as *const u64;
+                let p_z = &mut res[c*params.poly_len + i] as *mut u64;
+                let x = _mm256_loadu_si256(p_x as *const __m256i);
+                let y = _mm256_loadu_si256(p_y as *const __m256i);
+                let z = _mm256_loadu_si256(p_z as *const __m256i);
+
+                let product = _mm256_mul_epu32(x, y);
+                let out = _mm256_add_epi64(z, product);
+                
+                _mm256_storeu_si256(p_z as *mut __m256i, out);
+            }
+        }
+    }
+}
+
+pub fn modular_reduce(params: &Params, res: &mut [u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[c*params.poly_len + i] %= params.moduli[c];
+        }
+    }
+}
+
+#[cfg(not(target_feature = "avx2"))]
 pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     assert!(a.cols == b.rows);
 
@@ -136,12 +196,43 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     }
 }
 
+#[cfg(target_feature = "avx2")]
+pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert!(a.cols == b.rows);
+
+    let params = res.params;
+    for i in 0..a.rows {
+        for j in 0..b.cols {
+            for z in 0..res.params.poly_len {
+                res.get_poly_mut(i, j)[z] = 0;
+            }
+            let res_poly = res.get_poly_mut(i, j);
+            for k in 0..a.cols {
+                let pol1 = a.get_poly(i, k);
+                let pol2 = b.get_poly(k, j);
+                multiply_add_poly_avx(params, res_poly, pol1, pol2);
+            }
+            modular_reduce(params, res_poly);
+        }
+    }
+}
+
+impl<'a> Mul for PolyMatrixNTT<'a> {
+    type Output = Self;
+
+    fn mul(self, rhs: Self) -> Self::Output {
+        let mut out = PolyMatrixNTT::zero(self.params, self.rows, rhs.cols);
+        multiply(&mut out, &self, &rhs);
+        out
+    }
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
 
     fn get_params() -> Params {
-        Params::init(2048, vec![268369921u64, 249561089u64])
+        Params::init(2048, &vec![268369921u64, 249561089u64])
     }
 
     fn assert_all_zero(a: &[u64]) {
@@ -162,8 +253,7 @@ mod test {
         let params = get_params();
         let m1 = PolyMatrixNTT::zero(&params, 2, 1);
         let m2 = PolyMatrixNTT::zero(&params, 3, 2);
-        let mut m3 = PolyMatrixNTT::zero(&params, 3, 1);
-        multiply(&mut m3, &m2, &m1);
+        let m3 = m2 * m1;
         assert_all_zero(m3.as_slice());
     }
 }