discrete_gaussian.rs 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. use rand::distributions::WeightedIndex;
  2. use rand::prelude::Distribution;
  3. use rand::thread_rng;
  4. use rand::Rng;
  5. use rand::SeedableRng;
  6. use rand_chacha::ChaCha20Rng;
  7. use std::cell::*;
  8. use crate::client::*;
  9. use crate::params::*;
  10. use crate::poly::*;
  11. use std::f64::consts::PI;
  12. pub const NUM_WIDTHS: usize = 8;
  13. thread_local!(static RNGS: RefCell<Option<ChaCha20Rng>> = RefCell::new(None));
  14. pub struct DiscreteGaussian {
  15. choices: Vec<i64>,
  16. dist: WeightedIndex<f64>,
  17. }
  18. impl DiscreteGaussian {
  19. pub fn init(params: &Params) -> Self {
  20. Self::init_impl(params, None)
  21. }
  22. pub fn init_with_seed(params: &Params, seed: Seed) -> Self {
  23. Self::init_impl(params, Some(seed))
  24. }
  25. fn init_impl(params: &Params, seed: Option<Seed>) -> Self {
  26. let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
  27. let mut choices = Vec::new();
  28. let mut table = vec![0f64; 0];
  29. for i in -max_val..max_val + 1 {
  30. let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2));
  31. choices.push(i);
  32. table.push(p_val);
  33. }
  34. let dist = WeightedIndex::new(&table).unwrap();
  35. if seed.is_some() {
  36. RNGS.with(|f| *f.borrow_mut() = Some(ChaCha20Rng::from_seed(seed.unwrap())));
  37. } else {
  38. RNGS.with(|f| *f.borrow_mut() = Some(ChaCha20Rng::from_seed(thread_rng().gen())));
  39. }
  40. Self { choices, dist }
  41. }
  42. // FIXME: not constant-time
  43. pub fn sample(&self) -> i64 {
  44. RNGS.with(|f| {
  45. let mut rng = f.borrow_mut();
  46. let rng_mut = rng.as_mut().unwrap();
  47. self.choices[self.dist.sample(rng_mut)]
  48. })
  49. }
  50. pub fn sample_matrix(&self, p: &mut PolyMatrixRaw) {
  51. let modulus = p.get_params().modulus;
  52. for r in 0..p.rows {
  53. for c in 0..p.cols {
  54. let poly = p.get_poly_mut(r, c);
  55. for z in 0..poly.len() {
  56. let mut s = self.sample();
  57. s += modulus as i64;
  58. s %= modulus as i64; // FIXME: not constant time
  59. poly[z] = s as u64;
  60. }
  61. }
  62. }
  63. }
  64. }
  65. #[cfg(test)]
  66. mod test {
  67. use super::*;
  68. use crate::util::*;
  69. #[test]
  70. fn dg_seems_okay() {
  71. let params = get_test_params();
  72. let dg = DiscreteGaussian::init_with_seed(&params, get_chacha_seed());
  73. let mut v = Vec::new();
  74. let trials = 10000;
  75. let mut sum = 0;
  76. for _ in 0..trials {
  77. let val = dg.sample();
  78. v.push(val);
  79. sum += val;
  80. }
  81. let mean = sum as f64 / trials as f64;
  82. let std_dev = params.noise_width / f64::sqrt(2f64 * std::f64::consts::PI);
  83. let std_dev_of_mean = std_dev / f64::sqrt(trials as f64);
  84. assert!(f64::abs(mean) < std_dev_of_mean * 5f64);
  85. }
  86. }