|
@@ -1,9 +1,6 @@
|
|
|
#[cfg(target_feature = "avx2")]
|
|
|
use std::arch::x86_64::*;
|
|
|
|
|
|
-#[cfg(target_feature = "avx2")]
|
|
|
-use crate::aligned_memory::*;
|
|
|
-
|
|
|
use crate::arith::*;
|
|
|
use crate::aligned_memory::*;
|
|
|
use crate::client::PublicParameters;
|
|
@@ -315,40 +312,46 @@ pub fn generate_random_db_and_get_item<'a>(
|
|
|
) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
|
|
|
let mut rng = get_seeded_rng();
|
|
|
|
|
|
+ let instances = params.instances;
|
|
|
let trials = params.n * params.n;
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
|
let num_per = 1 << params.db_dim_2;
|
|
|
let num_items = dim0 * num_per;
|
|
|
- let db_size_words = trials * num_items * params.poly_len;
|
|
|
+ let db_size_words = instances * trials * num_items * params.poly_len;
|
|
|
let mut v = AlignedMemory64::new(db_size_words);
|
|
|
|
|
|
+ let mut tmp_item_ntt = PolyMatrixNTT::zero(params, 1, 1);
|
|
|
let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
|
|
|
|
|
|
- for trial in 0..trials {
|
|
|
- for i in 0..num_items {
|
|
|
- let ii = i % num_per;
|
|
|
- let j = i / num_per;
|
|
|
-
|
|
|
- let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
|
|
|
- db_item.reduce_mod(params.pt_modulus);
|
|
|
-
|
|
|
- if i == item_idx {
|
|
|
- item.copy_into(&db_item, trial / params.n, trial % params.n);
|
|
|
- }
|
|
|
+ for instance in 0..instances {
|
|
|
+ println!("Instance {:?}", instance);
|
|
|
+ for trial in 0..trials {
|
|
|
+ println!("Trial {:?}", trial);
|
|
|
+ for i in 0..num_items {
|
|
|
+ let ii = i % num_per;
|
|
|
+ let j = i / num_per;
|
|
|
+
|
|
|
+ let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
|
|
|
+ db_item.reduce_mod(params.pt_modulus);
|
|
|
+
|
|
|
+ if i == item_idx && instance == 0 {
|
|
|
+ item.copy_into(&db_item, trial / params.n, trial % params.n);
|
|
|
+ }
|
|
|
|
|
|
- for z in 0..params.poly_len {
|
|
|
- db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
- }
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
+ }
|
|
|
|
|
|
- let db_item_ntt = db_item.ntt();
|
|
|
- for z in 0..params.poly_len {
|
|
|
- let idx_dst = calc_index(
|
|
|
- &[trial, z, ii, j],
|
|
|
- &[trials, params.poly_len, num_per, dim0],
|
|
|
- );
|
|
|
+ let db_item_ntt = db_item.ntt();
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ let idx_dst = calc_index(
|
|
|
+ &[instance, trial, z, ii, j],
|
|
|
+ &[instances, trials, params.poly_len, num_per, dim0],
|
|
|
+ );
|
|
|
|
|
|
- v[idx_dst] = db_item_ntt.data[z]
|
|
|
- | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
|
|
|
+ v[idx_dst] = db_item_ntt.data[z]
|
|
|
+ | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -577,35 +580,40 @@ pub fn process_query(
|
|
|
intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
|
|
|
}
|
|
|
|
|
|
- let mut v_ct = Vec::new();
|
|
|
- for trial in 0..(params.n * params.n) {
|
|
|
- let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
|
|
|
+ let mut v_packed_ct = Vec::new();
|
|
|
+
|
|
|
+ for instance in 0..params.instances {
|
|
|
+ let mut v_ct = Vec::new();
|
|
|
|
|
|
- multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
+ for trial in 0..(params.n * params.n) {
|
|
|
+ let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
|
|
|
+ let cur_db = &db[idx..(idx + db_slice_sz)];
|
|
|
|
|
|
- for i in 0..intermediate.len() {
|
|
|
- from_ntt(&mut intermediate_raw[i], &intermediate[i]);
|
|
|
+ multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
+
|
|
|
+ for i in 0..intermediate.len() {
|
|
|
+ from_ntt(&mut intermediate_raw[i], &intermediate[i]);
|
|
|
+ }
|
|
|
+
|
|
|
+ fold_ciphertexts(
|
|
|
+ params,
|
|
|
+ &mut intermediate_raw,
|
|
|
+ &v_folding,
|
|
|
+ &v_folding_neg
|
|
|
+ );
|
|
|
+
|
|
|
+ v_ct.push(intermediate_raw[0].clone());
|
|
|
}
|
|
|
|
|
|
- fold_ciphertexts(
|
|
|
+ let packed_ct = pack(
|
|
|
params,
|
|
|
- &mut intermediate_raw,
|
|
|
- &v_folding,
|
|
|
- &v_folding_neg
|
|
|
+ &v_ct,
|
|
|
+ &v_packing,
|
|
|
);
|
|
|
|
|
|
- v_ct.push(intermediate_raw[0].clone());
|
|
|
+ v_packed_ct.push(packed_ct.raw());
|
|
|
}
|
|
|
|
|
|
- let packed_ct = pack(
|
|
|
- params,
|
|
|
- &v_ct,
|
|
|
- &v_packing,
|
|
|
- );
|
|
|
-
|
|
|
- let mut v_packed_ct = Vec::new();
|
|
|
- v_packed_ct.push(packed_ct.raw());
|
|
|
-
|
|
|
encode(params, &v_packed_ct)
|
|
|
}
|
|
|
|
|
@@ -613,7 +621,7 @@ pub fn process_query(
|
|
|
mod test {
|
|
|
use super::*;
|
|
|
use crate::{client::*};
|
|
|
- use rand::{prelude::StdRng, Rng};
|
|
|
+ use rand::{prelude::SmallRng, Rng};
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
let mut params = get_expansion_testing_params();
|
|
@@ -626,7 +634,7 @@ mod test {
|
|
|
fn dec_reg<'a>(
|
|
|
params: &'a Params,
|
|
|
ct: &PolyMatrixNTT<'a>,
|
|
|
- client: &mut Client<'a, StdRng>,
|
|
|
+ client: &mut Client<'a, SmallRng>,
|
|
|
scale_k: u64,
|
|
|
) -> u64 {
|
|
|
let dec = client.decrypt_matrix_reg(ct).raw();
|
|
@@ -645,7 +653,7 @@ mod test {
|
|
|
fn dec_gsw<'a>(
|
|
|
params: &'a Params,
|
|
|
ct: &PolyMatrixNTT<'a>,
|
|
|
- client: &mut Client<'a, StdRng>,
|
|
|
+ client: &mut Client<'a, SmallRng>,
|
|
|
) -> u64 {
|
|
|
let dec = client.decrypt_matrix_reg(ct).raw();
|
|
|
let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
|
|
@@ -857,21 +865,19 @@ mod test {
|
|
|
assert_eq!(dec_reg(¶ms, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
|
|
|
}
|
|
|
|
|
|
- #[test]
|
|
|
- fn full_protocol_is_correct() {
|
|
|
- let params = get_params();
|
|
|
+ fn full_protocol_is_correct_for_params(params: &Params) {
|
|
|
let mut seeded_rng = get_seeded_rng();
|
|
|
|
|
|
let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
|
|
|
- let mut client = Client::init(¶ms, &mut seeded_rng);
|
|
|
+ let mut client = Client::init(params, &mut seeded_rng);
|
|
|
|
|
|
let public_params = client.generate_keys();
|
|
|
let query = client.generate_query(target_idx);
|
|
|
|
|
|
- let (corr_item, db) = generate_random_db_and_get_item(¶ms, target_idx);
|
|
|
+ let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
|
|
|
|
|
|
- let response = process_query(¶ms, &public_params, &query, db.as_slice());
|
|
|
+ let response = process_query(params, &public_params, &query, db.as_slice());
|
|
|
|
|
|
let result = client.decode_response(response.as_slice());
|
|
|
|
|
@@ -882,4 +888,19 @@ mod test {
|
|
|
assert_eq!(result[z], corr_result[z]);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn full_protocol_is_correct() {
|
|
|
+ full_protocol_is_correct_for_params(&get_params());
|
|
|
+ }
|
|
|
+
|
|
|
+ // #[test]
|
|
|
+ // fn full_protocol_is_correct_20_256() {
|
|
|
+ // full_protocol_is_correct_for_params(¶ms_from_json(&CFG_20_256.replace("'", "\"")));
|
|
|
+ // }
|
|
|
+
|
|
|
+ // #[test]
|
|
|
+ // fn full_protocol_is_correct_16_100000() {
|
|
|
+ // full_protocol_is_correct_for_params(¶ms_from_json(&CFG_16_100000.replace("'", "\"")));
|
|
|
+ // }
|
|
|
}
|