Browse Source

Add support for nu_2=0, and e2e test binary

Samir Menon 2 years ago
parent
commit
31bf0df166

File diff suppressed because it is too large
+ 0 - 0
params_store.json


+ 3 - 0
spiral-rs/Cargo.toml

@@ -16,6 +16,9 @@ required-features = ["client"]
 name = "server"
 required-features = ["server"]
 
+[[bin]]
+name = "e2e"
+
 [dependencies]
 getrandom = { features = ["js"], version = "0.2.6" }
 rand = { version = "0.8.5", features = ["small_rng"] }

+ 1 - 0
spiral-rs/benches/server.rs

@@ -113,6 +113,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     for i in 0..db_size_words {
         db[i] = seeded_rng.gen();
     }
+    println!("{} polys in db", trials * num_items);
 
     let v_reg_sz = dim0 * 2 * params.poly_len;
     let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);

+ 4 - 0
spiral-rs/src/arith.rs

@@ -133,6 +133,10 @@ pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
     }
 }
 
+pub fn barrett_u64(params: &Params, val: u64) -> u64 {
+    barrett_raw_u64(val, params.barrett_cr_1_modulus, params.modulus)
+}
+
 pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
     barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
 }

+ 70 - 0
spiral-rs/src/bin/e2e.rs

@@ -0,0 +1,70 @@
+use rand::Rng;
+use rand::thread_rng;
+use spiral_rs::client::*;
+use spiral_rs::server::*;
+use spiral_rs::util::*;
+use spiral_rs::arith::*;
+use spiral_rs::params::*;
+use std::env;
+use std::fs;
+use std::time::Instant;
+
+fn print_params_summary(params: &Params) {
+    let total_size = params.num_items() * params.db_item_size;
+    println!("{} x {} database ({} bytes total)", params.num_items(), params.db_item_size, total_size);    
+}
+
+fn main() {
+    let params;
+    let args: Vec<String> = env::args().collect();
+
+    if args.len() == 2 {
+        let inp_params_fname = &args[1];
+        let params_json_str = fs::read_to_string(inp_params_fname).unwrap();
+
+        params = params_from_json(&params_json_str);
+    } else {
+        let target_num_log2: usize = args[1].parse().unwrap();
+        let item_size_bytes: usize = args[2].parse().unwrap();
+
+        params = get_params_from_store(target_num_log2, item_size_bytes);
+    }
+
+    print_params_summary(&params);
+
+    let mut rng = thread_rng();
+    let idx_target: usize = rng.gen::<usize>() % params.num_items();
+
+    println!("fetching index {} out of {} items", idx_target, params.num_items());
+    println!("initializing client");
+    let mut client = Client::init(&params, &mut rng);
+    println!("generating public parameters");
+    let pub_params = client.generate_keys();
+    let pub_params_buf = pub_params.serialize();
+    println!("public parameters size: {} bytes", pub_params_buf.len());
+    let query = client.generate_query(idx_target);
+    let query_buf = query.serialize();
+    println!("initial query size: {} bytes", query_buf.len());
+
+    println!("generating db");
+    let (corr_item, db) = generate_random_db_and_get_item(&params, idx_target);
+
+    println!("processing query");
+    let now = Instant::now();
+    let response = process_query(&params, &pub_params, &query, db.as_slice());
+    println!("done processing (took {} us).", now.elapsed().as_micros());
+    println!("response size: {} bytes", response.len());
+
+    println!("decoding response");
+    let result = client.decode_response(response.as_slice());
+
+    let p_bits = log2_ceil(params.pt_modulus) as usize;
+    let corr_result = corr_item.to_vec(p_bits, params.modp_words_per_chunk());
+
+    assert_eq!(result.len(), corr_result.len());
+    for z in 0..corr_result.len() {
+        assert_eq!(result[z], corr_result[z], "error in response at {:?}", z);
+    }
+
+    println!("completed correctly!");
+}

+ 1 - 1
spiral-rs/src/bin/preprocess_db.rs

