Browse Source

add database loading/preprocessing support

Samir Menon 2 years ago
parent
commit
428c521727
2 changed files with 180 additions and 0 deletions
  1. 26 0
      spiral-rs/src/bin/preprocess_db.rs
  2. 154 0
      spiral-rs/src/server.rs

+ 26 - 0
spiral-rs/src/bin/preprocess_db.rs

@@ -0,0 +1,26 @@
+use std::env;
+use std::fs::File;
+use std::io::Write;
+
+use spiral_rs::server::*;
+use spiral_rs::util::*;
+
+fn main() {
+    let base_params = params_from_json(&CFG_16_100000.replace("'", "\""));
+    let params = &base_params;
+
+    let args: Vec<String> = env::args().collect();
+    let inp_db_path: &String = &args[1];
+    let out_db_path: &String = &args[2];
+
+    let mut inp_file = File::open(inp_db_path).unwrap();
+
+    let db = load_db_from_file(params, &mut inp_file);
+    let db_slice = db.as_slice();
+
+    let mut out_file = File::create(out_db_path).unwrap();
+    for i in 0..db.len() {
+        let coeff = db_slice[i];
+        out_file.write_all(&coeff.to_ne_bytes()).unwrap();
+    }
+}

+ 154 - 0
spiral-rs/src/server.rs

@@ -1,5 +1,11 @@
 #[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
+use std::fs::File;
+use std::io::BufReader;
+use std::io::Read;
+use std::io::Seek;
+use std::io::SeekFrom;
+use std::mem::size_of;
 
 use crate::arith::*;
 use crate::aligned_memory::*;
@@ -358,6 +364,122 @@ pub fn generate_random_db_and_get_item<'a>(
     (item, v)
 }
 
+pub fn load_item_from_file<'a>(
+    params: &'a Params,
+    file: &mut File,
+    instance: usize, 
+    trial: usize,
+    item_idx: usize
+) -> PolyMatrixRaw<'a> {
+    let db_item_size = params.db_item_size;
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+
+    let chunks = instances * trials;
+    let bytes_per_chunk = f64::ceil(db_item_size as f64 / chunks as f64) as usize;
+    let logp = f64::ceil(f64::log2(params.pt_modulus as f64)) as usize;
+    let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_per_chunk <= params.poly_len);
+
+    let idx_item_in_file = item_idx * db_item_size;
+    let idx_chunk = instance * trials + trial;
+    let idx_poly_in_file = idx_item_in_file + idx_chunk * bytes_per_chunk;
+
+    let mut out = PolyMatrixRaw::zero(params, 1, 1);
+
+    let seek_result = file.seek(SeekFrom::Start(idx_poly_in_file as u64));
+    if seek_result.is_err() {
+        return out;
+    }
+    let mut data = vec![0u8; 2 * bytes_per_chunk];
+    let bytes_read = file.read(&mut data.as_mut_slice()[0..bytes_per_chunk]).unwrap();
+
+    let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
+    assert!(modp_words_read <= params.poly_len);
+
+    for i in 0..modp_words_read {
+        out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
+        assert!(out.data[i] <= params.pt_modulus);
+    }
+    
+    out
+}
+
+pub fn load_db_from_file(
+    params: &Params,
+    file: &mut File
+) -> AlignedMemory64 {
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = instances * trials * num_items * params.poly_len;
+    let mut v = AlignedMemory64::new(db_size_words);
+
+    for instance in 0..instances {
+        println!("Instance {:?}", instance);
+        for trial in 0..trials {
+            println!("Trial {:?}", trial);
+            for i in 0..num_items {
+                if i % 8192 == 0 {
+                    println!("item {:?}", i);
+                }
+                let ii = i % num_per;
+                let j = i / num_per;
+
+                let mut db_item = load_item_from_file(params, file, instance, trial, i);
+                // db_item.reduce_mod(params.pt_modulus);
+                
+                for z in 0..params.poly_len {
+                    db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+                }
+
+                let db_item_ntt = db_item.ntt();
+                for z in 0..params.poly_len {
+                    let idx_dst = calc_index(
+                        &[instance, trial, z, ii, j],
+                        &[instances, trials, params.poly_len, num_per, dim0],
+                    );
+
+                    v[idx_dst] = db_item_ntt.data[z]
+                        | (db_item_ntt.data[params.poly_len + z] << PACKED_OFFSET_2);
+                }
+            }
+        }
+    }
+    v
+}
+
+pub fn load_preprocessed_db_from_file(
+    params: &Params,
+    file: &mut File
+) -> AlignedMemory64 {
+    let instances = params.instances;
+    let trials = params.n * params.n;
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let num_items = dim0 * num_per;
+    let db_size_words = instances * trials * num_items * params.poly_len;
+    let mut v = AlignedMemory64::new(db_size_words);
+    let v_mut_slice = v.as_mut_slice();
+
+    let mut reader = BufReader::new(file);
+    let mut buf = [0u8; 8];
+    for i in 0..db_size_words {
+        if i % 1000000000 == 0 {
+            println!("{} GB loaded", i);
+        }
+        reader.read(&mut buf).unwrap();
+        v_mut_slice[i] = u64::from_ne_bytes(buf);
+    }
+
+    v
+}
+
 pub fn fold_ciphertexts(
     params: &Params,
     v_cts: &mut Vec<PolyMatrixRaw>,
@@ -623,6 +745,8 @@ mod test {
     use crate::{client::*};
     use rand::{prelude::SmallRng, Rng};
 
+    const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp"; 
+
     fn get_params() -> Params {
         let mut params = get_expansion_testing_params();
         params.db_dim_1 = 6;
@@ -888,6 +1012,31 @@ mod test {
             assert_eq!(result[z], corr_result[z]);
         }
     }
+
+    fn full_protocol_is_correct_for_params_real_db(params: &Params) {
+        let mut seeded_rng = get_seeded_rng();
+
+        let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+
+        let mut client = Client::init(params, &mut seeded_rng);
+
+        let public_params = client.generate_keys();
+        let query = client.generate_query(target_idx);
+
+        let mut file = File::open(TEST_PREPROCESSED_DB_PATH).unwrap();
+
+        let db = load_preprocessed_db_from_file(params, &mut file);
+
+        let response = process_query(params, &public_params, &query, db.as_slice());
+
+        let result = client.decode_response(response.as_slice());
+
+        let corr_result = vec![0x42, 0x5a, 0x68];
+
+        for z in 0..corr_result.len() {
+            assert_eq!(result[z], corr_result[z]);
+        }
+    }
     
     #[test]
     fn full_protocol_is_correct() {
@@ -903,4 +1052,9 @@ mod test {
     // fn full_protocol_is_correct_16_100000() {
     //     full_protocol_is_correct_for_params(&params_from_json(&CFG_16_100000.replace("'", "\"")));
     // }
+
+    #[test]
+    fn full_protocol_is_correct_real_db_16_100000() {
+        full_protocol_is_correct_for_params_real_db(&params_from_json(&CFG_16_100000.replace("'", "\"")));
+    }
 }