Browse Source

add regev to gsw conversion

Samir Menon 1 year ago
parent
commit
9dca8ed1a0
3 changed files with 138 additions and 5 deletions
  1. 8 0
      spiral-rs/src/client.rs
  2. 3 3
      spiral-rs/src/poly.rs
  3. 127 2
      spiral-rs/src/server.rs

+ 8 - 0
spiral-rs/src/client.rs

@@ -226,6 +226,14 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
         &p + &a.pad_top(1)
     }
 
+    pub fn decrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+        &self.sk_reg_full.ntt() * a
+    }
+
+    pub fn decrypt_matrix_gsw(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+        &self.sk_gsw_full.ntt() * a
+    }
+
     fn generate_expansion_params(
         &mut self,
         num_exp: usize,

+ 3 - 3
spiral-rs/src/poly.rs

@@ -386,9 +386,9 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
 
 #[cfg(target_feature = "avx2")]
 pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
-    assert!(res.rows == a.rows);
-    assert!(res.cols == b.cols);
-    assert!(a.cols == b.rows);
+    assert_eq!(res.rows, a.rows);
+    assert_eq!(res.cols, b.cols);
+    assert_eq!(a.cols, b.rows);
 
     let params = res.params;
     for i in 0..a.rows {

+ 127 - 2
spiral-rs/src/server.rs

@@ -87,14 +87,89 @@ pub fn coefficient_expansion(
     }
 }
 
+pub fn regev_to_gsw<'a>(
+    v_gsw: &mut Vec<PolyMatrixNTT<'a>>,
+    v_inp: &Vec<PolyMatrixNTT<'a>>,
+    v: &PolyMatrixNTT<'a>,
+    params: &'a Params,
+    idx_factor: usize,
+    idx_offset: usize,
+) {
+    assert!(v.rows == 2);
+    assert!(v.cols == 2 * params.t_conv);
+
+    let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
+    let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
+    let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
+    let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
+
+    for i in 0..params.db_dim_2 {
+        let ct = &mut v_gsw[i];
+        for j in 0..params.t_gsw {
+            let idx_ct = i * params.t_gsw + j;
+            let idx_inp = idx_factor * (idx_ct) + idx_offset;
+            ct.copy_into(&v_inp[idx_inp], 0, 2 * j + 1);
+            from_ntt(&mut tmp_ct_raw, &v_inp[idx_inp]);
+            gadget_invert(&mut ginv_c_inp, &tmp_ct_raw);
+            to_ntt(&mut ginv_c_inp_ntt, &ginv_c_inp);
+            multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
+            ct.copy_into(&tmp_ct, 0, 2 * j);
+        }
+    }
+}
+
 #[cfg(test)]
 mod test {
+    use rand::prelude::StdRng;
+
     use crate::{client::*, util::*};
 
     use super::*;
 
     fn get_params() -> Params {
-        get_expansion_testing_params()
+        let mut params = get_expansion_testing_params();
+        params.db_dim_1 = 6;
+        params.db_dim_2 = 2;
+        params.t_exp_right = 8;
+        params
+    }
+
+    fn dec_reg<'a>(
+        params: &'a Params,
+        ct: &PolyMatrixNTT<'a>,
+        client: &mut Client<'a, StdRng>,
+        scale_k: u64,
+    ) -> u64 {
+        let dec = client.decrypt_matrix_reg(ct).raw();
+        let mut val = dec.data[0] as i64;
+        if val >= (params.modulus / 2) as i64 {
+            val -= params.modulus as i64;
+        }
+        let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
+        println!("{:?} {:?}", val, val_rounded);
+        if val_rounded == 0 {
+            0
+        } else {
+            1
+        }
+    }
+
+    fn dec_gsw<'a>(
+        params: &'a Params,
+        ct: &PolyMatrixNTT<'a>,
+        client: &mut Client<'a, StdRng>,
+    ) -> u64 {
+        let dec = client.decrypt_matrix_reg(ct).raw();
+        let idx = (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
+        let mut val = dec.data[idx] as i64;
+        if val >= (params.modulus / 2) as i64 {
+            val -= params.modulus as i64;
+        }
+        if val < 100 {
+            0
+        } else {
+            1
+        }
     }
 
     #[test]
@@ -109,10 +184,13 @@ mod test {
         for _ in 0..params.poly_len {
             v.push(PolyMatrixNTT::zero(&params, 2, 1));
         }
+
+        let target = 7;
         let scale_k = params.modulus / params.pt_modulus;
         let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
-        sigma.data[7] = scale_k;
+        sigma.data[target] = scale_k;
         v[0] = client.encrypt_matrix_reg(&sigma.ntt());
+        let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
 
         let v_w_left = public_params.v_expansion_left.unwrap();
         let v_w_right = public_params.v_expansion_right.unwrap();
@@ -126,5 +204,52 @@ mod test {
             &v_neg1,
             params.t_gsw * params.db_dim_2,
         );
+
+        assert_eq!(dec_reg(&params, &test_ct, &mut client, scale_k), 0);
+
+        for i in 0..v.len() {
+            if i == target {
+                assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 1);
+            } else {
+                assert_eq!(dec_reg(&params, &v[i], &mut client, scale_k), 0);
+            }
+        }
+    }
+
+    #[test]
+    fn regev_to_gsw_is_correct() {
+        let mut params = get_params();
+        params.db_dim_2 = 1;
+        let mut seeded_rng = get_seeded_rng();
+        let mut client = Client::init(&params, &mut seeded_rng);
+        let public_params = client.generate_keys();
+
+        let mut enc_constant = |val| {
+            let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
+            sigma.data[0] = val;
+            client.encrypt_matrix_reg(&sigma.ntt())
+        };
+
+        let v = &public_params.v_conversion.unwrap()[0];
+
+        let bits_per = get_bits_per(&params, params.t_gsw);
+        let mut v_inp_1 = Vec::new();
+        let mut v_inp_0 = Vec::new();
+        for i in 0..params.t_gsw {
+            let val = 1u64 << (bits_per * i);
+            v_inp_1.push(enc_constant(val));
+            v_inp_0.push(enc_constant(0));
+        }
+
+        let mut v_gsw = Vec::new();
+        v_gsw.push(PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw));
+
+        regev_to_gsw(&mut v_gsw, &v_inp_1, v, &params, 1, 0);
+
+        assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 1);
+
+        regev_to_gsw(&mut v_gsw, &v_inp_0, v, &params, 1, 0);
+
+        assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
     }
 }