Browse Source

full correct server processing!

Samir Menon 1 year ago
parent
commit
5131c85f4c

+ 30 - 0
spiral-rs/benches/server.rs

@@ -1,3 +1,5 @@
+use criterion::BenchmarkGroup;
+use criterion::measurement::WallTime;
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 use pprof::criterion::{Output, PProfProfiler};
 
@@ -9,6 +11,29 @@ use spiral_rs::server::*;
 use spiral_rs::util::*;
 use std::time::Duration;
 
+fn test_full_processing(group: &mut BenchmarkGroup<WallTime>) {
+    let params = get_expansion_testing_params();
+    let mut seeded_rng = get_seeded_rng();
+
+    let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+    
+    let mut client = Client::init(&params, &mut seeded_rng);
+    let public_params = client.generate_keys();
+    let query = client.generate_query(target_idx);
+    let (_, db) = generate_random_db_and_get_item(&params, target_idx);
+
+    group.bench_function("server_processing", |b| {
+        b.iter(|| {
+            black_box(process_query(
+                black_box(&params),
+                black_box(&public_params),
+                black_box(&query),
+                black_box(db.as_slice()),
+            ));
+        });
+    });
+}
+
 fn criterion_benchmark(c: &mut Criterion) {
     let mut group = c.benchmark_group("server");
     group
@@ -83,6 +108,11 @@ fn criterion_benchmark(c: &mut Criterion) {
             )
         });
     });
+
+    // full server processing benchmark
+
+    test_full_processing(&mut group);
+
     group.finish();
 }
 

+ 1 - 1
spiral-rs/src/arith.rs

@@ -123,7 +123,7 @@ pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
     let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
 
     // Barrett subtraction
-    let mut res = input - tmp * modulus;
+    let res = input - tmp * modulus;
 
     // One more subtraction is enough
     if res >= modulus {

+ 1 - 0
spiral-rs/src/client.rs

@@ -465,6 +465,7 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
         let bytes_per_chunk = f64::ceil(params.db_item_size as f64 / chunks as f64) as usize;
         let logp = log2(params.pt_modulus);
         let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+        println!("modp_words_per_chunk {:?}", modp_words_per_chunk);
         result.to_vec(p_bits as usize, modp_words_per_chunk)
     }
 }

+ 6 - 2
spiral-rs/src/params.rs

