Browse Source

enable seeded randomness

Samir Menon 2 years ago
parent
commit
246d7ff8c1

+ 4 - 4
spiral-rs/Cargo.lock

@@ -310,9 +310,9 @@ dependencies = [
 
 [[package]]
 name = "getrandom"
-version = "0.2.4"
+version = "0.2.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "418d37c8b1d42553c93648be529cb70f920d3baf8ef469b74b9638df426e0b4c"
+checksum = "9be70c98951c83b8d2f8f60d7065fa6d5146873094452a1008da8c2f1e4205ad"
 dependencies = [
  "cfg-if",
  "js-sys",
@@ -506,9 +506,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
 
 [[package]]
 name = "libc"
-version = "0.2.119"
+version = "0.2.122"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1bf2e165bb3457c8e098ea76f3e3bc9db55f87aa90d52d0e6be741470916aaa4"
+checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259"
 
 [[package]]
 name = "log"

+ 1 - 1
spiral-rs/Cargo.toml

@@ -4,7 +4,7 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
-getrandom = { features = ["js"] }
+getrandom = { features = ["js"], version = "0.2.6" }
 rand = { version = "0.8.5" }
 reqwest = { version = "0.11", features = ["blocking"] }
 serde_json = "1.0"

+ 89 - 7
spiral-rs/src/client.rs

@@ -1,6 +1,8 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
+use rand::rngs::StdRng;
+use rand::{thread_rng, Rng};
 use std::iter::once;
 
 fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
@@ -106,13 +108,13 @@ impl<'a> Query<'a> {
     }
 }
 
-pub struct Client<'a> {
+pub struct Client<'a, TRng: Rng> {
     params: &'a Params,
     sk_gsw: PolyMatrixRaw<'a>,
     sk_reg: PolyMatrixRaw<'a>,
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
-    dg: DiscreteGaussian,
+    dg: DiscreteGaussian<'a, TRng>,
     g: usize,
     stop_round: usize,
 }
@@ -145,15 +147,15 @@ fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
     )
 }
 
