Browse Source

begin server

Samir Menon 2 years ago
parent
commit
f40bfc6239

+ 4 - 0
spiral-rs/Cargo.toml

@@ -16,6 +16,10 @@ criterion = "0.3"
 name = "ntt"
 harness = false
 
+[[bench]]
+name = "server"
+harness = false
+
 [profile.release]
 lto = "fat"
 codegen-units = 1

+ 50 - 0
spiral-rs/benches/server.rs

@@ -0,0 +1,50 @@
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use spiral_rs::client::*;
+use spiral_rs::poly::*;
+use spiral_rs::server::*;
+use spiral_rs::util::*;
+use std::time::Duration;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let mut group = c.benchmark_group("sample-size");
+    group
+        .sample_size(10)
+        .measurement_time(Duration::from_secs(10));
+
+    let params = get_short_keygen_params();
+    let v_neg1 = params.get_v_neg1();
+    let mut seeded_rng = get_seeded_rng();
+    let mut client = Client::init(&params, &mut seeded_rng);
+    let public_params = client.generate_keys();
+
+    let mut v = Vec::new();
+    for _ in 0..params.poly_len {
+        v.push(PolyMatrixNTT::zero(&params, 2, 1));
+    }
+    let scale_k = params.modulus / params.pt_modulus;
+    let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
+    sigma.data[7] = scale_k;
+    v[0] = client.encrypt_matrix_reg(&sigma.ntt());
+
+    let v_w_left = public_params.v_expansion_left.unwrap();
+    let v_w_right = public_params.v_expansion_right.unwrap();
+
+    group.bench_function("coeff exp", |b| {
+        b.iter(|| {
+            coefficient_expansion(
+                black_box(&mut v),
+                black_box(client.g),
+                black_box(client.stop_round),
+                black_box(&params),
+                black_box(&v_w_left),
+                black_box(&v_w_right),
+                black_box(&v_neg1),
+                black_box(params.t_gsw * params.db_dim_2),
+            )
+        });
+    });
+    group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);

+ 9 - 38
spiral-rs/src/client.rs

@@ -18,10 +18,10 @@ fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
 }
 
 pub struct PublicParameters<'a> {
-    v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
-    v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
-    v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
-    v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
+    pub v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
+    pub v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
+    pub v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
+    pub v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
 }
 
 impl<'a> PublicParameters<'a> {
@@ -115,8 +115,8 @@ pub struct Client<'a, TRng: Rng> {
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
     dg: DiscreteGaussian<'a, TRng>,
-    g: usize,
-    stop_round: usize,
+    pub g: usize,
+    pub stop_round: usize,
 }
 
 fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
@@ -221,7 +221,7 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
         res
     }
 
-    fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         let m = a.cols;
         let p = self.get_fresh_reg_public_key(m);
         &p + &a.pad_top(1)
@@ -247,7 +247,7 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
         res
     }
 
-    pub fn generate_keys(&mut self) -> PublicParameters {
+    pub fn generate_keys(&mut self) -> PublicParameters<'a> {
         let params = self.params;
         self.dg.sample_matrix(&mut self.sk_gsw);
         self.dg.sample_matrix(&mut self.sk_reg);
@@ -500,44 +500,15 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
 
 #[cfg(test)]
 mod test {
-    use rand::SeedableRng;
-
     use super::*;
 
-    fn get_seed() -> [u8; 32] {
-        [
-            1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5,
-            6, 7, 8,
-        ]
-    }
-
-    fn get_seeded_rng() -> StdRng {
-        StdRng::from_seed(get_seed())
-    }
-
     fn assert_first8(m: &[u64], gold: [u64; 8]) {
         let got: [u64; 8] = m[0..8].try_into().unwrap();
         assert_eq!(got, gold);
     }
 
     fn get_params() -> Params {
-        Params::init(
-            2048,
-            &vec![268369921u64, 249561089u64],
-            6.4,
-            2,
-            256,
-            20,
-            4,
-            4,
-            4,
-            4,
-            true,
-            9,
-            6,
-            1,
-            2048,
-        )
+        get_short_keygen_params()
     }
 
     #[test]

+ 56 - 0
spiral-rs/src/gadget.rs

@@ -1,3 +1,5 @@
+use std::primitive;
+
 use crate::{params::*, poly::*};
 
 pub fn get_bits_per(params: &Params, dim: usize) -> usize {
@@ -30,3 +32,57 @@ pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw
     }
     g
 }
+
+pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
+    let params = inp.params;
+
+    let num_elems = mx / inp.rows;
+    let bits_per = get_bits_per(params, num_elems);
+    let mask = (1u64 << bits_per) - 1;
+
+    let mut out = PolyMatrixRaw::zero(params, mx, inp.cols);
+    for i in 0..inp.cols {
+        for j in 0..inp.rows {
+            for z in 0..params.poly_len {
+                let val = inp.get_poly(j, i)[z];
+                for k in 0..num_elems {
+                    let bit_offs = usize::min(k * bits_per, 64) as u64;
+                    let shifted = val.checked_shr(bit_offs as u32);
+                    let piece = match shifted {
+                        Some(x) => x & mask,
+                        None => 0,
+                    };
+
+                    out.get_poly_mut(j + k * inp.rows, i)[z] = piece;
+                }
+            }
+        }
+    }
+    out
+}
+
+#[cfg(test)]
+mod test {
+    use crate::util::get_test_params;
+
+    use super::*;
+
+    #[test]
+    fn gadget_invert_is_correct() {
+        let params = get_test_params();
+        let mut mat = PolyMatrixRaw::zero(&params, 2, 1);
+        mat.get_poly_mut(0, 0)[37] = 3;
+        mat.get_poly_mut(1, 0)[37] = 6;
+        let log_q = params.modulus_log2 as usize;
+        let result = gadget_invert(2 * log_q, &mat);
+
+        assert_eq!(result.get_poly(0, 0)[37], 1);
+        assert_eq!(result.get_poly(2, 0)[37], 1);
+        assert_eq!(result.get_poly(4, 0)[37], 0); // binary for '3'
+
+        assert_eq!(result.get_poly(1, 0)[37], 0);
+        assert_eq!(result.get_poly(3, 0)[37], 1);
+        assert_eq!(result.get_poly(5, 0)[37], 1);
+        assert_eq!(result.get_poly(7, 0)[37], 0); // binary for '6'
+    }
+}