@@ -167,8 +167,12 @@ impl Params {
         let modulus_log2 = log2_ceil(modulus);
         let (barrett_cr_0, barrett_cr_1) = get_barrett(moduli);
         let (barrett_cr_0_modulus, barrett_cr_1_modulus) = get_barrett_crs(modulus);
-        let mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap();
-        let mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap();
+        let mut mod0_inv_mod1 = 0;
+        let mut mod1_inv_mod0 = 0;
+        if crt_count == 2 {
+            mod0_inv_mod1 = moduli[0] * invert_uint_mod(moduli[0], moduli[1]).unwrap();
+            mod1_inv_mod0 = moduli[1] * invert_uint_mod(moduli[1], moduli[0]).unwrap();
+        }
         Self {
             poly_len,
             poly_len_log2,

+ 120 - 41
spiral-rs/src/server.rs

@@ -240,10 +240,79 @@ pub fn multiply_reg_by_database(
     }
 }
 
+#[cfg(not(target_feature = "avx2"))]
+pub fn multiply_reg_by_database(
+    out: &mut Vec<PolyMatrixNTT>,
+    db: &[u64],
+    v_firstdim: &[u64],
+    params: &Params,
+    dim0: usize,
+    num_per: usize,
+) {
+    let ct_rows = 2;
+    let ct_cols = 1;
+    let pt_rows = 1;
+    let pt_cols = 1;
+
+    for z in 0..params.poly_len {
+        let idx_a_base = z * (ct_cols * dim0 * ct_rows);
+        let mut idx_b_base = z * (num_per * pt_cols * dim0 * pt_rows);
+
+        for i in 0..num_per {
+            for c in 0..pt_cols {
+                let mut sums_out_n0_0 = 0u128;
+                let mut sums_out_n0_1 = 0u128;
+                let mut sums_out_n1_0 = 0u128;
+                let mut sums_out_n1_1 = 0u128;
+
+                for jm in 0..(dim0 * pt_rows) {
+                    let b = db[idx_b_base];
+                    idx_b_base += 1;
+                        
+                    let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
+                    let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
+
+                    let b_lo = b as u32;
+                    let b_hi = (b >> 32) as u32;
+
+                    let v_a0_lo = v_a0 as u32;
+                    let v_a0_hi = (v_a0 >> 32) as u32;
+
+                    let v_a1_lo = v_a1 as u32;
+                    let v_a1_hi = (v_a1 >> 32) as u32;
+                    
+                    // do n0
+                    sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
+                    sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
+
+                    // do n1
+                    sums_out_n1_0 += ((v_a0_hi as u64) * (b_hi as u64)) as u128;
+                    sums_out_n1_1 += ((v_a1_hi as u64) * (b_hi as u64)) as u128;
+                }
+
+                // output n0
+                let (crt_count, poly_len) = (params.crt_count, params.poly_len);
+                let mut n = 0;
+                let mut idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
+                out[i].data[idx_c] = (sums_out_n0_0 % (params.moduli[0] as u128)) as u64;
+                idx_c += pt_cols * crt_count * poly_len;
+                out[i].data[idx_c] = (sums_out_n0_1 % (params.moduli[0] as u128)) as u64;
+
+                // output n1
+                n = 1;
+                idx_c = c * (crt_count * poly_len) + n * (poly_len) + z;
+                out[i].data[idx_c] = (sums_out_n1_0 % (params.moduli[1] as u128)) as u64;
+                idx_c += pt_cols * crt_count * poly_len;
+                out[i].data[idx_c] = (sums_out_n1_1 % (params.moduli[1] as u128)) as u64;
+            }
+        }
+    }
+}
+
 pub fn generate_random_db_and_get_item<'a>(
     params: &'a Params,
     item_idx: usize,
-) -> (PolyMatrixRaw<'a>, Vec<u64>) {
+) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
     let mut rng = get_seeded_rng();
 
     let trials = params.n * params.n;
@@ -251,7 +320,7 @@ pub fn generate_random_db_and_get_item<'a>(
     let num_per = 1 << params.db_dim_2;
     let num_items = dim0 * num_per;
     let db_size_words = trials * num_items * params.poly_len;
-    let mut v = vec![0u64; db_size_words];
+    let mut v = AlignedMemory64::new(db_size_words);
 
     let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
 
@@ -341,8 +410,8 @@ pub fn pack<'a>(
         for r in 0..params.n {
             let w = &v_w[r];
             let ct = &v_ct[r * params.n + c];
-            ct_1.copy_into(ct, 0, 0);
-            ct_2.copy_into(ct, 1, 0);
+            ct_1.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(0, 0));
+            ct_2.get_poly_mut(0, 0).copy_from_slice(ct.get_poly(1, 0));
             to_ntt(&mut ct_2_ntt, &ct_2);
             gadget_invert(&mut ginv, &ct_1);
             to_ntt(&mut ginv_nttd, &ginv);
@@ -396,17 +465,33 @@ pub fn encode(
     result
 }
 
+pub fn get_v_folding_neg<'a>(
+    params: &'a Params, 
+    v_folding: &Vec<PolyMatrixNTT>,
+) -> Vec<PolyMatrixNTT<'a>> {
+    let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
+
+    let mut v_folding_neg = Vec::new();
+    let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
+    for i in 0..params.db_dim_2 {
+        invert(&mut ct_gsw_inv, &v_folding[i].raw());
+        let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
+        add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
+        v_folding_neg.push(ct_gsw_neg);
+    }
+    v_folding_neg
+}
+
 pub fn expand_query<'a>(
     params: &'a Params, 
     public_params: &PublicParameters<'a>, 
     query: &Query<'a>,
-) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>, Vec<PolyMatrixNTT<'a>>) {
+) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
     let dim0 = 1 << params.db_dim_1;
     let further_dims = params.db_dim_2;
 
     let mut v_reg_reoriented;
     let mut v_folding;
-    let mut v_folding_neg;
 
     let num_bits_to_gen = params.t_gsw * further_dims + dim0;
     let g = log2_ceil_usize(num_bits_to_gen);
