|
@@ -5,6 +5,7 @@ use std::io::BufReader;
|
|
|
use std::io::Read;
|
|
|
use std::io::Seek;
|
|
|
use std::io::SeekFrom;
|
|
|
+use std::time::Instant;
|
|
|
|
|
|
use crate::aligned_memory::*;
|
|
|
use crate::arith::*;
|
|
@@ -15,6 +16,8 @@ use crate::params::*;
|
|
|
use crate::poly::*;
|
|
|
use crate::util::*;
|
|
|
|
|
|
+use rayon::prelude::*;
|
|
|
+
|
|
|
pub fn coefficient_expansion(
|
|
|
v: &mut Vec<PolyMatrixNTT>,
|
|
|
g: usize,
|
|
@@ -46,8 +49,8 @@ pub fn coefficient_expansion(
|
|
|
let neg1 = &v_neg1[r];
|
|
|
|
|
|
for i in 0..num_out {
|
|
|
- if stop_round > 0 && i % 2 == 1 && r > stop_round
|
|
|
- || (r == stop_round && i / 2 >= max_bits_to_gen_right)
|
|
|
+ if (stop_round > 0 && r > stop_round && (i % 2) == 1)
|
|
|
+ || (stop_round > 0 && r == stop_round && (i % 2) == 1 && (i / 2) >= max_bits_to_gen_right)
|
|
|
{
|
|
|
continue;
|
|
|
}
|
|
@@ -454,6 +457,22 @@ pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
|
|
|
v
|
|
|
}
|
|
|
|
|
|
+pub fn load_file_unsafe(data: &mut[u64], file: &mut File) {
|
|
|
+ let data_as_u8_mut = unsafe {
|
|
|
+ data.align_to_mut::<u8>().1
|
|
|
+ };
|
|
|
+ file.read_exact(data_as_u8_mut).unwrap();
|
|
|
+}
|
|
|
+
|
|
|
+pub fn load_file(data: &mut[u64], file: &mut File) {
|
|
|
+ let mut reader = BufReader::with_capacity(1 << 24, file);
|
|
|
+ let mut buf = [0u8; 8];
|
|
|
+ for i in 0..data.len() {
|
|
|
+ reader.read(&mut buf).unwrap();
|
|
|
+ data[i] = u64::from_ne_bytes(buf);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
|
|
|
let instances = params.instances;
|
|
|
let trials = params.n * params.n;
|
|
@@ -463,16 +482,10 @@ pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> Align
|
|
|
let db_size_words = instances * trials * num_items * params.poly_len;
|
|
|
let mut v = AlignedMemory64::new(db_size_words);
|
|
|
let v_mut_slice = v.as_mut_slice();
|
|
|
-
|
|
|
- let mut reader = BufReader::with_capacity(1 << 18, file);
|
|
|
- let mut buf = [0u8; 8];
|
|
|
- for i in 0..db_size_words {
|
|
|
- if i % 1000000000 == 0 {
|
|
|
- println!("{} GB loaded", i / 1000000000);
|
|
|
- }
|
|
|
- reader.read(&mut buf).unwrap();
|
|
|
- v_mut_slice[i] = u64::from_ne_bytes(buf);
|
|
|
- }
|
|
|
+
|
|
|
+ let now = Instant::now();
|
|
|
+ load_file(v_mut_slice, file);
|
|
|
+ println!("Done loading ({} ms).", now.elapsed().as_millis());
|
|
|
|
|
|
v
|
|
|
}
|
|
@@ -694,23 +707,20 @@ pub fn process_query(
|
|
|
.v_ct
|
|
|
.as_ref()
|
|
|
.unwrap()
|
|
|
- .clone()
|
|
|
.iter()
|
|
|
.map(|x| x.ntt())
|
|
|
.collect();
|
|
|
}
|
|
|
let v_folding_neg = get_v_folding_neg(params, &v_folding);
|
|
|
|
|
|
- let mut intermediate = Vec::with_capacity(num_per);
|
|
|
- let mut intermediate_raw = Vec::with_capacity(num_per);
|
|
|
- for _ in 0..num_per {
|
|
|
- intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
|
|
|
- intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
|
|
|
- }
|
|
|
-
|
|
|
- let mut v_packed_ct = Vec::new();
|
|
|
-
|
|
|
- for instance in 0..params.instances {
|
|
|
+ let v_packed_ct = (0..params.instances).into_par_iter().map(|instance| {
|
|
|
+ let mut intermediate = Vec::with_capacity(num_per);
|
|
|
+ let mut intermediate_raw = Vec::with_capacity(num_per);
|
|
|
+ for _ in 0..num_per {
|
|
|
+ intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
|
|
|
+ intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
|
|
|
+ }
|
|
|
+
|
|
|
let mut v_ct = Vec::new();
|
|
|
|
|
|
for trial in 0..(params.n * params.n) {
|
|
@@ -737,8 +747,8 @@ pub fn process_query(
|
|
|
|
|
|
let packed_ct = pack(params, &v_ct, &v_packing);
|
|
|
|
|
|
- v_packed_ct.push(packed_ct.raw());
|
|
|
- }
|
|
|
+ packed_ct.raw()
|
|
|
+ }).collect();
|
|
|
|
|
|
encode(params, &v_packed_ct)
|
|
|
}
|
|
@@ -997,7 +1007,7 @@ mod test {
|
|
|
fn full_protocol_is_correct_for_params(params: &Params) {
|
|
|
let mut seeded_rng = get_seeded_rng();
|
|
|
|
|
|
- let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
+ let target_idx = 22456;//22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
|
|
|
let mut client = Client::init(params, &mut seeded_rng);
|
|
|
|
|
@@ -1011,17 +1021,19 @@ mod test {
|
|
|
let result = client.decode_response(response.as_slice());
|
|
|
|
|
|
let p_bits = log2_ceil(params.pt_modulus) as usize;
|
|
|
- let corr_result = corr_item.to_vec(p_bits, params.poly_len);
|
|
|
+ let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
|
|
|
+
|
|
|
+ assert_eq!(result.len(), corr_result.len());
|
|
|
|
|
|
for z in 0..corr_result.len() {
|
|
|
- assert_eq!(result[z], corr_result[z]);
|
|
|
+ assert_eq!(result[z], corr_result[z], "at {:?}", z);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
fn full_protocol_is_correct_for_params_real_db(params: &Params) {
|
|
|
let mut seeded_rng = get_seeded_rng();
|
|
|
|
|
|
- let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
+ let target_idx = 22456; //seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
|
|
|
|
|
|
let mut client = Client::init(params, &mut seeded_rng);
|
|
|
|
|
@@ -1048,6 +1060,31 @@ mod test {
|
|
|
full_protocol_is_correct_for_params(&get_params());
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn larger_full_protocol_is_correct() {
|
|
|
+ let cfg_expand = r#"
|
|
|
+ {
|
|
|
+ 'n': 2,
|
|
|
+ 'nu_1': 10,
|
|
|
+ 'nu_2': 6,
|
|
|
+ 'p': 512,
|
|
|
+ 'q2_bits': 21,
|
|
|
+ 's_e': 85.83255142749422,
|
|
|
+ 't_gsw': 10,
|
|
|
+ 't_conv': 4,
|
|
|
+ 't_exp_left': 16,
|
|
|
+ 't_exp_right': 56,
|
|
|
+ 'instances': 1,
|
|
|
+ 'db_item_size': 9000 }
|
|
|
+ "#;
|
|
|
+ let cfg = cfg_expand;
|
|
|
+ let cfg = cfg.replace("'", "\"");
|
|
|
+ let params = params_from_json(&cfg);
|
|
|
+
|
|
|
+ full_protocol_is_correct_for_params(¶ms);
|
|
|
+ full_protocol_is_correct_for_params_real_db(¶ms);
|
|
|
+ }
|
|
|
+
|
|
|
// #[test]
|
|
|
// fn full_protocol_is_correct_20_256() {
|
|
|
// full_protocol_is_correct_for_params(¶ms_from_json(&CFG_20_256.replace("'", "\"")));
|