Browse Source

add first dimension processing

Samir Menon 2 years ago
parent
commit
f908ead77b
6 changed files with 333 additions and 47 deletions
  1. 39 2
      spiral-rs/benches/server.rs
  2. 32 1
      spiral-rs/src/arith.rs
  3. 1 37
      spiral-rs/src/client.rs
  4. 10 0
      spiral-rs/src/poly.rs
  5. 216 6
      spiral-rs/src/server.rs
  6. 35 1
      spiral-rs/src/util.rs

+ 39 - 2
spiral-rs/benches/server.rs

@@ -1,6 +1,8 @@
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 use pprof::criterion::{Output, PProfProfiler};
 
+use rand::Rng;
+use spiral_rs::aligned_memory::AlignedMemory64;
 use spiral_rs::client::*;
 use spiral_rs::poly::*;
 use spiral_rs::server::*;
@@ -8,7 +10,7 @@ use spiral_rs::util::*;
 use std::time::Duration;
 
 fn criterion_benchmark(c: &mut Criterion) {
-    let mut group = c.benchmark_group("sample-size");
+    let mut group = c.benchmark_group("server");
     group
         .sample_size(10)
         .measurement_time(Duration::from_secs(30));
@@ -32,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     let v_w_right = public_params.v_expansion_right.unwrap();
 
     // note: the benchmark on AVX2 is 545ms for the c++ impl
-    group.bench_function("coeff_exp", |b| {
+    group.bench_function("coefficient_expansion", |b| {
         b.iter(|| {
             coefficient_expansion(
                 black_box(&mut v),
@@ -46,6 +48,41 @@ fn criterion_benchmark(c: &mut Criterion) {
             )
         });
     });
+
+    let mut seeded_rng = get_seeded_rng();
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = trials * num_items * params.poly_len;
+    let mut db = vec![0u64; db_size_words];
+    for i in 0..db_size_words {
+        db[i] = seeded_rng.gen();
+    }
+    
+    let v_reg_sz = dim0 * 2 * params.poly_len;
+    let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
+    for i in 0..v_reg_sz {
+        v_reg_reoriented[i] = seeded_rng.gen();
+    }
+    let mut out = Vec::with_capacity(num_per);
+    for _ in 0..dim0 {
+        out.push(PolyMatrixNTT::zero(&params, 2, 1));
+    }
+
+    // note: the benchmark on AVX2 is 45ms for the c++ impl
+    group.bench_function("first_dimension_processing", |b| {
+        b.iter(|| {
+            multiply_reg_by_database(
+                black_box(&mut out),
+                black_box(db.as_slice()),
+                black_box(v_reg_reoriented.as_slice()),
+                black_box(&params),
+                black_box(dim0), 
+                black_box(num_per)
+            )
+        });
+    });
     group.finish();
 }
 

+ 32 - 1
spiral-rs/src/arith.rs

@@ -123,7 +123,7 @@ pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
     let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
 
     // Barrett subtraction
-    let res = input - tmp * modulus;
+    let mut res = input - tmp * modulus;
 
     // One more subtraction is enough
     if res >= modulus {
@@ -408,6 +408,37 @@ pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u6
     (numerator, quotient)
 }
 
+pub fn recenter_mod(val: u64, small_modulus: u64, large_modulus: u64) -> u64 {
+    assert!(val < small_modulus);
+    let mut val_i64 = val as i64;
+    let small_modulus_i64 = small_modulus as i64;
+    let large_modulus_i64 = large_modulus as i64;
+    if val_i64 > small_modulus_i64 / 2 {
+        val_i64 -= small_modulus_i64;
+    }
+    if val_i64 < 0 {
+        val_i64 += large_modulus_i64;
+    }
+    val_i64 as u64
+}
+
+pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 {
+    let inp_mod_i64 = inp_mod as i64;
+    let out_mod_i128 = out_mod as i128;
+    let mut inp_val = (a % inp_mod) as i64;
+    if inp_val >= (inp_mod_i64 / 2) {
+        inp_val -= inp_mod_i64;
+    }
+    let sign: i64 = if inp_val >= 0 { 1 } else { -1 };
+    let val = (inp_val as i128) * (out_mod as i128);
+    let mut result = (val + (sign*(inp_mod_i64/2)) as i128) / (inp_mod as i128);
+    result = (result + ((inp_mod/out_mod)*out_mod) as i128 + (2*out_mod_i128)) % out_mod_i128;
+
+    assert!(result >= 0);
+    
+    ((result + out_mod_i128) % out_mod_i128) as u64
+}
+
 #[cfg(test)]
 mod test {
     use super::*;

+ 1 - 37
spiral-rs/src/client.rs

@@ -309,42 +309,6 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
         pp
     }
 
-    // reindexes a vector of regev ciphertexts, to help server
-    fn reorient_reg_ciphertexts(&self, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
-        let params = self.params;
-        let poly_len = params.poly_len;
-        let crt_count = params.crt_count;
-
-        assert_eq!(crt_count, 2);
-        assert!(log2(params.moduli[0]) <= 32);
-
-        let num_reg_expanded = 1 << params.db_dim_1;
-        let ct_rows = v_reg[0].rows;
-        let ct_cols = v_reg[0].cols;
-
-        assert_eq!(ct_rows, 2);
-        assert_eq!(ct_cols, 1);
-
-        for j in 0..num_reg_expanded {
-            for r in 0..ct_rows {
-                for m in 0..ct_cols {
-                    for z in 0..params.poly_len {
-                        let idx_a_in =
-                            r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
-                        let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
-                            + j * (ct_cols * ct_rows)
-                            + m * (ct_rows)
-                            + r;
-                        let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
-                        let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
-
-                        out[idx_a_out] = val1 | (val2 << 32);
-                    }
-                }
-            }
-        }
-    }
-
     pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
         let params = self.params;
         let further_dims = params.db_dim_2;
@@ -393,7 +357,7 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
                 reg_cts.push(self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)));
             }
             // reorient into server's preferred indexing