@@ -454,21 +539,10 @@ pub fn expand_query<'a>(
     }
 
     regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
-
-    let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
-    v_folding_neg = Vec::new();
-    let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
-    for i in 0..params.db_dim_2 {
-        invert(&mut ct_gsw_inv, &v_folding[i].raw());
-        let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
-        add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
-        v_folding_neg.push(ct_gsw_neg);
-    }
-
-    (v_reg_reoriented, v_folding, v_folding_neg)
+    
+    (v_reg_reoriented, v_folding)
 }
 
-#[cfg(target_feature = "avx2")]
 pub fn process_query(
     params: &Params, 
     public_params: &PublicParameters, 
@@ -477,20 +551,28 @@ pub fn process_query(
 ) -> Vec<u8> {
     let dim0 = 1 << params.db_dim_1;
     let num_per = 1 << params.db_dim_2;
-    let further_dims = params.db_dim_2;
     let db_slice_sz = dim0 * num_per * params.poly_len;
 
-    
-
     let v_packing = public_params.v_packing.as_ref();
 
+    let mut v_reg_reoriented;
+    let v_folding;
     if params.expand_queries {
-        
+        (v_reg_reoriented, v_folding) = 
+            expand_query(params, public_params, query);
+    } else {
+        v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
+        v_reg_reoriented.as_mut_slice().copy_from_slice(query.v_buf.as_ref().unwrap());
+
+        v_folding = query.v_ct.as_ref().unwrap().clone().iter()
+            .map(|x| { x.ntt() })
+            .collect();
     }
+    let v_folding_neg = get_v_folding_neg(params, &v_folding);
 
     let mut intermediate = Vec::with_capacity(num_per);
     let mut intermediate_raw = Vec::with_capacity(num_per);
-    for _ in 0..dim0 {
+    for _ in 0..num_per {
         intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
         intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
     }
@@ -499,7 +581,7 @@ pub fn process_query(
     for trial in 0..(params.n * params.n) {
         let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
 
-        multiply_reg_by_database(&mut intermediate, db, v_reg_reoriented.as_slice(), params, dim0, num_per);
+        multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
 
         for i in 0..intermediate.len() {
             from_ntt(&mut intermediate_raw[i], &intermediate[i]);
@@ -512,7 +594,7 @@ pub fn process_query(
             &v_folding_neg
         );
 
-        v_ct.push(intermediate_raw[0]);
+        v_ct.push(intermediate_raw[0].clone());
     }
 
     let packed_ct = pack(
@@ -780,27 +862,24 @@ mod test {
         let params = get_params();
         let mut seeded_rng = get_seeded_rng();
 
-        let dim0 = 1 << params.db_dim_1;
-        let num_per = 1 << params.db_dim_2;
-        let scale_k = params.modulus / params.pt_modulus;
-
-        let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
-        let target_idx_dim0 = target_idx / num_per;
-        let target_idx_num_per = target_idx % num_per;
+        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(&params, &mut seeded_rng);
-        let public_parameters = client.generate_keys();
+
+        let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
 
         let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
 
-        let mut v_reg = Vec::new();
-        for i in 0..dim0 {
-            let val = if i == target_idx_dim0 { scale_k } else { 0 };
-            let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
-            v_reg.push(client.encrypt_matrix_reg(&sigma));
-        }
+        let response = process_query(&params, &public_params, &query, db.as_slice());
 
-        
+        let result = client.decode_response(response.as_slice());
+
+        let p_bits = log2_ceil(params.pt_modulus) as usize;
+        let corr_result = corr_item.to_vec(p_bits, params.poly_len);
+
+        for z in 0..corr_result.len() {
+            assert_eq!(result[z], corr_result[z]);
+        }
     }
 }

+ 1 - 1
spiral-rs/src/util.rs

@@ -65,7 +65,7 @@ pub fn get_expansion_testing_params() -> Params {
         't_exp': 8,
         't_exp_right': 56,
         'instances': 1,
-        'db_item_size': 256 }
+        'db_item_size': 8192 }
     "#;
     let cfg = cfg.replace("'", "\"");
     let b = params_from_json(&cfg);