Browse Source

Prep to use seed for public parts of Regev cts

Samir Menon 1 year ago
parent
commit
08972fe23b
8 changed files with 57 additions and 46 deletions
  1. 13 0
      README.md
  2. 2 2
      client/src/lib.rs
  3. 0 10
      client/src/utils.rs
  4. 1 0
      spiral-rs/Cargo.lock
  5. 1 0
      spiral-rs/Cargo.toml
  6. 2 2
      spiral-rs/benches/server.rs
  7. 33 27
      spiral-rs/src/client.rs
  8. 5 5
      spiral-rs/src/server.rs

+ 13 - 0
README.md

@@ -0,0 +1,13 @@
+# Spiral: Fast, High-Rate Single-Server PIR via FHE Composition
+
+This is an implementation of our paper "Spiral: Fast, High-Rate Single-Server PIR via FHE Composition", available [here](https://eprint.iacr.org/2022/368.pdf). 
+
+> **WARNING**: This is research-quality code; it has not been checked for side-channel leakage or basic logical or memory safety issues. Do not use this in production.
+
+## Building
+
+To build this project, run `cargo build`.
+
+## Structure
+
+...

+ 2 - 2
client/src/lib.rs

@@ -1,9 +1,8 @@
-mod utils;
 use std::convert::TryInto;
 
 use rand::{thread_rng, SeedableRng, RngCore};
 use rand_chacha::ChaCha20Rng;
-use spiral_rs::{client::*, discrete_gaussian::*, params::*, util::*};
+use spiral_rs::{client::*, discrete_gaussian::*, util::*};
 use wasm_bindgen::prelude::*;
 
 const UUID_V4_LEN: usize = 36;
@@ -14,6 +13,7 @@ extern "C" {
     #[wasm_bindgen(js_namespace = console)]
     fn log(s: &str);
 }
+#[allow(unused_macros)]
 macro_rules! console_log {
     ($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
 }

+ 0 - 10
client/src/utils.rs

@@ -1,10 +0,0 @@
-#[cfg(feature = "console_error_panic_hook")]
-pub fn set_panic_hook() {
-    // When the `console_error_panic_hook` feature is enabled, we can call the
-    // `set_panic_hook` function at least once during initialization, and then
-    // we will get better error messages if our code ever panics.
-    //
-    // For more details see
-    // https://github.com/rustwasm/console_error_panic_hook#readme
-    console_error_panic_hook::set_once();
-}

+ 1 - 0
spiral-rs/Cargo.lock

@@ -1745,6 +1745,7 @@ dependencies = [
  "getrandom",
  "pprof",
  "rand",
+ "rand_chacha",
  "rayon",
  "reqwest",
  "serde",

+ 1 - 0
spiral-rs/Cargo.toml

@@ -21,6 +21,7 @@ getrandom = { features = ["js"], version = "0.2.6" }
 rand = { version = "0.8.5", features = ["small_rng"] }
 serde_json = "1.0"
 rayon = "1.5.2"
+rand_chacha = "0.3.1"
 
 reqwest = { version = "0.11", features = ["blocking"], optional = true }
 

+ 2 - 2
spiral-rs/benches/server.rs

@@ -92,8 +92,8 @@ fn criterion_benchmark(c: &mut Criterion) {
         b.iter(|| {
             coefficient_expansion(
                 black_box(&mut v),
-                black_box(client.g),
-                black_box(client.stop_round),
+                black_box(params.g()),
+                black_box(params.stop_round()),
                 black_box(&params),
                 black_box(&v_w_left),
                 black_box(&v_w_right),

+ 33 - 27
spiral-rs/src/client.rs

@@ -1,8 +1,9 @@
 use crate::{
     arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
 };
-use rand::Rng;
+use rand::{Rng, SeedableRng};
 use std::{iter::once, mem::size_of};
+use rand_chacha::ChaCha20Rng;
 
 fn new_vec_raw<'a>(
     params: &'a Params,
@@ -202,15 +203,15 @@ impl<'a> Query<'a> {
     }
 }
 
-pub struct Client<'a, TRng: Rng> {
+pub struct Client<'a, T: Rng> {
     params: &'a Params,
     sk_gsw: PolyMatrixRaw<'a>,
-    pub sk_reg: PolyMatrixRaw<'a>,
+    sk_reg: PolyMatrixRaw<'a>,
     sk_gsw_full: PolyMatrixRaw<'a>,
     sk_reg_full: PolyMatrixRaw<'a>,
-    dg: DiscreteGaussian<'a, TRng>,
-    pub g: usize,
-    pub stop_round: usize,
+    dg: DiscreteGaussian<'a, T>,
+    public_rng: ChaCha20Rng,
+    public_seed: <ChaCha20Rng as SeedableRng>::Seed,
 }
 
 fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
@@ -241,21 +242,18 @@ fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
     )
 }
 
-impl<'a, TRng: Rng> Client<'a, TRng> {
-    pub fn init(params: &'a Params, rng: &'a mut TRng) -> Self {
+impl<'a, T: Rng> Client<'a, T> {
+    pub fn init(params: &'a Params, rng: &'a mut T) -> Self {
         let sk_gsw_dims = params.get_sk_gsw();
         let sk_reg_dims = params.get_sk_reg();
         let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
         let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
         let sk_gsw_full = matrix_with_identity(&sk_gsw);
         let sk_reg_full = matrix_with_identity(&sk_reg);
+        let mut public_seed = [0u8; 32];
+        rng.fill_bytes(&mut public_seed);
+        let public_rng = ChaCha20Rng::from_seed(public_seed);
         let dg = DiscreteGaussian::init(params, rng);
-
-        let further_dims = params.db_dim_2;
-        let num_expanded = 1usize << params.db_dim_1;
-        let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
-        let g = log2_ceil_usize(num_bits_to_gen);
-        let stop_round = log2_ceil_usize(params.t_gsw * further_dims);
         Self {
             params,
             sk_gsw,
@@ -263,15 +261,28 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
             sk_gsw_full,
             sk_reg_full,
             dg,
-            g,
-            stop_round,
+            public_rng,
+            public_seed
         }
     }
+    
+    #[allow(dead_code)]
+    pub(crate) fn get_sk_reg(&self) -> &PolyMatrixRaw<'a> {
+        &self.sk_reg
+    }
 
-    pub fn get_rng(&mut self) -> &mut TRng {
+    pub fn get_rng(&mut self) -> &mut T {
         &mut self.dg.rng
     }
 
+    pub fn get_public_rng(&mut self) -> &mut ChaCha20Rng {
+        &mut self.public_rng
+    }
+
+    pub fn get_public_seed(&mut self) -> <ChaCha20Rng as SeedableRng>::Seed {
+        self.public_seed
+    }
+
     fn get_fresh_gsw_public_key(&mut self, m: usize) -> PolyMatrixRaw<'a> {
         let params = self.params;
         let n = params.n;
@@ -287,7 +298,7 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
 
     fn get_regev_sample(&mut self) -> PolyMatrixNTT<'a> {
         let params = self.params;
-        let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_rng());
+        let a = PolyMatrixRaw::random_rng(params, 1, 1, self.get_public_rng());
         let e = PolyMatrixRaw::noise(params, 1, 1, &mut self.dg);
         let b_p = &self.sk_reg.ntt() * &a.ntt();
         let b = &e.ntt() + &b_p;
@@ -372,9 +383,9 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
 
         if params.expand_queries {
             // Params for expansion
-            pp.v_expansion_left = Some(self.generate_expansion_params(self.g, params.t_exp_left));
+            pp.v_expansion_left = Some(self.generate_expansion_params(params.g(), params.t_exp_left));
             pp.v_expansion_right =
-                Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
+                Some(self.generate_expansion_params(params.stop_round() + 1, params.t_exp_right));
 
             // Params for converison
             let g_conv = build_gadget(params, 2, 2 * params.t_conv);
@@ -423,8 +434,8 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
                     sigma.data[2 * idx + 1] = val;
                 }
             }
-            let inv_2_g_first = invert_uint_mod(1 << self.g, params.modulus).unwrap();
-            let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round + 1), params.modulus).unwrap();
+            let inv_2_g_first = invert_uint_mod(1 << params.g(), params.modulus).unwrap();
+            let inv_2_g_rest = invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap();
 
             for i in 0..params.poly_len / 2 {
                 sigma.data[2 * i] =
@@ -578,10 +589,6 @@ mod test {
         let mut rng = thread_rng();
         let client = Client::init(&params, &mut rng);
 
-        assert_eq!(client.stop_round, 5);
-        assert_eq!(client.stop_round, params.stop_round());
-        assert_eq!(client.g, 10);
-        assert_eq!(client.g, params.g());
         assert_eq!(*client.params, params);
     }
 
@@ -624,7 +631,6 @@ mod test {
         let mut seeded_rng = get_static_seeded_rng();
         let mut client = Client::init(&params, &mut seeded_rng);
         let pub_params = client.generate_keys();
-        assert_eq!(client.stop_round, params.stop_round());
 
         let serialized1 = pub_params.serialize();
         let deserialized1 = PublicParameters::deserialize(&params, &serialized1);

+ 5 - 5
spiral-rs/src/server.rs

@@ -830,8 +830,8 @@ mod test {
         let v_w_right = public_params.v_expansion_right.unwrap();
         coefficient_expansion(
             &mut v,
-            client.g,
-            client.stop_round,
+            params.g(),
+            params.stop_round(),
             &params,
             &v_w_left,
             &v_w_right,
@@ -981,7 +981,7 @@ mod test {
                 let sigma_ntt = to_ntt_alloc(&sigma);
                 let ct = client.encrypt_matrix_reg(&sigma_ntt);
                 ct_gsw.copy_into(&ct, 0, 2 * j + 1);
-                let prod = &to_ntt_alloc(&client.sk_reg) * &sigma_ntt;
+                let prod = &to_ntt_alloc(client.get_sk_reg()) * &sigma_ntt;
                 let ct = &client.encrypt_matrix_reg(&prod);
                 ct_gsw.copy_into(&ct, 0, 2 * j);
             }
@@ -1011,7 +1011,7 @@ mod test {
     fn full_protocol_is_correct_for_params(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
-        let target_idx = 22456; //22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(params, &mut seeded_rng);
 
@@ -1037,7 +1037,7 @@ mod test {
     fn full_protocol_is_correct_for_params_real_db(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
-        let target_idx = 22456; //seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(params, &mut seeded_rng);