|
@@ -5,6 +5,9 @@ use std::arch::x86_64::*;
|
|
|
use crate::aligned_memory::*;
|
|
|
|
|
|
use crate::arith::*;
|
|
|
+use crate::aligned_memory::*;
|
|
|
+use crate::client::PublicParameters;
|
|
|
+use crate::client::Query;
|
|
|
use crate::gadget::*;
|
|
|
use crate::params::*;
|
|
|
use crate::poly::*;
|
|
@@ -13,7 +16,7 @@ use crate::util::*;
|
|
|
pub fn coefficient_expansion(
|
|
|
v: &mut Vec<PolyMatrixNTT>,
|
|
|
g: usize,
|
|
|
- stopround: usize,
|
|
|
+ stop_round: usize,
|
|
|
params: &Params,
|
|
|
v_w_left: &Vec<PolyMatrixNTT>,
|
|
|
v_w_right: &Vec<PolyMatrixNTT>,
|
|
@@ -41,8 +44,8 @@ pub fn coefficient_expansion(
|
|
|
let neg1 = &v_neg1[r];
|
|
|
|
|
|
for i in 0..num_out {
|
|
|
- if stopround > 0 && i % 2 == 1 && r > stopround
|
|
|
- || (r == stopround && i / 2 >= max_bits_to_gen_right)
|
|
|
+ if stop_round > 0 && i % 2 == 1 && r > stop_round
|
|
|
+ || (r == stop_round && i / 2 >= max_bits_to_gen_right)
|
|
|
{
|
|
|
continue;
|
|
|
}
|
|
@@ -312,6 +315,218 @@ pub fn fold_ciphertexts(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+pub fn pack<'a>(
|
|
|
+ params: &'a Params,
|
|
|
+ v_ct: &Vec<PolyMatrixRaw>,
|
|
|
+ v_w: &Vec<PolyMatrixNTT>
|
|
|
+) -> PolyMatrixNTT<'a> {
|
|
|
+ assert!(v_ct.len() >= params.n * params.n);
|
|
|
+ assert!(v_w.len() == params.n);
|
|
|
+ assert!(v_ct[0].rows == 2);
|
|
|
+ assert!(v_ct[0].cols == 1);
|
|
|
+ assert!(v_w[0].rows == (params.n + 1));
|
|
|
+ assert!(v_w[0].cols == params.t_conv);
|
|
|
+
|
|
|
+ let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
|
|
|
+
|
|
|
+ let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
|
|
|
+ let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
|
|
|
+ let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
|
|
|
+ let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
|
|
|
+ let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
|
|
|
+ let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
|
|
|
+
|
|
|
+ for c in 0..params.n {
|
|
|
+ let mut v_int = PolyMatrixNTT::zero(¶ms, params.n + 1, 1);
|
|
|
+ for r in 0..params.n {
|
|
|
+ let w = &v_w[r];
|
|
|
+ let ct = &v_ct[r * params.n + c];
|
|
|
+ ct_1.copy_into(ct, 0, 0);
|
|
|
+ ct_2.copy_into(ct, 1, 0);
|
|
|
+ to_ntt(&mut ct_2_ntt, &ct_2);
|
|
|
+ gadget_invert(&mut ginv, &ct_1);
|
|
|
+ to_ntt(&mut ginv_nttd, &ginv);
|
|
|
+ multiply(&mut prod, &w, &ginv_nttd);
|
|
|
+ add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
|
|
|
+ add_into(&mut v_int, &prod);
|
|
|
+ }
|
|
|
+ result.copy_into(&v_int, 0, c);
|
|
|
+ }
|
|
|
+
|
|
|
+ result
|
|
|
+}
|
|
|
+
|
|
|
+pub fn encode(
|
|
|
+ params: &Params,
|
|
|
+ v_packed_ct: &Vec<PolyMatrixRaw>
|
|
|
+) -> Vec<u8> {
|
|
|
+ let q1 = 4 * params.pt_modulus;
|
|
|
+ let q1_bits = log2_ceil(q1) as usize;
|
|
|
+ let q2 = Q2_VALUES[params.q2_bits as usize];
|
|
|
+ let q2_bits = params.q2_bits as usize;
|
|
|
+
|
|
|
+ let num_bits = params.instances *
|
|
|
+ (
|
|
|
+ (q2_bits * params.n * params.poly_len) +
|
|
|
+ (q1_bits * params.n * params.n * params.poly_len)
|
|
|
+ );
|
|
|
+ let round_to = 64;
|
|
|
+ let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
|
|
|
+
|
|
|
+ let mut result = vec![0u8; num_bytes_rounded_up];
|
|
|
+ let mut bit_offs = 0;
|
|
|
+ for instance in 0..params.instances {
|
|
|
+ let packed_ct = &v_packed_ct[instance];
|
|
|
+
|
|
|
+ let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
|
|
|
+ let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
|
|
|
+ first_row.apply_func(|x| { rescale(x, params.modulus, q2) });
|
|
|
+ rest_rows.apply_func(|x| { rescale(x, params.modulus, q1) });
|
|
|
+
|
|
|
+ let data = result.as_mut_slice();
|
|
|
+ for i in 0..params.n * params.poly_len {
|
|
|
+ write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
|
|
|
+ bit_offs += q2_bits;
|
|
|
+ }
|
|
|
+ for i in 0..params.n * params.n * params.poly_len {
|
|
|
+ write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
|
|
|
+ bit_offs += q1_bits;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ result
|
|
|
+}
|
|
|
+
|
|
|
+pub fn expand_query<'a>(
|
|
|
+ params: &'a Params,
|
|
|
+ public_params: &PublicParameters<'a>,
|
|
|
+ query: &Query<'a>,
|
|
|
+) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>, Vec<PolyMatrixNTT<'a>>) {
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let further_dims = params.db_dim_2;
|
|
|
+
|
|
|
+ let mut v_reg_reoriented;
|
|
|
+ let mut v_folding;
|
|
|
+ let mut v_folding_neg;
|
|
|
+
|
|
|
+ let num_bits_to_gen = params.t_gsw * further_dims + dim0;
|
|
|
+ let g = log2_ceil_usize(num_bits_to_gen);
|
|
|
+ let right_expanded = params.t_gsw * further_dims;
|
|
|
+ let stop_round = log2_ceil_usize(right_expanded);
|
|
|
+
|
|
|
+ let mut v = Vec::new();
|
|
|
+ for _ in 0..(1 << g) {
|
|
|
+ v.push(PolyMatrixNTT::zero(params, 2, 1));
|
|
|
+ }
|
|
|
+ v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
|
|
|
+
|
|
|
+ let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
|
|
|
+ let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
|
|
|
+ let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
|
|
|
+ let v_neg1 = params.get_v_neg1();
|
|
|
+
|
|
|
+ coefficient_expansion(
|
|
|
+ &mut v,
|
|
|
+ g,
|
|
|
+ stop_round,
|
|
|
+ params,
|
|
|
+ &v_w_left,
|
|
|
+ &v_w_right,
|
|
|
+ &v_neg1,
|
|
|
+ params.t_gsw * params.db_dim_2,
|
|
|
+ );
|
|
|
+
|
|
|
+ let mut v_reg_inp = Vec::with_capacity(dim0);
|
|
|
+ for i in 0..dim0 {
|
|
|
+ v_reg_inp.push(v[2 * i].clone());
|
|
|
+ }
|
|
|
+ let mut v_gsw_inp = Vec::with_capacity(right_expanded);
|
|
|
+ for i in 0..right_expanded {
|
|
|
+ v_gsw_inp.push(v[2 * i + 1].clone());
|
|
|
+ }
|
|
|
+
|
|
|
+ let v_reg_sz = dim0 * 2 * params.poly_len;
|
|
|
+ v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
|
|
|
+ reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
|
|
|
+
|
|
|
+ v_folding = Vec::new();
|
|
|
+ for _ in 0..params.db_dim_2 {
|
|
|
+ v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
|
|
|
+ }
|
|
|
+
|
|
|
+ regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
|
|
|
+
|
|
|
+ let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt();
|
|
|
+ v_folding_neg = Vec::new();
|
|
|
+ let mut ct_gsw_inv = PolyMatrixRaw::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ for i in 0..params.db_dim_2 {
|
|
|
+ invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
+ let mut ct_gsw_neg = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
+ v_folding_neg.push(ct_gsw_neg);
|
|
|
+ }
|
|
|
+
|
|
|
+ (v_reg_reoriented, v_folding, v_folding_neg)
|
|
|
+}
|
|
|
+
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
+pub fn process_query(
|
|
|
+ params: &Params,
|
|
|
+ public_params: &PublicParameters,
|
|
|
+ query: &Query,
|
|
|
+ db: &[u64],
|
|
|
+) -> Vec<u8> {
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let num_per = 1 << params.db_dim_2;
|
|
|
+ let further_dims = params.db_dim_2;
|
|
|
+ let db_slice_sz = dim0 * num_per * params.poly_len;
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ let v_packing = public_params.v_packing.as_ref();
|
|
|
+
|
|
|
+ if params.expand_queries {
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut intermediate = Vec::with_capacity(num_per);
|
|
|
+ let mut intermediate_raw = Vec::with_capacity(num_per);
|
|
|
+ for _ in 0..dim0 {
|
|
|
+ intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
|
|
|
+ intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut v_ct = Vec::new();
|
|
|
+ for trial in 0..(params.n * params.n) {
|
|
|
+ let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
|
|
|
+
|
|
|
+ multiply_reg_by_database(&mut intermediate, db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
+
|
|
|
+ for i in 0..intermediate.len() {
|
|
|
+ from_ntt(&mut intermediate_raw[i], &intermediate[i]);
|
|
|
+ }
|
|
|
+
|
|
|
+ fold_ciphertexts(
|
|
|
+ params,
|
|
|
+ &mut intermediate_raw,
|
|
|
+ &v_folding,
|
|
|
+ &v_folding_neg
|
|
|
+ );
|
|
|
+
|
|
|
+ v_ct.push(intermediate_raw[0]);
|
|
|
+ }
|
|
|
+
|
|
|
+ let packed_ct = pack(
|
|
|
+ params,
|
|
|
+ &v_ct,
|
|
|
+ &v_packing,
|
|
|
+ );
|
|
|
+
|
|
|
+ let mut v_packed_ct = Vec::new();
|
|
|
+ v_packed_ct.push(packed_ct.raw());
|
|
|
+
|
|
|
+ encode(params, &v_packed_ct)
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
use super::*;
|
|
@@ -351,12 +566,12 @@ mod test {
|
|
|
client: &mut Client<'a, StdRng>,
|
|
|
) -> u64 {
|
|
|
let dec = client.decrypt_matrix_reg(ct).raw();
|
|
|
- let idx = (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
|
|
|
+ let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
|
|
|
let mut val = dec.data[idx] as i64;
|
|
|
if val >= (params.modulus / 2) as i64 {
|
|
|
val -= params.modulus as i64;
|
|
|
}
|
|
|
- if val < 100 {
|
|
|
+ if i64::abs(val) < (1i64 << 10) {
|
|
|
0
|
|
|
} else {
|
|
|
1
|
|
@@ -559,4 +774,33 @@ mod test {
|
|
|
// decrypt
|
|
|
assert_eq!(dec_reg(¶ms, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn full_protocol_is_correct() {
|
|
|
+ let params = get_params();
|
|
|
+ let mut seeded_rng = get_seeded_rng();
|
|
|
+
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let num_per = 1 << params.db_dim_2;
|
|
|
+ let scale_k = params.modulus / params.pt_modulus;
|
|
|
+
|
|
|
+ let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
|
|
|
+ let target_idx_dim0 = target_idx / num_per;
|
|
|
+ let target_idx_num_per = target_idx % num_per;
|
|
|
+
|
|
|
+ let mut client = Client::init(¶ms, &mut seeded_rng);
|
|
|
+ let public_parameters = client.generate_keys();
|
|
|
+ let query = client.generate_query(target_idx);
|
|
|
+
|
|
|
+ let (corr_item, db) = generate_random_db_and_get_item(¶ms, target_idx);
|
|
|
+
|
|
|
+ let mut v_reg = Vec::new();
|
|
|
+ for i in 0..dim0 {
|
|
|
+ let val = if i == target_idx_dim0 { scale_k } else { 0 };
|
|
|
+ let sigma = PolyMatrixRaw::single_value(¶ms, val).ntt();
|
|
|
+ v_reg.push(client.encrypt_matrix_reg(&sigma));
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ }
|
|
|
}
|