discrete_gaussian.rs 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. use rand::distributions::WeightedIndex;
  2. use rand::prelude::Distribution;
  3. use rand::Rng;
  4. use rand::{rngs::ThreadRng, thread_rng};
  5. use crate::params::*;
  6. use crate::poly::*;
  7. use std::f64::consts::PI;
  8. pub const NUM_WIDTHS: usize = 8;
  9. pub struct DiscreteGaussian<'a, T: Rng> {
  10. choices: Vec<i64>,
  11. dist: WeightedIndex<f64>,
  12. pub rng: &'a mut T,
  13. }
  14. impl<'a, T: Rng> DiscreteGaussian<'a, T> {
  15. pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
  16. let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
  17. let mut choices = Vec::new();
  18. let mut table = vec![0f64; 0];
  19. for i in -max_val..max_val + 1 {
  20. let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2));
  21. choices.push(i);
  22. table.push(p_val);
  23. }
  24. let dist = WeightedIndex::new(&table).unwrap();
  25. Self { choices, dist, rng }
  26. }
  27. // FIXME: not constant-time
  28. pub fn sample(&mut self) -> i64 {
  29. self.choices[self.dist.sample(&mut self.rng)]
  30. }
  31. pub fn sample_matrix(&mut self, p: &mut PolyMatrixRaw) {
  32. let modulus = p.get_params().modulus;
  33. for r in 0..p.rows {
  34. for c in 0..p.cols {
  35. let poly = p.get_poly_mut(r, c);
  36. for z in 0..poly.len() {
  37. let mut s = self.sample();
  38. s += modulus as i64;
  39. s %= modulus as i64; // FIXME: not constant time
  40. poly[z] = s as u64;
  41. }
  42. }
  43. }
  44. }
  45. }
  46. #[cfg(test)]
  47. mod test {
  48. use super::*;
  49. use crate::util::*;
  50. #[test]
  51. fn dg_seems_okay() {
  52. let params = get_test_params();
  53. let mut rng = thread_rng();
  54. let mut dg = DiscreteGaussian::init(&params, &mut rng);
  55. let mut v = Vec::new();
  56. let trials = 10000;
  57. let mut sum = 0;
  58. for _ in 0..trials {
  59. let val = dg.sample();
  60. v.push(val);
  61. sum += val;
  62. }
  63. let mean = sum as f64 / trials as f64;
  64. let std_dev = params.noise_width / f64::sqrt(2f64 * std::f64::consts::PI);
  65. let std_dev_of_mean = std_dev / f64::sqrt(trials as f64);
  66. assert!(f64::abs(mean) < std_dev_of_mean * 5f64);
  67. }
  68. }