+ 1 - 0
spiral-rs/src/lib.rs

@@ -9,3 +9,4 @@ pub mod params;
 pub mod poly;
 
 pub mod client;
+pub mod server;

+ 3 - 1
spiral-rs/src/main.rs

@@ -1,3 +1,4 @@
+use rand::thread_rng;
 use serde_json::Value;
 use spiral_rs::client::*;
 use spiral_rs::util::*;
@@ -71,7 +72,8 @@ fn main() {
     let idx_target: usize = (&args[1]).parse().unwrap();
 
     println!("initializing client");
-    let mut c = Client::init(&params);
+    let mut rng = thread_rng();
+    let mut c = Client::init(&params, &mut rng);
     println!("generating public parameters");
     let pub_params = c.generate_keys();
     let pub_params_buf = pub_params.serialize();

+ 12 - 1
spiral-rs/src/params.rs

@@ -1,4 +1,4 @@
-use crate::{arith::*, ntt::*, number_theory::*};
+use crate::{arith::*, ntt::*, number_theory::*, poly::*};
 
 pub static Q2_VALUES: [u64; 37] = [
     0,
@@ -82,6 +82,17 @@ impl Params {
         self.ntt_tables[i][3].as_slice()
     }
 
+    pub fn get_v_neg1(&self) -> Vec<PolyMatrixNTT> {
+        let mut v_neg1 = Vec::new();
+        for i in 0..self.poly_len_log2 {
+            let idx = self.poly_len - (1 << i);
+            let mut ng1 = PolyMatrixRaw::zero(&self, 1, 1);
+            ng1.data[idx] = 1;
+            v_neg1.push((-&ng1).ntt());
+        }
+        v_neg1
+    }
+
     pub fn get_sk_gsw(&self) -> (usize, usize) {
         (self.n, 1)
     }

+ 33 - 0
spiral-rs/src/poly.rs

@@ -50,6 +50,8 @@ pub trait PolyMatrix<'a> {
             }
         }
     }
+
+    fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self;
     fn pad_top(&self, pad_rows: usize) -> Self;
 }
 
@@ -121,6 +123,21 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
         padded.copy_into(&self, pad_rows, 0);
         padded
     }
+    fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self {
+        let mut m = Self::zero(self.params, rows, cols);
+        assert!(target_row < self.rows);
+        assert!(target_col < self.cols);
+        assert!(target_row + rows <= self.rows);
+        assert!(target_col + cols <= self.cols);
+        for r in 0..rows {
+            for c in 0..cols {
+                let pol_src = self.get_poly(target_row + r, target_col + c);
+                let pol_dst = m.get_poly_mut(r, c);
+                pol_dst.copy_from_slice(pol_src);
+            }
+        }
+        m
+    }
 }
 
 impl<'a> PolyMatrixRaw<'a> {
@@ -243,6 +260,22 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
         padded.copy_into(&self, pad_rows, 0);
         padded
     }
