|
@@ -1,7 +1,14 @@
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
+use std::arch::x86_64::*;
|
|
|
+
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
+use crate::aligned_memory::*;
|
|
|
+
|
|
|
use crate::arith::*;
|
|
|
use crate::gadget::*;
|
|
|
use crate::params::*;
|
|
|
use crate::poly::*;
|
|
|
+use crate::util::*;
|
|
|
|
|
|
pub fn coefficient_expansion(
|
|
|
v: &mut Vec<PolyMatrixNTT>,
|
|
@@ -118,13 +125,169 @@ pub fn regev_to_gsw<'a>(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#[cfg(test)]
|
|
|
-mod test {
|
|
|
- use rand::prelude::StdRng;
|
|
|
+pub const MAX_SUMMED: usize = 1 << 6;
|
|
|
+pub const PACKED_OFFSET_2: i32 = 32;
|
|
|
+
|
|
|
+#[cfg(target_feature = "avx2")]
|
|
|
+pub fn multiply_reg_by_database(
|
|
|
+ out: &mut Vec<PolyMatrixNTT>,
|
|
|
+ db: &[u64],
|
|
|
+ v_firstdim: &[u64],
|
|
|
+ params: &Params,
|
|
|
+ dim0: usize,
|
|
|
+ num_per: usize,
|
|
|
+) {
|
|
|
+ let ct_rows = 2;
|
|
|
+ let ct_cols = 1;
|
|
|
+ let pt_rows = 1;
|
|
|
+ let pt_cols = 1;
|
|
|
+
|
|
|
+ assert!(dim0 * ct_rows >= MAX_SUMMED);
|
|
|
+
|
|
|
+ let mut sums_out_n0_u64 = AlignedMemory64::new(4);
|
|
|
+ let mut sums_out_n2_u64 = AlignedMemory64::new(4);
|
|
|
+
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ let idx_a_base = z * (ct_cols * dim0 * ct_rows);
|
|
|
+ let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
|
|
|
+
|
|
|
+ for i in 0..num_per {
|
|
|
+ for c in 0..pt_cols {
|
|
|
+ let inner_limit = MAX_SUMMED;
|
|
|
+ let outer_limit = dim0 * ct_rows / inner_limit;
|
|
|
+
|
|
|
+ let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
|
|
|
+ let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
|
|
|
+
|
|
|
+ for o_jm in 0..outer_limit {
|
|
|
+ unsafe {
|
|
|
+ let mut sums_out_n0 = _mm256_setzero_si256();
|
|
|
+ let mut sums_out_n2 = _mm256_setzero_si256();
|
|
|
+
|
|
|
+ for i_jm in 0..inner_limit / 4 {
|
|
|
+ let jm = o_jm * inner_limit + (4 * i_jm);
|
|
|
+
|
|
|
+ let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
|
|
|
+ idx_b_base += 1;
|
|
|
+ let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
|
|
|
+ idx_b_base += 1;
|
|
|
+ let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
|
|
|
+
|
|
|
+ let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
|
|
|
+
|
|
|
+ let a = _mm256_load_si256(v_a as *const __m256i);
|
|
|
+ let a_lo = a;
|
|
|
+ let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
|
|
|
+ let b_lo = b;
|
|
|
+ let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
|
|
|
+
|
|
|
+ sums_out_n0 =
|
|
|
+ _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
|
|
|
+ sums_out_n2 =
|
|
|
+ _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
|
|
|
+ }
|
|
|
+
|
|
|
+ // reduce here, otherwise we will overflow
|
|
|
+
|
|
|
+ _mm256_store_si256(
|
|
|
+ sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
|
|
|
+ sums_out_n0,
|
|
|
+ );
|
|
|
+ _mm256_store_si256(
|
|
|
+ sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
|
|
|
+ sums_out_n2,
|
|
|
+ );
|
|
|
+
|
|
|
+ for idx in 0..4 {
|
|
|
+ let val = sums_out_n0_u64[idx];
|
|
|
+ sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
|
|
|
+ }
|
|
|
+ for idx in 0..4 {
|
|
|
+ let val = sums_out_n2_u64[idx];
|
|
|
+ sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for idx in 0..4 {
|
|
|
+ sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
|
|
|
+ sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // output n0
|
|
|
+ let (crt_count, poly_len) = (params.crt_count, params.poly_len);
|
|
|
+ let mut n = 0;
|
|
|
+ let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
|
|
|
+ out[i].data[idx_c] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
|
|
|
+ idx_c += pt_cols * crt_count * poly_len;
|
|
|
+ out[i].data[idx_c] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
|
|
|
+
|
|
|
+ // output n1
|
|
|
+ n = 1;
|
|
|
+ idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
|
|
|
+ out[i].data[idx_c] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
|
|
|
+ idx_c += pt_cols * crt_count * poly_len;
|
|
|
+ out[i].data[idx_c] =
|
|
|
+ barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn generate_random_db_and_get_item<'a>(
|
|
|
+ params: &'a Params,
|
|
|
+ item_idx: usize,
|
|
|
+) -> (PolyMatrixRaw<'a>, Vec<u64>) {
|
|
|
+ let mut rng = get_seeded_rng();
|
|
|
+
|
|
|
+ let trials = params.n * params.n;
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let num_per = 1 << params.db_dim_2;
|
|
|
+ let num_items = dim0 * num_per;
|
|
|
+ let db_size_words = trials * num_items * params.poly_len;
|
|
|
+ let mut v = vec![0u64; db_size_words];
|
|
|
+
|
|
|
+ let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
|
|
|
+
|
|
|
+ for trial in 0..trials {
|
|
|
+ for i in 0..num_items {
|
|
|
+ let ii = i % num_per;
|
|
|
+ let j = i / num_per;
|
|
|
+
|
|
|
+ let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
|
|
|
+ db_item.reduce_mod(params.pt_modulus);
|
|
|
+
|
|
|
+ if i == item_idx {
|
|
|
+ item.copy_into(&db_item, trial / params.n, trial % params.n);
|
|
|
+ }
|
|
|
+
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
|
|
|
+ }
|
|
|
|
|
|
- use crate::{client::*, util::*};
|
|
|
+ let db_item_ntt = db_item.ntt();
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ let idx_dst = calc_index(
|
|
|
+ &[trial, z, ii, j],
|
|
|
+ &[trials, params.poly_len, num_per, dim0],
|
|
|
+ );
|
|
|
+
|
|
|
+ v[idx_dst] = db_item_ntt.data[z]
|
|
|
+ | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ (item, v)
|
|
|
+}
|
|
|
|
|
|
+#[cfg(test)]
|
|
|
+mod test {
|
|
|
use super::*;
|
|
|
+ use crate::{client::*};
|
|
|
+ use rand::{prelude::StdRng, Rng};
|
|
|
|
|
|
fn get_params() -> Params {
|
|
|
let mut params = get_expansion_testing_params();
|
|
@@ -146,7 +309,6 @@ mod test {
|
|
|
val -= params.modulus as i64;
|
|
|
}
|
|
|
let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
|
|
|
- println!("{:?} {:?}", val, val_rounded);
|
|
|
if val_rounded == 0 {
|
|
|
0
|
|
|
} else {
|
|
@@ -181,7 +343,7 @@ mod test {
|
|
|
let public_params = client.generate_keys();
|
|
|
|
|
|
let mut v = Vec::new();
|
|
|
- for _ in 0..params.poly_len {
|
|
|
+ for _ in 0..(1 << (params.db_dim_1 + 1)) {
|
|
|
v.push(PolyMatrixNTT::zero(¶ms, 2, 1));
|
|
|
}
|
|
|
|
|
@@ -252,4 +414,52 @@ mod test {
|
|
|
|
|
|
assert_eq!(dec_gsw(¶ms, &v_gsw[0], &mut client), 0);
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn multiply_reg_by_database_is_correct() {
|
|
|
+ let params = get_params();
|
|
|
+ let mut seeded_rng = get_seeded_rng();
|
|
|
+
|
|
|
+ let dim0 = 1 << params.db_dim_1;
|
|
|
+ let num_per = 1 << params.db_dim_2;
|
|
|
+ let scale_k = params.modulus / params.pt_modulus;
|
|
|
+
|
|
|
+ let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
|
|
|
+ let target_idx_dim0 = target_idx / num_per;
|
|
|
+ let target_idx_num_per = target_idx % num_per;
|
|
|
+
|
|
|
+ let mut client = Client::init(¶ms, &mut seeded_rng);
|
|
|
+ _ = client.generate_keys();
|
|
|
+
|
|
|
+ let (corr_item, db) = generate_random_db_and_get_item(¶ms, target_idx);
|
|
|
+
|
|
|
+ let mut v_reg = Vec::new();
|
|
|
+ for i in 0..dim0 {
|
|
|
+ let val = if i == target_idx_dim0 { scale_k } else { 0 };
|
|
|
+ let sigma = PolyMatrixRaw::single_value(¶ms, val).ntt();
|
|
|
+ v_reg.push(client.encrypt_matrix_reg(&sigma));
|
|
|
+ }
|
|
|
+
|
|
|
+ let v_reg_sz = dim0 * 2 * params.poly_len;
|
|
|
+ let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
|
|
|
+ reorient_reg_ciphertexts(¶ms, v_reg_reoriented.as_mut_slice(), &v_reg);
|
|
|
+
|
|
|
+ let mut out = Vec::with_capacity(num_per);
|
|
|
+ for _ in 0..dim0 {
|
|
|
+ out.push(PolyMatrixNTT::zero(¶ms, 2, 1));
|
|
|
+ }
|
|
|
+ multiply_reg_by_database(&mut out, db.as_slice(), v_reg_reoriented.as_slice(), ¶ms, dim0, num_per);
|
|
|
+
|
|
|
+ // decrypt
|
|
|
+ let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
|
|
|
+ let mut dec_rescaled = PolyMatrixRaw::zero(¶ms, 1, 1);
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
|
|
|
+ }
|
|
|
+
|
|
|
+ for z in 0..params.poly_len {
|
|
|
+ // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
|
|
|
+ assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|