Browse Source

More threading, fix tests, update README

Samir Menon 2 years ago
parent
commit
2274e9df9a
5 changed files with 108 additions and 71 deletions
  1. 8 5
      README.md
  2. 3 3
      client/src/lib.rs
  3. 8 1
      client/static/index.html
  4. 31 24
      spiral-rs/src/client.rs
  5. 58 38
      spiral-rs/src/server.rs

+ 8 - 5
README.md

@@ -6,8 +6,11 @@ This is an implementation of our paper "Spiral: Fast, High-Rate Single-Server PI
 
 ## Building
 
-To build this project, run `cargo build`.
-
-## Structure
-
-...
+- In `spiral-rs/spiral-rs`:
+    - To build the library `spiral-rs`, run `cargo build --release`.
+    - To run the library tests, run `cargo test`.
+    - To build the server, run `cargo build --release --bin server --features server`.
+    - To preprocess a database, run `cargo build --release --bin preprocess_db`.
+    - To run the server, run `target/release/server dbfile.dbp` with the preprocessed database file `dbfile.dbp`
+- In `spiral-rs/client`:
+    - To build the client for our Wikipedia demo, run `wasm-pack build --target web --out-dir static/pkg`

+ 3 - 3
client/src/lib.rs

@@ -1,6 +1,6 @@
 use std::convert::TryInto;
 
-use rand::{thread_rng, SeedableRng, RngCore};
+use rand::{thread_rng, RngCore, SeedableRng};
 use rand_chacha::ChaCha20Rng;
 use spiral_rs::{client::*, discrete_gaussian::*, util::*};
 use wasm_bindgen::prelude::*;
@@ -81,7 +81,7 @@ pub fn initialize(json_params: Option<String>, seed: Box<[u8]>) -> WrappedClient
 pub fn generate_public_parameters(c: &mut WrappedClient) -> Box<[u8]> {
     let res = c.client.generate_keys().serialize().into_boxed_slice();
 
-    // important to re-seed here; only query and public key are deterministic
+    // important to re-seed here; only public key is deterministic, not queries
     let mut rng = thread_rng();
     let mut new_seed = [0u8; 32];
     rng.fill_bytes(&mut new_seed);
@@ -118,4 +118,4 @@ mod test {
         let val2: u64 = Standard.sample(&mut rng2);
         assert_eq!(val1, val2);
     }
-}
+}

+ 8 - 1
client/static/index.html

@@ -51,7 +51,7 @@
 
           <h2>What costs are associated with running the demo?</h2>
           <p>
-            When making your first query, this demo will upload 18 MB of data; each later query requires only 14 KB of upload. 
+            When making your first query, this demo will upload 18 MB of data; each later query requires only 28 KB of upload. 
             The server response to each query is 250 KB.
           </p>
 
@@ -60,6 +60,13 @@
             No! This is research-quality software built for demonstration purposes; it is not intended to be 
             side-channel resistant and has not undergone any kind fo security review. Don't use this code in production.
           </p>
+
+          <h2>Who made this demo?</h2>
+          <p>
+            Samir Menon, a (currently) independent researcher in PIR and lattice-based cryptography. 
+            Please contact me with any questions, comments, and sugggestions at <a href="menon.samir@gmail.com">menon.samir@gmail.com</a>.
+            The scheme that this demo uses is "Spiral", which is a joint work with Prof. David Wu at UT Austin.
+          </p>
           </div>
       </div>
     </div>

+ 31 - 24
spiral-rs/src/client.rs

@@ -2,8 +2,8 @@ use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
 use rand::{Rng, SeedableRng};
-use std::{iter::once, mem::size_of};
 use rand_chacha::ChaCha20Rng;
+use std::{iter::once, mem::size_of};
 
 fn new_vec_raw<'a>(
     params: &'a Params,
@@ -262,10 +262,10 @@ impl<'a, T: Rng> Client<'a, T> {
             sk_reg_full,
             dg,
             public_rng,
-            public_seed
+            public_seed,
         }
     }
-    
+
     #[allow(dead_code)]
     pub(crate) fn get_sk_reg(&self) -> &PolyMatrixRaw<'a> {
         &self.sk_reg
