|
@@ -7,8 +7,8 @@ use std::io::Seek;
|
|
|
use std::io::SeekFrom;
|
|
|
use std::mem::size_of;
|
|
|
|
|
|
-use crate::arith::*;
|
|
|
use crate::aligned_memory::*;
|
|
|
+use crate::arith::*;
|
|
|
use crate::client::PublicParameters;
|
|
|
use crate::client::Query;
|
|
|
use crate::gadget::*;
|
|
@@ -206,18 +206,22 @@ pub fn multiply_reg_by_database(
|
|
|
|
|
|
for idx in 0..4 {
|
|
|
let val = sums_out_n0_u64[idx];
|
|
|
- sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
|
|
|
+ sums_out_n0_u64_acc[idx] =
|
|
|
+ barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
|
|
|
}
|
|
|
for idx in 0..4 {
|
|
|
let val = sums_out_n2_u64[idx];
|
|
|
- sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
|
|
|
+ sums_out_n2_u64_acc[idx] =
|
|
|
+ barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for idx in 0..4 {
|
|
|
- sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
|
|
|
- sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
|
|
|
+ sums_out_n0_u64_acc[idx] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
|
|
|
+ sums_out_n2_u64_acc[idx] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
|
|
|
}
|
|
|
|
|
|
// output n0
|
|
@@ -271,7 +275,7 @@ pub fn multiply_reg_by_database(
|
|
|
for jm in 0..(dim0 * pt_rows) {
|
|
|
let b = db[idx_b_base];
|
|
|
idx_b_base += 1;
|
|
|
-
|
|
|
+
|
|
|
let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
|
|
|
let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
|
|
|
|
|
@@ -283,7 +287,7 @@ pub fn multiply_reg_by_database(
|
|
|
|
|
|
let v_a1_lo = v_a1 as u32;
|
|
|
let v_a1_hi = (v_a1 >> 32) as u32;
|
|
|
-
|
|
|
+
|
|
|
// do n0
|
|
|
sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
|
|
|
sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
|
|
@@ -339,13 +343,14 @@ pub fn generate_random_db_and_get_item<'a>(
|
|
|
|
|
|
let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
|
|
|
db_item.reduce_mod(params.pt_modulus);
|
|
|
-
|
|
|
+
|
|
|
if i == item_idx && instance == 0 {
|
|
|
item.copy_into(&db_item, trial / params.n, trial % params.n);
|
|
|
}
|
|
|
|
|
|
for z in 0..params.poly_len {
|
|
|
- db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
+ db_item.data[z] =
|
|
|
+ recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
}
|
|
|
|
|
|
let db_item_ntt = db_item.ntt();
|
|
@@ -367,9 +372,9 @@ pub fn generate_random_db_and_get_item<'a>(
|
|
|
pub fn load_item_from_file<'a>(
|
|
|
params: &'a Params,
|
|
|
file: &mut File,
|
|
|
- instance: usize,
|
|
|
+ instance: usize,
|
|
|
trial: usize,
|
|
|
- item_idx: usize
|
|
|
+ item_idx: usize,
|
|
|
) -> PolyMatrixRaw<'a> {
|
|
|
let db_item_size = params.db_item_size;
|
|
|
let instances = params.instances;
|
|
@@ -395,7 +400,9 @@ pub fn load_item_from_file<'a>(
|
|
|
return out;
|
|
|
}
|
|
|
let mut data = vec![0u8; 2 * bytes_per_chunk];
|
|
|
- let bytes_read = file.read(&mut data.as_mut_slice()[0..bytes_per_chunk]).unwrap();
|
|
|
+ let bytes_read = file
|
|
|
+ .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
|
|
|
+ .unwrap();
|
|
|
|
|
|
let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
|
|
|
assert!(modp_words_read <= params.poly_len);
|
|
@@ -404,14 +411,11 @@ pub fn load_item_from_file<'a>(
|
|
|
out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
|
|
|
assert!(out.data[i] <= params.pt_modulus);
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
out
|
|
|
}
|
|
|
|
|
|
-pub fn load_db_from_file(
|
|
|
- params: &Params,
|
|
|
- file: &mut File
|
|
|
-) -> AlignedMemory64 {
|
|
|
+pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
|
|
|
let instances = params.instances;
|
|
|
let trials = params.n * params.n;
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
@@ -433,9 +437,10 @@ pub fn load_db_from_file(
|
|
|
|
|
|
let mut db_item = load_item_from_file(params, file, instance, trial, i);
|
|
|
// db_item.reduce_mod(params.pt_modulus);
|
|
|
-
|
|
|
+
|
|
|
for z in 0..params.poly_len {
|
|
|
- db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
+ db_item.data[z] =
|
|
|
+ recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
}
|
|
|
|
|
|
let db_item_ntt = db_item.ntt();
|
|
@@ -454,10 +459,7 @@ pub fn load_db_from_file(
|
|
|
v
|
|
|
}
|
|
|
|
|
|
-pub fn load_preprocessed_db_from_file(
|
|
|
- params: &Params,
|
|
|
- file: &mut File
|
|
|
-) -> AlignedMemory64 {
|
|
|
+pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
|
|
|
let instances = params.instances;
|
|
|
let trials = params.n * params.n;
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
@@ -484,7 +486,7 @@ pub fn fold_ciphertexts(
|
|
|
params: &Params,
|
|
|
v_cts: &mut Vec<PolyMatrixRaw>,
|
|
|
v_folding: &Vec<PolyMatrixNTT>,
|
|
|
- v_folding_neg: &Vec<PolyMatrixNTT>
|
|
|
+ v_folding_neg: &Vec<PolyMatrixNTT>,
|
|
|
) {
|
|
|
let further_dims = log2(v_cts.len() as u64) as usize;
|
|
|
let ell = v_folding[0].cols / 2;
|
|
@@ -499,10 +501,18 @@ pub fn fold_ciphertexts(
|
|
|
for i in 0..num_per {
|
|
|
gadget_invert(&mut ginv_c, &v_cts[i]);
|
|
|
to_ntt(&mut ginv_c_ntt, &ginv_c);
|
|
|
- multiply(&mut prod, &v_folding_neg[further_dims - 1 - cur_dim], &ginv_c_ntt);
|
|
|
+ multiply(
|
|
|
+ &mut prod,
|
|
|
+ &v_folding_neg[further_dims - 1 - cur_dim],
|
|
|
+ &ginv_c_ntt,
|
|
|
+ );
|
|
|
gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
|
|
|
to_ntt(&mut ginv_c_ntt, &ginv_c);
|
|
|
- multiply(&mut sum, &v_folding[further_dims - 1 - cur_dim], &ginv_c_ntt);
|
|
|
+ multiply(
|
|
|
+ &mut sum,
|
|
|
+ &v_folding[further_dims - 1 - cur_dim],
|
|
|
+ &ginv_c_ntt,
|
|
|
+ );
|
|
|
add_into(&mut sum, &prod);
|
|
|
from_ntt(&mut v_cts[i], &sum);
|
|
|
}
|
|
@@ -512,7 +522,7 @@ pub fn fold_ciphertexts(
|
|
|
pub fn pack<'a>(
|
|
|
params: &'a Params,
|
|
|
v_ct: &Vec<PolyMatrixRaw>,
|
|
|
- v_w: &Vec<PolyMatrixNTT>
|
|
|
+ v_w: &Vec<PolyMatrixNTT>,
|
|
|
) -> PolyMatrixNTT<'a> {
|
|
|
assert!(v_ct.len() >= params.n * params.n);
|
|
|
assert!(v_w.len() == params.n);
|
|
@@ -550,20 +560,15 @@ pub fn pack<'a>(
|
|
|
result
|
|
|
}
|
|
|
|
|
|
-pub fn encode(
|
|
|
- params: &Params,
|
|
|
- v_packed_ct: &Vec<PolyMatrixRaw>
|
|
|
-) -> Vec<u8> {
|
|
|
+pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
|
|
|
let q1 = 4 * params.pt_modulus;
|
|
|
let q1_bits = log2_ceil(q1) as usize;
|
|
|
let q2 = Q2_VALUES[params.q2_bits as usize];
|
|
|
let q2_bits = params.q2_bits as usize;
|
|
|
|
|
|
- let num_bits = params.instances *
|
|
|
- (
|
|
|
- (q2_bits * params.n * params.poly_len) +
|
|
|
- (q1_bits * params.n * params.n * params.poly_len)
|
|
|
- );
|
|
|
+ let num_bits = params.instances
|
|
|
+ * ((q2_bits * params.n * params.poly_len)
|
|
|
+ + (q1_bits * params.n * params.n * params.poly_len));
|
|
|
let round_to = 64;
|
|
|
let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
|
|
|
|
|
@@ -571,11 +576,11 @@ pub fn encode(
|
|
|
let mut bit_offs = 0;
|
|
|
for instance in 0..params.instances {
|
|
|
let packed_ct = &v_packed_ct[instance];
|
|
|
-
|
|
|
+
|
|
|
let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
|
|
|
let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
|
|
|
- first_row.apply_func(|x| { rescale(x, params.modulus, q2) });
|
|
|
- rest_rows.apply_func(|x| { rescale(x, params.modulus, q1) });
|
|
|
+ first_row.apply_func(|x| rescale(x, params.modulus, q2));
|
|
|
+ rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
|
|
|
|
|
|
let data = result.as_mut_slice();
|
|
|
for i in 0..params.n * params.poly_len {
|
|
@@ -591,7 +596,7 @@ pub fn encode(
|
|
|
}
|
|
|
|
|
|
pub fn get_v_folding_neg<'a>(
|
|
|
- params: &'a Params,
|
|
|
+ params: &'a Params,
|
|
|
v_folding: &Vec<PolyMatrixNTT>,
|
|
|
) -> Vec<PolyMatrixNTT<'a>> {
|
|
|
let gadget_ntt = build_gadget(¶ms, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
|
|
@@ -608,8 +613,8 @@ pub fn get_v_folding_neg<'a>(
|
|
|
}
|
|
|
|
|
|
pub fn expand_query<'a>(
|
|
|
- params: &'a Params,
|
|
|
- public_params: &PublicParameters<'a>,
|
|
|
+ params: &'a Params,
|
|
|
+ public_params: &PublicParameters<'a>,
|
|
|
query: &Query<'a>,
|
|
|
) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
|
|
|
let dim0 = 1 << params.db_dim_1;
|
|
@@ -664,13 +669,13 @@ pub fn expand_query<'a>(
|
|
|
}
|
|
|
|
|
|
regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
|
|
|
-
|
|
|
+
|
|
|
(v_reg_reoriented, v_folding)
|
|
|
}
|
|
|
|
|
|
pub fn process_query(
|
|
|
- params: &Params,
|
|
|
- public_params: &PublicParameters,
|
|
|
+ params: &Params,
|
|
|
+ public_params: &PublicParameters,
|
|
|
query: &Query,
|
|
|
db: &[u64],
|
|
|
) -> Vec<u8> {
|
|
@@ -683,14 +688,20 @@ pub fn process_query(
|
|
|
let mut v_reg_reoriented;
|
|
|
let v_folding;
|
|
|
if params.expand_queries {
|
|
|
- (v_reg_reoriented, v_folding) =
|
|
|
- expand_query(params, public_params, query);
|
|
|
+ (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
|
|
|
} else {
|
|
|
v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
|
|
|
- v_reg_reoriented.as_mut_slice().copy_from_slice(query.v_buf.as_ref().unwrap());
|
|
|
-
|
|
|
- v_folding = query.v_ct.as_ref().unwrap().clone().iter()
|
|
|
- .map(|x| { x.ntt() })
|
|
|
+ v_reg_reoriented
|
|
|
+ .as_mut_slice()
|
|
|
+ .copy_from_slice(query.v_buf.as_ref().unwrap());
|
|
|
+
|
|
|
+ v_folding = query
|
|
|
+ .v_ct
|
|
|
+ .as_ref()
|
|
|
+ .unwrap()
|
|
|
+ .clone()
|
|
|
+ .iter()
|
|
|
+ .map(|x| x.ntt())
|
|
|
.collect();
|
|
|
}
|
|
|
let v_folding_neg = get_v_folding_neg(params, &v_folding);
|
|
@@ -711,27 +722,25 @@ pub fn process_query(
|
|
|
let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
|
|
|
let cur_db = &db[idx..(idx + db_slice_sz)];
|
|
|
|
|
|
- multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
|
|
|
+ multiply_reg_by_database(
|
|
|
+ &mut intermediate,
|
|
|
+ cur_db,
|
|
|
+ v_reg_reoriented.as_slice(),
|
|
|
+ params,
|
|
|
+ dim0,
|
|
|
+ num_per,
|
|
|
+ );
|
|
|
|
|
|
for i in 0..intermediate.len() {
|
|
|
from_ntt(&mut intermediate_raw[i], &intermediate[i]);
|
|
|
}
|
|
|
|
|
|
- fold_ciphertexts(
|
|
|
- params,
|
|
|
- &mut intermediate_raw,
|
|
|
- &v_folding,
|
|
|
- &v_folding_neg
|
|
|
- );
|
|
|
+ fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
|
|
|
|
|
|
v_ct.push(intermediate_raw[0].clone());
|
|
|
}
|
|
|
|
|
|
- let packed_ct = pack(
|
|
|
- params,
|
|
|
- &v_ct,
|
|
|
- &v_packing,
|
|
|
- );
|
|
|
+ let packed_ct = pack(params, &v_ct, &v_packing);
|
|
|
|
|
|
v_packed_ct.push(packed_ct.raw());
|
|
|
}
|
|
@@ -742,10 +751,10 @@ pub fn process_query(
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
use super::*;
|
|
|
- use crate::{client::*};
|
|
|
+ use crate::client::*;
|
|
|
use rand::{prelude::SmallRng, Rng};
|
|
|
|
|
|
- const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
|
|
|
+ const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
let mut params = get_expansion_testing_params();
|
|
@@ -906,7 +915,14 @@ mod test {
|
|
|
for _ in 0..dim0 {
|
|
|
out.push(PolyMatrixNTT::zero(¶ms, 2, 1));
|
|
|
}
|
|
|
- multiply_reg_by_database(&mut out, db.as_slice(), v_reg_reoriented.as_slice(), ¶ms, dim0, num_per);
|
|
|
+ multiply_reg_by_database(
|
|
|
+ &mut out,
|
|
|
+ db.as_slice(),
|
|
|
+ v_reg_reoriented.as_slice(),
|
|
|
+ ¶ms,
|
|
|
+ dim0,
|
|
|
+ num_per,
|
|
|
+ );
|
|
|
|
|
|
// decrypt
|
|
|
let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
|
|
@@ -978,15 +994,13 @@ mod test {
|
|
|
v_folding_neg.push(ct_gsw_neg);
|
|
|
}
|
|
|
|
|
|
- fold_ciphertexts(
|
|
|
- ¶ms,
|
|
|
- &mut v_reg_raw,
|
|
|
- &v_folding,
|
|
|
- &v_folding_neg
|
|
|
- );
|
|
|
-
|
|
|
+ fold_ciphertexts(¶ms, &mut v_reg_raw, &v_folding, &v_folding_neg);
|
|
|
+
|
|
|
// decrypt
|
|
|
- assert_eq!(dec_reg(¶ms, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
|
|
|
+ assert_eq!(
|
|
|
+ dec_reg(¶ms, &v_reg_raw[0].ntt(), &mut client, scale_k),
|
|
|
+ 1
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
fn full_protocol_is_correct_for_params(params: &Params) {
|
|
@@ -1037,7 +1051,7 @@ mod test {
|
|
|
assert_eq!(result[z], corr_result[z]);
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
#[test]
|
|
|
fn full_protocol_is_correct() {
|
|
|
full_protocol_is_correct_for_params(&get_params());
|
|
@@ -1054,7 +1068,10 @@ mod test {
|
|
|
// }
|
|
|
|
|
|
#[test]
|
|
|
+ #[ignore]
|
|
|
fn full_protocol_is_correct_real_db_16_100000() {
|
|
|
- full_protocol_is_correct_for_params_real_db(¶ms_from_json(&CFG_16_100000.replace("'", "\"")));
|
|
|
+ full_protocol_is_correct_for_params_real_db(¶ms_from_json(
|
|
|
+ &CFG_16_100000.replace("'", "\""),
|
|
|
+ ));
|
|
|
}
|
|
|
}
|