Browse Source

Use interior mutability to allow multithreaded use of Client

It would be nice to be able to call client.generate_query from
multiple threads in parallel.  It currently wants self to be
mutable, however, so that can't happen.

The only reason it needs mutability of self, though, is for
the embedded random number generators.

So the strategy is this: make client itself immutable, but wrap
the two rngs (the one in DiscreteGaussian and the one in Client itself)
with Mutex.  Then when we need the rng, we get a lock to the mutable rng
from the immutable Mutex using Mutex's interior mutability.

But this would still cause contention if lots of threads were trying to
use the same rng at the same time, so we put each rng inside a ThreadLocal
so that there's actually one instance of each mutex-wrapped rng per
thread.  The locks will therefore never be in contention, and in fact
with ThreadLocal, we can use a simpler RefCell instead of a Mutex.

(Note that the ThreadLocal would not be necessary if we insisted
that the rng passed to client were ThreadRng, but that would break
the existing tests that pass an expicitly seeded rng to check
a deterministic output.)

This requires a change to the Client::init API: instead of passing
a reference to a static rng, you pass a factory function that
_outputs_ a new rng.  This function will be called once per thread.
Also the type of the rng has to be Rng + Send instead of just Rng,
since we'll be putting it in a ThreadLocal.  This means the factory
function can't output a ThreadRng (which wouldn't make sense anyway),
but should be something like ChaCha20Rng::from_entropy.

With this change, almost all of Client's methods (including
generate_query and its callees) can take an immutable &self, and
be able to safely run from multiple threads in parallel.
Ian Goldberg 1 year ago
parent
commit
d6d546d056

+ 10 - 0
spiral-rs/Cargo.lock

@@ -1750,6 +1750,7 @@ dependencies = [
  "reqwest",
  "serde",
  "serde_json",
+ "thread_local",
  "uuid 1.0.0",
 ]
 
@@ -1842,6 +1843,15 @@ dependencies = [
  "syn",
 ]
 
+[[package]]
+name = "thread_local"
+version = "1.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180"
+dependencies = [
+ "once_cell",
+]
+
 [[package]]
 name = "time"
 version = "0.3.9"

+ 1 - 0
spiral-rs/Cargo.toml

@@ -25,6 +25,7 @@ rand = { version = "0.8.5", features = ["small_rng"] }
 serde_json = "1.0"
 rayon = "1.5.2"
 rand_chacha = "0.3.1"
+thread_local = "1.1"
 
 reqwest = { version = "0.11", features = ["blocking"], optional = true }
 

+ 4 - 2
spiral-rs/src/bin/e2e.rs

@@ -1,5 +1,7 @@
 use rand::Rng;
+use rand::SeedableRng;
 use rand::thread_rng;
+use rand_chacha::ChaCha20Rng;
 use spiral_rs::client::*;
 use spiral_rs::server::*;
 use spiral_rs::util::*;
@@ -38,7 +40,7 @@ fn main() {
 
     println!("fetching index {} out of {} items", idx_target, params.num_items());
     println!("initializing client");
-    let mut client = Client::init(&params, &mut rng);
+    let mut client = Client::init(&params, ChaCha20Rng::from_entropy);
     println!("generating public parameters");
     let pub_params = client.generate_keys();
     let pub_params_buf = pub_params.serialize();
@@ -68,4 +70,4 @@ fn main() {
     }
 
     println!("completed correctly!");
-}
+}

+ 46 - 42
spiral-rs/src/client.rs

@@ -4,6 +4,8 @@ use crate::{
 use rand::{Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use std::{iter::once, mem::size_of};
+use std::cell::RefCell;
+use thread_local::ThreadLocal;
 
 fn new_vec_raw<'a>(
     params: &'a Params,
@@ -203,15 +205,15 @@ impl<'a> Query<'a> {
     }
 }
 
-pub struct Client<'a, T: Rng> {
+pub struct Client<'a, T: Rng + Send> {
     params: &'a Params,
     sk_gsw: PolyMatrixRaw<'a>,
     sk_reg: PolyMatrixRaw<'a>,
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
-    dg: DiscreteGaussian<'a, T>,
-    public_rng: ChaCha20Rng,
-    public_seed: <ChaCha20Rng as SeedableRng>::Seed,
+    dg: DiscreteGaussian<T>,
+    public_rng: ThreadLocal<RefCell<ChaCha20Rng>>,
+    public_seed: ThreadLocal<<ChaCha20Rng as SeedableRng>::Seed>,
 }
 
 fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
