Samir Menon 2 years ago
parent
commit
ad2700b24a
2 changed files with 60 additions and 44 deletions
  1. 17 5
      spiral-rs/src/client.rs
  2. 43 39
      spiral-rs/src/server.rs

+ 17 - 5
spiral-rs/src/client.rs

@@ -159,7 +159,7 @@ impl<'a> Query<'a> {
         }
     }
 
-    pub fn serialize(&self) -> Vec<u8> {        
+    pub fn serialize(&self) -> Vec<u8> {
         let mut data = Vec::new();
         if self.ct.is_some() {
             let ct = self.ct.as_ref().unwrap();
@@ -631,10 +631,22 @@ mod test {
         let serialized2 = deserialized1.serialize();
 
         assert_eq!(serialized1, serialized2);
-        assert_eq!(get_vec(&pub_params.v_packing), get_vec(&deserialized1.v_packing));
-        assert_eq!(get_vec(&pub_params.v_conversion.unwrap()), get_vec(&deserialized1.v_conversion.unwrap()));
-        assert_eq!(get_vec(&pub_params.v_expansion_left.unwrap()), get_vec(&deserialized1.v_expansion_left.unwrap()));
-        assert_eq!(get_vec(&pub_params.v_expansion_right.unwrap()), get_vec(&deserialized1.v_expansion_right.unwrap()));
+        assert_eq!(
+            get_vec(&pub_params.v_packing),
+            get_vec(&deserialized1.v_packing)
+        );
+        assert_eq!(
+            get_vec(&pub_params.v_conversion.unwrap()),
+            get_vec(&deserialized1.v_conversion.unwrap())
+        );
+        assert_eq!(
+            get_vec(&pub_params.v_expansion_left.unwrap()),
+            get_vec(&deserialized1.v_expansion_left.unwrap())
+        );
+        assert_eq!(
+            get_vec(&pub_params.v_expansion_right.unwrap()),
+            get_vec(&deserialized1.v_expansion_right.unwrap())
+        );
     }
 
     #[test]

+ 43 - 39
spiral-rs/src/server.rs

@@ -50,7 +50,10 @@ pub fn coefficient_expansion(
 
         for i in 0..num_out {
             if (stop_round > 0 && r > stop_round && (i % 2) == 1)
-                || (stop_round > 0 && r == stop_round && (i % 2) == 1 && (i / 2) >= max_bits_to_gen_right)
+                || (stop_round > 0
+                    && r == stop_round
+                    && (i % 2) == 1
+                    && (i / 2) >= max_bits_to_gen_right)
             {
                 continue;
             }
@@ -457,14 +460,12 @@ pub fn load_db_from_file(params: &Params, file: &mut File) -> AlignedMemory64 {
     v
 }
 
-pub fn load_file_unsafe(data: &mut[u64], file: &mut File) {
-    let data_as_u8_mut = unsafe {
-        data.align_to_mut::<u8>().1
-    };
+pub fn load_file_unsafe(data: &mut [u64], file: &mut File) {
+    let data_as_u8_mut = unsafe { data.align_to_mut::<u8>().1 };
     file.read_exact(data_as_u8_mut).unwrap();
 }
 
-pub fn load_file(data: &mut[u64], file: &mut File) {
+pub fn load_file(data: &mut [u64], file: &mut File) {
     let mut reader = BufReader::with_capacity(1 << 24, file);
     let mut buf = [0u8; 8];
     for i in 0..data.len() {
@@ -482,7 +483,7 @@ pub fn load_preprocessed_db_from_file(params: &Params, file: &mut File) -> Align
     let db_size_words = instances * trials * num_items * params.poly_len;
     let mut v = AlignedMemory64::new(db_size_words);
     let v_mut_slice = v.as_mut_slice();
-    
+
     let now = Instant::now();
     load_file(v_mut_slice, file);
     println!("Done loading ({} ms).", now.elapsed().as_millis());
@@ -713,42 +714,45 @@ pub fn process_query(
     }
     let v_folding_neg = get_v_folding_neg(params, &v_folding);
 
-    let v_packed_ct = (0..params.instances).into_par_iter().map(|instance| {
-        let mut intermediate = Vec::with_capacity(num_per);
-        let mut intermediate_raw = Vec::with_capacity(num_per);
-        for _ in 0..num_per {
-            intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
-            intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
-        }
-    
-        let mut v_ct = Vec::new();
-
-        for trial in 0..(params.n * params.n) {
-            let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
-            let cur_db = &db[idx..(idx + db_slice_sz)];
-
-            multiply_reg_by_database(
-                &mut intermediate,
-                cur_db,
-                v_reg_reoriented.as_slice(),
-                params,
-                dim0,
-                num_per,
-            );
-
-            for i in 0..intermediate.len() {
-                from_ntt(&mut intermediate_raw[i], &intermediate[i]);
+    let v_packed_ct = (0..params.instances)
+        .into_par_iter()
+        .map(|instance| {
+            let mut intermediate = Vec::with_capacity(num_per);
+            let mut intermediate_raw = Vec::with_capacity(num_per);
+            for _ in 0..num_per {
+                intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
+                intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
             }
 
-            fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
+            let mut v_ct = Vec::new();
 
-            v_ct.push(intermediate_raw[0].clone());
-        }
+            for trial in 0..(params.n * params.n) {
+                let idx = (instance * (params.n * params.n) + trial) * db_slice_sz;
+                let cur_db = &db[idx..(idx + db_slice_sz)];
+
+                multiply_reg_by_database(
+                    &mut intermediate,
+                    cur_db,
+                    v_reg_reoriented.as_slice(),
+                    params,
+                    dim0,
+                    num_per,
+                );
+
+                for i in 0..intermediate.len() {
+                    from_ntt(&mut intermediate_raw[i], &intermediate[i]);
+                }
+
+                fold_ciphertexts(params, &mut intermediate_raw, &v_folding, &v_folding_neg);
+
+                v_ct.push(intermediate_raw[0].clone());
+            }
 
-        let packed_ct = pack(params, &v_ct, &v_packing);
+            let packed_ct = pack(params, &v_ct, &v_packing);
 
-        packed_ct.raw()
-    }).collect();
+            packed_ct.raw()
+        })
+        .collect();
 
     encode(params, &v_packed_ct)
 }
@@ -1007,7 +1011,7 @@ mod test {
     fn full_protocol_is_correct_for_params(params: &Params) {
         let mut seeded_rng = get_seeded_rng();
 
-        let target_idx = 22456;//22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
+        let target_idx = 22456; //22456;//seeded_rng.gen::<usize>() % (params.db_dim_1 + params.db_dim_2);
 
         let mut client = Client::init(params, &mut seeded_rng);