use rand::distributions::WeightedIndex; use rand::prelude::Distribution; use rand::Rng; use std::sync::{Mutex, MutexGuard}; use thread_local::ThreadLocal; use crate::params::*; use crate::poly::*; use std::f64::consts::PI; pub const NUM_WIDTHS: usize = 8; pub struct DiscreteGaussian { choices: Vec, dist: WeightedIndex, rng: ThreadLocal>, rnggen: fn() -> T, } impl DiscreteGaussian { pub fn init(params: &Params, rnggen: fn() -> T) -> Self { let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64; let mut choices = Vec::new(); let mut table = vec![0f64; 0]; for i in -max_val..max_val + 1 { let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2)); choices.push(i); table.push(p_val); } let dist = WeightedIndex::new(&table).unwrap(); Self { choices, dist, rng: ThreadLocal::new(), rnggen } } // FIXME: not constant-time fn sample_from_members(choices: &Vec, dist: &WeightedIndex, rng: &mut T) -> i64 { choices[dist.sample(rng)] } pub fn get_rng(&self) -> MutexGuard { self.rng.get_or(|| Mutex::new((self.rnggen)())).lock().unwrap() } #[cfg(test)] fn sample(&self) -> i64 { let mut rng = self.get_rng(); Self::sample_from_members(&self.choices, &self.dist, &mut *rng) } pub fn sample_matrix(&self, p: &mut PolyMatrixRaw) { let modulus = p.get_params().modulus; let choices = &self.choices; let dist = &self.dist; let rng = &mut *self.get_rng(); for r in 0..p.rows { for c in 0..p.cols { let poly = p.get_poly_mut(r, c); for z in 0..poly.len() { let mut s = Self::sample_from_members(choices, dist, rng); s += modulus as i64; s %= modulus as i64; // FIXME: not constant time poly[z] = s as u64; } } } } } #[cfg(test)] mod test { use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use super::*; use crate::util::*; #[test] fn dg_seems_okay() { let params = get_test_params(); let dg = DiscreteGaussian::init(¶ms, ChaCha20Rng::from_entropy); let mut v = Vec::new(); let trials = 10000; let mut sum = 0; for _ in 0..trials { let val = dg.sample(); v.push(val); sum += val; } let mean = sum as f64 / trials as f64; let std_dev = params.noise_width / f64::sqrt(2f64 * std::f64::consts::PI); let std_dev_of_mean = std_dev / f64::sqrt(trials as f64); assert!(f64::abs(mean) < std_dev_of_mean * 5f64); } }