Browse Source

add folding

Samir Menon 2 years ago
parent
commit
bad75819dd
3 changed files with 121 additions and 1 deletions
  1. 1 1
      spiral-rs/src/client.rs
  2. 23 0
      spiral-rs/src/poly.rs
  3. 97 0
      spiral-rs/src/server.rs

+ 1 - 1
spiral-rs/src/client.rs

@@ -110,7 +110,7 @@ impl<'a> Query<'a> {
 pub struct Client<'a, TRng: Rng> {
     params: &'a Params,
     sk_gsw: PolyMatrixRaw<'a>,
-    sk_reg: PolyMatrixRaw<'a>,
+    pub sk_reg: PolyMatrixRaw<'a>,
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
     dg: DiscreteGaussian<'a, TRng>,

+ 23 - 0
spiral-rs/src/poly.rs

@@ -321,6 +321,15 @@ pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     }
 }
 
+pub fn add_poly_into(params: &Params, res: &mut [u64], a: &[u64]) {
+    for c in 0..params.crt_count {
+        for i in 0..params.poly_len {
+            let idx = c * params.poly_len + i;
+            res[idx] = add_modular(params, res[idx], a[idx], c);
+        }
+    }
+}
+
 pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
     for i in 0..params.poly_len {
         res[i] = params.modulus - a[i];
@@ -434,6 +443,20 @@ pub fn add(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     }
 }
 
+pub fn add_into(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == a.cols);
+
+    let params = res.params;
+    for i in 0..res.rows {
+        for j in 0..res.cols {
+            let res_poly = res.get_poly_mut(i, j);
+            let pol2 = a.get_poly(i, j);
+            add_poly_into(params, res_poly, pol2);
+        }
+    }
+}
+
 pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
     assert!(res.rows == a.rows);
     assert!(res.cols == a.cols);

+ 97 - 0
spiral-rs/src/server.rs

@@ -283,6 +283,35 @@ pub fn generate_random_db_and_get_item<'a>(
     (item, v)
 }
 
+pub fn fold_ciphertexts(
+    params: &Params,
+    v_cts: &mut Vec<PolyMatrixRaw>,
+    v_folding: &Vec<PolyMatrixNTT>,
+    v_folding_neg: &Vec<PolyMatrixNTT>
+) {
+    let further_dims = log2(v_cts.len() as u64) as usize;
+    let ell = v_folding[0].cols / 2;
+    let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
+    let mut ginv_c_ntt = PolyMatrixNTT::zero(&params, 2 * ell, 1);
+    let mut prod = PolyMatrixNTT::zero(&params, 2, 1);
+    let mut sum = PolyMatrixNTT::zero(&params, 2, 1);
+
+    let mut num_per = v_cts.len();
+    for cur_dim in 0..further_dims {
+        num_per = num_per / 2;
+        for i in 0..num_per {
+            gadget_invert(&mut ginv_c, &v_cts[i]);
+            to_ntt(&mut ginv_c_ntt, &ginv_c);
+            multiply(&mut prod, &v_folding_neg[further_dims - 1 - cur_dim], &ginv_c_ntt);
+            gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
+            to_ntt(&mut ginv_c_ntt, &ginv_c);
+            multiply(&mut sum, &v_folding[further_dims - 1 - cur_dim], &ginv_c_ntt);
+            add_into(&mut sum, &prod);
+            from_ntt(&mut v_cts[i], &sum);
+        }
+    }
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -462,4 +491,72 @@ mod test {
             assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
         }
     }
+
+    #[test]
+    fn fold_ciphertexts_is_correct() {
+        let params = get_params();
+        let mut seeded_rng = get_seeded_rng();
+
+        let dim0 = 1 << params.db_dim_1;
+        let num_per = 1 << params.db_dim_2;
+        let scale_k = params.modulus / params.pt_modulus;
+
+        let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
+        let target_idx_num_per = target_idx % num_per;
+
+        let mut client = Client::init(&params, &mut seeded_rng);
+        _ = client.generate_keys();
+
+        let mut v_reg = Vec::new();
+        for i in 0..num_per {
+            let val = if i == target_idx_num_per { scale_k } else { 0 };
+            let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
+            v_reg.push(client.encrypt_matrix_reg(&sigma));
+        }
+
+        let mut v_reg_raw = Vec::new();
+        for i in 0..num_per {
+            v_reg_raw.push(v_reg[i].raw());
+        }
+
+        let bits_per = get_bits_per(&params, params.t_gsw);
+        let mut v_folding = Vec::new();
+        for i in 0..params.db_dim_2 {
+            let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
+            let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
+
+            for j in 0..params.t_gsw {
+                let value = (1u64 << (bits_per * j)) * bit;
+                let sigma = PolyMatrixRaw::single_value(&params, value);
+                let sigma_ntt = to_ntt_alloc(&sigma);
+                let ct = client.encrypt_matrix_reg(&sigma_ntt);
+                ct_gsw.copy_into(&ct, 0, 2 * j + 1);
+                let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
+                let ct = &client.encrypt_matrix_reg(&prod);
+                ct_gsw.copy_into(&ct, 0, 2 * j);
+            }
+
+            v_folding.push(ct_gsw);
+        }
+
+        let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
+        let mut v_folding_neg = Vec::new();
+        let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
+        for i in 0..params.db_dim_2 {
+            invert(&mut ct_gsw_inv, &v_folding[i].raw());
+            let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
+            add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
+            v_folding_neg.push(ct_gsw_neg);
+        }
+
+        fold_ciphertexts(
+            &params,
+            &mut v_reg_raw,
+            &v_folding,
+            &v_folding_neg
+        );
+        
+        // decrypt
+        assert_eq!(dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
+    }
 }