-impl<'a> Client<'a> {
-    pub fn init(params: &'a Params) -> Self {
+impl<'a, TRng: Rng> Client<'a, TRng> {
+    pub fn init(params: &'a Params, rng: &'a mut TRng) -> Self {
         let sk_gsw_dims = params.get_sk_gsw();
         let sk_reg_dims = params.get_sk_reg();
         let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
-        let dg = DiscreteGaussian::init(params);
+        let dg = DiscreteGaussian::init(params, rng);
 
         let further_dims = params.db_dim_2;
         let num_expanded = 1usize << params.db_dim_1;
@@ -172,11 +174,15 @@ impl<'a> Client<'a> {
         }
     }
 
+    pub fn get_rng(&mut self) -> &mut TRng {
+        &mut self.dg.rng
+    }
+
     fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
         let params = self.params;
         let n = params.n;
 
-        let a = PolyMatrixRaw::random(params, 1, m);
+        let a = PolyMatrixRaw::random_rng(params, 1, m, self.get_rng());
         let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
         let a_inv = -&a;
         let b_p = &self.sk_gsw.ntt() * &a.ntt();
@@ -187,7 +193,7 @@ impl<'a> Client<'a> {
 
     fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
         let params = self.params;
-        let a = PolyMatrixRaw::random(params, 1, 1);
+        let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_rng());
         let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
         let b_p = &self.sk_reg.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
@@ -491,3 +497,79 @@ impl<'a> Client<'a> {
         result.to_vec(p_bits as usize, modp_words_per_chunk)
     }
 }
+
+#[cfg(test)]
+mod test {
+    use rand::SeedableRng;
+
+    use super::*;
+
+    fn get_seed() -> [u8; 32] {
+        [
+            1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5,
+            6, 7, 8,
+        ]
+    }
+
+    fn get_seeded_rng() -> StdRng {
+        StdRng::from_seed(get_seed())
+    }
+
+    fn assert_first8(m: &[u64], gold: [u64; 8]) {
+        let got: [u64; 8] = m[0..8].try_into().unwrap();
+        assert_eq!(got, gold);
+    }
+
+    fn get_params() -> Params {
+        Params::init(
+            2048,
+            &vec![268369921u64, 249561089u64],
+            6.4,
+            2,
+            256,
+            20,
+            4,
+            4,
+            4,
+            4,
+            true,
+            9,
+            6,
+            1,
+            2048,
+        )
+    }
+
+    #[test]
+    fn init_is_correct() {
+        let params = get_params();
+        let mut rng = thread_rng();
+        let client = Client::init(&params, &mut rng);
+
+        assert_eq!(client.stop_round, 6);
+        assert_eq!(client.g, 10);
+        assert_eq!(*client.params, params);
+    }
+
+    #[test]
+    fn keygen_is_correct() {
+        let params = get_params();
+        let mut seeded_rng = get_seeded_rng();
+        let mut client = Client::init(&params, &mut seeded_rng);
+
+        let public_params = client.generate_keys();
+
+        assert_first8(
+            &public_params.v_conversion.unwrap()[0].data,
+            [
+                253586619, 247235120, 141892996, 163163429, 15531298, 200914775, 125109567,
+                75889562,
+            ],
+        );
+
+        assert_first8(
+            &client.sk_gsw.data,
+            [1, 5, 0, 3, 1, 3, 66974689739603967, 3],
+        );
+    }
+}

+ 8 - 10
spiral-rs/src/discrete_gaussian.rs

@@ -1,5 +1,6 @@
 use rand::distributions::WeightedIndex;
 use rand::prelude::Distribution;
+use rand::Rng;
 use rand::{rngs::ThreadRng, thread_rng};
 
 use crate::params::*;
@@ -8,14 +9,14 @@ use std::f64::consts::PI;
 
 pub const NUM_WIDTHS: usize = 8;
 
-pub struct DiscreteGaussian {
+pub struct DiscreteGaussian<'a, T: Rng> {
     choices: Vec<i64>,
     dist: WeightedIndex<f64>,
-    rng: ThreadRng,
+    pub rng: &'a mut T,
 }
 
-impl DiscreteGaussian {
-    pub fn init(params: &Params) -> Self {
+impl<'a, T: Rng> DiscreteGaussian<'a, T> {
+    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
         let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
         let mut choices = Vec::new();
         let mut table = vec![0f64; 0];
@@ -26,11 +27,7 @@ impl DiscreteGaussian {
         }
         let dist = WeightedIndex::new(&table).unwrap();
 
-        Self {
-            choices,
-            dist,
-            rng: thread_rng(),
-        }
+        Self { choices, dist, rng }
     }
 
     // FIXME: not constant-time
@@ -62,7 +59,8 @@ mod test {
     #[test]
     fn dg_seems_okay() {
         let params = get_test_params();
-        let mut dg = DiscreteGaussian::init(&params);
+        let mut rng = thread_rng();
+        let mut dg = DiscreteGaussian::init(&params, &mut rng);
         let mut v = Vec::new();
         let trials = 10000;
         let mut sum = 0;

+ 1 - 1
spiral-rs/src/main.rs

@@ -50,7 +50,7 @@ fn main() {
         'instances': 11,
         'db_item_size': 100000 }
     "#;
-    let cfg_direct = r#"
+    let _cfg_direct = r#"
         {'kinda_direct_upload': 1,
         'n': 5,
         'nu_1': 11,

+ 17 - 6
spiral-rs/src/poly.rs

@@ -1,4 +1,3 @@
-use core::num;
 #[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
 
@@ -20,6 +19,7 @@ pub trait PolyMatrix<'a> {
     fn num_words(&self) -> usize;
     fn zero(params: &'a Params, rows: usize, cols: usize) -> Self;
     fn random(params: &'a Params, rows: usize, cols: usize) -> Self;
+    fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self;
     fn as_slice(&self) -> &[u64];
     fn as_mut_slice(&mut self) -> &mut [u64];
     fn zero_out(&mut self) {
@@ -99,8 +99,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
             data,
         }
     }
-    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
-        let rng = rand::thread_rng();
+    fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self {
         let mut iter = rng.sample_iter(&Standard);
         let mut out = PolyMatrixRaw::zero(params, rows, cols);
         for r in 0..rows {
@@ -113,6 +112,10 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
         }
         out
     }
+    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
+        let mut rng = rand::thread_rng();
+        Self::random_rng(params, rows, cols, &mut rng)
+    }
     fn pad_top(&self, pad_rows: usize) -> Self {
         let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
         padded.copy_into(&self, pad_rows, 0);
@@ -137,7 +140,12 @@ impl<'a> PolyMatrixRaw<'a> {
         }
     }
 
-    pub fn noise(params: &'a Params, rows: usize, cols: usize, dg: &mut DiscreteGaussian) -> Self {
+    pub fn noise<T: Rng>(
+        params: &'a Params,
+        rows: usize,
+        cols: usize,
+        dg: &mut DiscreteGaussian<T>,
+    ) -> Self {
         let mut out = PolyMatrixRaw::zero(params, rows, cols);
         dg.sample_matrix(&mut out);
         out
@@ -210,8 +218,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
             data,
         }
     }
-    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
-        let rng = rand::thread_rng();
+    fn random_rng<T: Rng>(params: &'a Params, rows: usize, cols: usize, rng: &mut T) -> Self {
         let mut iter = rng.sample_iter(&Standard);
         let mut out = PolyMatrixNTT::zero(params, rows, cols);
         for r in 0..rows {
@@ -227,6 +234,10 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
         }
         out
     }
+    fn random(params: &'a Params, rows: usize, cols: usize) -> Self {
+        let mut rng = rand::thread_rng();
+        Self::random_rng(params, rows, cols, &mut rng)
+    }
     fn pad_top(&self, pad_rows: usize) -> Self {
         let mut padded = Self::zero(self.params, self.rows + pad_rows, self.cols);
         padded.copy_into(&self, pad_rows, 0);