Browse Source

cargo fmt and move

Samir Menon 2 years ago
parent
commit
d3f0dbc366

+ 5 - 5
spiral-rs/benches/server.rs

@@ -1,5 +1,5 @@
-use criterion::BenchmarkGroup;
 use criterion::measurement::WallTime;
+use criterion::BenchmarkGroup;
 use criterion::{black_box, criterion_group, criterion_main, Criterion};
 use pprof::criterion::{Output, PProfProfiler};
 
@@ -41,7 +41,7 @@ fn test_full_processing(group: &mut BenchmarkGroup<WallTime>) {
         let mut seeded_rng = get_seeded_rng();
 
         let target_idx = seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
-        
+
         let mut client = Client::init(&params, &mut seeded_rng);
         let public_params = client.generate_keys();
         let query = client.generate_query(target_idx);
@@ -113,7 +113,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     for i in 0..db_size_words {
         db[i] = seeded_rng.gen();
     }
-    
+
     let v_reg_sz = dim0 * 2 * params.poly_len;
     let mut v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
     for i in 0..v_reg_sz {
@@ -132,8 +132,8 @@ fn criterion_benchmark(c: &mut Criterion) {
                 black_box(db.as_slice()),
                 black_box(v_reg_reoriented.as_slice()),
                 black_box(&params),
-                black_box(dim0), 
-                black_box(num_per)
+                black_box(dim0),
+                black_box(num_per),
             )
         });
     });

+ 3 - 3
spiral-rs/src/arith.rs

@@ -431,11 +431,11 @@ pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 {
     }
     let sign: i64 = if inp_val >= 0 { 1 } else { -1 };
     let val = (inp_val as i128) * (out_mod as i128);
-    let mut result = (val + (sign*(inp_mod_i64/2)) as i128) / (inp_mod as i128);
-    result = (result + ((inp_mod/out_mod)*out_mod) as i128 + (2*out_mod_i128)) % out_mod_i128;
+    let mut result = (val + (sign * (inp_mod_i64 / 2)) as i128) / (inp_mod as i128);
+    result = (result + ((inp_mod / out_mod) * out_mod) as i128 + (2 * out_mod_i128)) % out_mod_i128;
 
     assert!(result >= 0);
-    
+
     ((result + out_mod_i128) % out_mod_i128) as u64
 }
 

+ 0 - 0
spiral-rs/src/main.rs → spiral-rs/src/bin/client.rs


+ 1 - 1
spiral-rs/src/bin/preprocess_db.rs

@@ -23,4 +23,4 @@ fn main() {
         let coeff = db_slice[i];
         out_file.write_all(&coeff.to_ne_bytes()).unwrap();
     }
-}
+}

+ 12 - 2
spiral-rs/src/client.rs

@@ -507,13 +507,23 @@ mod test {
         assert_first8(
             public_params.v_conversion.unwrap()[0].data.as_slice(),
             [
-                122680182, 165987256, 137892309, 95732358, 221787731, 13233184, 156136764, 259944211,
+                122680182, 165987256, 137892309, 95732358, 221787731, 13233184, 156136764,
+                259944211,
             ],
         );
 
         assert_first8(
             client.sk_gsw.data.as_slice(),
-            [66974689739603965, 66974689739603965, 0, 1, 0, 5, 66974689739603967, 2],
+            [
+                66974689739603965,
+                66974689739603965,
+                0,
+                1,
+                0,
+                5,
+                66974689739603967,
+                2,
+            ],
         );
     }
 }

+ 6 - 2
spiral-rs/src/poly.rs

