Selaa lähdekoodia

Client improvements
1. Make DiscreteGaussian use thread_locals to avoid the need for mutable
references to it
2. Make Client.generate_query and associated methods only require &self,
instead of &mut self; this makes parallel query generation easier
3. Remove type parameters from Client and DiscreteGaussian (simplifying
their implementations)
4. Implement the PRG seed trick to reduce the upload sizes for queries
and public parameters by 2x

These changes (1-3) are adapted from work by Prof. Ian Goldberg,
available here:
https://git-crysp.uwaterloo.ca/iang/spiral-rs-fork/commit/d6d546d

Samir Menon 1 vuosi sitten
vanhempi
commit
e5dd741e2a

+ 2 - 3
spiral-rs/benches/server.rs

@@ -42,7 +42,7 @@ fn test_full_processing(group: &mut BenchmarkGroup<WallTime>) {
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params);
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
 
@@ -71,9 +71,8 @@ fn criterion_benchmark(c: &mut Criterion) {
 
     let params = get_expansion_testing_params();
     let v_neg1 = params.get_v_neg1();
-    let mut seeded_rng = get_seeded_rng();
     let mut chacha_rng = get_chacha_rng();
-    let mut client = Client::init(&params, &mut seeded_rng);
+    let mut client = Client::init(&params);
     let public_params = client.generate_keys();
 
     let mut v = Vec::new();

+ 1 - 2
spiral-rs/src/bin/client.rs

@@ -67,8 +67,7 @@ fn main() {
     let idx_target: usize = (&args[2]).parse().unwrap();
 
     println!("initializing client");
-    let mut rng = thread_rng();
-    let mut c = Client::init(&params, &mut rng);
+    let mut c = Client::init(&params);
     println!("generating public parameters");
     let pub_params = c.generate_keys();
     let pub_params_buf = pub_params.serialize();

+ 1 - 1
spiral-rs/src/bin/e2e.rs

@@ -47,7 +47,7 @@ fn main() {
         params.num_items()
     );
     println!("initializing client");
-    let mut client = Client::init(&params, &mut rng);
+    let mut client = Client::init_with_seed(&params, get_chacha_seed());
     println!("generating public parameters");
     let pub_params = client.generate_keys();
     let pub_params_buf = pub_params.serialize();

+ 58 - 43
spiral-rs/src/client.rs

@@ -1,11 +1,11 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
-use rand::{thread_rng, Rng, SeedableRng};
+use rand::{thread_rng, Rng, RngCore, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use std::{iter::once, mem::size_of};
 
-type Seed = <ChaCha20Rng as SeedableRng>::Seed;
+pub type Seed = <ChaCha20Rng as SeedableRng>::Seed;
 const SEED_LENGTH: usize = 32;
 
 fn new_vec_raw<'a>(
@@ -282,17 +282,6 @@ impl<'a> Query<'a> {
     }
 }
 
-pub struct Client<'a, T: Rng> {
-    params: &'a Params,
-    sk_gsw: PolyMatrixRaw<'a>,
-    sk_reg: PolyMatrixRaw<'a>,
-    sk_gsw_full: PolyMatrixRaw<'a>,
-    sk_reg_full: PolyMatrixRaw<'a>,
-    dg: DiscreteGaussian<'a, T>,
-    pp_seed: Seed,
-    query_seed: Seed,
-}
-
 fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
     assert_eq!(p.cols, 1);
     let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
@@ -321,19 +310,46 @@ fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
     )
 }
 
