Browse Source

finish keygen

Samir Menon 2 years ago
parent
commit
db1400f916
4 changed files with 117 additions and 4 deletions
  1. 68 3
      src/client.rs
  2. 6 0
      src/params.rs
  3. 40 0
      src/poly.rs
  4. 3 1
      src/util.rs

+ 68 - 3
src/client.rs

@@ -1,6 +1,6 @@
 use std::collections::HashMap;
 
-use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*};
+use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*, arith::*};
 
 pub struct PublicParameters<'a> {
     v_packing: Vec<PolyMatrixNTT<'a>>,            // Ws
@@ -69,13 +69,61 @@ impl<'a> Client<'a> {
         p
     }
 
-    fn encrypt_matrix_gsw(&mut self, ag: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
+        let params = self.params;
+        let a = PolyMatrixRaw::random(params, 1, 1);
+        let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
+        let b_p = &self.sk_reg.ntt() * &a.ntt();
+        let b = &e.ntt() + &b_p;
+        let mut p = PolyMatrixNTT::zero(params, 2, 1);
+        p.copy_into(&(-&a).ntt(), 0, 0);
+        p.copy_into(&b, 1, 0);
+        p
+    }
+
+    fn get_fresh_reg_public_key(&mut self, m: usize) -> PolyMatrixNTT<'a> {
         let params = self.params;
+
+        let mut p = PolyMatrixNTT::zero(params, 2, m);
+
+        for i in 0..m {
+            p.copy_into(&self.get_regev_sample(), 0, i);
+        }
+
+        p
+    }
+
+    fn encrypt_matrix_gsw(&mut self, ag: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         let mx = ag.cols;
         let p = self.get_fresh_gsw_public_key(mx);
         let res = &(p.ntt()) + &(ag.pad_top(1));
         res
     }
+    
+    fn encrypt_matrix_reg(&mut self, a: PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+        let m = a.cols;
+        let p = self.get_fresh_reg_public_key(m);
+        &p + &a.pad_top(1)
+    }
+
+    fn generate_expansion_params(&mut self, num_exp: usize, m_exp: usize) -> Vec<PolyMatrixNTT<'a>> {
+        // MatPoly G_exp = buildGadget(1, m_exp);
+        // MatPoly G_exp_nttd = to_ntt(G_exp);
+        let params = self.params;
+        let g_exp = build_gadget(params, 1, m_exp);
+        let g_exp_ntt = g_exp.ntt();
+        let mut res = Vec::new();
+
+        for i in 0..num_exp {
+            let t = (params.poly_len / (1 << i)) + 1;
+            let tau_sk_reg = automorph_alloc(&self.sk_reg, t);
+            // MatPoly W_exp_i = encryptSimpleRegevMatrix(s0, multiply(tau_s0, G_exp_nttd));
+            let prod = &tau_sk_reg.ntt() * &g_exp_ntt;
+            let w_exp_i = self.encrypt_matrix_reg(prod);
+            res.push(w_exp_i);
+        }
+        res
+    }
 
     pub fn generate_keys(&mut self) -> PublicParameters {
         let params = self.params;
@@ -100,9 +148,26 @@ impl<'a> Client<'a> {
         }
 
         // Params for expansion
+        let further_dims = 1usize << params.db_dim_2;
+        let num_expanded = 1usize << params.db_dim_1;
+        let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
+        let g = log2(num_bits_to_gen as u64) as usize;
+        let stop_round = log2((params.t_gsw * further_dims) as u64) as usize;
+        pp.v_expansion_left = self.generate_expansion_params(g, params.t_exp_left);
+        pp.v_expansion_right = self.generate_expansion_params(stop_round + 1, params.t_exp_right);
 
         // Params for converison
-
+        let g_conv = build_gadget(params, 2, 2 * m_conv);
+        let sk_reg_squared_ntt = &self.sk_reg.ntt() * &self.sk_reg.ntt();
+        pp.v_conversion = PolyMatrixNTT::zero(params, 2, 2 * m_conv);
+        for i in 0..2*m_conv {
+            if i % 2 == 0 {
+                let val = g_conv.get_poly(0, i)[0];
+                let sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
+                let ct = self.encrypt_matrix_reg(sigma);
+                pp.v_conversion.copy_into(&ct, 0, i);
+            }
+        }
 
         pp
     }

+ 6 - 0
src/params.rs

@@ -21,6 +21,8 @@ pub struct Params {
     pub t_gsw: usize,
 
     pub expand_queries: bool,
+    pub db_dim_1: usize,
+    pub db_dim_2: usize,
 }
 
 impl Params {
@@ -70,6 +72,8 @@ impl Params {
         t_exp_right: usize,
         t_gsw: usize,
         expand_queries: bool,
+        db_dim_1: usize,
+        db_dim_2: usize,
     ) -> Self {
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
@@ -96,6 +100,8 @@ impl Params {
             t_exp_right,
             t_gsw,
             expand_queries,
+            db_dim_1,
+            db_dim_2,
         }
     }
 }

+ 40 - 0
src/poly.rs

@@ -239,6 +239,20 @@ pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
     }
 }
 
+pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
+    let poly_len = params.poly_len;
+    for i in 0..poly_len {
+        let num = (i * t) / poly_len;
+        let rem = (i * t) % poly_len;
+
+        if num % 2 == 0 {
+            res[rem] = a[i];
+        } else {
+            res[rem] = params.modulus - a[i];
+        }
+    }
+}
+
 pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in (0..params.poly_len).step_by(4) {
@@ -341,6 +355,26 @@ pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
     }
 }
 
+pub fn automorph<'a>(res: &mut PolyMatrixRaw<'a>, a: &PolyMatrixRaw<'a>, t: usize) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == a.cols);
+
+    let params = res.params;
+    for i in 0..a.rows {
+        for j in 0..a.cols {
+            let res_poly = res.get_poly_mut(i, j);
+            let pol1 = a.get_poly(i, j);
+            automorph_poly(params, res_poly, pol1, t);
+        }
+    }
+}
+
+pub fn automorph_alloc<'a>(a: &PolyMatrixRaw<'a>, t: usize) -> PolyMatrixRaw<'a> {
+    let mut res = PolyMatrixRaw::zero(a.params, a.rows, a.cols);
+    automorph(&mut res, a, t);
+    res
+}
+
 pub fn stack<'a>(a: &PolyMatrixRaw<'a>, b: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
     assert_eq!(a.cols, b.cols);
     let mut c = PolyMatrixRaw::zero(a.params, a.rows + b.rows, a.cols);
@@ -370,6 +404,12 @@ pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -
     res
 }
 
+pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
+    let mut res = PolyMatrixRaw::zero(params, 1, 1);
+    res.get_poly_mut(0, 0)[0] = val;
+    res
+}
+
 
 pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
     let params = a.params;

+ 3 - 1
src/util.rs

@@ -20,6 +20,8 @@ pub fn get_test_params() -> Params {
         56,
         56,
         56,
-        true
+        true,
+        9,
+        6,
     )
 }