@@ -242,18 +244,17 @@ fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
     )
 }
 
-impl<'a, T: Rng> Client<'a, T> {
-    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
+impl<'a, T: Rng + Send> Client<'a, T> {
+    pub fn init(params: &'a Params, rnggen: fn() -> T) -> Self {
         let sk_gsw_dims = params.get_sk_gsw();
         let sk_reg_dims = params.get_sk_reg();
         let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
-        let mut public_seed = [0u8; 32];
-        rng.fill_bytes(&mut public_seed);
-        let public_rng = ChaCha20Rng::from_seed(public_seed);
-        let dg = DiscreteGaussian::init(params, rng);
+        let public_seed = ThreadLocal::new();
+        let dg = DiscreteGaussian::init(params, rnggen);
+        let public_rng = ThreadLocal::new();
         Self {
             params,
             sk_gsw,
@@ -271,24 +272,23 @@ impl<'a, T: Rng> Client<'a, T> {
         &self.sk_reg
     }
 
-    pub fn get_rng(&mut self) -> &mut T {
-        &mut self.dg.rng
+    pub fn get_public_seed(&self) -> <ChaCha20Rng as SeedableRng>::Seed {
+        *self.public_seed.get_or( || {
+            let mut seed = [0u8; 32];
+            self.dg.get_rng().fill_bytes(&mut seed);
+            seed
+        })
     }
 
-    pub fn get_public_rng(&mut self) -> &mut ChaCha20Rng {
-        &mut self.public_rng
-    }
-
-    pub fn get_public_seed(&mut self) -> <ChaCha20Rng as SeedableRng>::Seed {
-        self.public_seed
-    }
-
-    fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
+    fn get_fresh_gsw_public_key(&self, m: usize) -> PolyMatrixRaw<'a> {
         let params = self.params;
         let n = params.n;
 
-        let a = PolyMatrixRaw::random_rng(params, 1, m, self.get_rng());
-        let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
+        let a = {
+            let rng = &mut *self.dg.get_rng();
+            PolyMatrixRaw::random_rng(params, 1, m, rng)
+        };
+        let e = PolyMatrixRaw::noise(params, n, m, &self.dg);
         let a_inv = -&a;
         let b_p = &self.sk_gsw.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
@@ -296,10 +296,16 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
+    fn get_regev_sample(&self) -> PolyMatrixNTT<'a> {
         let params = self.params;
-        let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_public_rng());
-        let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
+        let a = {
+            let public_rng = &mut *self.public_rng.get_or(|| {
+                RefCell::new(ChaCha20Rng::from_seed(self.get_public_seed()))
+            })
+            .borrow_mut();
+            PolyMatrixRaw::random_rng(params, 1, 1, public_rng)
+        };
+        let e = PolyMatrixRaw::noise(params, 1, 1, &self.dg);
         let b_p = &self.sk_reg.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
         let mut p = PolyMatrixNTT::zero(params, 2, 1);
@@ -308,7 +314,7 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_fresh_reg_public_key(&mut self, m: usize) -> PolyMatrixNTT<'a> {
+    fn get_fresh_reg_public_key(&self, m: usize) -> PolyMatrixNTT<'a> {
         let params = self.params;
 
         let mut p = PolyMatrixNTT::zero(params, 2, m);
@@ -319,29 +325,29 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn encrypt_matrix_gsw(&mut self, ag: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    fn encrypt_matrix_gsw(&self, ag: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         let mx = ag.cols;
         let p = self.get_fresh_gsw_public_key(mx);
         let res = &(p.ntt()) + &(ag.pad_top(1));
         res
     }
 
-    pub fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn encrypt_matrix_reg(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         let m = a.cols;
         let p = self.get_fresh_reg_public_key(m);
         &p + &a.pad_top(1)
     }
 
-    pub fn decrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn decrypt_matrix_reg(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         &self.sk_reg_full.ntt() * a
     }
 
-    pub fn decrypt_matrix_gsw(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn decrypt_matrix_gsw(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         &self.sk_gsw_full.ntt() * a
     }
 
     fn generate_expansion_params(
-        &mut self,
+        &self,
         num_exp: usize,
         m_exp: usize,
     ) -> Vec<PolyMatrixNTT<'a>> {
@@ -414,7 +420,7 @@ impl<'a, T: Rng> Client<'a, T> {
         pp
     }
 
-    pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
+    pub fn generate_query(&self, idx_target: usize) -> Query<'a> {
         let params = self.params;
         let further_dims = params.db_dim_2;
         let idx_dim0 = idx_target / (1 << further_dims);
@@ -582,7 +588,8 @@ impl<'a, T: Rng> Client<'a, T> {
 
 #[cfg(test)]
 mod test {
-    use rand::thread_rng;
+    use rand::SeedableRng;
+    use rand_chacha::ChaCha20Rng;
 
     use super::*;
 
@@ -598,8 +605,7 @@ mod test {
     #[test]
     fn init_is_correct() {
         let params = get_params();
-        let mut rng = thread_rng();
-        let client = Client::init(&params, &mut rng);
+        let client = Client::init(&params, ChaCha20Rng::from_entropy);
 
         assert_eq!(*client.params, params);
     }
@@ -607,8 +613,8 @@ mod test {
     #[test]
     fn keygen_is_correct() {
         let params = get_params();
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_static_seeded_rng);
+        client.get_public_seed();
 
         let pub_params = client.generate_keys();
 
@@ -639,8 +645,7 @@ mod test {
     }
 
     fn public_parameters_serialization_is_correct_for_params(params: Params) {
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_static_seeded_rng);
         let pub_params = client.generate_keys();
 
         let serialized1 = pub_params.serialize();
@@ -704,8 +709,7 @@ mod test {
     }
 
     fn query_serialization_is_correct_for_params(params: Params) {
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_static_seeded_rng);
         _ = client.generate_keys();
         let query = client.generate_query(1);
 

+ 31 - 12
spiral-rs/src/discrete_gaussian.rs

@@ -2,20 +2,25 @@ use rand::distributions::WeightedIndex;
 use rand::prelude::Distribution;
 use rand::Rng;
 
+use std::cell::{RefCell, RefMut};
+
+use thread_local::ThreadLocal;
+
 use crate::params::*;
 use crate::poly::*;
 use std::f64::consts::PI;
 
 pub const NUM_WIDTHS: usize = 8;
 
-pub struct DiscreteGaussian<'a, T: Rng> {
+pub struct DiscreteGaussian<T: Rng + Send> {
     choices: Vec<i64>,
     dist: WeightedIndex<f64>,
-    pub rng: &'a mut T,
+    rng: ThreadLocal<RefCell<T>>,
+    rnggen: fn() -> T,
 }
 
-impl<'a, T: Rng> DiscreteGaussian<'a, T> {
-    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
+impl<T: Rng + Send> DiscreteGaussian<T> {
+    pub fn init(params: &Params, rnggen: fn() -> T) -> Self {
         let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
         let mut choices = Vec::new();
         let mut table = vec![0f64; 0];
@@ -26,21 +31,35 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
         }
         let dist = WeightedIndex::new(&table).unwrap();
 
-        Self { choices, dist, rng }
+        Self { choices, dist, rng: ThreadLocal::new(), rnggen }
     }
 
     // FIXME: not constant-time
-    pub fn sample(&mut self) -> i64 {
-        self.choices[self.dist.sample(&mut self.rng)]
+    fn sample_from_members(choices: &Vec<i64>, dist: &WeightedIndex<f64>,
+        rng: &mut T) -> i64 {
+        choices[dist.sample(rng)]
+    }
+
+    pub fn get_rng(&self) -> RefMut<T> {
+        self.rng.get_or(|| RefCell::new((self.rnggen)())).borrow_mut()
+    }
+
+    #[cfg(test)]
+    fn sample(&self) -> i64 {
+        let mut rng = self.get_rng();
+        Self::sample_from_members(&self.choices, &self.dist, &mut *rng)
     }
 
-    pub fn sample_matrix(&mut self, p: &mut PolyMatrixRaw) {
+    pub fn sample_matrix(&self, p: &mut PolyMatrixRaw) {
         let modulus = p.get_params().modulus;
+        let choices = &self.choices;
+        let dist = &self.dist;
+        let rng = &mut *self.get_rng();
         for r in 0..p.rows {
             for c in 0..p.cols {
                 let poly = p.get_poly_mut(r, c);
                 for z in 0..poly.len() {
-                    let mut s = self.sample();
+                    let mut s = Self::sample_from_members(choices, dist, rng);
                     s += modulus as i64;
                     s %= modulus as i64; // FIXME: not constant time
                     poly[z] = s as u64;
@@ -52,7 +71,8 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
 
 #[cfg(test)]
 mod test {
-    use rand::thread_rng;
+    use rand::SeedableRng;
+    use rand_chacha::ChaCha20Rng;
 
     use super::*;
     use crate::util::*;
@@ -60,8 +80,7 @@ mod test {
     #[test]
     fn dg_seems_okay() {
         let params = get_test_params();
-        let mut rng = thread_rng();
-        let mut dg = DiscreteGaussian::init(&params, &mut rng);
+        let dg = DiscreteGaussian::init(&params, ChaCha20Rng::from_entropy);
         let mut v = Vec::new();
         let trials = 10000;
         let mut sum = 0;

+ 2 - 2
spiral-rs/src/poly.rs

@@ -172,11 +172,11 @@ impl<'a> PolyMatrixRaw<'a> {
         }
     }
 
-    pub fn noise<T: Rng>(
+    pub fn noise<T: Rng + Send>(
         params: &'a Params,
         rows: usize,
         cols: usize,
-        dg: &mut DiscreteGaussian<T>,
+        dg: &DiscreteGaussian<T>,
     ) -> Self {
         let mut out = PolyMatrixRaw::zero(params, rows, cols);
         dg.sample_matrix(&mut out);

+ 7 - 9
spiral-rs/src/server.rs

@@ -843,8 +843,7 @@ mod test {
     fn coefficient_expansion_is_correct() {
         let params = get_params();
         let v_neg1 = params.get_v_neg1();
-        let mut seeded_rng = get_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_seeded_rng);
         let public_params = client.generate_keys();
 
         let mut v = Vec::new();
@@ -887,11 +886,10 @@ mod test {
     fn regev_to_gsw_is_correct() {
         let mut params = get_params();
         params.db_dim_2 = 1;
-        let mut seeded_rng = get_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_seeded_rng);
         let public_params = client.generate_keys();
 
-        let mut enc_constant = |val| {
+        let enc_constant = |val| {
             let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
             sigma.data[0] = val;
             client.encrypt_matrix_reg(&sigma.ntt())
@@ -933,7 +931,7 @@ mod test {
         let target_idx_dim0 = target_idx / num_per;
         let target_idx_num_per = target_idx % num_per;
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_seeded_rng);
         _ = client.generate_keys();
 
         let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
@@ -987,7 +985,7 @@ mod test {
         let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
         let target_idx_num_per = target_idx % num_per;
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params, get_seeded_rng);
         _ = client.generate_keys();
 
         let mut v_reg = Vec::new();
@@ -1046,7 +1044,7 @@ mod test {
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(params, &mut seeded_rng);
+        let mut client = Client::init(params, get_seeded_rng);
 
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
@@ -1072,7 +1070,7 @@ mod test {
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(params, &mut seeded_rng);
+        let mut client = Client::init(params, get_seeded_rng);
 
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);