Browse Source

Use seeds to encode query + pub param data

Samir Menon 1 year ago
parent
commit
b4c7162f03

File diff suppressed because it is too large
+ 0 - 0
params_store.json


+ 2 - 1
spiral-rs/benches/server.rs

@@ -72,6 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     let params = get_expansion_testing_params();
     let v_neg1 = params.get_v_neg1();
     let mut seeded_rng = get_seeded_rng();
+    let mut chacha_rng = get_chacha_rng();
     let mut client = Client::init(&params, &mut seeded_rng);
     let public_params = client.generate_keys();
 
@@ -82,7 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     let scale_k = params.modulus / params.pt_modulus;
     let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
     sigma.data[7] = scale_k;
-    v[0] = client.encrypt_matrix_reg(&sigma.ntt());
+    v[0] = client.encrypt_matrix_reg(&sigma.ntt(), &mut chacha_rng);
 
     let v_w_left = public_params.v_expansion_left.unwrap();
     let v_w_right = public_params.v_expansion_right.unwrap();

+ 15 - 6
spiral-rs/src/bin/e2e.rs

@@ -1,10 +1,10 @@
-use rand::Rng;
 use rand::thread_rng;
+use rand::Rng;
+use spiral_rs::arith::*;
 use spiral_rs::client::*;
+use spiral_rs::params::*;
 use spiral_rs::server::*;
 use spiral_rs::util::*;
-use spiral_rs::arith::*;
-use spiral_rs::params::*;
 use std::env;
 use std::fs;
 use std::time::Instant;
@@ -12,7 +12,12 @@ use std::time::Instant;
 fn print_params_summary(params: &Params) {
     let db_elem_size = params.item_size();
     let total_size = params.num_items() * db_elem_size;
-    println!("Using a {} x {} byte database ({} bytes total)", params.num_items(), db_elem_size, total_size);    
+    println!(
+        "Using a {} x {} byte database ({} bytes total)",
+        params.num_items(),
+        db_elem_size,
+        total_size
+    );
 }
 
 fn main() {
@@ -36,7 +41,11 @@ fn main() {
     let mut rng = thread_rng();
     let idx_target: usize = rng.gen::<usize>() % params.num_items();
 
-    println!("fetching index {} out of {} items", idx_target, params.num_items());
+    println!(
+        "fetching index {} out of {} items",
+        idx_target,
+        params.num_items()
+    );
     println!("initializing client");
     let mut client = Client::init(&params, &mut rng);
     println!("generating public parameters");
@@ -68,4 +77,4 @@ fn main() {
     }
 
     println!("completed correctly!");
-}
+}

+ 166 - 82
spiral-rs/src/client.rs

@@ -1,10 +1,13 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
-use rand::{Rng, SeedableRng};
+use rand::{thread_rng, Rng, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use std::{iter::once, mem::size_of};
 
+type Seed = <ChaCha20Rng as SeedableRng>::Seed;
+const SEED_LENGTH: usize = 32;
+
 fn new_vec_raw<'a>(
     params: &'a Params,
     num: usize,
@@ -18,43 +21,95 @@ fn new_vec_raw<'a>(
     v
 }
 
-fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
-    for i in 0..a.rows * a.cols * a.params.poly_len {
-        vec.extend_from_slice(&u64::to_ne_bytes(a.data[i]));
-    }
+fn get_inv_from_rng(params: &Params, rng: &mut ChaCha20Rng) -> u64 {
+    params.modulus - (rng.gen::<u64>() % params.modulus)
 }
 
-fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
-    for i in 0..a.len() {
-        serialize_polymatrix(vec, &a[i]);
+fn mat_sz_bytes_excl_first_row(a: &PolyMatrixRaw) -> usize {
+    (a.rows - 1) * a.cols * a.params.poly_len * size_of::<u64>()
+}
+
+fn serialize_polymatrix_for_rng(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
+    let offs = a.cols * a.params.poly_len; // skip the first row
+    for i in 0..(a.rows - 1) * a.cols * a.params.poly_len {
+        vec.extend_from_slice(&u64::to_ne_bytes(a.data[offs + i]));
     }
 }
 
-fn mat_sz_bytes(a: &PolyMatrixRaw) -> usize {
-    a.rows * a.cols * a.params.poly_len * size_of::<u64>()
+fn serialize_vec_polymatrix_for_rng(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
+    for i in 0..a.len() {
+        serialize_polymatrix_for_rng(vec, &a[i]);
+    }
 }
 
-fn deserialize_polymatrix(a: &mut PolyMatrixRaw, data: &[u8]) -> usize {
+fn deserialize_polymatrix_rng(a: &mut PolyMatrixRaw, data: &[u8], rng: &mut ChaCha20Rng) -> usize {
+    let (first_row, rest) = a
+        .data
+        .as_mut_slice()
+        .split_at_mut(a.cols * a.params.poly_len);
+    for i in 0..first_row.len() {
+        first_row[i] = get_inv_from_rng(a.params, rng);
+    }
     for (i, chunk) in data.chunks(size_of::<u64>()).enumerate() {
-        a.data[i] = u64::from_ne_bytes(chunk.try_into().unwrap());
+        rest[i] = u64::from_ne_bytes(chunk.try_into().unwrap());
     }
-    mat_sz_bytes(a)
+    mat_sz_bytes_excl_first_row(a)
 }
 
-fn deserialize_vec_polymatrix(a: &mut Vec<PolyMatrixRaw>, data: &[u8]) -> usize {
-    let mut chunks = data.chunks(mat_sz_bytes(&a[0]));
+fn deserialize_vec_polymatrix_rng(
+    a: &mut Vec<PolyMatrixRaw>,
+    data: &[u8],
+    rng: &mut ChaCha20Rng,
+) -> usize {
+    let mut chunks = data.chunks(mat_sz_bytes_excl_first_row(&a[0]));
     let mut bytes_read = 0;
     for i in 0..a.len() {
-        bytes_read += deserialize_polymatrix(&mut a[i], chunks.next().unwrap());
+        bytes_read += deserialize_polymatrix_rng(&mut a[i], chunks.next().unwrap(), rng);
     }
     bytes_read
 }
 
+fn extract_excl_rng_data(v_buf: &[u64]) -> Vec<u64> {
+    let mut out = Vec::new();
+    for i in 0..v_buf.len() {
+        if i % 2 == 1 {
+            out.push(v_buf[i]);
+        }
+    }
+    out
+}
+
+fn interleave_rng_data(params: &Params, v_buf: &[u64], rng: &mut ChaCha20Rng) -> Vec<u64> {
+    let mut out = Vec::new();
+
+    let mut reg_cts = Vec::new();
+    for _ in 0..params.num_expanded() {
+        let mut sigma = PolyMatrixRaw::zero(&params, 2, 1);
+        for z in 0..params.poly_len {
+            sigma.data[z] = get_inv_from_rng(params, rng);
+        }
+        reg_cts.push(sigma.ntt());
+    }
+    // reorient into server's preferred indexing
+    let reg_cts_buf_words = params.num_expanded() * 2 * params.poly_len;
+    let mut reg_cts_buf = vec![0u64; reg_cts_buf_words];
+    reorient_reg_ciphertexts(params, reg_cts_buf.as_mut_slice(), &reg_cts);
+
+    assert_eq!(reg_cts_buf_words, 2 * v_buf.len());
+
+    for i in 0..v_buf.len() {
+        out.push(reg_cts_buf[2 * i]);
+        out.push(v_buf[i]);
+    }
+    out
+}
+
 pub struct PublicParameters<'a> {
     pub v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
     pub v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
     pub v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
     pub v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
+    pub seed: Option<Seed>,
 }
 
 impl<'a> PublicParameters<'a> {
@@ -65,6 +120,7 @@ impl<'a> PublicParameters<'a> {
                 v_expansion_left: Some(Vec::new()),
                 v_expansion_right: Some(Vec::new()),
                 v_conversion: Some(Vec::new()),
+                seed: None,
             }
         } else {
             PublicParameters {
@@ -72,6 +128,7 @@ impl<'a> PublicParameters<'a> {
                 v_expansion_left: None,
                 v_expansion_right: None,
                 v_conversion: None,
+                seed: None,
             }
         }
     }
@@ -101,38 +158,47 @@ impl<'a> PublicParameters<'a> {
 
     pub fn serialize(&self) -> Vec<u8> {
         let mut data = Vec::new();
+        if self.seed.is_some() {
+            let seed = self.seed.as_ref().unwrap();
+            data.extend(seed);
+        }
         for v in self.to_raw().iter() {
             if v.is_some() {
-                serialize_vec_polymatrix(&mut data, v.as_ref().unwrap());
+                serialize_vec_polymatrix_for_rng(&mut data, v.as_ref().unwrap());
             }
         }
         data
     }
 
     pub fn deserialize(params: &'a Params, data: &[u8]) -> Self {
-        assert_eq!(params.setup_bytes(), data.len());
+        assert_eq!(SEED_LENGTH + params.setup_bytes(), data.len());
 
         let mut idx = 0;
 
+        let seed = data[0..SEED_LENGTH].try_into().unwrap();
+        let mut rng = ChaCha20Rng::from_seed(seed);
+        idx += SEED_LENGTH;
+
         let mut v_packing = new_vec_raw(params, params.n, params.n + 1, params.t_conv);
-        idx += deserialize_vec_polymatrix(&mut v_packing, &data[idx..]);
+        idx += deserialize_vec_polymatrix_rng(&mut v_packing, &data[idx..], &mut rng);
 
         if params.expand_queries {
             let mut v_expansion_left = new_vec_raw(params, params.g(), 2, params.t_exp_left);
-            idx += deserialize_vec_polymatrix(&mut v_expansion_left, &data[idx..]);
+            idx += deserialize_vec_polymatrix_rng(&mut v_expansion_left, &data[idx..], &mut rng);
 
             let mut v_expansion_right =
                 new_vec_raw(params, params.stop_round() + 1, 2, params.t_exp_right);
-            idx += deserialize_vec_polymatrix(&mut v_expansion_right, &data[idx..]);
+            idx += deserialize_vec_polymatrix_rng(&mut v_expansion_right, &data[idx..], &mut rng);
 
             let mut v_conversion = new_vec_raw(params, 1, 2, 2 * params.t_conv);
-            _ = deserialize_vec_polymatrix(&mut v_conversion, &data[idx..]);
+            _ = deserialize_vec_polymatrix_rng(&mut v_conversion, &data[idx..], &mut rng);
 
             Self {
                 v_packing: Self::to_ntt_alloc_vec(&v_packing).unwrap(),
                 v_expansion_left: Self::to_ntt_alloc_vec(&v_expansion_left),
                 v_expansion_right: Self::to_ntt_alloc_vec(&v_expansion_right),
                 v_conversion: Self::to_ntt_alloc_vec(&v_conversion),
+                seed: Some(seed),
             }
         } else {
             Self {
@@ -140,6 +206,7 @@ impl<'a> PublicParameters<'a> {
                 v_expansion_left: None,
                 v_expansion_right: None,
                 v_conversion: None,
+                seed: Some(seed),
             }
         }
     }
@@ -149,6 +216,7 @@ pub struct Query<'a> {
     pub ct: Option<PolyMatrixRaw<'a>>,
     pub v_buf: Option<Vec<u64>>,
     pub v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
+    pub seed: Option<Seed>,
 }
 
 impl<'a> Query<'a> {
@@ -157,46 +225,57 @@ impl<'a> Query<'a> {
             ct: None,
             v_ct: None,
             v_buf: None,
+            seed: None,
         }
     }
 
     pub fn serialize(&self) -> Vec<u8> {
         let mut data = Vec::new();
+        if self.seed.is_some() {
+            let seed = self.seed.as_ref().unwrap();
+            data.extend(seed);
+        }
         if self.ct.is_some() {
             let ct = self.ct.as_ref().unwrap();
-            serialize_polymatrix(&mut data, &ct);
+            serialize_polymatrix_for_rng(&mut data, &ct);
         }
         if self.v_buf.is_some() {
             let v_buf = self.v_buf.as_ref().unwrap();
-            data.extend(v_buf.iter().map(|x| x.to_ne_bytes()).flatten());
+            let v_buf_extracted = extract_excl_rng_data(&v_buf);
+            data.extend(v_buf_extracted.iter().map(|x| x.to_ne_bytes()).flatten());
         }
         if self.v_ct.is_some() {
             let v_ct = self.v_ct.as_ref().unwrap();
             for x in v_ct {
-                serialize_polymatrix(&mut data, x);
+                serialize_polymatrix_for_rng(&mut data, x);
             }
         }
         data
     }
 
-    pub fn deserialize(params: &'a Params, data: &[u8]) -> Self {
-        assert_eq!(params.query_bytes(), data.len());
+    pub fn deserialize(params: &'a Params, mut data: &[u8]) -> Self {
+        assert_eq!(SEED_LENGTH + params.query_bytes(), data.len());
 
         let mut out = Query::empty();
+        let seed = data[0..SEED_LENGTH].try_into().unwrap();
+        out.seed = Some(seed);
+        let mut rng = ChaCha20Rng::from_seed(seed);
+        data = &data[SEED_LENGTH..];
         if params.expand_queries {
             let mut ct = PolyMatrixRaw::zero(params, 2, 1);
-            deserialize_polymatrix(&mut ct, data);
+            deserialize_polymatrix_rng(&mut ct, data, &mut rng);
             out.ct = Some(ct);
         } else {
             let v_buf_bytes = params.query_v_buf_bytes();
-            let v_buf = (&data[..v_buf_bytes])
+            let v_buf: Vec<u64> = (&data[..v_buf_bytes])
                 .chunks(size_of::<u64>())
                 .map(|x| u64::from_ne_bytes(x.try_into().unwrap()))
                 .collect();
-            out.v_buf = Some(v_buf);
+            let v_buf_interleaved = interleave_rng_data(params, &v_buf, &mut rng);
+            out.v_buf = Some(v_buf_interleaved);
 
             let mut v_ct = new_vec_raw(params, params.db_dim_2, 2, 2 * params.t_gsw);
-            deserialize_vec_polymatrix(&mut v_ct, &data[v_buf_bytes..]);
+            deserialize_vec_polymatrix_rng(&mut v_ct, &data[v_buf_bytes..], &mut rng);
             out.v_ct = Some(v_ct);
         }
         out
@@ -210,8 +289,8 @@ pub struct Client<'a, T: Rng> {
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
     dg: DiscreteGaussian<'a, T>,
-    public_rng: ChaCha20Rng,
-    public_seed: <ChaCha20Rng as SeedableRng>::Seed,
+    pp_seed: Seed,
+    query_seed: Seed,
 }
 
 fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
@@ -250,9 +329,10 @@ impl<'a, T: Rng> Client<'a, T> {
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
-        let mut public_seed = [0u8; 32];
-        rng.fill_bytes(&mut public_seed);
-        let public_rng = ChaCha20Rng::from_seed(public_seed);
+        let mut pp_seed = [0u8; 32];
+        let mut query_seed = [0u8; 32];
+        rng.fill_bytes(&mut query_seed);
+        rng.fill_bytes(&mut pp_seed);
         let dg = DiscreteGaussian::init(params, rng);
         Self {
             params,
@@ -261,8 +341,8 @@ impl<'a, T: Rng> Client<'a, T> {
             sk_gsw_full,
             sk_reg_full,
             dg,
-            public_rng,
-            public_seed,
+            pp_seed,
+            query_seed,
         }
     }
 
@@ -275,19 +355,11 @@ impl<'a, T: Rng> Client<'a, T> {
         &mut self.dg.rng
     }
 
-    pub fn get_public_rng(&mut self) -> &mut ChaCha20Rng {
-        &mut self.public_rng
-    }
-
-    pub fn get_public_seed(&mut self) -> <ChaCha20Rng as SeedableRng>::Seed {
-        self.public_seed
-    }
-
-    fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
+    fn get_fresh_gsw_public_key(&mut self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixRaw<'a> {
         let params = self.params;
         let n = params.n;
 
-        let a = PolyMatrixRaw::random_rng(params, 1, m, self.get_rng());
+        let a = PolyMatrixRaw::random_rng(params, 1, m, rng);
         let e = PolyMatrixRaw::noise(params, n, m, &mut self.dg);
         let a_inv = -&a;
         let b_p = &self.sk_gsw.ntt() * &a.ntt();
@@ -296,9 +368,9 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
+    fn get_regev_sample(&mut self, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
         let params = self.params;
-        let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_public_rng());
+        let a = PolyMatrixRaw::random_rng(params, 1, 1, rng);
         let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
         let b_p = &self.sk_reg.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
@@ -308,27 +380,35 @@ impl<'a, T: Rng> Client<'a, T> {
         p
     }
 
-    fn get_fresh_reg_public_key(&mut self, m: usize) -> PolyMatrixNTT<'a> {
+    fn get_fresh_reg_public_key(&mut self, m: usize, rng: &mut ChaCha20Rng) -> PolyMatrixNTT<'a> {
         let params = self.params;
 
         let mut p = PolyMatrixNTT::zero(params, 2, m);
 
         for i in 0..m {
-            p.copy_into(&self.get_regev_sample(), 0, i);
+            p.copy_into(&self.get_regev_sample(rng), 0, i);
         }
         p
     }
 
-    fn encrypt_matrix_gsw(&mut self, ag: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    fn encrypt_matrix_gsw(
+        &mut self,
+        ag: &PolyMatrixNTT<'a>,
+        rng: &mut ChaCha20Rng,
+    ) -> PolyMatrixNTT<'a> {
         let mx = ag.cols;
-        let p = self.get_fresh_gsw_public_key(mx);
+        let p = self.get_fresh_gsw_public_key(mx, rng);
         let res = &(p.ntt()) + &(ag.pad_top(1));
         res
     }
 
-    pub fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+    pub fn encrypt_matrix_reg(
+        &mut self,
+        a: &PolyMatrixNTT<'a>,
+        rng: &mut ChaCha20Rng,
+    ) -> PolyMatrixNTT<'a> {
         let m = a.cols;
-        let p = self.get_fresh_reg_public_key(m);
+        let p = self.get_fresh_reg_public_key(m, rng);
         &p + &a.pad_top(1)
     }
 
@@ -344,6 +424,7 @@ impl<'a, T: Rng> Client<'a, T> {
         &mut self,
         num_exp: usize,
         m_exp: usize,
+        rng: &mut ChaCha20Rng,
     ) -> Vec<PolyMatrixNTT<'a>> {
         let params = self.params;
         let g_exp = build_gadget(params, 1, m_exp);
@@ -354,7 +435,7 @@ impl<'a, T: Rng> Client<'a, T> {
             let t = (params.poly_len / (1 << i)) + 1;
             let tau_sk_reg = automorph_alloc(&self.sk_reg, t);
             let prod = &tau_sk_reg.ntt() * &g_exp_ntt;
-            let w_exp_i = self.encrypt_matrix_reg(&prod);
+            let w_exp_i = self.encrypt_matrix_reg(&prod, rng);
             res.push(w_exp_i);
         }
         res
@@ -369,6 +450,8 @@ impl<'a, T: Rng> Client<'a, T> {
         let sk_reg_ntt = to_ntt_alloc(&self.sk_reg);
 
         let mut pp = PublicParameters::init(params);
+        pp.seed = Some(self.pp_seed);
+        let mut rng = ChaCha20Rng::from_seed(self.pp_seed);
 
         // Params for packing
         let gadget_conv = build_gadget(params, 1, params.t_conv);
@@ -377,16 +460,19 @@ impl<'a, T: Rng> Client<'a, T> {
             let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt);
             let mut ag = PolyMatrixNTT::zero(params, params.n, params.t_conv);
             ag.copy_into(&scaled, i, 0);
-            let w = self.encrypt_matrix_gsw(&ag);
+            let w = self.encrypt_matrix_gsw(&ag, &mut rng);
             pp.v_packing.push(w);
         }
 
         if params.expand_queries {
             // Params for expansion
             pp.v_expansion_left =
-                Some(self.generate_expansion_params(params.g(), params.t_exp_left));
-            pp.v_expansion_right =
-                Some(self.generate_expansion_params(params.stop_round() + 1, params.t_exp_right));
+                Some(self.generate_expansion_params(params.g(), params.t_exp_left, &mut rng));
+            pp.v_expansion_right = Some(self.generate_expansion_params(
+                params.stop_round() + 1,
+                params.t_exp_right,
+                &mut rng,
+            ));
 
             // Params for converison
             let g_conv = build_gadget(params, 2, 2 * params.t_conv);
@@ -406,7 +492,7 @@ impl<'a, T: Rng> Client<'a, T> {
                     let val = g_conv.get_poly(1, i)[0];
                     sigma = &sk_reg_ntt * &single_poly(params, val).ntt();
                 }
-                let ct = self.encrypt_matrix_reg(&sigma);
+                let ct = self.encrypt_matrix_reg(&sigma, &mut rng);
                 pp.v_conversion.as_mut().unwrap()[0].copy_into(&ct, 0, i);
             }
         }
@@ -423,6 +509,8 @@ impl<'a, T: Rng> Client<'a, T> {
         let bits_per = get_bits_per(params, params.t_gsw);
 
         let mut query = Query::empty();
+        query.seed = Some(self.query_seed);
+        let mut rng = ChaCha20Rng::from_seed(self.query_seed);
         if params.expand_queries {
             // pack query into single ciphertext
             let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
@@ -433,12 +521,11 @@ impl<'a, T: Rng> Client<'a, T> {
             if params.db_dim_2 == 0 {
                 sigma.data[idx_dim0] = scale_k;
                 for i in 0..params.poly_len {
-                    sigma.data[i] =
-                        multiply_uint_mod(sigma.data[i], inv_2_g_first, params.modulus);
+                    sigma.data[i] = multiply_uint_mod(sigma.data[i], inv_2_g_first, params.modulus);
                 }
             } else {
                 sigma.data[2 * idx_dim0] = scale_k;
-            
+
                 for i in 0..further_dims as u64 {
                     let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
                     for j in 0..params.t_gsw {
@@ -457,7 +544,7 @@ impl<'a, T: Rng> Client<'a, T> {
             }
 
             query.ct = Some(from_ntt_alloc(
-                &self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)),
+                &self.encrypt_matrix_reg(&to_ntt_alloc(&sigma), &mut rng),
             ));
         } else {
             let num_expanded = 1 << params.db_dim_1;
@@ -470,7 +557,7 @@ impl<'a, T: Rng> Client<'a, T> {
             for i in 0..num_expanded {
                 let value = ((i == idx_dim0) as u64) * scale_k;
                 let sigma = PolyMatrixRaw::single_value(&params, value);
-                reg_cts.push(self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)));
+                reg_cts.push(self.encrypt_matrix_reg(&to_ntt_alloc(&sigma), &mut rng));
             }
             // reorient into server's preferred indexing
             reorient_reg_ciphertexts(self.params, reg_cts_buf.as_mut_slice(), &reg_cts);
@@ -484,11 +571,14 @@ impl<'a, T: Rng> Client<'a, T> {
                     let value = (1u64 << (bits_per * j)) * bit;
                     let sigma = PolyMatrixRaw::single_value(&params, value);
                     let sigma_ntt = to_ntt_alloc(&sigma);
-                    let ct = &self.encrypt_matrix_reg(&sigma_ntt);
-                    ct_gsw.copy_into(ct, 0, 2 * j + 1);
+
+                    // important to rng in the right order here
                     let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
-                    let ct = &self.encrypt_matrix_reg(&prod);
+                    let ct = &self.encrypt_matrix_reg(&prod, &mut rng);
                     ct_gsw.copy_into(ct, 0, 2 * j);
+
+                    let ct = &self.encrypt_matrix_reg(&sigma_ntt, &mut rng);
+                    ct_gsw.copy_into(ct, 0, 2 * j + 1);
                 }
                 sigma_v.push(ct_gsw);
             }
@@ -615,22 +705,13 @@ mod test {
         assert_first8(
             pub_params.v_conversion.unwrap()[0].data.as_slice(),
             [
-                48110940, 101047152, 169193903, 71831480, 48301935, 106009656, 97287006, 51905893,
+                77639594, 190195027, 25198243, 245727165, 99660925, 135957601, 187643645, 116322041,
             ],
         );
 
         assert_first8(
             client.sk_gsw.data.as_slice(),
-            [
-                2,
-                1,
-                5,
-                66974689739603968,
-                2,
-                66974689739603966,
-                66974689739603967,
-                5,
-            ],
+            [2, 66974689739603966, 66974689739603967, 5, 2, 1, 1, 0],
         );
     }
 
@@ -713,7 +794,10 @@ mod test {
         let deserialized1 = Query::deserialize(&params, &serialized1);
         let serialized2 = deserialized1.serialize();
 
-        assert_eq!(serialized1, serialized2);
+        assert_eq!(serialized1.len(), serialized2.len());
+        for i in 0..serialized1.len() {
+            assert_eq!(serialized1[i], serialized2[i], "at {}", i);
+        }
     }
 
     #[test]

+ 13 - 9
spiral-rs/src/params.rs

@@ -134,19 +134,23 @@ impl Params {
     }
 
     pub fn factor_on_first_dim(&self) -> usize {
-        if self.db_dim_2 == 0 { 1 } else { 2 }
+        if self.db_dim_2 == 0 {
+            1
+        } else {
+            2
+        }
     }
 
     pub fn setup_bytes(&self) -> usize {
         let mut sz_polys = 0;
 
-        let packing_sz = (self.n + 1) * self.t_conv;
+        let packing_sz = ((self.n + 1) - 1) * self.t_conv;
         sz_polys += self.n * packing_sz;
 
         if self.expand_queries {
-            let expansion_left_sz = self.g() * 2 * self.t_exp_left;
-            let expansion_right_sz = (self.stop_round() + 1) * 2 * self.t_exp_right;
-            let conversion_sz = 2 * (2 * self.t_conv);
+            let expansion_left_sz = self.g() * self.t_exp_left;
+            let expansion_right_sz = (self.stop_round() + 1) * self.t_exp_right;
+            let conversion_sz = 2 * self.t_conv;
 
             sz_polys += expansion_left_sz + expansion_right_sz + conversion_sz;
         }
@@ -159,10 +163,10 @@ impl Params {
         let sz_polys;
 
         if self.expand_queries {
-            sz_polys = 2;
+            sz_polys = 1;
         } else {
-            let first_dimension_sz = self.num_expanded() * 2;
-            let further_dimension_sz = self.db_dim_2 * 2 * (2 * self.t_gsw);
+            let first_dimension_sz = self.num_expanded();
+            let further_dimension_sz = self.db_dim_2 * (2 * self.t_gsw);
             sz_polys = first_dimension_sz + further_dimension_sz;
         }
 
@@ -171,7 +175,7 @@ impl Params {
     }
 
     pub fn query_v_buf_bytes(&self) -> usize {
-        self.num_expanded() * 2 * self.poly_len * size_of::<u64>()
+        self.num_expanded() * self.poly_len * size_of::<u64>()
     }
 
     pub fn bytes_per_chunk(&self) -> usize {

+ 18 - 19
spiral-rs/src/server.rs

@@ -367,7 +367,11 @@ pub fn generate_random_db_and_get_item<'a>(
                 db_item.reduce_mod(params.pt_modulus);
 
                 if i == item_idx {
-                    item.copy_into(&db_item, instance * params.n + trial / params.n, trial % params.n);
+                    item.copy_into(
+                        &db_item,
+                        instance * params.n + trial / params.n,
+                        trial % params.n,
+                    );
                 }
 
                 for z in 0..params.poly_len {
@@ -513,7 +517,7 @@ pub fn fold_ciphertexts(
     if v_cts.len() == 1 {
         return;
     }
-    
+
     let further_dims = log2(v_cts.len() as u64) as usize;
     let ell = v_folding[0].cols / 2;
     let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
@@ -689,16 +693,7 @@ pub fn expand_query<'a>(
             v_gsw_inp.push(v[2 * i + 1].clone());
         }
     } else {
-        coefficient_expansion(
-            &mut v,
-            g,
-            0,
-            params,
-            &v_w_left,
-            &v_w_left,
-            &v_neg1,
-            0,
-        );
+        coefficient_expansion(&mut v, g, 0, params, &v_w_left, &v_w_left, &v_neg1, 0);
         for i in 0..dim0 {
             v_reg_inp.push(v[i].clone());
         }
@@ -849,6 +844,7 @@ mod test {
         let params = get_params();
         let v_neg1 = params.get_v_neg1();
         let mut seeded_rng = get_seeded_rng();
+        let mut chacha_rng = get_chacha_rng();
         let mut client = Client::init(&params, &mut seeded_rng);
         let public_params = client.generate_keys();
 
@@ -861,8 +857,8 @@ mod test {
         let scale_k = params.modulus / params.pt_modulus;
         let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
         sigma.data[target] = scale_k;
-        v[0] = client.encrypt_matrix_reg(&sigma.ntt());
-        let test_ct = client.encrypt_matrix_reg(&sigma.ntt());
+        v[0] = client.encrypt_matrix_reg(&sigma.ntt(), &mut chacha_rng);
+        let test_ct = client.encrypt_matrix_reg(&sigma.ntt(), &mut chacha_rng);
 
         let v_w_left = public_params.v_expansion_left.unwrap();
         let v_w_right = public_params.v_expansion_right.unwrap();
@@ -893,13 +889,14 @@ mod test {
         let mut params = get_params();
         params.db_dim_2 = 1;
         let mut seeded_rng = get_seeded_rng();
+        let mut chacha_rng = get_chacha_rng();
         let mut client = Client::init(&params, &mut seeded_rng);
         let public_params = client.generate_keys();
 
         let mut enc_constant = |val| {
             let mut sigma = PolyMatrixRaw::zero(&params, 1, 1);
             sigma.data[0] = val;
-            client.encrypt_matrix_reg(&sigma.ntt())
+            client.encrypt_matrix_reg(&sigma.ntt(), &mut chacha_rng)
         };
 
         let v = &public_params.v_conversion.unwrap()[0];
@@ -929,6 +926,7 @@ mod test {
     fn multiply_reg_by_database_is_correct() {
         let params = get_params();
         let mut seeded_rng = get_seeded_rng();
+        let mut chacha_rng = get_chacha_rng();
 
         let dim0 = 1 << params.db_dim_1;
         let num_per = 1 << params.db_dim_2;
@@ -947,7 +945,7 @@ mod test {
         for i in 0..dim0 {
             let val = if i == target_idx_dim0 { scale_k } else { 0 };
             let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
-            v_reg.push(client.encrypt_matrix_reg(&sigma));
+            v_reg.push(client.encrypt_matrix_reg(&sigma, &mut chacha_rng));
         }
 
         let v_reg_sz = dim0 * 2 * params.poly_len;
@@ -984,6 +982,7 @@ mod test {
     fn fold_ciphertexts_is_correct() {
         let params = get_params();
         let mut seeded_rng = get_seeded_rng();
+        let mut chacha_rng = get_chacha_rng();
 
         let dim0 = 1 << params.db_dim_1;
         let num_per = 1 << params.db_dim_2;
@@ -999,7 +998,7 @@ mod test {
         for i in 0..num_per {
             let val = if i == target_idx_num_per { scale_k } else { 0 };
             let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
-            v_reg.push(client.encrypt_matrix_reg(&sigma));
+            v_reg.push(client.encrypt_matrix_reg(&sigma, &mut chacha_rng));
         }
 
         let mut v_reg_raw = Vec::new();
@@ -1017,10 +1016,10 @@ mod test {
                 let value = (1u64 << (bits_per * j)) * bit;
                 let sigma = PolyMatrixRaw::single_value(&params, value);
                 let sigma_ntt = to_ntt_alloc(&sigma);
-                let ct = client.encrypt_matrix_reg(&sigma_ntt);
+                let ct = client.encrypt_matrix_reg(&sigma_ntt, &mut chacha_rng);
                 ct_gsw.copy_into(&ct, 0, 2 * j + 1);
                 let prod = &to_ntt_alloc(client.get_sk_reg()) * &sigma_ntt;
-                let ct = &client.encrypt_matrix_reg(&prod);
+                let ct = &client.encrypt_matrix_reg(&prod, &mut chacha_rng);
                 ct_gsw.copy_into(&ct, 0, 2 * j);
             }
 

+ 16 - 5
spiral-rs/src/util.rs

@@ -1,5 +1,6 @@
 use crate::{arith::*, params::*, poly::*};
 use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
+use rand_chacha::ChaCha20Rng;
 use serde_json::Value;
 use std::fs;
 
@@ -140,6 +141,10 @@ pub fn get_seeded_rng() -> SmallRng {
     SmallRng::seed_from_u64(get_seed())
 }
 
+pub fn get_chacha_rng() -> ChaCha20Rng {
+    ChaCha20Rng::from_seed(thread_rng().gen::<[u8; 32]>())
+}
+
 pub fn get_static_seed() -> u64 {
     0x123456789
 }
@@ -225,18 +230,24 @@ pub fn params_from_json_obj(v: &Value) -> Params {
 static ALL_PARAMS_STORE_FNAME: &str = "../params_store.json";
 
 pub fn get_params_from_store(target_num_log2: usize, item_size: usize) -> Params {
-    
     let params_store_str = fs::read_to_string(ALL_PARAMS_STORE_FNAME).unwrap();
     let v: Value = serde_json::from_str(&params_store_str).unwrap();
     let nearest_target_num = target_num_log2;
     let nearest_item_size = 1 << usize::max(log2_ceil_usize(item_size), 8);
-    println!("Starting with parameters for 2^{} x {} bytes...", nearest_target_num, nearest_item_size);
-    let target = v.as_array().unwrap().iter()
-        .map(|x| x.as_object().unwrap() )
+    println!(
+        "Starting with parameters for 2^{} x {} bytes...",
+        nearest_target_num, nearest_item_size
+    );
+    let target = v
+        .as_array()
+        .unwrap()
+        .iter()
+        .map(|x| x.as_object().unwrap())
         .filter(|x| x.get("target_num").unwrap().as_u64().unwrap() == (nearest_target_num as u64))
         .filter(|x| x.get("item_size").unwrap().as_u64().unwrap() == (nearest_item_size as u64))
         .map(|x| x.get("params").unwrap())
-        .next().unwrap();
+        .next()
+        .unwrap();
     params_from_json_obj(target)
 }
 

Some files were not shown because too many files changed in this diff