@@ -383,7 +383,8 @@ impl<'a, T: Rng> Client<'a, T> {
 
         if params.expand_queries {
             // Params for expansion
-            pp.v_expansion_left = Some(self.generate_expansion_params(params.g(), params.t_exp_left));
+            pp.v_expansion_left =
+                Some(self.generate_expansion_params(params.g(), params.t_exp_left));
             pp.v_expansion_right =
                 Some(self.generate_expansion_params(params.stop_round() + 1, params.t_exp_right));
 
@@ -435,7 +436,8 @@ impl<'a, T: Rng> Client<'a, T> {
                 }
             }
             let inv_2_g_first = invert_uint_mod(1 << params.g(), params.modulus).unwrap();
-            let inv_2_g_rest = invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap();
+            let inv_2_g_rest =
+                invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap();
 
             for i in 0..params.poly_len / 2 {
                 sigma.data[2 * i] =
@@ -603,22 +605,21 @@ mod test {
         assert_first8(
             pub_params.v_conversion.unwrap()[0].data.as_slice(),
             [
-                122680182, 165987256, 137892309, 95732358, 221787731, 13233184, 156136764,
-                259944211,
+                48110940, 101047152, 169193903, 71831480, 48301935, 106009656, 97287006, 51905893,
             ],
         );
 
         assert_first8(
             client.sk_gsw.data.as_slice(),
             [
-                66974689739603965,
-                66974689739603965,
-                0,
+                2,
                 1,
-                0,
                 5,
-                66974689739603967,
+                66974689739603968,
                 2,
+                66974689739603966,
+                66974689739603967,
+                5,
             ],
         );
     }
@@ -641,18 +642,24 @@ mod test {
             get_vec(&pub_params.v_packing),
             get_vec(&deserialized1.v_packing)
         );
-        assert_eq!(
-            get_vec(&pub_params.v_conversion.unwrap()),
-            get_vec(&deserialized1.v_conversion.unwrap())
-        );
-        assert_eq!(
-            get_vec(&pub_params.v_expansion_left.unwrap()),
-            get_vec(&deserialized1.v_expansion_left.unwrap())
-        );
-        assert_eq!(
-            get_vec(&pub_params.v_expansion_right.unwrap()),
-            get_vec(&deserialized1.v_expansion_right.unwrap())
-        );
+        if pub_params.v_conversion.is_some() {
+            assert_eq!(
+                get_vec(&pub_params.v_conversion.unwrap()),
+                get_vec(&deserialized1.v_conversion.unwrap())
+            );
+        }
+        if pub_params.v_expansion_left.is_some() {
+            assert_eq!(
+                get_vec(&pub_params.v_expansion_left.unwrap()),
+                get_vec(&deserialized1.v_expansion_left.unwrap())
+            );
+        }
+        if pub_params.v_expansion_right.is_some() {
+            assert_eq!(
+                get_vec(&pub_params.v_expansion_right.unwrap()),
+                get_vec(&deserialized1.v_expansion_right.unwrap())
+            );
+        }
     }
 
     #[test]

+ 58 - 38
spiral-rs/src/server.rs

@@ -30,16 +30,6 @@ pub fn coefficient_expansion(
 ) {
     let poly_len = params.poly_len;
 
-    let mut ct = PolyMatrixRaw::zero(params, 2, 1);
-    let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
-    let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
-    let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
-    let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
-    let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
-    let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
-    let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
-    let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
-
     for r in 0..g {
         let num_in = 1 << r;
         let num_out = 2 * num_in;
@@ -48,16 +38,27 @@ pub fn coefficient_expansion(
 
         let neg1 = &v_neg1[r];
 
-        for i in 0..num_out {
+        let action_expand = |(i, v_i): (usize, &mut PolyMatrixNTT)| {
             if (stop_round > 0 && r > stop_round && (i % 2) == 1)
                 || (stop_round > 0
                     && r == stop_round
                     && (i % 2) == 1
                     && (i / 2) >= max_bits_to_gen_right)
             {
-                continue;
+                return;
             }
 
+            let mut ct = PolyMatrixRaw::zero(params, 2, 1);
+            let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
+            let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
+            let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
+            let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
+
+            let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
+            let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
+            let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
+            let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
+
             let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
                 0 => (
                     &v_w_left[r],
@@ -73,12 +74,12 @@ pub fn coefficient_expansion(
                 ),
             };
 
-            if i < num_in {
-                let (src, dest) = v.split_at_mut(num_in);
-                scalar_multiply(&mut dest[i], neg1, &src[i]);
-            }
+            // if i < num_in {
+            //     let (src, dest) = v.split_at_mut(num_in);
+            //     scalar_multiply(&mut dest[i], neg1, &src[i]);
+            // }
 
-            from_ntt(&mut ct, &v[i]);
+            from_ntt(&mut ct, &v_i);
             automorph(&mut ct_auto, &ct, t);
             gadget_invert_rdim(gi_ct, &ct_auto, 1);
             to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
@@ -93,15 +94,31 @@ pub fn coefficient_expansion(
             for j in 0..2 {
                 for n in 0..params.crt_count {
                     for z in 0..poly_len {
-                        let sum = v[i].data[idx]
+                        let sum = (*v_i).data[idx]
                             + w_times_ginv_ct.data[idx]
                             + j * ct_auto_1_ntt.data[n * poly_len + z];
-                        v[i].data[idx] = barrett_coeff_u64(params, sum, n);
+                        (*v_i).data[idx] = barrett_coeff_u64(params, sum, n);
                         idx += 1;
                     }
                 }
             }
-        }
+        };
+
+        let (src, dest) = v.split_at_mut(num_in);
+        src.par_iter_mut()
+            .zip(dest.par_iter_mut())
+            .for_each(|(s, d)| {
+                scalar_multiply(d, neg1, s);
+            });
+
+        v[0..num_in]
+            .par_iter_mut()
+            .enumerate()
+            .for_each(action_expand);
+        v[num_in..num_out]
+            .par_iter_mut()
+            .enumerate()
+            .for_each(action_expand);
     }
 }
 
@@ -116,13 +133,12 @@ pub fn regev_to_gsw<'a>(
     assert!(v.rows == 2);
     assert!(v.cols == 2 * params.t_conv);
 
-    let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
-    let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
-    let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
-    let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
+    v_gsw.par_iter_mut().enumerate().for_each(|(i, ct)| {
+        let mut ginv_c_inp = PolyMatrixRaw::zero(params, 2 * params.t_conv, 1);
+        let mut ginv_c_inp_ntt = PolyMatrixNTT::zero(params, 2 * params.t_conv, 1);
+        let mut tmp_ct_raw = PolyMatrixRaw::zero(params, 2, 1);
+        let mut tmp_ct = PolyMatrixNTT::zero(params, 2, 1);
 
-    for i in 0..params.db_dim_2 {
-        let ct = &mut v_gsw[i];
         for j in 0..params.t_gsw {
             let idx_ct = i * params.t_gsw + j;
             let idx_inp = idx_factor * (idx_ct) + idx_offset;
@@ -133,7 +149,7 @@ pub fn regev_to_gsw<'a>(
             multiply(&mut tmp_ct, v, &ginv_c_inp_ntt);
             ct.copy_into(&tmp_ct, 0, 2 * j);
         }
-    }
+    });
 }
 
 pub const MAX_SUMMED: usize = 1 << 6;
@@ -606,18 +622,21 @@ pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
 
 pub fn get_v_folding_neg<'a>(
     params: &'a Params,
-    v_folding: &Vec<PolyMatrixNTT>,
+    v_folding: &Vec<PolyMatrixNTT<'a>>,
 ) -> Vec<PolyMatrixNTT<'a>> {
-    let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
-
-    let mut v_folding_neg = Vec::new();
-    let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
-    for i in 0..params.db_dim_2 {
-        invert(&mut ct_gsw_inv, &v_folding[i].raw());
-        let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
-        add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
-        v_folding_neg.push(ct_gsw_neg);
-    }
+    let gadget_ntt = build_gadget(params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
+
+    let v_folding_neg = (0..params.db_dim_2)
+        .into_par_iter()
+        .map(|i| {
+            let mut ct_gsw_inv = PolyMatrixRaw::zero(params, 2, 2 * params.t_gsw);
+            invert(&mut ct_gsw_inv, &v_folding[i].raw());
+            let mut ct_gsw_neg = PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw);
+            add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
+            ct_gsw_neg
+        })
+        .collect();
+
     v_folding_neg
 }
 
@@ -1065,6 +1084,7 @@ mod test {
     }
 
     #[test]
+    #[ignore]
     fn larger_full_protocol_is_correct() {
         let cfg_expand = r#"
             {