Samir Menon 2 years ago
parent
commit
cbb8c95400
5 changed files with 271 additions and 16 deletions
  1. 74 0
      src/client.rs
  2. 66 0
      src/discrete_gaussian.rs
  3. 3 1
      src/lib.rs
  4. 41 5
      src/params.rs
  5. 87 10
      src/poly.rs

+ 74 - 0
src/client.rs

@@ -0,0 +1,74 @@
+use std::collections::HashMap;
+
+use crate::{poly::*, params::*, discrete_gaussian::*};
+
+pub struct PublicParameters<'a> {
+    v_packing: Vec<PolyMatrixNTT<'a>>,            // Ws
+    v_expansion_left: Vec<PolyMatrixNTT<'a>>,
+    v_expansion_right: Vec<PolyMatrixNTT<'a>>,
+    v_conversion: PolyMatrixNTT<'a>,              // V
+}
+
+impl<'a> PublicParameters<'a> {
+    fn init(params: &'a Params) -> Self {
+        PublicParameters { 
+            v_packing: Vec::new(), 
+            v_expansion_left: Vec::new(), 
+            v_expansion_right: Vec::new(), 
+            v_conversion: PolyMatrixNTT::zero(params, 2, 2 * params.m_conv()) 
+        }
+    }
+}
+
+pub struct Client<'a> {
+    params: &'a Params,
+    sk_gsw: PolyMatrixRaw<'a>,
+    sk_reg: PolyMatrixRaw<'a>,
+    sk_gsw_full: PolyMatrixRaw<'a>,
+    sk_reg_full: PolyMatrixRaw<'a>,
+}
+
+fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
+    assert_eq!(p.cols, 1);
+    let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
+    r.copy_into(p, 0, 0);
+    r.copy_into(&PolyMatrixRaw::identity(p.params, p.rows, p.rows), 0, 1);
+    r
+}
+
+impl<'a> Client<'a> {
+    fn init(params: &'a Params) -> Self {
+        let sk_gsw_dims = params.get_sk_gsw();
+        let sk_reg_dims = params.get_sk_reg();
+        let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
+        let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
+        let sk_gsw_full = matrix_with_identity(&sk_gsw);
+        let sk_reg_full = matrix_with_identity(&sk_reg);
+        Self {
+            params,
+            sk_gsw,
+            sk_reg,
+            sk_gsw_full,
+            sk_reg_full,
+        }
+    }
+    fn generate_keys(&mut self) -> PublicParameters {
+        let params = self.params;
+        let mut dg = DiscreteGaussian::init(params);
+        dg.sample_matrix(&mut self.sk_gsw);
+        dg.sample_matrix(&mut self.sk_reg);
+        self.sk_gsw_full = matrix_with_identity(&self.sk_gsw);
+        self.sk_reg_full = matrix_with_identity(&self.sk_reg);
+        let sk_reg_ntt = to_ntt()
+        let pp = PublicParameters::init(params);
+        
+        // For packing
+        for i in 0..params.n {
+            MatPoly scaled = from_scalar_multiply(sk_reg_ntt, )
+        }
+
+        pp
+    }
+    // fn generate_query(&self) -> Query<'a, Params>;
+    
+}

+ 66 - 0
src/discrete_gaussian.rs

