Browse Source

Switch from crossbeam to rayon

Ian Goldberg 1 year ago
parent
commit
4bf304da4f
3 changed files with 9 additions and 9 deletions
  1. 1 1
      Cargo.toml
  2. 5 4
      src/main.rs
  3. 3 4
      src/spiral_mt.rs

+ 1 - 1
Cargo.toml

@@ -17,7 +17,7 @@ lazy_static = "1"
 sha2 = "0.9"
 subtle = { package = "subtle-ng", version = "2.4" }
 spiral-rs = { git = "https://github.com/menonsamir/spiral-rs/", rev = "0f9bdc157" }
-crossbeam = "0.8"
+rayon = "1.5"
 
 [features]
 default = ["u64_backend"]

+ 5 - 4
src/main.rs

@@ -27,7 +27,8 @@ use curve25519_dalek::ristretto::RistrettoBasepointTable;
 use curve25519_dalek::ristretto::RistrettoPoint;
 use curve25519_dalek::scalar::Scalar;
 
-use crossbeam::thread;
+use rayon::scope;
+use rayon::ThreadPoolBuilder;
 
 use spiral_rs::client::*;
 use spiral_rs::params::*;
@@ -75,7 +76,7 @@ fn encdb_xor_keys(
     let num_records: usize = 1 << r;
     let mut ret = Vec::<u8>::with_capacity(num_records * mem::size_of::<DbEntry>());
     ret.resize(num_records * mem::size_of::<DbEntry>(), 0);
-    thread::scope(|s| {
+    scope(|s| {
         let mut record_thread_start = 0usize;
         let records_per_thread_base = num_records / num_threads;
         let records_per_thread_extra = num_records % num_threads;
@@ -108,8 +109,7 @@ fn encdb_xor_keys(
             });
             record_thread_start = record_thread_end;
         }
-    })
-    .unwrap();
+    });
     ret
 }
 
@@ -263,6 +263,7 @@ fn main() {
     let otsetup_start = Instant::now();
     let spiral_params = params::get_spiral_params(r);
     let mut rng = rand::thread_rng();
+    ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap();
     one_time_setup();
     let otsetup_us = otsetup_start.elapsed().as_micros();
     print_params_summary(&spiral_params);

+ 3 - 4
src/spiral_mt.rs

@@ -4,7 +4,7 @@ use spiral_rs::poly::*;
 use spiral_rs::server::*;
 use spiral_rs::util::*;
 
-use crossbeam::thread;
+use rayon::scope;
 
 use crate::aligned_memory_mt::*;
 
@@ -61,7 +61,7 @@ pub fn load_db_from_slice_mt(
 
     for instance in 0..instances {
         for trial in 0..trials {
-            thread::scope(|s| {
+            scope(|s| {
                 let mut item_thread_start = 0usize;
                 let items_per_thread_base = num_items / num_threads;
                 let items_per_thread_extra = num_items % num_threads;
@@ -112,8 +112,7 @@ pub fn load_db_from_slice_mt(
                     });
                     item_thread_start = item_thread_end;
                 }
-            })
-            .unwrap();
+            });
         }
     }
     v