Browse Source

improve benchmarking + tests

Samir Menon 2 years ago
parent
commit
e5441a23ab
5 changed files with 159 additions and 81 deletions
  1. 1 1
      spiral-rs/Cargo.toml
  2. 46 17
      spiral-rs/benches/server.rs
  3. 2 3
      spiral-rs/src/client.rs
  4. 76 55
      spiral-rs/src/server.rs
  5. 34 5
      spiral-rs/src/util.rs

+ 1 - 1
spiral-rs/Cargo.toml

@@ -5,7 +5,7 @@ edition = "2021"
 
 [dependencies]
 getrandom = { features = ["js"], version = "0.2.6" }
-rand = { version = "0.8.5" }
+rand = { version = "0.8.5", features = ["small_rng"] }
 reqwest = { version = "0.11", features = ["blocking"] }
 serde_json = "1.0"
 

+ 46 - 17
spiral-rs/benches/server.rs

@@ -6,32 +6,61 @@ use pprof::criterion::{Output, PProfProfiler};
 use rand::Rng;
 use spiral_rs::aligned_memory::AlignedMemory64;
 use spiral_rs::client::*;
+use spiral_rs::params::*;
 use spiral_rs::poly::*;
 use spiral_rs::server::*;
 use spiral_rs::util::*;
 use std::time::Duration;
 