@@ -0,0 +1,66 @@
+use rand::{thread_rng, Rng, rngs::ThreadRng};
+use rand::distributions::OpenClosed01;
+
+use crate::params::*;
+use crate::poly::*;
+use std::f64::consts::PI;
+
+pub const NUM_WIDTHS: usize = 8;
+
+pub struct DiscreteGaussian {
+    max_val: i64,
+    table: Vec<f64>,
+    rng: ThreadRng
+}
+
+impl DiscreteGaussian {
+    pub fn init(params: &Params) -> Self {
+        let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
+        let mut table = vec![0f64; 0];
+        let mut sum_p = 0f64;
+        table.push(0f64);
+        for i in -max_val..max_val+1 {
+            let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2));
+            table.push(p_val);
+            sum_p += p_val;
+        }
+        for i in 0..table.len() {
+            table[i] /= sum_p;
+        }
+        table.push(1.0);
+    
+        Self {
+            max_val,
+            table,
+            rng: thread_rng()
+        }
+    }
+
+    // FIXME: this is not necessarily constant-time w/ optimization
+    pub fn sample(&mut self) -> i64 {
+        let val: f64 = self.rng.sample(OpenClosed01);
+        let mut found = 0i64;
+        for i in 0..self.table.len()-1 {
+           let bit1: i64 = (val <= self.table[i]) as i64;
+           let bit2: i64 = (val > self.table[i+1]) as i64;
+            found += bit1 * bit2 * (i as i64);
+        }
+        found -= self.max_val;
+        found
+    }
+
+    pub fn sample_matrix(&mut self, p: &mut PolyMatrixRaw) {
+        let modulus = p.get_params().modulus;
+        for r in 0..p.rows {
+            for c in 0..p.cols {
+                let poly = p.get_poly_mut(r, c);
+                for z in 0..poly.len() {
+                    let mut s = self.sample();
+                    s += modulus as i64;
+                    s %= modulus as i64; // FIXME: not constant time
+                    poly[z] = s as u64;
+                }
+            }
+        }
+    }
+}

+ 3 - 1
src/lib.rs

@@ -3,4 +3,6 @@ pub mod ntt;
 pub mod number_theory;
 pub mod params;
 pub mod poly;
-pub mod util;
+pub mod util;
+pub mod client;
+pub mod discrete_gaussian;

+ 41 - 5
src/params.rs

@@ -1,12 +1,22 @@
 use crate::{arith::*, ntt::*};
 
-pub struct Params {
+pub struct Params {    
     pub poly_len: usize,
     pub poly_len_log2: usize,
-    pub ntt_tables: Vec<Vec<Vec<u64>>>,
+    pub ntt_tables: Vec<Vec<Vec<u64>>>,    
+    
     pub crt_count: usize,
     pub moduli: Vec<u64>,
-    pub modulus: u64
+    pub modulus: u64,
+
+    pub noise_width: f64,
+
+    pub n: usize,
+
+    pub t_conv: usize,
+    pub t_exp_left: usize,
+    pub t_exp_right: usize,
+    pub t_gsw: usize,
 }
 
 impl Params {
@@ -26,7 +36,27 @@ impl Params {
         self.ntt_tables[i][3].as_slice()
     }
 
-    pub fn init(poly_len: usize, moduli: &Vec<u64>) -> Self {
+    pub fn get_sk_gsw(&self) -> (usize, usize) {
+        (self.n, 1)
+    }
+    pub fn get_sk_reg(&self) -> (usize, usize) {
+        (1, 1)
+    }
+
+    pub fn m_conv(&self) -> usize {
+        2 * self.t_conv;
+    }
+
+    pub fn init(
+        poly_len: usize,
+        moduli: &Vec<u64>,
+        noise_width: f64,
+        n: usize,
+        t_conv: usize,
+        t_exp_left: usize,
+        t_exp_right: usize,
+        t_gsw: usize,
+    ) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
         let ntt_tables = build_ntt_tables(poly_len, moduli.as_slice());
@@ -40,7 +70,13 @@ impl Params {
             ntt_tables,
             crt_count,
             moduli: moduli.clone(),
-            modulus
+            modulus,
+            noise_width,
+            n,
+            t_conv,
+            t_exp_left,
+            t_exp_right,
+            t_gsw,
         }
     }
 }

+ 87 - 10
src/poly.rs

@@ -1,7 +1,7 @@
 use std::arch::x86_64::*;
 use std::ops::Mul;
 
-use crate::{arith::*, params::*, util::calc_index};
+use crate::{arith::*, params::*, ntt::*, util::calc_index};
 
 pub trait PolyMatrix<'a> {
     fn is_ntt(&self) -> bool;
@@ -27,20 +27,33 @@ pub trait PolyMatrix<'a> {
         let start = (row * self.get_cols() + col) * num_words;
         &mut self.as_mut_slice()[start..start + num_words]
     }
