|
@@ -240,10 +240,79 @@ pub fn multiply_reg_by_database(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+#[cfg(not(target_feature = "avx2"))]
|
|
|
+pub fn multiply_reg_by_database(
|
|
|
+ out: &mut Vec<PolyMatrixNTT>,
|
|
|
+ db: &[u64],
|
|
|
+ v_firstdim: &[u64],
|
|
|
+ params: &Params,
|
|
|
+ dim0: usize,
|
|
|
+ num_per: usize,
|
|
|
+) {
|
|
|
+ let ct_rows = 2;
|
|
|
+ let ct_cols = 1;
|
|
|
+ let pt_rows = 1;
|
|
|
+ let pt_cols = 1;
|
|
|
+
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ let idx_a_base = z * (ct_cols * dim0 * ct_rows);
|
|
|
+ let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
|
|
|
+
|
|
|
+ for i in 0..num_per {
|
|
|
+ for c in 0..pt_cols {
|
|
|
+ let mut sums_out_n0_0 = 0u128;
|
|
|
+ let mut sums_out_n0_1 = 0u128;
|
|
|
+ let mut sums_out_n1_0 = 0u128;
|
|
|
+ let mut sums_out_n1_1 = 0u128;
|
|
|
+
|
|
|
+ for jm in 0..(dim0 * pt_rows) {
|
|
|
+ let b = db[idx_b_base];
|
|
|
+ idx_b_base += 1;
|
|
|
+
|
|
|
+ let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
|
|
|
+ let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
|
|
|
+
|
|
|
+ let b_lo = b as u32;
|
|
|
+ let b_hi = (b >> 32) as u32;
|
|
|
+
|
|
|
+ let v_a0_lo = v_a0 as u32;
|
|
|
+ let v_a0_hi = (v_a0 >> 32) as u32;
|
|
|
+
|
|
|
+ let v_a1_lo = v_a1 as u32;
|
|
|
+ let v_a1_hi = (v_a1 >> 32) as u32;
|
|
|
+
|
|
|
+ // do n0
|
|
|
+ sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
|
|
|
+ sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
|
|
|
+
|
|
|
+ // do n1
|
|
|
+ sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
|
|
|
+ sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
|
|
|
+ }
|
|
|
+
|
|
|
+ // output n0
|
|
|
+ let (crt_count, poly_len) = (params.crt_count, params.poly_len);
|
|
|
+ let mut n = 0;
|
|
|
+ let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
|
|
|
+ out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
|
|
|
+ idx_c += pt_cols * crt_count * poly_len;
|
|
|
+ out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
|
|
|
+
|
|
|
+ // output n1
|
|
|
+ n = 1;
|
|
|
+ idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
|
|
|
+ out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
|
|
|
+ idx_c += pt_cols * crt_count * poly_len;
|
|
|
+ out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
pub fn generate_random_db_and_get_item<'a>(
|
|
|
params: &'a Params,
|
|
|
item_idx: usize,
|
|
|
-) -> (PolyMatrixRaw<'a>, Vec<u64>) {
|
|
|
+) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
|
|
|
let mut rng = get_seeded_rng();
|
|
|
|
|
|
let trials = params.n * params.n;
|
|
@@ -251,7 +320,7 @@ pub fn generate_random_db_and_get_item<'a>(
|
|
|
let num_per = 1 << params.db_dim_2;
|
|
|
let num_items = dim0 * num_per;
|
|
|
let db_size_words = trials * num_items * params.poly_len;
|
|
|
- let mut v = vec![0u64; db_size_words];
|
|
|
+ let mut v = AlignedMemory64::new(db_size_words);
|
|
|
|
|
|
let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
|
|
|
|
|
@@ -341,8 +410,8 @@ pub fn pack<'a>(
|
|
|
for r in 0..params.n {
|
|
|
let w = &v_w[r];
|
|
|
let ct = &v_ct[r * params.n + c];
|
|
|
- ct_1.copy_into(ct, 0, 0);
|
|
|
- ct_2.copy_into(ct, 1, 0);
|
|
|
+ ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
|
|
|
+ ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
|
|
|
to_ntt(&mut ct_2_ntt, &ct_2);
|
|
|
gadget_invert(&mut ginv, &ct_1);
|
|
|
to_ntt(&mut ginv_nttd, &ginv);
|
|
@@ -396,17 +465,33 @@ pub fn encode(
|
|
|
result
|
|
|
}
|
|
|
|
|
|
+pub fn get_v_folding_neg<'a>(
|
|
|
+ params: &'a Params,
|
|
|
+ v_folding: &Vec<PolyMatrixNTT>,
|
|
|
+) -> Vec<PolyMatrixNTT<'a>> {
|
|
|
+ let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
|
|
|
+
|
|
|
+ let mut v_folding_neg = Vec::new();
|
|
|
+ let mut ct_gsw_inv = PolyMatrixRaw::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ for i in 0..params.db_dim_2 {
|
|
|
+ invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
+ let mut ct_gsw_neg = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
+ add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
+ v_folding_neg.push(ct_gsw_neg);
|
|
|
+ }
|
|
|
+ v_folding_neg
|
|
|
+}
|
|
|
+
|
|
|
pub fn expand_query<'a>(
|
|
|
params: &'a Params,
|
|
|
public_params: &PublicParameters<'a>,
|
|
|
query: &Query<'a>,
|
|
|
-) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>, Vec<PolyMatrixNTT<'a>>) {
|
|
|
+) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
|
let further_dims = params.db_dim_2;
|
|
|
|
|
|
let mut v_reg_reoriented;
|
|
|
let mut v_folding;
|
|
|
- let mut v_folding_neg;
|
|
|
|
|
|
let num_bits_to_gen = params.t_gsw * further_dims + dim0;
|
|
|
let g = log2_ceil_usize(num_bits_to_gen);
|
|
@@ -454,21 +539,10 @@ pub fn expand_query<'a>(
|
|
|
}
|
|
|
|
|
|
regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
|
|
|
-
|
|
|
- let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt();
|
|
|
- v_folding_neg = Vec::new();
|
|
|
- let mut ct_gsw_inv = PolyMatrixRaw::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
- for i in 0..params.db_dim_2 {
|
|
|
- invert(&mut ct_gsw_inv, &v_folding[i].raw());
|
|
|
- let mut ct_gsw_neg = PolyMatrixNTT::zero(¶ms, 2, 2 * params.t_gsw);
|
|
|
- add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
|
|
|
- v_folding_neg.push(ct_gsw_neg);
|
|
|
- }
|
|
|
-
|
|
|
- (v_reg_reoriented, v_folding, v_folding_neg)
|
|
|
+
|
|
|
+ (v_reg_reoriented, v_folding)
|
|
|
}
|
|
|
|
|
|
-#[cfg(target_feature = "avx2")]
|
|
|
pub fn process_query(
|
|
|
params: &Params,
|
|
|
public_params: &PublicParameters,
|
|
@@ -477,20 +551,28 @@ pub fn process_query(
|
|
|
) -> Vec<u8> {
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
|
let num_per = 1 << params.db_dim_2;
|
|
|
- let further_dims = params.db_dim_2;
|
|
|
let db_slice_sz = dim0 * num_per * params.poly_len;
|
|
|
|
|
|
-
|
|
|
-
|
|
|
let v_packing = public_params.v_packing.as_ref();
|
|
|
|
|
|
+ let mut v_reg_reoriented;
|
|
|
+ let v_folding;
|
|
|
if params.expand_queries {
|
|
|
-
|
|
|
+ (v_reg_reoriented, v_folding) =
|
|
|
+ expand_query(params, public_params, query);
|
|
|
+ } else {
|
|
|
+ v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
|
|
|
+ v_reg_reoriented.as_mut_slice().copy_from_slice(query.v_buf.as_ref().unwrap());
|
|
|
+
|
|
|
+ v_folding = query.v_ct.as_ref().unwrap().clone().iter()
|
|
|
+ .map(|x| { x.ntt() })
|
|
|
+ .collect();
|
|
|
}
|
|
|
+ let v_folding_neg = get_v_folding_neg(params, &v_folding);
|
|
|
|
|
|
let mut intermediate = Vec::with_capacity(num_per);
|
|
|
let mut intermediate_raw = Vec::with_capacity(num_per);
|
|
|
- for _ in 0..dim0 {
|
|
|
+ for _ in 0..num_per {
|
|
|
intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
|
|
|
intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
|
|
|
}
|
|
@@ -499,7 +581,7 @@ pub fn process_query(
|
|
|
for trial in 0..(params.n * params.n) {
|
|
|
let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
|
|
|
|
|
|
- multiply_reg_by_database(&mut intermediate, db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
+ multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
|
|
|
for i in 0..intermediate.len() {
|
|
|
from_ntt(&mut intermediate_raw[i], &intermediate[i]);
|
|
@@ -512,7 +594,7 @@ pub fn process_query(
|
|
|
&v_folding_neg
|
|
|
);
|
|
|
|
|
|
- v_ct.push(intermediate_raw[0]);
|
|
|
+ v_ct.push(intermediate_raw[0].clone());
|
|
|
}
|
|
|
|
|
|
let packed_ct = pack(
|
|
@@ -780,27 +862,24 @@ mod test {
|
|
|
let params = get_params();
|
|
|
let mut seeded_rng = get_seeded_rng();
|
|
|
|
|
|
- let dim0 = 1 << params.db_dim_1;
|
|
|
- let num_per = 1 << params.db_dim_2;
|
|
|
- let scale_k = params.modulus / params.pt_modulus;
|
|
|
-
|
|
|
- let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
|
|
|
- let target_idx_dim0 = target_idx / num_per;
|
|
|
- let target_idx_num_per = target_idx % num_per;
|
|
|
+ let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
|
|
|
let mut client = Client::init(¶ms, &mut seeded_rng);
|
|
|
- let public_parameters = client.generate_keys();
|
|
|
+
|
|
|
+ let public_params = client.generate_keys();
|
|
|
let query = client.generate_query(target_idx);
|
|
|
|
|
|
let (corr_item, db) = generate_random_db_and_get_item(¶ms, target_idx);
|
|
|
|
|
|
- let mut v_reg = Vec::new();
|
|
|
- for i in 0..dim0 {
|
|
|
- let val = if i == target_idx_dim0 { scale_k } else { 0 };
|
|
|
- let sigma = PolyMatrixRaw::single_value(¶ms, val).ntt();
|
|
|
- v_reg.push(client.encrypt_matrix_reg(&sigma));
|
|
|
- }
|
|
|
+ let response = process_query(¶ms, &public_params, &query, db.as_slice());
|
|
|
|
|
|
-
|
|
|
+ let result = client.decode_response(response.as_slice());
|
|
|
+
|
|
|
+ let p_bits = log2_ceil(params.pt_modulus) as usize;
|
|
|
+ let corr_result = corr_item.to_vec(p_bits, params.poly_len);
|
|
|
+
|
|
|
+ for z in 0..corr_result.len() {
|
|
|
+ assert_eq!(result[z], corr_result[z]);
|
|
|
+ }
|
|
|
}
|
|
|
}
|