@@ -15,7 +15,7 @@ fn main() {
 
     let mut inp_file = File::open(inp_db_path).unwrap();
 
-    let db = load_db_from_file(params, &mut inp_file);
+    let db = load_db_from_seek(params, &mut inp_file);
     let db_slice = db.as_slice();
 
     let mut out_file = File::create(out_db_path).unwrap();

+ 24 - 14
spiral-rs/src/client.rs

@@ -426,24 +426,34 @@ impl<'a, T: Rng> Client<'a, T> {
         if params.expand_queries {
             // pack query into single ciphertext
             let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
-            sigma.data[2 * idx_dim0] = scale_k;
-            for i in 0..further_dims as u64 {
-                let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
-                for j in 0..params.t_gsw {
-                    let val = (1u64 << (bits_per * j)) * bit;
-                    let idx = (i as usize) * params.t_gsw + (j as usize);
-                    sigma.data[2 * idx + 1] = val;
-                }
-            }
             let inv_2_g_first = invert_uint_mod(1 << params.g(), params.modulus).unwrap();
             let inv_2_g_rest =
                 invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap();
 
-            for i in 0..params.poly_len / 2 {
-                sigma.data[2 * i] =
-                    multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
-                sigma.data[2 * i + 1] =
-                    multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
+            if params.db_dim_2 == 0 {
+                sigma.data[idx_dim0] = scale_k;
+                for i in 0..params.poly_len {
+                    sigma.data[i] =
+                        multiply_uint_mod(sigma.data[i], inv_2_g_first, params.modulus);
+                }
+            } else {
+                sigma.data[2 * idx_dim0] = scale_k;
+            
+                for i in 0..further_dims as u64 {
+                    let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
+                    for j in 0..params.t_gsw {
+                        let val = (1u64 << (bits_per * j)) * bit;
+                        let idx = (i as usize) * params.t_gsw + (j as usize);
+                        sigma.data[2 * idx + 1] = val;
+                    }
+                }
+
+                for i in 0..params.poly_len / 2 {
+                    sigma.data[2 * i] =
+                        multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
+                    sigma.data[2 * i + 1] =
+                        multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
+                }
             }
 
             query.ct = Some(from_ntt_alloc(

+ 11 - 0
spiral-rs/src/params.rs

@@ -4,6 +4,7 @@ use crate::{arith::*, ntt::*, number_theory::*, poly::*};
 
 pub const MAX_MODULI: usize = 4;
 
+pub static MIN_Q2_BITS: u64 = 14;
 pub static Q2_VALUES: [u64; 37] = [
     0,
     0,
@@ -114,6 +115,10 @@ impl Params {
         1 << self.db_dim_1
     }
 
+    pub fn num_items(&self) -> usize {
+        (1 << self.db_dim_1) * (1 << self.db_dim_2)
+    }
+
     pub fn g(&self) -> usize {
         let num_bits_to_gen = self.t_gsw * self.db_dim_2 + self.num_expanded();
         log2_ceil_usize(num_bits_to_gen)
@@ -123,6 +128,10 @@ impl Params {
         log2_ceil_usize(self.t_gsw * self.db_dim_2)
     }
 
+    pub fn factor_on_first_dim(&self) -> usize {
+        if self.db_dim_2 == 0 { 1 } else { 2 }
+    }
+
     pub fn setup_bytes(&self) -> usize {
         let mut sz_polys = 0;
 
@@ -213,6 +222,8 @@ impl Params {
         instances: usize,
         db_item_size: usize,
     ) -> Self {
+        assert!(q2_bits >= MIN_Q2_BITS);
+
         let poly_len_log2 = log2(poly_len as u64) as usize;
         let crt_count = moduli.len();
         assert!(crt_count <= MAX_MODULI);

+ 57 - 38
spiral-rs/src/server.rs

@@ -59,14 +59,14 @@ pub fn coefficient_expansion(
             let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
             let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
 
-            let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
-                0 => (
+            let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match (r != 0) && (i % 2 == 0) {
+                true => (
                     &v_w_left[r],
                     params.t_exp_left,
                     &mut ginv_ct_left,
                     &mut ginv_ct_left_ntt,
                 ),
-                1 | _ => (
+                false => (
                     &v_w_right[r],
                     params.t_exp_right,
                     &mut ginv_ct_right,
@@ -169,7 +169,7 @@ pub fn multiply_reg_by_database(
     let pt_rows = 1;
     let pt_cols = 1;
 
-    assert!(dim0 * ct_rows >= MAX_SUMMED);
+    // assert!(dim0 * ct_rows >= MAX_SUMMED);
 
     let mut sums_out_n0_u64 = AlignedMemory64::new(4);
     let mut sums_out_n2_u64 = AlignedMemory64::new(4);
@@ -180,8 +180,12 @@ pub fn multiply_reg_by_database(
 
         for i in 0..num_per {
             for c in 0..pt_cols {
-                let inner_limit = MAX_SUMMED;
-                let outer_limit = dim0 * ct_rows / inner_limit;
+                let mut inner_limit = MAX_SUMMED;
+                let mut outer_limit = dim0 * ct_rows / inner_limit;
+                if MAX_SUMMED > dim0 * ct_rows {
+                    inner_limit = dim0 * ct_rows;
+                    outer_limit = 1;
+                }
 
                 let mut sums_out_n0_u64_acc = [0u64, 0, 0, 0];
                 let mut sums_out_n2_u64_acc = [0u64, 0, 0, 0];
@@ -351,12 +355,10 @@ pub fn generate_random_db_and_get_item<'a>(
     let db_size_words = instances * trials * num_items * params.poly_len;
     let mut v = AlignedMemory64::new(db_size_words);
 
-    let mut item = PolyMatrixRaw::zero(params, params.n, params.n);
+    let mut item = PolyMatrixRaw::zero(params, params.instances * params.n, params.n);
 
     for instance in 0..instances {
-        println!("Instance {:?}", instance);
         for trial in 0..trials {
-            println!("Trial {:?}", trial);
             for i in 0..num_items {
                 let ii = i % num_per;
                 let j = i / num_per;
@@ -364,8 +366,8 @@ pub fn generate_random_db_and_get_item<'a>(
                 let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
                 db_item.reduce_mod(params.pt_modulus);
 
-                if i == item_idx && instance == 0 {
-                    item.copy_into(&db_item, trial / params.n, trial % params.n);
+                if i == item_idx {
+                    item.copy_into(&db_item, instance * params.n + trial / params.n, trial % params.n);
                 }
 
                 for z in 0..params.poly_len {
@@ -389,9 +391,9 @@ pub fn generate_random_db_and_get_item<'a>(
     (item, v)
 }
 
-pub fn load_item_from_file<'a>(
+pub fn load_item_from_seek<'a, T: Seek + Read>(
     params: &'a Params,
-    file: &mut File,
+    seekable: &mut T,
     instance: usize,
     trial: usize,
     item_idx: usize,
@@ -412,12 +414,12 @@ pub fn load_item_from_file<'a>(
 
     let mut out = PolyMatrixRaw::zero(params, 1, 1);
 
-    let seek_result = file.seek(SeekFrom::Start(idx_poly_in_file as u64));
+    let seek_result = seekable.seek(SeekFrom::Start(idx_poly_in_file as u64));
     if seek_result.is_err() {
         return out;
     }
     let mut data = vec![0u8; 2 * bytes_per_chunk];
-    let bytes_read = file
+    let bytes_read = seekable
         .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
         .unwrap();
 
@@ -432,7 +434,7 @@ pub fn load_item_from_file<'a>(
     out
 }
 
-pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
+pub fn load_db_from_seek<T: Seek + Read>(params: &Params, seekable: &mut T) -> AlignedMemory64 {
     let instances = params.instances;
     let trials = params.n * params.n;
     let dim0 = 1 << params.db_dim_1;
@@ -442,17 +444,12 @@ pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     let mut v = AlignedMemory64::new(db_size_words);
 
     for instance in 0..instances {
-        println!("Instance {:?}", instance);
         for trial in 0..trials {
-            println!("Trial {:?}", trial);
             for i in 0..num_items {
-                if i % 8192 == 0 {
-                    println!("item {:?}", i);
-                }
                 let ii = i % num_per;
                 let j = i / num_per;
 
-                let mut db_item = load_item_from_file(params, file, instance, trial, i);
+                let mut db_item = load_item_from_seek(params, seekable, instance, trial, i);
                 // db_item.reduce_mod(params.pt_modulus);
 
                 for z in 0..params.poly_len {
@@ -513,6 +510,10 @@ pub fn fold_ciphertexts(
     v_folding: &Vec<PolyMatrixNTT>,
     v_folding_neg: &Vec<PolyMatrixNTT>,
 ) {
+    if v_cts.len() == 1 {
+        return;
+    }
+    
     let further_dims = log2(v_cts.len() as u64) as usize;
     let ell = v_folding[0].cols / 2;
     let mut ginv_c = PolyMatrixRaw::zero(&params, 2 * ell, 1);
@@ -667,24 +668,40 @@ pub fn expand_query<'a>(
     let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
     let v_neg1 = params.get_v_neg1();
 
-    coefficient_expansion(
-        &mut v,
-        g,
-        stop_round,
-        params,
-        &v_w_left,
-        &v_w_right,
-        &v_neg1,
-        params.t_gsw * params.db_dim_2,
-    );
-
     let mut v_reg_inp = Vec::with_capacity(dim0);
-    for i in 0..dim0 {
-        v_reg_inp.push(v[2 * i].clone());
-    }
     let mut v_gsw_inp = Vec::with_capacity(right_expanded);
-    for i in 0..right_expanded {
-        v_gsw_inp.push(v[2 * i + 1].clone());
+    if further_dims > 0 {
+        coefficient_expansion(
+            &mut v,
+            g,
+            stop_round,
+            params,
+            &v_w_left,
+            &v_w_right,
+            &v_neg1,
+            params.t_gsw * params.db_dim_2,
+        );
+
+        for i in 0..dim0 {
+            v_reg_inp.push(v[2 * i].clone());
+        }
+        for i in 0..right_expanded {
+            v_gsw_inp.push(v[2 * i + 1].clone());
+        }
+    } else {
+        coefficient_expansion(
+            &mut v,
+            g,
+            0,
+            params,
+            &v_w_left,
+            &v_w_left,
+            &v_neg1,
+            0,
+        );
+        for i in 0..dim0 {
+            v_reg_inp.push(v[i].clone());
+        }
     }
 
     let v_reg_sz = dim0 * 2 * params.poly_len;
@@ -716,7 +733,9 @@ pub fn process_query(
     let mut v_reg_reoriented;
     let v_folding;
     if params.expand_queries {
+        let now = Instant::now();
         (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
+        println!("expansion (took {} us).", now.elapsed().as_micros());
     } else {
         v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
         v_reg_reoriented

+ 30 - 2
spiral-rs/src/util.rs

@@ -1,6 +1,7 @@
 use crate::{arith::*, params::*, poly::*};
 use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
 use serde_json::Value;
+use std::fs;
 
 pub const CFG_20_256: &'static str = r#"
         {'n': 2,
@@ -181,18 +182,27 @@ pub const fn get_empty_params() -> Params {
 
 pub fn params_from_json(cfg: &str) -> Params {
     let v: Value = serde_json::from_str(cfg).unwrap();
+    params_from_json_obj(&v)
+}
+
+pub fn params_from_json_obj(v: &Value) -> Params {
     let n = v["n"].as_u64().unwrap() as usize;
     let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
     let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
     let instances = v["instances"].as_u64().unwrap_or(1) as usize;
-    let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
     let p = v["p"].as_u64().unwrap();
-    let q2_bits = v["q2_bits"].as_u64().unwrap();
+    let q2_bits = u64::max(v["q2_bits"].as_u64().unwrap(), MIN_Q2_BITS);
     let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
     let t_conv = v["t_conv"].as_u64().unwrap() as usize;
     let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
     let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
     let do_expansion = v.get("direct_upload").is_none();
+
+    let mut db_item_size = v["db_item_size"].as_u64().unwrap_or(0) as usize;
+    if db_item_size == 0 {
+        db_item_size = instances * n * n;
+        db_item_size = db_item_size * 2048 * log2_ceil(p) as usize / 8;
+    }
     Params::init(
         2048,
         &vec![268369921u64, 249561089u64],
@@ -212,6 +222,24 @@ pub fn params_from_json(cfg: &str) -> Params {
     )
 }
 
+static ALL_PARAMS_STORE_FNAME: &str = "../params_store.json";
+
+pub fn get_params_from_store(target_num_log2: usize, item_size: usize) -> Params {
+    
+    let params_store_str = fs::read_to_string(ALL_PARAMS_STORE_FNAME).unwrap();
+    let v: Value = serde_json::from_str(&params_store_str).unwrap();
+    let nearest_target_num = target_num_log2;
+    let nearest_item_size = 1 << usize::max(log2_ceil_usize(item_size), 8);
+    println!("{} x {}", nearest_target_num, nearest_item_size);
+    let target = v.as_array().unwrap().iter()
+        .map(|x| x.as_object().unwrap() )
+        .filter(|x| x.get("target_num").unwrap().as_u64().unwrap() == (nearest_target_num as u64))
+        .filter(|x| x.get("item_size").unwrap().as_u64().unwrap() == (nearest_item_size as u64))
+        .map(|x| x.get("params").unwrap())
+        .next().unwrap();
+    params_from_json_obj(target)
+}
+
 pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
     let word_off = bit_offs / 64;
     let bit_off_within_word = bit_offs % 64;

Some files were not shown because too many files changed in this diff