2 Commits 790e70ca06 ... d6d546d056

Author SHA1 Message Date
  Ian Goldberg d6d546d056 Use interior mutability to allow multithreaded use of Client 1 year ago
  Ian Goldberg 790e70ca06 Use interior mutability to allow multithreaded use of Client 1 year ago
2 changed files with 8 additions and 8 deletions
  1. 4 4
      spiral-rs/src/client.rs
  2. 4 4
      spiral-rs/src/discrete_gaussian.rs

+ 4 - 4
spiral-rs/src/client.rs

@@ -4,7 +4,7 @@ use crate::{
 use rand::{Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use std::{iter::once, mem::size_of};
-use std::sync::Mutex;
+use std::cell::RefCell;
 use thread_local::ThreadLocal;
 
 fn new_vec_raw<'a>(
@@ -212,7 +212,7 @@ pub struct Client<'a, T: Rng + Send> {
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
     dg: DiscreteGaussian<T>,
-    public_rng: ThreadLocal<Mutex<ChaCha20Rng>>,
+    public_rng: ThreadLocal<RefCell<ChaCha20Rng>>,
     public_seed: ThreadLocal<<ChaCha20Rng as SeedableRng>::Seed>,
 }
 
@@ -300,9 +300,9 @@ impl<'a, T: Rng + Send> Client<'a, T> {
         let params = self.params;
         let a = {
             let public_rng = &mut *self.public_rng.get_or(|| {
-                Mutex::new(ChaCha20Rng::from_seed(self.get_public_seed()))
+                RefCell::new(ChaCha20Rng::from_seed(self.get_public_seed()))
             })
-            .lock().unwrap();
+            .borrow_mut();
             PolyMatrixRaw::random_rng(params, 1, 1, public_rng)
         };
         let e = PolyMatrixRaw::noise(params, 1, 1, &self.dg);

+ 4 - 4
spiral-rs/src/discrete_gaussian.rs

@@ -2,7 +2,7 @@ use rand::distributions::WeightedIndex;
 use rand::prelude::Distribution;
 use rand::Rng;
 
-use std::sync::{Mutex, MutexGuard};
+use std::cell::{RefCell, RefMut};
 
 use thread_local::ThreadLocal;
 
@@ -15,7 +15,7 @@ pub const NUM_WIDTHS: usize = 8;
 pub struct DiscreteGaussian<T: Rng + Send> {
     choices: Vec<i64>,
     dist: WeightedIndex<f64>,
-    rng: ThreadLocal<Mutex<T>>,
+    rng: ThreadLocal<RefCell<T>>,
     rnggen: fn() -> T,
 }
 
@@ -40,8 +40,8 @@ impl<T: Rng + Send> DiscreteGaussian<T> {
         choices[dist.sample(rng)]
     }
 
-    pub fn get_rng(&self) -> MutexGuard<T> {
-        self.rng.get_or(|| Mutex::new((self.rnggen)())).lock().unwrap()
+    pub fn get_rng(&self) -> RefMut<T> {
+        self.rng.get_or(|| RefCell::new((self.rnggen)())).borrow_mut()
     }
 
     #[cfg(test)]