+pub fn generate_random_incorrect_db(params: &Params) -> AlignedMemory64 {
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = instances * trials * num_items * params.poly_len;
+    let mut out = AlignedMemory64::new(db_size_words);
+    let out_mut_slice = out.as_mut_slice();
+    let mut rng = get_seeded_rng();
+    for i in 0..out_mut_slice.len() {
+        out_mut_slice[i] = rng.gen();
+    }
+    out
+}
+
 fn test_full_processing(group: &mut BenchmarkGroup<WallTime>) {
-    let params = get_expansion_testing_params();
-    let mut seeded_rng = get_seeded_rng();
+    // let names = ["server_processing_20_256", "server_processing_16_100000"];
+    // let cfgs = [CFG_20_256, CFG_16_100000];
+    let names = ["server_processing_16_100000"];
+    let cfgs = [CFG_16_100000];
 
-    let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
-    
-    let mut client = Client::init(&params, &mut seeded_rng);
-    let public_params = client.generate_keys();
-    let query = client.generate_query(target_idx);
-    let (_, db) = generate_random_db_and_get_item(&params, target_idx);
+    for i in 0..names.len() {
+        let name = names[i];
+        let cfg = cfgs[i];
+        let params = params_from_json(&cfg.replace("'", "\""));
+        let mut seeded_rng = get_seeded_rng();
 
-    group.bench_function("server_processing", |b| {
-        b.iter(|| {
-            black_box(process_query(
-                black_box(&params),
-                black_box(&public_params),
-                black_box(&query),
-                black_box(db.as_slice()),
-            ));
+        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        
+        let mut client = Client::init(&params, &mut seeded_rng);
+        let public_params = client.generate_keys();
+        let query = client.generate_query(target_idx);
+
+        println!("Generating database...");
+        let db = generate_random_incorrect_db(&params);
+        println!("Done generating database.");
+
+        group.bench_function(name, |b| {
+            b.iter(|| {
+                black_box(process_query(
+                    black_box(&params),
+                    black_box(&public_params),
+                    black_box(&query),
+                    black_box(db.as_slice()),
+                ));
+            });
         });
-    });
+    }
 }
 
 fn criterion_benchmark(c: &mut Criterion) {

+ 2 - 3
spiral-rs/src/client.rs

@@ -507,14 +507,13 @@ mod test {
         assert_first8(
             public_params.v_conversion.unwrap()[0].data.as_slice(),
             [
-                253586619, 247235120, 141892996, 163163429, 15531298, 200914775, 125109567,
-                75889562,
+                122680182, 165987256, 137892309, 95732358, 221787731, 13233184, 156136764, 259944211,
             ],
         );
 
         assert_first8(
             client.sk_gsw.data.as_slice(),
-            [1, 5, 0, 3, 1, 3, 66974689739603967, 3],
+            [66974689739603965, 66974689739603965, 0, 1, 0, 5, 66974689739603967, 2],
         );
     }
 }

+ 76 - 55
spiral-rs/src/server.rs

@@ -1,9 +1,6 @@
 #[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
 
-#[cfg(target_feature = "avx2")]
-use crate::aligned_memory::*;
-
 use crate::arith::*;
 use crate::aligned_memory::*;
 use crate::client::PublicParameters;
@@ -315,40 +312,46 @@ pub fn generate_random_db_and_get_item<'a>(
 ) -> (PolyMatrixRaw<'a>, AlignedMemory64) {
     let mut rng = get_seeded_rng();
 
+    let instances = params.instances;
     let trials = params.n * params.n;
     let dim0 = 1 << params.db_dim_1;
     let num_per = 1 << params.db_dim_2;
     let num_items = dim0 * num_per;
-    let db_size_words = trials * num_items * params.poly_len;
+    let db_size_words = instances * trials * num_items * params.poly_len;
     let mut v = AlignedMemory64::new(db_size_words);
 
+    let mut tmp_item_ntt = PolyMatrixNTT::zero(params, 1, 1);
     let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
 
-    for trial in 0..trials {
-        for i in 0..num_items {
-            let ii = i % num_per;
-            let j = i / num_per;
-
-            let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
-            db_item.reduce_mod(params.pt_modulus);
-            
-            if i == item_idx {
-                item.copy_into(&db_item, trial / params.n, trial % params.n);
-            }
+    for instance in 0..instances {
+        println!("Instance {:?}", instance);
+        for trial in 0..trials {
+            println!("Trial {:?}", trial);
+            for i in 0..num_items {
+                let ii = i % num_per;
+                let j = i / num_per;
+
+                let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
+                db_item.reduce_mod(params.pt_modulus);
+                
+                if i == item_idx && instance == 0 {
+                    item.copy_into(&db_item, trial / params.n, trial % params.n);
+                }
 
-            for z in 0..params.poly_len {
-                db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
-            }
+                for z in 0..params.poly_len {
+                    db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+                }
 
-            let db_item_ntt = db_item.ntt();
-            for z in 0..params.poly_len {
-                let idx_dst = calc_index(
-                    &[trial, z, ii, j],
-                    &[trials, params.poly_len, num_per, dim0],
-                );
+                let db_item_ntt = db_item.ntt();
+                for z in 0..params.poly_len {
+                    let idx_dst = calc_index(
+                        &[instance, trial, z, ii, j],
+                        &[instances, trials, params.poly_len, num_per, dim0],
+                    );
 
-                v[idx_dst] = db_item_ntt.data[z]
-                    | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+                    v[idx_dst] = db_item_ntt.data[z]
+                        | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+                }
             }
         }
     }
@@ -577,35 +580,40 @@ pub fn process_query(
         intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
     }
 
-    let mut v_ct = Vec::new();
-    for trial in 0..(params.n * params.n) {
-        let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
+    let mut v_packed_ct = Vec::new();
+
+    for instance in 0..params.instances {
+        let mut v_ct = Vec::new();
 
-        multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
+        for trial in 0..(params.n * params.n) {
+            let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
+            let cur_db = &db[idx..(idx + db_slice_sz)];
 
-        for i in 0..intermediate.len() {
-            from_ntt(&mut intermediate_raw[i], &intermediate[i]);
+            multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
+
+            for i in 0..intermediate.len() {
+                from_ntt(&mut intermediate_raw[i], &intermediate[i]);
+            }
+
+            fold_ciphertexts(
+                params,
+                &mut intermediate_raw,
+                &v_folding,
+                &v_folding_neg
+            );
+
+            v_ct.push(intermediate_raw[0].clone());
         }
 
-        fold_ciphertexts(
+        let packed_ct = pack(
             params,
-            &mut intermediate_raw,
-            &v_folding,
-            &v_folding_neg
+            &v_ct,
+            &v_packing,
         );
 
-        v_ct.push(intermediate_raw[0].clone());
+        v_packed_ct.push(packed_ct.raw());
     }
 
-    let packed_ct = pack(
-        params,
-        &v_ct,
-        &v_packing,
-    );
-
-    let mut v_packed_ct = Vec::new();
-    v_packed_ct.push(packed_ct.raw());
-
     encode(params, &v_packed_ct)
 }
 
@@ -613,7 +621,7 @@ pub fn process_query(
 mod test {
     use super::*;
     use crate::{client::*};
-    use rand::{prelude::StdRng, Rng};
+    use rand::{prelude::SmallRng, Rng};
 
     fn get_params() -> Params {
         let mut params = get_expansion_testing_params();
@@ -626,7 +634,7 @@ mod test {
     fn dec_reg<'a>(
         params: &'a Params,
         ct: &PolyMatrixNTT<'a>,
-        client: &mut Client<'a, StdRng>,
+        client: &mut Client<'a, SmallRng>,
         scale_k: u64,
     ) -> u64 {
         let dec = client.decrypt_matrix_reg(ct).raw();
@@ -645,7 +653,7 @@ mod test {
     fn dec_gsw<'a>(
         params: &'a Params,
         ct: &PolyMatrixNTT<'a>,
-        client: &mut Client<'a, StdRng>,
+        client: &mut Client<'a, SmallRng>,
     ) -> u64 {
         let dec = client.decrypt_matrix_reg(ct).raw();
         let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
@@ -857,21 +865,19 @@ mod test {
         assert_eq!(dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
     }
 
-    #[test]
-    fn full_protocol_is_correct() {
-        let params = get_params();
+    fn full_protocol_is_correct_for_params(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(params, &mut seeded_rng);
 
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
 
-        let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
+        let (corr_item, db) = generate_random_db_and_get_item(params, target_idx);
 
-        let response = process_query(&params, &public_params, &query, db.as_slice());
+        let response = process_query(params, &public_params, &query, db.as_slice());
 
         let result = client.decode_response(response.as_slice());
 
@@ -882,4 +888,19 @@ mod test {
             assert_eq!(result[z], corr_result[z]);
         }
     }
+    
+    #[test]
+    fn full_protocol_is_correct() {
+        full_protocol_is_correct_for_params(&get_params());
+    }
+
+    // #[test]
+    // fn full_protocol_is_correct_20_256() {
+    //     full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));
+    // }
+
+    // #[test]
+    // fn full_protocol_is_correct_16_100000() {
+    //     full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
+    // }
 }

+ 34 - 5
spiral-rs/src/util.rs

@@ -1,7 +1,36 @@
 use crate::{arith::*, params::*, poly::*};
-use rand::{prelude::StdRng, SeedableRng, thread_rng, Rng};
+use rand::{prelude::{SmallRng}, SeedableRng, thread_rng, Rng};
 use serde_json::Value;
 
+pub const CFG_20_256: &'static str = r#"
+        {'n': 2,
+        'nu_1': 9,
+        'nu_2': 6,
+        'p': 256,
+        'q_prime_bits': 20,
+        's_e': 87.62938774292914,
+        't_GSW': 8,
+        't_conv': 4,
+        't_exp': 8,
+        't_exp_right': 56,
+        'instances': 1,
+        'db_item_size': 8192 }
+    "#;
+pub const CFG_16_100000: &'static str = r#"
+        {'n': 2,
+        'nu_1': 10,
+        'nu_2': 6,
+        'p': 512,
+        'q_prime_bits': 21,
+        's_e': 85.83255142749422,
+        't_GSW': 10,
+        't_conv': 4,
+        't_exp': 16,
+        't_exp_right': 56,
+        'instances': 11,
+        'db_item_size': 100000 }
+    "#;
+
 pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
     let mut idx = 0usize;
     let mut prod = 1usize;
@@ -76,8 +105,8 @@ pub fn get_seed() -> [u8; 32] {
     thread_rng().gen::<[u8; 32]>()
 }
 
-pub fn get_seeded_rng() -> StdRng {
-    StdRng::from_seed(get_seed())
+pub fn get_seeded_rng() -> SmallRng {
+    SmallRng::from_seed(get_seed())
 }
 
 pub fn get_static_seed() -> [u8; 32] {
@@ -87,8 +116,8 @@ pub fn get_static_seed() -> [u8; 32] {
     ]
 }
 
-pub fn get_static_seeded_rng() -> StdRng {
-    StdRng::from_seed(get_static_seed())
+pub fn get_static_seeded_rng() -> SmallRng {
+    SmallRng::from_seed(get_static_seed())
 }
 
 pub const fn get_empty_params() -> Params {