|
@@ -283,6 +283,35 @@ pub fn generate_random_db_and_get_item<'a>(
|
|
|
(item, v)
|
|
|
}
|
|
|
|
|
|
+pub fn fold_ciphertexts(
|
|
|
+ params: &Params,
|
|
|
+ v_cts: &mut Vec<PolyMatrixRaw>,
|
|
|
+ v_folding: &Vec<PolyMatrixNTT>,
|
|
|
+ v_folding_neg: &Vec<PolyMatrixNTT>
|
|
|
+) {
|
|
|
+ let further_dims = log2(v_cts.len() as u64) as usize;
|
|
|
+ let ell = v_folding[0].cols / 2;
|
|
|
+ let mut ginv_c = PolyMatrixRaw::zero(¶ms, 2 * ell, 1);
|
|
|
+ let mut ginv_c_ntt = PolyMatrixNTT::zero(¶ms, 2 * ell, 1);
|
|
|
+ let mut prod = PolyMatrixNTT::zero(¶ms, 2, 1);
|
|
|
+ let mut sum = PolyMatrixNTT::zero(¶ms, 2, 1);
|
|
|
+
|
|
|
+ let mut num_per = v_cts.len();
|
|
|
+ for cur_dim in 0..further_dims {
|
|
|
+ num_per = num_per / 2;
|
|
|
+ for i in 0..num_per {
|
|
|
+ gadget_invert(&mut ginv_c, &v_cts[i]);
|
|
|
+ to_ntt(&mut ginv_c_ntt, &ginv_c);
|
|
|
+ multiply(&mut prod, &v_folding_neg[further_dims - 1 - cur_dim], &ginv_c_ntt);
|
|
|
+ gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
|
|
|
+ to_ntt(&mut ginv_c_ntt, &ginv_c);
|
|
|
+ multiply(&mut sum, &v_folding[further_dims - 1 - cur_dim], &ginv_c_ntt);
|
|
|
+ add_into(&mut sum, &prod);
|
|
|
+ from_ntt(&mut v_cts[i], &sum);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
use super::*;
|
|
@@ -462,4 +491,72 @@ mod test {
|
|
|
assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn fold_ciphertexts_is_correct() {
|
|
|
+ let params = get_params();
|
|
|
+ let mut seeded_rng = get_seeded_rng();
|
|
|
+
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let num_per = 1 << params.db_dim_2;
|
|
|
+ let scale_k = params.modulus / params.pt_modulus;
|
|
|
+
|
|
|
+ let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
|
|
|
+ let target_idx_num_per = target_idx % num_per;
|
|
|
+
|
|
|
+ let mut client = Client::init(¶ms, &mut seeded_rng);
|
|
|
+ _ = client.generate_keys();
|
|
|
+
|
|
|
+ let mut v_reg = Vec::new();
|
|
|
+ for i in 0..num_per {
|
|
|
+ let val = if i == target_idx_num_per { scale_k } else { 0 };
|
|
|
+ let sigma = PolyMatrixRaw::single_value(¶ms, val).ntt();
|
|
|
+ v_reg.push(client.encrypt_matrix_reg(&sigma));
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut v_reg_raw = Vec::new();
|
|
|
+ for i in 0..num_per {
|
|
|
+ v_reg_raw.push(v_reg[i].raw());
|
|
|
+ }
|
|
|
+
|
|
|
+ let bits_per = get_bits_per(¶ms, params.t_gsw);
|
|
|
+ let mut v_folding = Vec::new();
|
|
|
+ for i in 0..params.db_dim_2 {
|
|
|
+ let bit = ((target_idx_num_per as u64) & (1 << (i as u64))) >> (i as u64);
|
|
|
+ let mut ct_gsw = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+
|
|
|
+ for j in 0..params.t_gsw {
|
|
|
+ let value = (1u64 << (bits_per * j)) * bit;
|
|
|
+ let sigma = PolyMatrixRaw::single_value(¶ms, value);
|
|
|
+ let sigma_ntt = to_ntt_alloc(&sigma);
|
|
|
+ let ct = client.encrypt_matrix_reg(&sigma_ntt);
|
|
|
+ ct_gsw.copy_into(&ct, 0, 2 * j + 1);
|
|
|
+ let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
|
|
|
+ let ct = &client.encrypt_matrix_reg(&prod);
|
|
|
+ ct_gsw.copy_into(&ct, 0, 2 * j);
|
|
|
+ }
|
|
|
+
|
|
|
+ v_folding.push(ct_gsw);
|
|
|
+ }
|
|
|
+
|
|
|
+ let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt();
|
|
|
+ let mut v_folding_neg = Vec::new();
|
|
|
+ let mut ct_gsw_inv = PolyMatrixRaw::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ for i in 0..params.db_dim_2 {
|
|
|
+ invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
+ let mut ct_gsw_neg = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
+ v_folding_neg.push(ct_gsw_neg);
|
|
|
+ }
|
|
|
+
|
|
|
+ fold_ciphertexts(
|
|
|
+ ¶ms,
|
|
|
+ &mut v_reg_raw,
|
|
|
+ &v_folding,
|
|
|
+ &v_folding_neg
|
|
|
+ );
|
|
|
+
|
|
|
+ // decrypt
|
|
|
+ assert_eq!(dec_reg(¶ms, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
|
|
|
+ }
|
|
|
}
|