-impl<'a, T: Rng> Client<'a, T> {
-    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
+pub struct Client<'a> {
+    params: &'a Params,
+    sk_gsw: PolyMatrixRaw<'a>,
+    sk_reg: PolyMatrixRaw<'a>,
+    sk_gsw_full: PolyMatrixRaw<'a>,
+    sk_reg_full: PolyMatrixRaw<'a>,
+    dg: DiscreteGaussian,
+    pp_seed: Seed,
+    query_seed: Seed,
+}
+
+impl<'a> Client<'a> {
+    pub fn init(params: &'a Params) -> Self {
+        let mut root_seed = [0u8; 32];
+        thread_rng().fill_bytes(&mut root_seed);
+        Self::init_with_seed_impl(params, root_seed, false)
+    }
+
+    pub fn init_with_seed(params: &'a Params, root_seed: Seed) -> Self {
+        Self::init_with_seed_impl(params, root_seed, true)
+    }
+
+    fn init_with_seed_impl(params: &'a Params, root_seed: Seed, deterministic: bool) -> Self {
         let sk_gsw_dims = params.get_sk_gsw();
         let sk_reg_dims = params.get_sk_reg();
         let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
+
+        let mut rng = ChaCha20Rng::from_seed(root_seed);
         let mut pp_seed = [0u8; 32];
         let mut query_seed = [0u8; 32];
         rng.fill_bytes(&mut query_seed);
         rng.fill_bytes(&mut pp_seed);
-        let dg = DiscreteGaussian::init(params, rng);
+        let dg = if deterministic {
+            DiscreteGaussian::init_with_seed(params, root_seed)
+        } else {
+            DiscreteGaussian::init(params)
+        };
         Self {
             params,
             sk_gsw,
@@ -351,16 +367,12 @@ impl<'a, T: Rng> Client<'a, T> {
         &self.sk_reg
     }
 
-    pub fn get_rng(&mut self) -> &mut T {
-        &mut self.dg.rng
-    }
-
-    fn get_fresh_gsw_public_key(&mut self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixRaw<'a> {
+    fn get_fresh_gsw_public_key(&self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixRaw<'a> {
         let params = self.params;
         let n = params.n;
 
         let a = PolyMatrixRaw::random_rng(params, 1, m, rng);
-        let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
+        let e = PolyMatrixRaw::noise(params, n, m, &self.dg);
         let a_inv = -&a;
         let b_p = &self.sk_gsw.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
@@ -368,10 +380,10 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_regev_sample(&mut self, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
+    fn get_regev_sample(&self, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
         let params = self.params;
         let a = PolyMatrixRaw::random_rng(params, 1, 1, rng);
-        let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
+        let e = PolyMatrixRaw::noise(params, 1, 1, &self.dg);
         let b_p = &self.sk_reg.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
         let mut p = PolyMatrixNTT::zero(params, 2, 1);
@@ -380,7 +392,7 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_fresh_reg_public_key(&mut self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
+    fn get_fresh_reg_public_key(&self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
         let params = self.params;
 
         let mut p = PolyMatrixNTT::zero(params, 2, m);
@@ -392,7 +404,7 @@ impl<'a, T: Rng> Client<'a, T> {
     }
 
     fn encrypt_matrix_gsw(
-        &mut self,
+        &self,
         ag: &PolyMatrixNTT<'a>,
         rng: &mut ChaCha20Rng,
     ) -> PolyMatrixNTT<'a> {
@@ -403,7 +415,7 @@ impl<'a, T: Rng> Client<'a, T> {
     }
 
     pub fn encrypt_matrix_reg(
-        &mut self,
+        &self,
         a: &PolyMatrixNTT<'a>,
         rng: &mut ChaCha20Rng,
     ) -> PolyMatrixNTT<'a> {
@@ -412,16 +424,16 @@ impl<'a, T: Rng> Client<'a, T> {
         &p + &a.pad_top(1)
     }
 
-    pub fn decrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn decrypt_matrix_reg(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         &self.sk_reg_full.ntt() * a
     }
 
-    pub fn decrypt_matrix_gsw(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn decrypt_matrix_gsw(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         &self.sk_gsw_full.ntt() * a
     }
 
     fn generate_expansion_params(
-        &mut self,
+        &self,
         num_exp: usize,
         m_exp: usize,
         rng: &mut ChaCha20Rng,
@@ -500,7 +512,7 @@ impl<'a, T: Rng> Client<'a, T> {
         pp
     }
 
-    pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
+    pub fn generate_query(&self, idx_target: usize) -> Query<'a> {
         let params = self.params;
         let further_dims = params.db_dim_2;
         let idx_dim0 = idx_target / (1 << further_dims);
@@ -672,8 +684,6 @@ impl<'a, T: Rng> Client<'a, T> {
 
 #[cfg(test)]
 mod test {
-    use rand::thread_rng;
-
     use super::*;
 
     fn assert_first8(m: &[u64], gold: [u64; 8]) {
@@ -688,8 +698,7 @@ mod test {
     #[test]
     fn init_is_correct() {
         let params = get_params();
-        let mut rng = thread_rng();
-        let client = Client::init(&params, &mut rng);
+        let client = Client::init(&params);
 
         assert_eq!(*client.params, params);
     }
@@ -697,21 +706,29 @@ mod test {
     #[test]
     fn keygen_is_correct() {
         let params = get_params();
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
 
         let pub_params = client.generate_keys();
 
         assert_first8(
             pub_params.v_conversion.unwrap()[0].data.as_slice(),
             [
-                77639594, 190195027, 25198243, 245727165, 99660925, 135957601, 187643645, 116322041,
+                160400272, 10738392, 27480222, 201012452, 42036824, 3955189, 201319389, 181880730,
             ],
         );
 
         assert_first8(
             client.sk_gsw.data.as_slice(),
-            [2, 66974689739603966, 66974689739603967, 5, 2, 1, 1, 0],
+            [
+                1,
+                66974689739603968,
+                4,
+                3,
+                66974689739603965,
+                66974689739603967,
+                1,
+                0,
+            ],
         );
     }
 
@@ -720,8 +737,7 @@ mod test {
     }
 
     fn public_parameters_serialization_is_correct_for_params(params: Params) {
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
         let pub_params = client.generate_keys();
 
         let serialized1 = pub_params.serialize();
@@ -785,8 +801,7 @@ mod test {
     }
 
     fn query_serialization_is_correct_for_params(params: Params) {
-        let mut seeded_rng = get_static_seeded_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init(&params);
         _ = client.generate_keys();
         let query = client.generate_query(1);
 

+ 33 - 12
spiral-rs/src/discrete_gaussian.rs

@@ -1,21 +1,35 @@
 use rand::distributions::WeightedIndex;
 use rand::prelude::Distribution;
+use rand::thread_rng;
 use rand::Rng;
+use rand::SeedableRng;
+use rand_chacha::ChaCha20Rng;
+use std::cell::*;
 
+use crate::client::*;
 use crate::params::*;
 use crate::poly::*;
 use std::f64::consts::PI;
 
 pub const NUM_WIDTHS: usize = 8;
 
-pub struct DiscreteGaussian<'a, T: Rng> {
+thread_local!(static RNGS: RefCell<Option<ChaCha20Rng>> = RefCell::new(None));
+
+pub struct DiscreteGaussian {
     choices: Vec<i64>,
     dist: WeightedIndex<f64>,
-    pub rng: &'a mut T,
 }
 
-impl<'a, T: Rng> DiscreteGaussian<'a, T> {
-    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
+impl DiscreteGaussian {
+    pub fn init(params: &Params) -> Self {
+        Self::init_impl(params, None)
+    }
+
+    pub fn init_with_seed(params: &Params, seed: Seed) -> Self {
+        Self::init_impl(params, Some(seed))
+    }
+
+    fn init_impl(params: &Params, seed: Option<Seed>) -> Self {
         let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
         let mut choices = Vec::new();
         let mut table = vec![0f64; 0];
@@ -26,15 +40,25 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
         }
         let dist = WeightedIndex::new(&table).unwrap();
 
-        Self { choices, dist, rng }
+        if seed.is_some() {
+            RNGS.with(|f| *f.borrow_mut() = Some(ChaCha20Rng::from_seed(seed.unwrap())));
+        } else {
+            RNGS.with(|f| *f.borrow_mut() = Some(ChaCha20Rng::from_seed(thread_rng().gen())));
+        }
+
+        Self { choices, dist }
     }
 
     // FIXME: not constant-time
-    pub fn sample(&mut self) -> i64 {
-        self.choices[self.dist.sample(&mut self.rng)]
+    pub fn sample(&self) -> i64 {
+        RNGS.with(|f| {
+            let mut rng = f.borrow_mut();
+            let rng_mut = rng.as_mut().unwrap();
+            self.choices[self.dist.sample(rng_mut)]
+        })
     }
 
-    pub fn sample_matrix(&mut self, p: &mut PolyMatrixRaw) {
+    pub fn sample_matrix(&self, p: &mut PolyMatrixRaw) {
         let modulus = p.get_params().modulus;
         for r in 0..p.rows {
             for c in 0..p.cols {
@@ -52,16 +76,13 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
 
 #[cfg(test)]
 mod test {
-    use rand::thread_rng;
-
     use super::*;
     use crate::util::*;
 
     #[test]
     fn dg_seems_okay() {
         let params = get_test_params();
-        let mut rng = thread_rng();
-        let mut dg = DiscreteGaussian::init(&params, &mut rng);
+        let dg = DiscreteGaussian::init_with_seed(&params, get_chacha_seed());
         let mut v = Vec::new();
         let trials = 10000;
         let mut sum = 0;

+ 1 - 6
spiral-rs/src/poly.rs

@@ -172,12 +172,7 @@ impl<'a> PolyMatrixRaw<'a> {
         }
     }
 
-    pub fn noise<T: Rng>(
-        params: &'a Params,
-        rows: usize,
-        cols: usize,
-        dg: &mut DiscreteGaussian<T>,
-    ) -> Self {
+    pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &DiscreteGaussian) -> Self {
         let mut out = PolyMatrixRaw::zero(params, rows, cols);
         dg.sample_matrix(&mut out);
         out

+ 9 - 20
spiral-rs/src/server.rs

@@ -5,7 +5,6 @@ use std::io::BufReader;
 use std::io::Read;
 use std::io::Seek;
 use std::io::SeekFrom;
-use std::time::Instant;
 
 use crate::aligned_memory::*;
 use crate::arith::*;
@@ -501,9 +500,7 @@ pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> Align
     let mut v = AlignedMemory64::new(db_size_words);
     let v_mut_slice = v.as_mut_slice();
 
-    let now = Instant::now();
     load_file(v_mut_slice, file);
-    println!("Done loading ({} ms).", now.elapsed().as_millis());
 
     v
 }
@@ -728,9 +725,7 @@ pub fn process_query(
     let mut v_reg_reoriented;
     let v_folding;
     if params.expand_queries {
-        let now = Instant::now();
         (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
-        println!("expansion (took {} us).", now.elapsed().as_micros());
     } else {
         v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
         v_reg_reoriented
@@ -794,7 +789,7 @@ pub fn process_query(
 mod test {
     use super::*;
     use crate::client::*;
-    use rand::{prelude::SmallRng, Rng};
+    use rand::Rng;
 
     const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
 
@@ -805,7 +800,7 @@ mod test {
     fn dec_reg<'a>(
         params: &'a Params,
         ct: &PolyMatrixNTT<'a>,
-        client: &mut Client<'a, SmallRng>,
+        client: &mut Client<'a>,
         scale_k: u64,
     ) -> u64 {
         let dec = client.decrypt_matrix_reg(ct).raw();
@@ -821,11 +816,7 @@ mod test {
         }
     }
 
-    fn dec_gsw<'a>(
-        params: &'a Params,
-        ct: &PolyMatrixNTT<'a>,
-        client: &mut Client<'a, SmallRng>,
-    ) -> u64 {
+    fn dec_gsw<'a>(params: &'a Params, ct: &PolyMatrixNTT<'a>, client: &mut Client<'a>) -> u64 {
         let dec = client.decrypt_matrix_reg(ct).raw();
         let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
         let mut val = dec.data[idx] as i64;
@@ -843,9 +834,8 @@ mod test {
     fn coefficient_expansion_is_correct() {
         let params = get_params();
         let v_neg1 = params.get_v_neg1();
-        let mut seeded_rng = get_seeded_rng();
         let mut chacha_rng = get_chacha_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
         let public_params = client.generate_keys();
 
         let mut v = Vec::new();
@@ -888,9 +878,8 @@ mod test {
     fn regev_to_gsw_is_correct() {
         let mut params = get_params();
         params.db_dim_2 = 1;
-        let mut seeded_rng = get_seeded_rng();
         let mut chacha_rng = get_chacha_rng();
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
         let public_params = client.generate_keys();
 
         let mut enc_constant = |val| {
@@ -936,7 +925,7 @@ mod test {
         let target_idx_dim0 = target_idx / num_per;
         let target_idx_num_per = target_idx % num_per;
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
         _ = client.generate_keys();
 
         let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
@@ -991,7 +980,7 @@ mod test {
         let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
         let target_idx_num_per = target_idx % num_per;
 
-        let mut client = Client::init(&params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
         _ = client.generate_keys();
 
         let mut v_reg = Vec::new();
@@ -1050,7 +1039,7 @@ mod test {
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
 
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
@@ -1076,7 +1065,7 @@ mod test {
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
-        let mut client = Client::init(params, &mut seeded_rng);
+        let mut client = Client::init_with_seed(&params, get_chacha_static_seed());
 
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);

+ 13 - 2
spiral-rs/src/util.rs

@@ -1,4 +1,4 @@
-use crate::{arith::*, params::*, poly::*};
+use crate::{arith::*, client::Seed, params::*, poly::*};
 use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use serde_json::Value;
@@ -141,14 +141,25 @@ pub fn get_seeded_rng() -> SmallRng {
     SmallRng::seed_from_u64(get_seed())
 }
 
+pub fn get_chacha_seed() -> Seed {
+    thread_rng().gen::<[u8; 32]>()
+}
+
 pub fn get_chacha_rng() -> ChaCha20Rng {
-    ChaCha20Rng::from_seed(thread_rng().gen::<[u8; 32]>())
+    ChaCha20Rng::from_seed(get_chacha_seed())
 }
 
 pub fn get_static_seed() -> u64 {
     0x123456789
 }
 
+pub fn get_chacha_static_seed() -> Seed {
+    [
+        0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x0, 0x1,
+        0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
+    ]
+}
+
 pub fn get_static_seeded_rng() -> SmallRng {
     SmallRng::seed_from_u64(get_static_seed())
 }