+    fn copy_into(&mut self, p: &Self, target_row: usize, target_col: usize) {
+        assert!(target_row < self.get_rows());
+        assert!(target_col < self.get_cols());
+        assert!(target_row + p.get_rows() < self.get_rows());
+        assert!(target_col + p.get_cols() < self.get_cols());
+        for r in 0..p.get_rows() {
+            for c in 0..p.get_cols() {
+                let pol_src = p.get_poly(r, c);
+                let pol_dst = self.get_poly_mut(target_row + r, target_col + c);
+                pol_dst.copy_from_slice(pol_src);
+            }
+        }
+    }
 }
 
 pub struct PolyMatrixRaw<'a> {
-    params: &'a Params,
-    rows: usize,
-    cols: usize,
-    data: Vec<u64>,
+    pub params: &'a Params,
+    pub rows: usize,
+    pub cols: usize,
+    pub data: Vec<u64>,
 }
 
 pub struct PolyMatrixNTT<'a> {
-    params: &'a Params,
-    rows: usize,
-    cols: usize,
-    data: Vec<u64>,
+    pub params: &'a Params,
+    pub rows: usize,
+    pub cols: usize,
+    pub data: Vec<u64>,
 }
 
 impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
@@ -86,6 +99,24 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
     }
 }
 
+impl<'a> PolyMatrixRaw<'a> {
+    pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
+        let num_coeffs = rows * cols * params.poly_len;
+        let mut data: Vec<u64> = vec![0; num_coeffs];
+        for r in 0..rows {
+            let c = r;
+            let idx = r * cols * params.poly_len + c * params.poly_len;
+            data[idx] = 1;
+        }
+        PolyMatrixRaw {
+            params,
+            rows,
+            cols,
+            data,
+        }
+    }
+}
+
 impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
     fn is_ntt(&self) -> bool {
         true
@@ -217,6 +248,52 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     }
 }
 
+pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert_eq!(b.rows, 1);
+    assert_eq!(b.cols, 1);
+
+    let params = res.params;
+    let pol2 = b.get_poly(0, 0);
+    for i in 0..a.rows {
+        for j in 0..a.cols {
+            let res_poly = res.get_poly_mut(i, j);
+            let pol1 = a.get_poly(i, j);
+            multiply_poly(params, res_poly, pol1, pol2);
+        }
+    }
+}
+
+pub fn from_scalar_multiply<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    let mut res = PolyMatrixNTT::zero(a.params, a.rows, a.cols);
+    scalar_multiply(res, a, b);
+    res
+}
+
+
+pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
+    for r in 0..a.rows {
+        for c in 0..a.cols {
+            let pol_src = a.get_poly_mut(r, c);
+            let pol_dst = b.get_poly_mut(r, c);
+            for n in 0..a.params.crt_count {
+                for z in 0..a.params.poly_len {
+                    pol_dst[n * a.params.poly_len + z] = pol_src[z] % a.params.moduli[n];
+                }
+            }
+            ntt_forward(a.params, pol_dst);
+        }
+    }
+}
+
+pub fn from_ntt(params: &Params, res: &mut [u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            res[c*params.poly_len + i] %= params.moduli[c];
+        }
+    }
+}
+
+
 impl<'a> Mul for PolyMatrixNTT<'a> {
     type Output = Self;
 
@@ -232,7 +309,7 @@ mod test {
     use super::*;
 
     fn get_params() -> Params {
-        Params::init(2048, &vec![268369921u64, 249561089u64])
+        Params::init(2048, &vec![268369921u64, 249561089u64], 2, 6.4)
     }
 
     fn assert_all_zero(a: &[u64]) {