|
@@ -30,16 +30,6 @@ pub fn coefficient_expansion(
|
|
|
) {
|
|
|
let poly_len = params.poly_len;
|
|
|
|
|
|
- let mut ct = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
- let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
- let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
|
|
|
- let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
|
|
|
- let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
|
|
|
- let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
|
|
|
- let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
|
|
|
- let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
|
|
|
- let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
|
|
|
-
|
|
|
for r in 0..g {
|
|
|
let num_in = 1 << r;
|
|
|
let num_out = 2 * num_in;
|
|
@@ -48,16 +38,27 @@ pub fn coefficient_expansion(
|
|
|
|
|
|
let neg1 = &v_neg1[r];
|
|
|
|
|
|
- for i in 0..num_out {
|
|
|
+ let action_expand = |(i, v_i): (usize, &mut PolyMatrixNTT)| {
|
|
|
if (stop_round > 0 && r > stop_round && (i % 2) == 1)
|
|
|
|| (stop_round > 0
|
|
|
&& r == stop_round
|
|
|
&& (i % 2) == 1
|
|
|
&& (i / 2) >= max_bits_to_gen_right)
|
|
|
{
|
|
|
- continue;
|
|
|
+ return;
|
|
|
}
|
|
|
|
|
|
+ let mut ct = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
+ let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
+ let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
|
|
|
+ let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
|
|
|
+ let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
|
|
|
+
|
|
|
+ let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
|
|
|
+ let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
|
|
|
+ let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
|
|
|
+ let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
|
|
|
+
|
|
|
let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
|
|
|
0 => (
|
|
|
&v_w_left[r],
|
|
@@ -73,12 +74,12 @@ pub fn coefficient_expansion(
|
|
|
),
|
|
|
};
|
|
|
|
|
|
- if i < num_in {
|
|
|
- let (src, dest) = v.split_at_mut(num_in);
|
|
|
- scalar_multiply(&mut dest[i], neg1, &src[i]);
|
|
|
- }
|
|
|
+ // if i < num_in {
|
|
|
+ // let (src, dest) = v.split_at_mut(num_in);
|
|
|
+ // scalar_multiply(&mut dest[i], neg1, &src[i]);
|
|
|
+ // }
|
|
|
|
|
|
- from_ntt(&mut ct, &v[i]);
|
|
|
+ from_ntt(&mut ct, &v_i);
|
|
|
automorph(&mut ct_auto, &ct, t);
|
|
|
gadget_invert_rdim(gi_ct, &ct_auto, 1);
|
|
|
to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
|
|
@@ -93,15 +94,31 @@ pub fn coefficient_expansion(
|
|
|
for j in 0..2 {
|
|
|
for n in 0..params.crt_count {
|
|
|
for z in 0..poly_len {
|
|
|
- let sum = v[i].data[idx]
|
|
|
+ let sum = (*v_i).data[idx]
|
|
|
+ w_times_ginv_ct.data[idx]
|
|
|
+ j * ct_auto_1_ntt.data[n * poly_len + z];
|
|
|
- v[i].data[idx] = barrett_coeff_u64(params, sum, n);
|
|
|
+ (*v_i).data[idx] = barrett_coeff_u64(params, sum, n);
|
|
|
idx += 1;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
+ };
|
|
|
+
|
|
|
+ let (src, dest) = v.split_at_mut(num_in);
|
|
|
+ src.par_iter_mut()
|
|
|
+ .zip(dest.par_iter_mut())
|
|
|
+ .for_each(|(s, d)| {
|
|
|
+ scalar_multiply(d, neg1, s);
|
|
|
+ });
|
|
|
+
|
|
|
+ v[0..num_in]
|
|
|
+ .par_iter_mut()
|
|
|
+ .enumerate()
|
|
|
+ .for_each(action_expand);
|
|
|
+ v[num_in..num_out]
|
|
|
+ .par_iter_mut()
|
|
|
+ .enumerate()
|
|
|
+ .for_each(action_expand);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -116,13 +133,12 @@ pub fn regev_to_gsw<'a>(
|
|
|
assert!(v.rows == 2);
|
|
|
assert!(v.cols == 2 * params.t_conv);
|
|
|
|
|
|
- let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
|
|
|
- let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
|
|
|
- let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
- let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
|
|
|
+ v_gsw.par_iter_mut().enumerate().for_each(|(i, ct)| {
|
|
|
+ let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
|
|
|
+ let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
|
|
|
+ let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
|
|
|
+ let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
|
|
|
|
|
|
- for i in 0..params.db_dim_2 {
|
|
|
- let ct = &mut v_gsw[i];
|
|
|
for j in 0..params.t_gsw {
|
|
|
let idx_ct = i * params.t_gsw + j;
|
|
|
let idx_inp = idx_factor * (idx_ct) + idx_offset;
|
|
@@ -133,7 +149,7 @@ pub fn regev_to_gsw<'a>(
|
|
|
multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
|
|
|
ct.copy_into(&tmp_ct, 0, 2 * j);
|
|
|
}
|
|
|
- }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
pub const MAX_SUMMED: usize = 1 << 6;
|
|
@@ -606,18 +622,21 @@ pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
|
|
|
|
|
|
pub fn get_v_folding_neg<'a>(
|
|
|
params: &'a Params,
|
|
|
- v_folding: &Vec<PolyMatrixNTT>,
|
|
|
+ v_folding: &Vec<PolyMatrixNTT<'a>>,
|
|
|
) -> Vec<PolyMatrixNTT<'a>> {
|
|
|
- let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
|
|
|
-
|
|
|
- let mut v_folding_neg = Vec::new();
|
|
|
- let mut ct_gsw_inv = PolyMatrixRaw::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
- for i in 0..params.db_dim_2 {
|
|
|
- invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
- let mut ct_gsw_neg = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
- add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
- v_folding_neg.push(ct_gsw_neg);
|
|
|
- }
|
|
|
+ let gadget_ntt = build_gadget(params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
|
|
|
+
|
|
|
+ let v_folding_neg = (0..params.db_dim_2)
|
|
|
+ .into_par_iter()
|
|
|
+ .map(|i| {
|
|
|
+ let mut ct_gsw_inv = PolyMatrixRaw::zero(params, 2, 2 * params.t_gsw);
|
|
|
+ invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
+ let mut ct_gsw_neg = PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw);
|
|
|
+ add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
+ ct_gsw_neg
|
|
|
+ })
|
|
|
+ .collect();
|
|
|
+
|
|
|
v_folding_neg
|
|
|
}
|
|
|
|
|
@@ -1065,6 +1084,7 @@ mod test {
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
+ #[ignore]
|
|
|
fn larger_full_protocol_is_correct() {
|
|
|
let cfg_expand = r#"
|
|
|
{
|