+
+    fn submatrix(&self, target_row: usize, target_col: usize, rows: usize, cols: usize) -> Self {
+        let mut m = Self::zero(self.params, rows, cols);
+        assert!(target_row < self.rows);
+        assert!(target_col < self.cols);
+        assert!(target_row + rows <= self.rows);
+        assert!(target_col + cols <= self.cols);
+        for r in 0..rows {
+            for c in 0..cols {
+                let pol_src = self.get_poly(target_row + r, target_col + c);
+                let pol_dst = m.get_poly_mut(r, c);
+                pol_dst.copy_from_slice(pol_src);
+            }
+        }
+        m
+    }
 }
 
 impl<'a> PolyMatrixNTT<'a> {

+ 107 - 0
spiral-rs/src/server.rs

@@ -0,0 +1,107 @@
+use crate::arith;
+use crate::gadget::gadget_invert;
+use crate::params::*;
+use crate::poly::*;
+
+pub fn coefficient_expansion(
+    v: &mut Vec<PolyMatrixNTT>,
+    g: usize,
+    stopround: usize,
+    params: &Params,
+    v_w_left: &Vec<PolyMatrixNTT>,
+    v_w_right: &Vec<PolyMatrixNTT>,
+    v_neg1: &Vec<PolyMatrixNTT>,
+    max_bits_to_gen_right: usize,
+) {
+    let poly_len = params.poly_len;
+
+    for r in 0..g {
+        let num_in = 1 << r;
+        let num_out = 2 * num_in;
+
+        let t = (poly_len / (1 << r)) + 1;
+
+        let neg1 = &v_neg1[r];
+
+        for i in 0..num_out {
+            if stopround > 0 && i % 2 == 1 && r > stopround
+                || (r == stopround && i / 2 >= max_bits_to_gen_right)
+            {
+                continue;
+            }
+
+            let (w, gadget_dim) = match i % 2 {
+                0 => (&v_w_left[r], params.t_exp_left),
+                1 | _ => (&v_w_right[r], params.t_exp_right),
+            };
+
+            if i < num_in {
+                let (src, dest) = v.split_at_mut(num_in);
+                scalar_multiply(&mut dest[i], neg1, &src[i]);
+            }
+
+            let ct = from_ntt_alloc(&v[i]);
+            let ct_auto = automorph_alloc(&ct, t);
+            let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
+            let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
+            let ginv_ct = gadget_invert(gadget_dim, &ct_auto_0);
+            let ginv_ct_ntt = ginv_ct.ntt();
+            let w_times_ginv_ct = w * &ginv_ct_ntt;
+
+            let mut idx = 0;
+            for j in 0..2 {
+                for n in 0..params.crt_count {
+                    for z in 0..poly_len {
+                        let sum = v[i].data[idx]
+                            + w_times_ginv_ct.data[idx]
+                            + j * ct_auto_1_ntt.data[n * poly_len + z];
+                        v[i].data[idx] = arith::modular_reduce(params, sum, n);
+                        idx += 1;
+                    }
+                }
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::{client::*, util::*};
+
+    use super::*;
+
+    fn get_params() -> Params {
+        get_short_keygen_params()
+    }
+
+    #[test]
+    fn coefficient_expansion_is_correct() {
+        let params = get_params();
+        let v_neg1 = params.get_v_neg1();
+        let mut seeded_rng = get_seeded_rng();
+        let mut client = Client::init(&params, &mut seeded_rng);
+        let public_params = client.generate_keys();
+
+        let mut v = Vec::new();
+        for _ in 0..params.poly_len {
+            v.push(PolyMatrixNTT::zero(&params, 2, 1));
+        }
+        let scale_k = params.modulus / params.pt_modulus;
+        let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
+        sigma.data[7] = scale_k;
+        v[0] = client.encrypt_matrix_reg(&sigma.ntt());
+
+        let v_w_left = public_params.v_expansion_left.unwrap();
+        let v_w_right = public_params.v_expansion_right.unwrap();
+        coefficient_expansion(
+            &mut v,
+            client.g,
+            client.stop_round,
+            &params,
+            &v_w_left,
+            &v_w_right,
+            &v_neg1,
+            params.t_gsw * params.db_dim_2,
+        );
+    }
+}

+ 32 - 0
spiral-rs/src/util.rs

@@ -1,4 +1,5 @@
 use crate::params::*;
+use rand::{prelude::StdRng, SeedableRng};
 use serde_json::Value;
 
 pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
@@ -31,6 +32,37 @@ pub fn get_test_params() -> Params {
     )
 }
 
+pub fn get_short_keygen_params() -> Params {
+    Params::init(
+        2048,
+        &vec![268369921u64, 249561089u64],
+        6.4,
+        2,
+        256,
+        20,
+        4,
+        4,
+        4,
+        4,
+        true,
+        9,
+        6,
+        1,
+        2048,
+    )
+}
+
+pub fn get_seed() -> [u8; 32] {
+    [
+        1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
+        7, 8,
+    ]
+}
+
+pub fn get_seeded_rng() -> StdRng {
+    StdRng::from_seed(get_seed())
+}
+
 pub const fn get_empty_params() -> Params {
     Params {
         poly_len: 0,