discrete_gaussian.rs 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. use rand::distributions::WeightedIndex;
  2. use rand::prelude::Distribution;
  3. use rand::Rng;
  4. use std::cell::{RefCell, RefMut};
  5. use thread_local::ThreadLocal;
  6. use crate::params::*;
  7. use crate::poly::*;
  8. use std::f64::consts::PI;
  9. pub const NUM_WIDTHS: usize = 8;
  10. pub struct DiscreteGaussian<T: Rng + Send> {
  11. choices: Vec<i64>,
  12. dist: WeightedIndex<f64>,
  13. rng: ThreadLocal<RefCell<T>>,
  14. rnggen: fn() -> T,
  15. }
  16. impl<T: Rng + Send> DiscreteGaussian<T> {
  17. pub fn init(params: &Params, rnggen: fn() -> T) -> Self {
  18. let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
  19. let mut choices = Vec::new();
  20. let mut table = vec![0f64; 0];
  21. for i in -max_val..max_val + 1 {
  22. let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2));
  23. choices.push(i);
  24. table.push(p_val);
  25. }
  26. let dist = WeightedIndex::new(&table).unwrap();
  27. Self { choices, dist, rng: ThreadLocal::new(), rnggen }
  28. }
  29. // FIXME: not constant-time
  30. fn sample_from_members(choices: &Vec<i64>, dist: &WeightedIndex<f64>,
  31. rng: &mut T) -> i64 {
  32. choices[dist.sample(rng)]
  33. }
  34. pub fn get_rng(&self) -> RefMut<T> {
  35. self.rng.get_or(|| RefCell::new((self.rnggen)())).borrow_mut()
  36. }
  37. #[cfg(test)]
  38. fn sample(&self) -> i64 {
  39. let mut rng = self.get_rng();
  40. Self::sample_from_members(&self.choices, &self.dist, &mut *rng)
  41. }
  42. pub fn sample_matrix(&self, p: &mut PolyMatrixRaw) {
  43. let modulus = p.get_params().modulus;
  44. let choices = &self.choices;
  45. let dist = &self.dist;
  46. let rng = &mut *self.get_rng();
  47. for r in 0..p.rows {
  48. for c in 0..p.cols {
  49. let poly = p.get_poly_mut(r, c);
  50. for z in 0..poly.len() {
  51. let mut s = Self::sample_from_members(choices, dist, rng);
  52. s += modulus as i64;
  53. s %= modulus as i64; // FIXME: not constant time
  54. poly[z] = s as u64;
  55. }
  56. }
  57. }
  58. }
  59. }
  60. #[cfg(test)]
  61. mod test {
  62. use rand::SeedableRng;
  63. use rand_chacha::ChaCha20Rng;
  64. use super::*;
  65. use crate::util::*;
  66. #[test]
  67. fn dg_seems_okay() {
  68. let params = get_test_params();
  69. let dg = DiscreteGaussian::init(&params, ChaCha20Rng::from_entropy);
  70. let mut v = Vec::new();
  71. let trials = 10000;
  72. let mut sum = 0;
  73. for _ in 0..trials {
  74. let val = dg.sample();
  75. v.push(val);
  76. sum += val;
  77. }
  78. let mean = sum as f64 / trials as f64;
  79. let std_dev = params.noise_width / f64::sqrt(2f64 * std::f64::consts::PI);
  80. let std_dev_of_mean = std_dev / f64::sqrt(trials as f64);
  81. assert!(f64::abs(mean) < std_dev_of_mean * 5f64);
  82. }
  83. }