Browse Source

add multithreading, fix bug

Samir Menon 2 years ago
parent
commit
a0b7ad0ce2
5 changed files with 125 additions and 45 deletions
  1. 5 5
      spiral-rs/Cargo.lock
  2. 4 3
      spiral-rs/Cargo.toml
  3. 36 8
      spiral-rs/src/client.rs
  4. 14 0
      spiral-rs/src/params.rs
  5. 66 29
      spiral-rs/src/server.rs

+ 5 - 5
spiral-rs/Cargo.lock

@@ -1468,9 +1468,9 @@ dependencies = [
 
 [[package]]
 name = "rayon"
-version = "1.5.1"
+version = "1.5.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
+checksum = "fd249e82c21598a9a426a4e00dd7adc1d640b22445ec8545feef801d1a74c221"
 dependencies = [
  "autocfg",
  "crossbeam-deque",
@@ -1480,14 +1480,13 @@ dependencies = [
 
 [[package]]
 name = "rayon-core"
-version = "1.9.1"
+version = "1.9.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
+checksum = "9f51245e1e62e1f1629cbfec37b5793bbabcaeb90f30e94d2ba03564687353e4"
 dependencies = [
  "crossbeam-channel",
  "crossbeam-deque",
  "crossbeam-utils",
- "lazy_static",
  "num_cpus",
 ]
 
@@ -1755,6 +1754,7 @@ dependencies = [
  "openssl",
  "pprof",
  "rand",
+ "rayon",
  "reqwest",
  "serde_json",
  "uuid 1.0.0",

+ 4 - 3
spiral-rs/Cargo.toml

@@ -15,6 +15,7 @@ getrandom = { features = ["js"], version = "0.2.6" }
 rand = { version = "0.8.5", features = ["small_rng"] }
 reqwest = { version = "0.11", features = ["blocking"] }
 serde_json = "1.0"
+rayon = "1.5.2"
 actix-web = { version = "4.0.1", features = ["openssl"], optional = true }
 openssl = { version = "0.10", features = ["v110"], optional = true }
 futures = { version = "0.3", optional = true }
@@ -37,6 +38,6 @@ name = "poly"
 harness = false
 
 [profile.release]
-lto = "fat"
-codegen-units = 1
-panic = "abort"
+# lto = "fat"
+# codegen-units = 1
+# panic = "abort"

+ 36 - 8
spiral-rs/src/client.rs

@@ -159,7 +159,7 @@ impl<'a> Query<'a> {
         }
     }
 
-    pub fn serialize(&self) -> Vec<u8> {
+    pub fn serialize(&self) -> Vec<u8> {        
         let mut data = Vec::new();
         if self.ct.is_some() {
             let ct = self.ct.as_ref().unwrap();
@@ -179,6 +179,8 @@ impl<'a> Query<'a> {
     }
 
     pub fn deserialize(params: &'a Params, data: &[u8]) -> Self {
+        assert_eq!(params.query_bytes(), data.len());
+
         let mut out = Query::empty();
         if params.expand_queries {
             let mut ct = PolyMatrixRaw::zero(params, 2, 1);
@@ -550,13 +552,8 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
             }
         }
 
-        // println!("{:?}", result.data);
-        let trials = params.n * params.n;
-        let chunks = params.instances * trials;
-        let bytes_per_chunk = f64::ceil(params.db_item_size as f64 / chunks as f64) as usize;
-        let logp = log2(params.pt_modulus);
-        let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
-        result.to_vec(p_bits as usize, modp_words_per_chunk)
+        // println!("{:?}", result.data.as_slice().to_vec());
+        result.to_vec(p_bits as usize, params.modp_words_per_chunk())
     }
 }
 
@@ -582,7 +579,9 @@ mod test {
         let client = Client::init(&params, &mut rng);
 
         assert_eq!(client.stop_round, 5);
+        assert_eq!(client.stop_round, params.stop_round());
         assert_eq!(client.g, 10);
+        assert_eq!(client.g, params.g());
         assert_eq!(*client.params, params);
     }
 
@@ -617,6 +616,10 @@ mod test {
         );
     }
 
+    fn get_vec(v: &Vec<PolyMatrixNTT>) -> Vec<u64> {
+        v.iter().map(|d| d.as_slice().to_vec()).flatten().collect()
+    }
+
     fn public_parameters_serialization_is_correct_for_params(params: Params) {
         let mut seeded_rng = get_static_seeded_rng();
         let mut client = Client::init(&params, &mut seeded_rng);
@@ -628,6 +631,10 @@ mod test {
         let serialized2 = deserialized1.serialize();
 
         assert_eq!(serialized1, serialized2);
+        assert_eq!(get_vec(&pub_params.v_packing), get_vec(&deserialized1.v_packing));
+        assert_eq!(get_vec(&pub_params.v_conversion.unwrap()), get_vec(&deserialized1.v_conversion.unwrap()));
+        assert_eq!(get_vec(&pub_params.v_expansion_left.unwrap()), get_vec(&deserialized1.v_expansion_left.unwrap()));
+        assert_eq!(get_vec(&pub_params.v_expansion_right.unwrap()), get_vec(&deserialized1.v_expansion_right.unwrap()));
     }
 
     #[test]
@@ -635,6 +642,27 @@ mod test {
         public_parameters_serialization_is_correct_for_params(get_params())
     }
 
+    #[test]
+    fn real_public_parameters_serialization_is_correct() {
+        let cfg_expand = r#"
+            {'n': 2,
+            'nu_1': 10,
+            'nu_2': 6,
+            'p': 512,
+            'q2_bits': 21,
+            's_e': 85.83255142749422,
+            't_gsw': 10,
+            't_conv': 4,
+            't_exp_left': 16,
+            't_exp_right': 56,
+            'instances': 11,
+            'db_item_size': 100000 }
+        "#;
+        let cfg = cfg_expand.replace("'", "\"");
+        let params = params_from_json(&cfg);
+        public_parameters_serialization_is_correct_for_params(params)
+    }
+
     #[test]
     fn no_expansion_public_parameters_serialization_is_correct() {
         public_parameters_serialization_is_correct_for_params(get_no_expansion_testing_params())

+ 14 - 0
spiral-rs/src/params.rs

@@ -160,6 +160,20 @@ impl Params {
         self.num_expanded() * 2 * self.poly_len * size_of::<u64>()
     }
 
+    pub fn bytes_per_chunk(&self) -> usize {
+        let trials = self.n * self.n;
+        let chunks = self.instances * trials;
+        let bytes_per_chunk = f64::ceil(self.db_item_size as f64 / chunks as f64) as usize;
+        bytes_per_chunk
+    }
+
+    pub fn modp_words_per_chunk(&self) -> usize {
+        let bytes_per_chunk = self.bytes_per_chunk();
+        let logp = log2(self.pt_modulus);
+        let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+        modp_words_per_chunk
+    }
+
     pub fn crt_compose_1(&self, x: u64) -> u64 {
         assert_eq!(self.crt_count, 1);
         x

+ 66 - 29
spiral-rs/src/server.rs

@@ -5,6 +5,7 @@ use std::io::BufReader;
 use std::io::Read;
 use std::io::Seek;
 use std::io::SeekFrom;
+use std::time::Instant;
 
 use crate::aligned_memory::*;
 use crate::arith::*;
@@ -15,6 +16,8 @@ use crate::params::*;
 use crate::poly::*;
 use crate::util::*;
 
+use rayon::prelude::*;
+
 pub fn coefficient_expansion(
     v: &mut Vec<PolyMatrixNTT>,
     g: usize,
@@ -46,8 +49,8 @@ pub fn coefficient_expansion(
         let neg1 = &v_neg1[r];
 
         for i in 0..num_out {
-            if stop_round > 0 && i % 2 == 1 && r > stop_round
-                || (r == stop_round && i / 2 >= max_bits_to_gen_right)
+            if (stop_round > 0 && r > stop_round && (i % 2) == 1)
+                || (stop_round > 0 && r == stop_round && (i % 2) == 1 && (i / 2) >= max_bits_to_gen_right)
             {
                 continue;
             }
@@ -454,6 +457,22 @@ pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     v
 }
 
+pub fn load_file_unsafe(data: &mut[u64], file: &mut File) {
+    let data_as_u8_mut = unsafe {
+        data.align_to_mut::<u8>().1
+    };
+    file.read_exact(data_as_u8_mut).unwrap();
+}
+
+pub fn load_file(data: &mut[u64], file: &mut File) {
+    let mut reader = BufReader::with_capacity(1 << 24, file);
+    let mut buf = [0u8; 8];
+    for i in 0..data.len() {
+        reader.read(&mut buf).unwrap();
+        data[i] = u64::from_ne_bytes(buf);
+    }
+}
+
 pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     let instances = params.instances;
     let trials = params.n * params.n;
@@ -463,16 +482,10 @@ pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> Align
     let db_size_words = instances * trials * num_items * params.poly_len;
     let mut v = AlignedMemory64::new(db_size_words);
     let v_mut_slice = v.as_mut_slice();
-
-    let mut reader = BufReader::with_capacity(1 << 18, file);
-    let mut buf = [0u8; 8];
-    for i in 0..db_size_words {
-        if i % 1000000000 == 0 {
-            println!("{} GB loaded", i / 1000000000);
-        }
-        reader.read(&mut buf).unwrap();
-        v_mut_slice[i] = u64::from_ne_bytes(buf);
-    }
+    
+    let now = Instant::now();
+    load_file(v_mut_slice, file);
+    println!("Done loading ({} ms).", now.elapsed().as_millis());
 
     v
 }
@@ -694,23 +707,20 @@ pub fn process_query(
             .v_ct
             .as_ref()
             .unwrap()
-            .clone()
             .iter()
             .map(|x| x.ntt())
             .collect();
     }
     let v_folding_neg = get_v_folding_neg(params, &v_folding);
 
-    let mut intermediate = Vec::with_capacity(num_per);
-    let mut intermediate_raw = Vec::with_capacity(num_per);
-    for _ in 0..num_per {
-        intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
-        intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
-    }
-
-    let mut v_packed_ct = Vec::new();
-
-    for instance in 0..params.instances {
+    let v_packed_ct = (0..params.instances).into_par_iter().map(|instance| {
+        let mut intermediate = Vec::with_capacity(num_per);
+        let mut intermediate_raw = Vec::with_capacity(num_per);
+        for _ in 0..num_per {
+            intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
+            intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
+        }
+    
         let mut v_ct = Vec::new();
 
         for trial in 0..(params.n * params.n) {
@@ -737,8 +747,8 @@ pub fn process_query(
 
         let packed_ct = pack(params, &v_ct, &v_packing);
 
-        v_packed_ct.push(packed_ct.raw());
-    }
+        packed_ct.raw()
+    }).collect();
 
     encode(params, &v_packed_ct)
 }
@@ -997,7 +1007,7 @@ mod test {
     fn full_protocol_is_correct_for_params(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
-        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        let target_idx = 22456;//22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(params, &mut seeded_rng);
 
@@ -1011,17 +1021,19 @@ mod test {
         let result = client.decode_response(response.as_slice());
 
         let p_bits = log2_ceil(params.pt_modulus) as usize;
-        let corr_result = corr_item.to_vec(p_bits, params.poly_len);
+        let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
+
+        assert_eq!(result.len(), corr_result.len());
 
         for z in 0..corr_result.len() {
-            assert_eq!(result[z], corr_result[z]);
+            assert_eq!(result[z], corr_result[z], "at {:?}", z);
         }
     }
 
     fn full_protocol_is_correct_for_params_real_db(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
-        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        let target_idx = 22456; //seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(params, &mut seeded_rng);
 
@@ -1048,6 +1060,31 @@ mod test {
         full_protocol_is_correct_for_params(&get_params());
     }
 
+    #[test]
+    fn larger_full_protocol_is_correct() {
+        let cfg_expand = r#"
+            {
+            'n': 2,
+            'nu_1': 10,
+            'nu_2': 6,
+            'p': 512,
+            'q2_bits': 21,
+            's_e': 85.83255142749422,
+            't_gsw': 10,
+            't_conv': 4,
+            't_exp_left': 16,
+            't_exp_right': 56,
+            'instances': 1,
+            'db_item_size': 9000 }
+        "#;
+        let cfg = cfg_expand;
+        let cfg = cfg.replace("'", "\"");
+        let params = params_from_json(&cfg);
+
+        full_protocol_is_correct_for_params(&params);
+        full_protocol_is_correct_for_params_real_db(&params);
+    }
+
     // #[test]
     // fn full_protocol_is_correct_20_256() {
     //     full_protocol_is_correct_for_params(&params_from_json(&CFG_20_256.replace("'", "\"")));