-            self.reorient_reg_ciphertexts(reg_cts_buf.as_mut_slice(), &reg_cts);
+            reorient_reg_ciphertexts(self.params, reg_cts_buf.as_mut_slice(), &reg_cts);
 
             // generate GSW ciphertexts
             for i in 0..further_dims {

+ 10 - 0
spiral-rs/src/poly.rs

@@ -172,6 +172,16 @@ impl<'a> PolyMatrixRaw<'a> {
         to_ntt_alloc(&self)
     }
 
+    pub fn reduce_mod(&mut self, modulus: u64) {
+        for r in 0..self.rows {
+            for c in 0..self.cols {
+                for z in 0..self.params.poly_len {
+                    self.get_poly_mut(r, c)[z] %= modulus;
+                }
+            }
+        }
+    }
+
     pub fn to_vec(&self, modulus_bits: usize, num_coeffs: usize) -> Vec<u8> {
         let sz_bits = self.rows * self.cols * num_coeffs * modulus_bits;
         let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize + 32;

+ 216 - 6
spiral-rs/src/server.rs

@@ -1,7 +1,14 @@
+#[cfg(target_feature = "avx2")]
+use std::arch::x86_64::*;
+
+#[cfg(target_feature = "avx2")]
+use crate::aligned_memory::*;
+
 use crate::arith::*;
 use crate::gadget::*;
 use crate::params::*;
 use crate::poly::*;
+use crate::util::*;
 
 pub fn coefficient_expansion(
     v: &mut Vec<PolyMatrixNTT>,
@@ -118,13 +125,169 @@ pub fn regev_to_gsw<'a>(
     }
 }
 
-#[cfg(test)]
-mod test {
-    use rand::prelude::StdRng;
+pub const MAX_SUMMED: usize = 1 << 6;
+pub const PACKED_OFFSET_2: i32 = 32;
+
+#[cfg(target_feature = "avx2")]
+pub fn multiply_reg_by_database(
+    out: &mut Vec<PolyMatrixNTT>,
+    db: &[u64],
+    v_firstdim: &[u64],
+    params: &Params,
+    dim0: usize,
+    num_per: usize,
+) {
+    let ct_rows = 2;
+    let ct_cols = 1;
+    let pt_rows = 1;
+    let pt_cols = 1;
+
+    assert!(dim0 * ct_rows >= MAX_SUMMED);
+
+    let mut sums_out_n0_u64 = AlignedMemory64::new(4);
+    let mut sums_out_n2_u64 = AlignedMemory64::new(4);
+
+    for z in 0..params.poly_len {
+        let idx_a_base = z * (ct_cols * dim0 * ct_rows);
+        let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
+
+        for i in 0..num_per {
+            for c in 0..pt_cols {
+                let inner_limit = MAX_SUMMED;
+                let outer_limit = dim0 * ct_rows / inner_limit;
+
+                let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
+                let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
+
+                for o_jm in 0..outer_limit {
+                    unsafe {
+                        let mut sums_out_n0 = _mm256_setzero_si256();
+                        let mut sums_out_n2 = _mm256_setzero_si256();
+
+                        for i_jm in 0..inner_limit / 4 {
+                            let jm = o_jm * inner_limit + (4 * i_jm);
+
+                            let b_inp_1 = *db.get_unchecked(idx_b_base) as i64;
+                            idx_b_base += 1;
+                            let b_inp_2 = *db.get_unchecked(idx_b_base) as i64;
+                            idx_b_base += 1;
+                            let b = _mm256_set_epi64x(b_inp_2, b_inp_2, b_inp_1, b_inp_1);
+
+                            let v_a = v_firstdim.get_unchecked(idx_a_base + jm) as *const u64;
+
+                            let a = _mm256_load_si256(v_a as *const __m256i);
+                            let a_lo = a;
+                            let a_hi_hi = _mm256_srli_epi64(a, PACKED_OFFSET_2);
+                            let b_lo = b;
+                            let b_hi_hi = _mm256_srli_epi64(b, PACKED_OFFSET_2);
+
+                            sums_out_n0 =
+                                _mm256_add_epi64(sums_out_n0, _mm256_mul_epu32(a_lo, b_lo));
+                            sums_out_n2 =
+                                _mm256_add_epi64(sums_out_n2, _mm256_mul_epu32(a_hi_hi, b_hi_hi));
+                        }
+
+                        // reduce here, otherwise we will overflow
+
+                        _mm256_store_si256(
+                            sums_out_n0_u64.as_mut_ptr() as *mut __m256i,
+                            sums_out_n0,
+                        );
+                        _mm256_store_si256(
+                            sums_out_n2_u64.as_mut_ptr() as *mut __m256i,
+                            sums_out_n2,
+                        );
+
+                        for idx in 0..4 {
+                            let val = sums_out_n0_u64[idx];
+                            sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
+                        }
+                        for idx in 0..4 {
+                            let val = sums_out_n2_u64[idx];
+                            sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
+                        }
+                    }
+                }
+
+                for idx in 0..4 {
+                    sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
+                    sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
+                }
+
+                // output n0
+                let (crt_count, poly_len) = (params.crt_count, params.poly_len);
+                let mut n = 0;
+                let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
+                out[i].data[idx_c] =
+                    barrett_coeff_u64(params, sums_out_n0_u64_acc[0] + sums_out_n0_u64_acc[2], 0);
+                idx_c += pt_cols * crt_count * poly_len;
+                out[i].data[idx_c] =
+                    barrett_coeff_u64(params, sums_out_n0_u64_acc[1] + sums_out_n0_u64_acc[3], 0);
+
+                // output n1
+                n = 1;
+                idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
+                out[i].data[idx_c] =
+                    barrett_coeff_u64(params, sums_out_n2_u64_acc[0] + sums_out_n2_u64_acc[2], 1);
+                idx_c += pt_cols * crt_count * poly_len;
+                out[i].data[idx_c] =
+                    barrett_coeff_u64(params, sums_out_n2_u64_acc[1] + sums_out_n2_u64_acc[3], 1);
+            }
+        }
+    }
+}
+
+pub fn generate_random_db_and_get_item<'a>(
+    params: &'a Params,
+    item_idx: usize,
+) -> (PolyMatrixRaw<'a>, Vec<u64>) {
+    let mut rng = get_seeded_rng();
+
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = trials * num_items * params.poly_len;
+    let mut v = vec![0u64; db_size_words];
+
+    let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
+
+    for trial in 0..trials {
+        for i in 0..num_items {
+            let ii = i % num_per;
+            let j = i / num_per;
+
+            let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
+            db_item.reduce_mod(params.pt_modulus);
+            
+            if i == item_idx {
+                item.copy_into(&db_item, trial / params.n, trial % params.n);
+            }
+
+            for z in 0..params.poly_len {
+                db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+            }
 
-    use crate::{client::*, util::*};
+            let db_item_ntt = db_item.ntt();
+            for z in 0..params.poly_len {
+                let idx_dst = calc_index(
+                    &[trial, z, ii, j],
+                    &[trials, params.poly_len, num_per, dim0],
+                );
+
+                v[idx_dst] = db_item_ntt.data[z]
+                    | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+            }
+        }
+    }
+    (item, v)
+}
 
+#[cfg(test)]
+mod test {
     use super::*;
+    use crate::{client::*};
+    use rand::{prelude::StdRng, Rng};
 
     fn get_params() -> Params {
         let mut params = get_expansion_testing_params();
@@ -146,7 +309,6 @@ mod test {
             val -= params.modulus as i64;
         }
         let val_rounded = f64::round(val as f64 / scale_k as f64) as i64;
-        println!("{:?} {:?}", val, val_rounded);
         if val_rounded == 0 {
             0
         } else {
@@ -181,7 +343,7 @@ mod test {
         let public_params = client.generate_keys();
 
         let mut v = Vec::new();
-        for _ in 0..params.poly_len {
+        for _ in 0..(1 << (params.db_dim_1 + 1)) {
             v.push(PolyMatrixNTT::zero(&params, 2, 1));
         }
 
@@ -252,4 +414,52 @@ mod test {
 
         assert_eq!(dec_gsw(&params, &v_gsw[0], &mut client), 0);
     }
+
+    #[test]
+    fn multiply_reg_by_database_is_correct() {
+        let params = get_params();
+        let mut seeded_rng = get_seeded_rng();
+
+        let dim0 = 1 << params.db_dim_1;
+        let num_per = 1 << params.db_dim_2;
+        let scale_k = params.modulus / params.pt_modulus;
+
+        let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
+        let target_idx_dim0 = target_idx / num_per;
+        let target_idx_num_per = target_idx % num_per;
+
+        let mut client = Client::init(&params, &mut seeded_rng);
+        _ = client.generate_keys();
+
+        let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
+
+        let mut v_reg = Vec::new();
+        for i in 0..dim0 {
+            let val = if i == target_idx_dim0 { scale_k } else { 0 };
+            let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
+            v_reg.push(client.encrypt_matrix_reg(&sigma));
+        }
+
+        let v_reg_sz = dim0 * 2 * params.poly_len;
+        let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
+        reorient_reg_ciphertexts(&params, v_reg_reoriented.as_mut_slice(), &v_reg);
+
+        let mut out = Vec::with_capacity(num_per);
+        for _ in 0..dim0 {
+            out.push(PolyMatrixNTT::zero(&params, 2, 1));
+        }
+        multiply_reg_by_database(&mut out, db.as_slice(), v_reg_reoriented.as_slice(), &params, dim0, num_per);
+
+        // decrypt
+        let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
+        let mut dec_rescaled = PolyMatrixRaw::zero(&params, 1, 1);
+        for z in 0..params.poly_len {
+            dec_rescaled.data[z] = rescale(dec.data[z], params.modulus, params.pt_modulus);
+        }
+
+        for z in 0..params.poly_len {
+            // println!("{:?} {:?}", dec_rescaled.data[z], corr_item.data[z]);
+            assert_eq!(dec_rescaled.data[z], corr_item.data[z]);
+        }
+    }
 }

+ 35 - 1
spiral-rs/src/util.rs

@@ -1,4 +1,4 @@
-use crate::params::*;
+use crate::{arith::*, params::*, poly::*};
 use rand::{prelude::StdRng, SeedableRng};
 use serde_json::Value;
 
@@ -182,6 +182,40 @@ pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_
     }
 }
 
+pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
+    let poly_len = params.poly_len;
+    let crt_count = params.crt_count;
+
+    assert_eq!(crt_count, 2);
+    assert!(log2(params.moduli[0]) <= 32);
+
+    let num_reg_expanded = 1 << params.db_dim_1;
+    let ct_rows = v_reg[0].rows;
+    let ct_cols = v_reg[0].cols;
+
+    assert_eq!(ct_rows, 2);
+    assert_eq!(ct_cols, 1);
+
+    for j in 0..num_reg_expanded {
+        for r in 0..ct_rows {
+            for m in 0..ct_cols {
+                for z in 0..params.poly_len {
+                    let idx_a_in =
+                        r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
+                    let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
+                        + j * (ct_cols * ct_rows)
+                        + m * (ct_rows)
+                        + r;
+                    let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
+                    let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
+
+                    out[idx_a_out] = val1 | (val2 << 32);
+                }
+            }
+        }
+    }
+}
+
 #[cfg(test)]
 mod test {
     use super::*;