@@ -143,7 +143,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
 impl<'a> Clone for PolyMatrixRaw<'a> {
     fn clone(&self) -> Self {
         let mut data_clone = AlignedMemory64::new(self.data.len());
-        data_clone.as_mut_slice().copy_from_slice(self.data.as_slice());
+        data_clone
+            .as_mut_slice()
+            .copy_from_slice(self.data.as_slice());
         PolyMatrixRaw {
             params: self.params,
             rows: self.rows,
@@ -315,7 +317,9 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
 impl<'a> Clone for PolyMatrixNTT<'a> {
     fn clone(&self) -> Self {
         let mut data_clone = AlignedMemory64::new(self.data.len());
-        data_clone.as_mut_slice().copy_from_slice(self.data.as_slice());
+        data_clone
+            .as_mut_slice()
+            .copy_from_slice(self.data.as_slice());
         PolyMatrixNTT {
             params: self.params,
             rows: self.rows,

+ 93 - 76
spiral-rs/src/server.rs

@@ -7,8 +7,8 @@ use std::io::Seek;
 use std::io::SeekFrom;
 use std::mem::size_of;
 
-use crate::arith::*;
 use crate::aligned_memory::*;
+use crate::arith::*;
 use crate::client::PublicParameters;
 use crate::client::Query;
 use crate::gadget::*;
@@ -206,18 +206,22 @@ pub fn multiply_reg_by_database(
 
                         for idx in 0..4 {
                             let val = sums_out_n0_u64[idx];
-                            sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
+                            sums_out_n0_u64_acc[idx] =
+                                barrett_coeff_u64(params, val + sums_out_n0_u64_acc[idx], 0);
                         }
                         for idx in 0..4 {
                             let val = sums_out_n2_u64[idx];
-                            sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
+                            sums_out_n2_u64_acc[idx] =
+                                barrett_coeff_u64(params, val + sums_out_n2_u64_acc[idx], 1);
                         }
                     }
                 }
 
                 for idx in 0..4 {
-                    sums_out_n0_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
-                    sums_out_n2_u64_acc[idx] = barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
+                    sums_out_n0_u64_acc[idx] =
+                        barrett_coeff_u64(params, sums_out_n0_u64_acc[idx], 0);
+                    sums_out_n2_u64_acc[idx] =
+                        barrett_coeff_u64(params, sums_out_n2_u64_acc[idx], 1);
                 }
 
                 // output n0
@@ -271,7 +275,7 @@ pub fn multiply_reg_by_database(
                 for jm in 0..(dim0 * pt_rows) {
                     let b = db[idx_b_base];
                     idx_b_base += 1;
-                        
+
                     let v_a0 = v_firstdim[idx_a_base + jm * ct_rows];
                     let v_a1 = v_firstdim[idx_a_base + jm * ct_rows + 1];
 
@@ -283,7 +287,7 @@ pub fn multiply_reg_by_database(
 
                     let v_a1_lo = v_a1 as u32;
                     let v_a1_hi = (v_a1 >> 32) as u32;
-                    
+
                     // do n0
                     sums_out_n0_0 += ((v_a0_lo as u64) * (b_lo as u64)) as u128;
                     sums_out_n0_1 += ((v_a1_lo as u64) * (b_lo as u64)) as u128;
@@ -339,13 +343,14 @@ pub fn generate_random_db_and_get_item<'a>(
 
                 let mut db_item = PolyMatrixRaw::random_rng(params, 1, 1, &mut rng);
                 db_item.reduce_mod(params.pt_modulus);
-                
+
                 if i == item_idx && instance == 0 {
                     item.copy_into(&db_item, trial / params.n, trial % params.n);
                 }
 
                 for z in 0..params.poly_len {
-                    db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+                    db_item.data[z] =
+                        recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
                 }
 
                 let db_item_ntt = db_item.ntt();
@@ -367,9 +372,9 @@ pub fn generate_random_db_and_get_item<'a>(
 pub fn load_item_from_file<'a>(
     params: &'a Params,
     file: &mut File,
-    instance: usize, 
+    instance: usize,
     trial: usize,
-    item_idx: usize
+    item_idx: usize,
 ) -> PolyMatrixRaw<'a> {
     let db_item_size = params.db_item_size;
     let instances = params.instances;
@@ -395,7 +400,9 @@ pub fn load_item_from_file<'a>(
         return out;
     }
     let mut data = vec![0u8; 2 * bytes_per_chunk];
-    let bytes_read = file.read(&mut data.as_mut_slice()[0..bytes_per_chunk]).unwrap();
+    let bytes_read = file
+        .read(&mut data.as_mut_slice()[0..bytes_per_chunk])
+        .unwrap();
 
     let modp_words_read = f64::ceil((bytes_read * 8) as f64 / logp as f64) as usize;
     assert!(modp_words_read <= params.poly_len);
@@ -404,14 +411,11 @@ pub fn load_item_from_file<'a>(
         out.data[i] = read_arbitrary_bits(&data, i * logp, logp);
         assert!(out.data[i] <= params.pt_modulus);
     }
-    
+
     out
 }
 
-pub fn load_db_from_file(
-    params: &Params,
-    file: &mut File
-) -> AlignedMemory64 {
+pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     let instances = params.instances;
     let trials = params.n * params.n;
     let dim0 = 1 << params.db_dim_1;
@@ -433,9 +437,10 @@ pub fn load_db_from_file(
 
                 let mut db_item = load_item_from_file(params, file, instance, trial, i);
                 // db_item.reduce_mod(params.pt_modulus);
-                
+
                 for z in 0..params.poly_len {
-                    db_item.data[z] = recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
+                    db_item.data[z] =
+                        recenter_mod(db_item.data[z], params.pt_modulus, params.modulus);
                 }
 
                 let db_item_ntt = db_item.ntt();
@@ -454,10 +459,7 @@ pub fn load_db_from_file(
     v
 }
 
-pub fn load_preprocessed_db_from_file(
-    params: &Params,
-    file: &mut File
-) -> AlignedMemory64 {
+pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     let instances = params.instances;
     let trials = params.n * params.n;
     let dim0 = 1 << params.db_dim_1;
@@ -484,7 +486,7 @@ pub fn fold_ciphertexts(
     params: &Params,
     v_cts: &mut Vec<PolyMatrixRaw>,
     v_folding: &Vec<PolyMatrixNTT>,
-    v_folding_neg: &Vec<PolyMatrixNTT>
+    v_folding_neg: &Vec<PolyMatrixNTT>,
 ) {
     let further_dims = log2(v_cts.len() as u64) as usize;
     let ell = v_folding[0].cols / 2;
@@ -499,10 +501,18 @@ pub fn fold_ciphertexts(
         for i in 0..num_per {
             gadget_invert(&mut ginv_c, &v_cts[i]);
             to_ntt(&mut ginv_c_ntt, &ginv_c);
-            multiply(&mut prod, &v_folding_neg[further_dims - 1 - cur_dim], &ginv_c_ntt);
+            multiply(
+                &mut prod,
+                &v_folding_neg[further_dims - 1 - cur_dim],
+                &ginv_c_ntt,
+            );
             gadget_invert(&mut ginv_c, &v_cts[num_per + i]);
             to_ntt(&mut ginv_c_ntt, &ginv_c);
-            multiply(&mut sum, &v_folding[further_dims - 1 - cur_dim], &ginv_c_ntt);
+            multiply(
+                &mut sum,
+                &v_folding[further_dims - 1 - cur_dim],
+                &ginv_c_ntt,
+            );
             add_into(&mut sum, &prod);
             from_ntt(&mut v_cts[i], &sum);
         }
@@ -512,7 +522,7 @@ pub fn fold_ciphertexts(
 pub fn pack<'a>(
     params: &'a Params,
     v_ct: &Vec<PolyMatrixRaw>,
-    v_w: &Vec<PolyMatrixNTT>
+    v_w: &Vec<PolyMatrixNTT>,
 ) -> PolyMatrixNTT<'a> {
     assert!(v_ct.len() >= params.n * params.n);
     assert!(v_w.len() == params.n);
@@ -550,20 +560,15 @@ pub fn pack<'a>(
     result
 }
 
-pub fn encode(
-    params: &Params,
-    v_packed_ct: &Vec<PolyMatrixRaw>
-) -> Vec<u8> {
+pub fn encode(params: &Params, v_packed_ct: &Vec<PolyMatrixRaw>) -> Vec<u8> {
     let q1 = 4 * params.pt_modulus;
     let q1_bits = log2_ceil(q1) as usize;
     let q2 = Q2_VALUES[params.q2_bits as usize];
     let q2_bits = params.q2_bits as usize;
 
-    let num_bits = params.instances * 
-        (
-            (q2_bits * params.n * params.poly_len) + 
-            (q1_bits * params.n * params.n * params.poly_len)
-        );
+    let num_bits = params.instances
+        * ((q2_bits * params.n * params.poly_len)
+            + (q1_bits * params.n * params.n * params.poly_len));
     let round_to = 64;
     let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
 
@@ -571,11 +576,11 @@ pub fn encode(
     let mut bit_offs = 0;
     for instance in 0..params.instances {
         let packed_ct = &v_packed_ct[instance];
-        
+
         let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
         let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
-        first_row.apply_func(|x| { rescale(x, params.modulus, q2) });
-        rest_rows.apply_func(|x| { rescale(x, params.modulus, q1) });
+        first_row.apply_func(|x| rescale(x, params.modulus, q2));
+        rest_rows.apply_func(|x| rescale(x, params.modulus, q1));
 
         let data = result.as_mut_slice();
         for i in 0..params.n * params.poly_len {
@@ -591,7 +596,7 @@ pub fn encode(
 }
 
 pub fn get_v_folding_neg<'a>(
-    params: &'a Params, 
+    params: &'a Params,
     v_folding: &Vec<PolyMatrixNTT>,
 ) -> Vec<PolyMatrixNTT<'a>> {
     let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt(); // TODO: make this better
@@ -608,8 +613,8 @@ pub fn get_v_folding_neg<'a>(
 }
 
 pub fn expand_query<'a>(
-    params: &'a Params, 
-    public_params: &PublicParameters<'a>, 
+    params: &'a Params,
+    public_params: &PublicParameters<'a>,
     query: &Query<'a>,
 ) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>) {
     let dim0 = 1 << params.db_dim_1;
@@ -664,13 +669,13 @@ pub fn expand_query<'a>(
     }
 
     regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
-    
+
     (v_reg_reoriented, v_folding)
 }
 
 pub fn process_query(
-    params: &Params, 
-    public_params: &PublicParameters, 
+    params: &Params,
+    public_params: &PublicParameters,
     query: &Query,
     db: &[u64],
 ) -> Vec<u8> {
@@ -683,14 +688,20 @@ pub fn process_query(
     let mut v_reg_reoriented;
     let v_folding;
     if params.expand_queries {
-        (v_reg_reoriented, v_folding) = 
-            expand_query(params, public_params, query);
+        (v_reg_reoriented, v_folding) = expand_query(params, public_params, query);
     } else {
         v_reg_reoriented = AlignedMemory64::new(query.v_buf.as_ref().unwrap().len());
-        v_reg_reoriented.as_mut_slice().copy_from_slice(query.v_buf.as_ref().unwrap());
-
-        v_folding = query.v_ct.as_ref().unwrap().clone().iter()
-            .map(|x| { x.ntt() })
+        v_reg_reoriented
+            .as_mut_slice()
+            .copy_from_slice(query.v_buf.as_ref().unwrap());
+
+        v_folding = query
+            .v_ct
+            .as_ref()
+            .unwrap()
+            .clone()
+            .iter()
+            .map(|x| x.ntt())
             .collect();
     }
     let v_folding_neg = get_v_folding_neg(params, &v_folding);
@@ -711,27 +722,25 @@ pub fn process_query(
             let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
             let cur_db = &db[idx..(idx + db_slice_sz)];
 
-            multiply_reg_by_database(&mut intermediate, cur_db, v_reg_reoriented.as_slice(), params, dim0, num_per);
+            multiply_reg_by_database(
+                &mut intermediate,
+                cur_db,
+                v_reg_reoriented.as_slice(),
+                params,
+                dim0,
+                num_per,
+            );
 
             for i in 0..intermediate.len() {
                 from_ntt(&mut intermediate_raw[i], &intermediate[i]);
             }
 
-            fold_ciphertexts(
-                params,
-                &mut intermediate_raw,
-                &v_folding,
-                &v_folding_neg
-            );
+            fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
 
             v_ct.push(intermediate_raw[0].clone());
         }
 
-        let packed_ct = pack(
-            params,
-            &v_ct,
-            &v_packing,
-        );
+        let packed_ct = pack(params, &v_ct, &v_packing);
 
         v_packed_ct.push(packed_ct.raw());
     }
@@ -742,10 +751,10 @@ pub fn process_query(
 #[cfg(test)]
 mod test {
     use super::*;
-    use crate::{client::*};
+    use crate::client::*;
     use rand::{prelude::SmallRng, Rng};
 
-    const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp"; 
+    const TEST_PREPROCESSED_DB_PATH: &'static str = "/home/samir/wiki/enwiki-20220320.dbp";
 
     fn get_params() -> Params {
         let mut params = get_expansion_testing_params();
@@ -906,7 +915,14 @@ mod test {
         for _ in 0..dim0 {
             out.push(PolyMatrixNTT::zero(&params, 2, 1));
         }
-        multiply_reg_by_database(&mut out, db.as_slice(), v_reg_reoriented.as_slice(), &params, dim0, num_per);
+        multiply_reg_by_database(
+            &mut out,
+            db.as_slice(),
+            v_reg_reoriented.as_slice(),
+            &params,
+            dim0,
+            num_per,
+        );
 
         // decrypt
         let dec = client.decrypt_matrix_reg(&out[target_idx_num_per]).raw();
@@ -978,15 +994,13 @@ mod test {
             v_folding_neg.push(ct_gsw_neg);
         }
 
-        fold_ciphertexts(
-            &params,
-            &mut v_reg_raw,
-            &v_folding,
-            &v_folding_neg
-        );
-        
+        fold_ciphertexts(&params, &mut v_reg_raw, &v_folding, &v_folding_neg);
+
         // decrypt
-        assert_eq!(dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
+        assert_eq!(
+            dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k),
+            1
+        );
     }
 
     fn full_protocol_is_correct_for_params(params: &Params) {
@@ -1037,7 +1051,7 @@ mod test {
             assert_eq!(result[z], corr_result[z]);
         }
     }
-    
+
     #[test]
     fn full_protocol_is_correct() {
         full_protocol_is_correct_for_params(&get_params());
@@ -1054,7 +1068,10 @@ mod test {
     // }
 
     #[test]
+    #[ignore]
     fn full_protocol_is_correct_real_db_16_100000() {
-        full_protocol_is_correct_for_params_real_db(&params_from_json(&CFG_16_100000.replace("'", "\"")));
+        full_protocol_is_correct_for_params_real_db(&params_from_json(
+            &CFG_16_100000.replace("'", "\""),
+        ));
     }
 }

+ 1 - 1
spiral-rs/src/util.rs

@@ -1,5 +1,5 @@
 use crate::{arith::*, params::*, poly::*};
-use rand::{prelude::{SmallRng}, SeedableRng, thread_rng, Rng};
+use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
 use serde_json::Value;
 
 pub const CFG_20_256: &'static str = r#"