Browse Source

cargo fmt

Samir Menon 2 years ago
parent
commit
96e7ce01be

+ 4 - 5
client/src/lib.rs

@@ -1,7 +1,7 @@
 mod utils;
 
+use spiral_rs::{client::*, discrete_gaussian::*, params::*, util::*};
 use wasm_bindgen::prelude::*;
-use spiral_rs::{params::*, util::*, client::*, discrete_gaussian::*};
 
 const UUID_V4_LEN: usize = 36;
 
@@ -19,7 +19,7 @@ macro_rules! console_log {
 // Avoids a lifetime in the return signature of bound Rust functions
 #[wasm_bindgen]
 pub struct WrappedClient {
-    client: Client<'static>
+    client: Client<'static>,
 }
 
 // Unsafe global with a static lifetime
@@ -74,8 +74,8 @@ pub fn initialize(json_params: Option<String>) -> WrappedClient {
         PARAMS = params_from_json(&cfg);
         client = Client::init(&PARAMS);
     }
-    
-    WrappedClient { client } 
+
+    WrappedClient { client }
 }
 
 #[wasm_bindgen]
@@ -93,7 +93,6 @@ pub fn generate_query(c: &mut WrappedClient, id: &str, idx_target: usize) -> Box
     full_query_buf.into_boxed_slice()
 }
 
-
 #[wasm_bindgen]
 pub fn decode_response(c: &mut WrappedClient, data: Box<[u8]>) -> Box<[u8]> {
     c.client.decode_response(&*data).into_boxed_slice()

+ 8 - 4
spiral-rs/benches/ntt.rs

@@ -1,7 +1,7 @@
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use rand::Rng;
 use spiral_rs::ntt::*;
 use spiral_rs::util::*;
-use rand::Rng;
-use criterion::{black_box, criterion_group, criterion_main, Criterion};
 
 fn criterion_benchmark(c: &mut Criterion) {
     let params = get_test_params();
@@ -14,8 +14,12 @@ fn criterion_benchmark(c: &mut Criterion) {
             v1[idx] = val % params.moduli[i];
         }
     }
-    c.bench_function("nttf 2048", |b| b.iter(|| ntt_forward(black_box(&params), black_box(&mut v1))));
-    c.bench_function("ntti 2048", |b| b.iter(|| ntt_inverse(black_box(&params), black_box(&mut v1))));
+    c.bench_function("nttf 2048", |b| {
+        b.iter(|| ntt_forward(black_box(&params), black_box(&mut v1)))
+    });
+    c.bench_function("ntti 2048", |b| {
+        b.iter(|| ntt_inverse(black_box(&params), black_box(&mut v1)))
+    });
 }
 
 criterion_group!(benches, criterion_benchmark);

+ 4 - 2
spiral-rs/benches/poly.rs

@@ -1,13 +1,15 @@
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
 use spiral_rs::poly::*;
 use spiral_rs::util::*;
-use criterion::{black_box, criterion_group, criterion_main, Criterion};
 
 fn criterion_benchmark(c: &mut Criterion) {
     let params = get_test_params();
     let m1 = PolyMatrixNTT::random(&params, 2, 1);
     let m2 = PolyMatrixNTT::random(&params, 3, 2);
     let mut m3 = PolyMatrixNTT::zero(&params, 2, 2);
-    c.bench_function("nttf 2048", |b| b.iter(|| multiply(black_box(&mut m3), black_box(&m1), black_box(&m2))));
+    c.bench_function("nttf 2048", |b| {
+        b.iter(|| multiply(black_box(&mut m3), black_box(&m1), black_box(&m2)))
+    });
 }
 
 criterion_group!(benches, criterion_benchmark);

+ 2 - 2
spiral-rs/src/arith.rs

@@ -94,10 +94,10 @@ pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
     let to_modulus_i64 = to_modulus as i64;
 
     let mut a_val = val as i64;
-    if val >= from_modulus/2 {
+    if val >= from_modulus / 2 {
         a_val -= from_modulus_i64;
     }
-    a_val = a_val + (from_modulus_i64/to_modulus_i64)*to_modulus_i64 + 2*to_modulus_i64;
+    a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64;
     a_val %= to_modulus_i64;
     a_val as u64
 }

+ 70 - 48
spiral-rs/src/client.rs

@@ -1,4 +1,6 @@
-use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*, arith::*, util::*, number_theory::*};
+use crate::{
+    arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
+};
 use std::iter::once;
 
 fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
@@ -14,24 +16,24 @@ fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
 }
 
 pub struct PublicParameters<'a> {
-    v_packing: Vec<PolyMatrixNTT<'a>>,                  // Ws
+    v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
     v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
     v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
-    v_conversion: Option<Vec<PolyMatrixNTT<'a>>>,              // V
+    v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
 }
 
 impl<'a> PublicParameters<'a> {
     pub fn init(params: &'a Params) -> Self {
         if params.expand_queries {
-            PublicParameters { 
-                v_packing: Vec::new(), 
-                v_expansion_left: Some(Vec::new()), 
-                v_expansion_right: Some(Vec::new()), 
-                v_conversion: Some(Vec::new())
+            PublicParameters {
+                v_packing: Vec::new(),
+                v_expansion_left: Some(Vec::new()),
+                v_expansion_right: Some(Vec::new()),
+                v_conversion: Some(Vec::new()),
             }
         } else {
-            PublicParameters { 
-                v_packing: Vec::new(), 
+            PublicParameters {
+                v_packing: Vec::new(),
                 v_expansion_left: None,
                 v_expansion_right: None,
                 v_conversion: None,
@@ -43,7 +45,9 @@ impl<'a> PublicParameters<'a> {
         Some(v.iter().map(from_ntt_alloc).collect())
     }
 
-    fn from_ntt_alloc_opt_vec(v: &Option<Vec<PolyMatrixNTT<'a>>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
+    fn from_ntt_alloc_opt_vec(
+        v: &Option<Vec<PolyMatrixNTT<'a>>>,
+    ) -> Option<Vec<PolyMatrixRaw<'a>>> {
         Some(v.as_ref()?.iter().map(from_ntt_alloc).collect())
     }
 
@@ -75,7 +79,11 @@ pub struct Query<'a> {
 
 impl<'a> Query<'a> {
     pub fn empty() -> Self {
-        Query { ct: None, v_ct: None, v_buf: None }
+        Query {
+            ct: None,
+            v_ct: None,
+            v_buf: None,
+        }
     }
 
     pub fn serialize(&self) -> Vec<u8> {
@@ -86,7 +94,7 @@ impl<'a> Query<'a> {
         }
         if self.v_buf.is_some() {
             let v_buf = self.v_buf.as_ref().unwrap();
-            data.extend(v_buf.iter().map(|x| { x.to_ne_bytes() }).flatten());
+            data.extend(v_buf.iter().map(|x| x.to_ne_bytes()).flatten());
         }
         if self.v_ct.is_some() {
             let v_ct = self.v_ct.as_ref().unwrap();
@@ -109,7 +117,7 @@ pub struct Client<'a> {
     stop_round: usize,
 }
 
-fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
+fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
     assert_eq!(p.cols, 1);
     let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
     r.copy_into(p, 0, 0);
@@ -119,8 +127,8 @@ fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
 
 fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
     Params::init(
-        params.poly_len, 
-        moduli, 
+        params.poly_len,
+        moduli,
         params.noise_width,
         params.n,
         params.pt_modulus,
@@ -206,14 +214,18 @@ impl<'a> Client<'a> {
         let res = &(p.ntt()) + &(ag.pad_top(1));
         res
     }
-    
+
     fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
         let m = a.cols;
         let p = self.get_fresh_reg_public_key(m);
         &p + &a.pad_top(1)
     }
 
-    fn generate_expansion_params(&mut self, num_exp: usize, m_exp: usize) -> Vec<PolyMatrixNTT<'a>> {
+    fn generate_expansion_params(
+        &mut self,
+        num_exp: usize,
+        m_exp: usize,
+    ) -> Vec<PolyMatrixNTT<'a>> {
         let params = self.params;
         let g_exp = build_gadget(params, 1, m_exp);
         let g_exp_ntt = g_exp.ntt();
@@ -239,7 +251,7 @@ impl<'a> Client<'a> {
         let m_conv = params.m_conv();
 
         let mut pp = PublicParameters::init(params);
-        
+
         // Params for packing
         let gadget_conv = build_gadget(params, 1, m_conv);
         let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
@@ -253,16 +265,21 @@ impl<'a> Client<'a> {
 
         if params.expand_queries {
             // Params for expansion
-            
+
             pp.v_expansion_left = Some(self.generate_expansion_params(self.g, params.t_exp_left));
-            pp.v_expansion_right = Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
+            pp.v_expansion_right =
+                Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
 
             // Params for converison
             let g_conv = build_gadget(params, 2, 2 * m_conv);
             let sk_reg_ntt = self.sk_reg.ntt();
             let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt;
-            pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(params, 2, 2 * m_conv))));
-            for i in 0..2*m_conv {
+            pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(
+                params,
+                2,
+                2 * m_conv,
+            ))));
+            for i in 0..2 * m_conv {
                 let sigma;
                 if i % 2 == 0 {
                     let val = g_conv.get_poly(0, i)[0];
@@ -288,7 +305,7 @@ impl<'a> Client<'a> {
         assert_eq!(crt_count, 2);
         assert!(log2(params.moduli[0]) <= 32);
 
-        let num_reg_expanded = 1<<params.db_dim_1;
+        let num_reg_expanded = 1 << params.db_dim_1;
         let ct_rows = v_reg[0].rows;
         let ct_cols = v_reg[0].cols;
 
@@ -297,13 +314,14 @@ impl<'a> Client<'a> {
 
         for j in 0..num_reg_expanded {
             for r in 0..ct_rows {
-                for m in 0.. ct_cols {
+                for m in 0..ct_cols {
                     for z in 0..params.poly_len {
-                        let idx_a_in = r * (ct_cols*crt_count*poly_len) + m * (crt_count*poly_len);
-                        let idx_a_out =     z * (num_reg_expanded*ct_cols*ct_rows)
-                                                + j * (ct_cols*ct_rows)
-                                                + m * (ct_rows)
-                                                + r;
+                        let idx_a_in =
+                            r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
+                        let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
+                            + j * (ct_cols * ct_rows)
+                            + m * (ct_rows)
+                            + r;
                         let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
                         let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
 
@@ -317,8 +335,8 @@ impl<'a> Client<'a> {
     pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
         let params = self.params;
         let further_dims = params.db_dim_2;
-        let idx_dim0= idx_target / (1 << further_dims);
-        let idx_further  = idx_target % (1 << further_dims);
+        let idx_dim0 = idx_target / (1 << further_dims);
+        let idx_further = idx_target % (1 << further_dims);
         let scale_k = params.modulus / params.pt_modulus;
         let bits_per = get_bits_per(params, params.t_gsw);
 
@@ -326,24 +344,28 @@ impl<'a> Client<'a> {
         if params.expand_queries {
             // pack query into single ciphertext
             let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
-            sigma.data[2*idx_dim0] = scale_k;
+            sigma.data[2 * idx_dim0] = scale_k;
             for i in 0..further_dims as u64 {
                 let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
                 for j in 0..params.t_gsw {
                     let val = (1u64 << (bits_per * j)) * bit;
                     let idx = (i as usize) * params.t_gsw + (j as usize);
-                    sigma.data[2*idx + 1] = val;
+                    sigma.data[2 * idx + 1] = val;
                 }
             }
             let inv_2_g_first = invert_uint_mod(1 << self.g, params.modulus).unwrap();
-            let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round+1), params.modulus).unwrap();
+            let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round + 1), params.modulus).unwrap();
 
-            for i in 0..params.poly_len/2 {
-                sigma.data[2*i]   = multiply_uint_mod(sigma.data[2*i], inv_2_g_first, params.modulus);
-                sigma.data[2*i+1] = multiply_uint_mod(sigma.data[2*i+1], inv_2_g_rest, params.modulus);
+            for i in 0..params.poly_len / 2 {
+                sigma.data[2 * i] =
+                    multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
+                sigma.data[2 * i + 1] =
+                    multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
             }
 
-            query.ct = Some(from_ntt_alloc(&self.encrypt_matrix_reg(&to_ntt_alloc(&sigma))));
+            query.ct = Some(from_ntt_alloc(
+                &self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)),
+            ));
         } else {
             let num_expanded = 1 << params.db_dim_1;
             let mut sigma_v = Vec::<PolyMatrixNTT>::new();
@@ -370,20 +392,20 @@ impl<'a> Client<'a> {
                     let sigma = PolyMatrixRaw::single_value(&params, value);
                     let sigma_ntt = to_ntt_alloc(&sigma);
                     let ct = &self.encrypt_matrix_reg(&sigma_ntt);
-                    ct_gsw.copy_into(ct, 0, 2*j + 1);
+                    ct_gsw.copy_into(ct, 0, 2 * j + 1);
                     let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
                     let ct = &self.encrypt_matrix_reg(&prod);
-                    ct_gsw.copy_into(ct, 0, 2*j);
+                    ct_gsw.copy_into(ct, 0, 2 * j);
                 }
                 sigma_v.push(ct_gsw);
             }
 
             query.v_buf = Some(reg_cts_buf);
-            query.v_ct = Some(sigma_v.iter().map(|x| { from_ntt_alloc(x) }).collect());
+            query.v_ct = Some(sigma_v.iter().map(|x| from_ntt_alloc(x)).collect());
         }
         query
     }
-    
+
     pub fn decode_response(&self, data: &[u8]) -> Vec<u8> {
         /*
             0. NTT over q2 the secret key
@@ -438,11 +460,11 @@ impl<'a> Client<'a> {
             let p_i128 = p as i128;
             for i in 0..params.n * params.n * params.poly_len {
                 let mut val_first = sk_prod.data[i] as i64;
-                if val_first >= q2_i64/2 {
+                if val_first >= q2_i64 / 2 {
                     val_first -= q2_i64;
                 }
                 let mut val_rest = rest_rows.data[i] as i64;
-                if val_rest >= q1_i64/2 {
+                if val_rest >= q1_i64 / 2 {
                     val_rest -= q1_i64;
                 }
 
@@ -453,8 +475,8 @@ impl<'a> Client<'a> {
 
                 // divide r by q2, rounding
                 let sign: i64 = if r >= 0 { 1 } else { -1 };
-                let mut res = ((r + sign*(denom/2)) as i128) / (denom as i128);
-                res = (res + (denom as i128/p_i128)*(p_i128) + 2*(p_i128)) % (p_i128);
+                let mut res = ((r + sign * (denom / 2)) as i128) / (denom as i128);
+                res = (res + (denom as i128 / p_i128) * (p_i128) + 2 * (p_i128)) % (p_i128);
                 let idx = instance * params.n * params.n * params.poly_len + i;
                 result.data[idx] = res as u64;
             }
@@ -468,4 +490,4 @@ impl<'a> Client<'a> {
         let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
         result.to_vec(p_bits as usize, modp_words_per_chunk)
     }
-}
+}

+ 6 - 6
spiral-rs/src/discrete_gaussian.rs

@@ -1,6 +1,6 @@
-use rand::prelude::Distribution;
-use rand::{thread_rng, rngs::ThreadRng};
 use rand::distributions::WeightedIndex;
+use rand::prelude::Distribution;
+use rand::{rngs::ThreadRng, thread_rng};
 
 use crate::params::*;
 use crate::poly::*;
@@ -11,7 +11,7 @@ pub const NUM_WIDTHS: usize = 8;
 pub struct DiscreteGaussian {
     choices: Vec<i64>,
     dist: WeightedIndex<f64>,
-    rng: ThreadRng
+    rng: ThreadRng,
 }
 
 impl DiscreteGaussian {
@@ -19,7 +19,7 @@ impl DiscreteGaussian {
         let max_val = (params.noise_width * (NUM_WIDTHS as f64)).ceil() as i64;
         let mut choices = Vec::new();
         let mut table = vec![0f64; 0];
-        for i in -max_val..max_val+1 {
+        for i in -max_val..max_val + 1 {
             let p_val = f64::exp(-PI * f64::powi(i as f64, 2) / f64::powi(params.noise_width, 2));
             choices.push(i);
             table.push(p_val);
@@ -29,7 +29,7 @@ impl DiscreteGaussian {
         Self {
             choices,
             dist,
-            rng: thread_rng()
+            rng: thread_rng(),
         }
     }
 
@@ -76,4 +76,4 @@ mod test {
         let std_dev_of_mean = std_dev / f64::sqrt(trials as f64);
         assert!(f64::abs(mean) < std_dev_of_mean * 5f64);
     }
-}
+}

+ 6 - 4
spiral-rs/src/gadget.rs

@@ -1,8 +1,10 @@
-use crate::{poly::*, params::*};
+use crate::{params::*, poly::*};
 
-pub fn get_bits_per(params: &Params, dim: usize) -> usize{
+pub fn get_bits_per(params: &Params, dim: usize) -> usize {
     let modulus_log2 = params.modulus_log2;
-    if dim as u64 == modulus_log2 { return 1; }
+    if dim as u64 == modulus_log2 {
+        return 1;
+    }
     ((modulus_log2 as f64) / (dim as f64)).floor() as usize + 1
 }
 
@@ -27,4 +29,4 @@ pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw
         }
     }
     g
-}
+}

+ 4 - 4
spiral-rs/src/lib.rs

@@ -1,11 +1,11 @@
-pub mod util;
 pub mod arith;
-pub mod number_theory;
 pub mod discrete_gaussian;
+pub mod number_theory;
+pub mod util;
 
-pub mod ntt;
 pub mod gadget;
+pub mod ntt;
 pub mod params;
 pub mod poly;
 
-pub mod client;
+pub mod client;

+ 17 - 11
spiral-rs/src/main.rs

@@ -1,14 +1,16 @@
 use serde_json::Value;
-use spiral_rs::util::*;
 use spiral_rs::client::*;
+use spiral_rs::util::*;
 use std::env;
 use std::time::Instant;
 
 fn send_api_req_text(path: &str, data: Vec<u8>) -> Option<String> {
     let client = reqwest::blocking::Client::builder()
         .timeout(None)
-        .build().unwrap();
-    client.post(format!("https://spiralwiki.com:8088{}", path))
+        .build()
+        .unwrap();
+    client
+        .post(format!("https://spiralwiki.com:8088{}", path))
         .body(data)
         .send()
         .ok()?
@@ -19,14 +21,18 @@ fn send_api_req_text(path: &str, data: Vec<u8>) -> Option<String> {
 fn send_api_req_vec(path: &str, data: Vec<u8>) -> Option<Vec<u8>> {
     let client = reqwest::blocking::Client::builder()
         .timeout(None)
-        .build().unwrap();
-    Some(client.post(format!("https://spiralwiki.com:8088{}", path))
-        .body(data)
-        .send()
-        .ok()?
-        .bytes()
-        .ok()?
-        .to_vec())
+        .build()
+        .unwrap();
+    Some(
+        client
+            .post(format!("https://spiralwiki.com:8088{}", path))
+            .body(data)
+            .send()
+            .ok()?
+            .bytes()
+            .ok()?
+            .to_vec(),
+    )
 }
 
 fn main() {

+ 42 - 41
spiral-rs/src/ntt.rs

@@ -1,11 +1,7 @@
 #[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
 
-use crate::{
-    arith::*,
-    number_theory::*,
-    params::*,
-};
+use crate::{arith::*, number_theory::*, params::*};
 
 pub fn powers_of_primitive_root(root: u64, modulus: u64, poly_len_log2: usize) -> Vec<u64> {
     let poly_len = 1usize << poly_len_log2;
@@ -51,7 +47,7 @@ pub fn build_ntt_tables(poly_len: usize, moduli: &[u64]) -> Vec<Vec<Vec<u64>>> {
 
         let root_powers = powers_of_primitive_root(root, modulus, poly_len_log2);
         let scaled_root_powers = scale_powers_u32(modulus_as_u32, poly_len, root_powers.as_slice());
-        let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2);        
+        let mut inv_root_powers = powers_of_primitive_root(inv_root, modulus, poly_len_log2);
         for i in 0..poly_len {
             inv_root_powers[i] = div2_uint_mod(inv_root_powers[i], modulus);
         }
@@ -74,7 +70,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
     let n = 1 << log_n;
 
     for coeff_mod in 0..params.crt_count {
-        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+        let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
 
         let forward_table = params.get_ntt_forward_table(coeff_mod);
         let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
@@ -88,27 +84,29 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
             let mut it = operand.chunks_exact_mut(2 * t);
 
             for i in 0..m {
-                let w = forward_table[m+i];
-                let w_prime = forward_table_prime[m+i];
-                
+                let w = forward_table[m + i];
+                let w_prime = forward_table_prime[m + i];
+
                 let op = it.next().unwrap();
 
                 for j in 0..t {
                     let x: u32 = op[j] as u32;
                     let y: u32 = op[t + j] as u32;
 
-                    let curr_x: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+                    let curr_x: u32 =
+                        x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
                     let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
                     let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
 
                     op[j] = curr_x as u64 + q_new;
-                    op[t + j] = curr_x as u64 +  ((two_times_modulus_small as u64) - q_new);
+                    op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new);
                 }
             }
         }
 
         for i in 0..n {
-            operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64) * two_times_modulus_small as u64;
+            operand[i] -= ((operand[i] >= two_times_modulus_small as u64) as u64)
+                * two_times_modulus_small as u64;
             operand[i] -= ((operand[i] >= modulus_small as u64) as u64) * modulus_small as u64;
         }
     }
@@ -120,7 +118,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
     let n = 1 << log_n;
 
     for coeff_mod in 0..params.crt_count {
-        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+        let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
 
         let forward_table = params.get_ntt_forward_table(coeff_mod);
         let forward_table_prime = params.get_ntt_forward_prime_table(coeff_mod);
@@ -134,22 +132,23 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
             let mut it = operand.chunks_exact_mut(2 * t);
 
             for i in 0..m {
-                let w = forward_table[m+i];
-                let w_prime = forward_table_prime[m+i];
-                
+                let w = forward_table[m + i];
+                let w_prime = forward_table_prime[m + i];
+
                 let op = it.next().unwrap();
-                
+
                 if t < 4 {
                     for j in 0..t {
                         let x: u32 = op[j] as u32;
                         let y: u32 = op[t + j] as u32;
 
-                        let curr_x: u32 = x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
+                        let curr_x: u32 =
+                            x - (two_times_modulus_small * ((x >= two_times_modulus_small) as u32));
                         let q_tmp: u64 = ((y as u64) * (w_prime as u64)) >> 32u64;
                         let q_new = w * (y as u64) - q_tmp * (modulus_small as u64);
 
                         op[j] = curr_x as u64 + q_new;
-                        op[t + j] = curr_x as u64 +  ((two_times_modulus_small as u64) - q_new);
+                        op[t + j] = curr_x as u64 + ((two_times_modulus_small as u64) - q_new);
                     }
                 } else {
                     unsafe {
@@ -215,7 +214,7 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
     for coeff_mod in 0..params.crt_count {
         let n = params.poly_len;
 
-        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+        let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
 
         let inverse_table = params.get_ntt_inverse_table(coeff_mod);
         let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
@@ -229,21 +228,21 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
             let mut it = operand.chunks_exact_mut(2 * t);
 
             for i in 0..h {
-                let w = inverse_table[h+i];
-                let w_prime = inverse_table_prime[h+i];
+                let w = inverse_table[h + i];
+                let w_prime = inverse_table_prime[h + i];
 
                 let op = it.next().unwrap();
 
                 for j in 0..t {
                     let x = op[j];
                     let y = op[t + j];
-                    
+
                     let t_tmp = two_times_modulus - y + x;
                     let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
                     let h_tmp = (t_tmp * w_prime) >> 32;
 
-                    let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
-                    let res_y= w * t_tmp - h_tmp * modulus;
+                    let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
+                    let res_y = w * t_tmp - h_tmp * modulus;
 
                     op[j] = res_x;
                     op[t + j] = res_y;
@@ -263,7 +262,7 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
     for coeff_mod in 0..params.crt_count {
         let n = params.poly_len;
 
-        let operand = &mut operand_overall[coeff_mod*n..coeff_mod*n+n];
+        let operand = &mut operand_overall[coeff_mod * n..coeff_mod * n + n];
 
         let inverse_table = params.get_ntt_inverse_table(coeff_mod);
         let inverse_table_prime = params.get_ntt_inverse_prime_table(coeff_mod);
@@ -276,22 +275,22 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
             let mut it = operand.chunks_exact_mut(2 * t);
 
             for i in 0..h {
-                let w = inverse_table[h+i];
-                let w_prime = inverse_table_prime[h+i];
+                let w = inverse_table[h + i];
+                let w_prime = inverse_table_prime[h + i];
 
                 let op = it.next().unwrap();
-                
+
                 if t < 4 {
                     for j in 0..t {
                         let x = op[j];
                         let y = op[t + j];
-                        
+
                         let t_tmp = two_times_modulus - y + x;
                         let curr_x = x + y - (two_times_modulus * (((x << 1) >= t_tmp) as u64));
                         let h_tmp = (t_tmp * w_prime) >> 32;
 
-                        let res_x= (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
-                        let res_y= w * t_tmp - h_tmp * modulus;
+                        let res_x = (curr_x + (modulus * ((t_tmp & 1) as u64))) >> 1;
+                        let res_y = w * t_tmp - h_tmp * modulus;
 
                         op[j] = res_x;
                         op[t + j] = res_y;
@@ -304,9 +303,10 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                             let p_y = &mut op[j + t] as *mut u64;
                             let x = _mm256_loadu_si256(p_x as *const __m256i);
                             let y = _mm256_loadu_si256(p_y as *const __m256i);
-                            
+
                             let modulus_vec = _mm256_set1_epi64x(modulus as i64);
-                            let two_times_modulus_vec = _mm256_set1_epi64x(two_times_modulus as i64);
+                            let two_times_modulus_vec =
+                                _mm256_set1_epi64x(two_times_modulus as i64);
                             let mut t_tmp = _mm256_set1_epi64x(two_times_modulus as i64);
                             t_tmp = _mm256_sub_epi64(t_tmp, y);
                             t_tmp = _mm256_add_epi64(t_tmp, x);
@@ -320,7 +320,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
                             h_tmp = _mm256_srli_epi64(h_tmp, 32);
 
                             let and_mask = _mm256_set_epi64x(1, 1, 1, 1);
-                            let eq_mask = _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask);
+                            let eq_mask =
+                                _mm256_cmpeq_epi64(_mm256_and_si256(t_tmp, and_mask), and_mask);
                             let to_add = _mm256_and_si256(eq_mask, modulus_vec);
 
                             let new_x = _mm256_srli_epi64(_mm256_add_epi64(curr_x, to_add), 1);
@@ -348,8 +349,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
 #[cfg(test)]
 mod test {
     use super::*;
-    use rand::Rng;
     use crate::util::*;
+    use rand::Rng;
 
     fn get_params() -> Params {
         get_test_params()
@@ -381,7 +382,7 @@ mod test {
     #[test]
     fn ntt_forward_correct() {
         let params = get_params();
-        let mut v1 = vec![0; 2*2048];
+        let mut v1 = vec![0; 2 * 2048];
         v1[0] = 100;
         v1[2048] = 100;
         ntt_forward(&params, v1.as_mut_slice());
@@ -392,7 +393,7 @@ mod test {
     #[test]
     fn ntt_inverse_correct() {
         let params = get_params();
-        let mut v1 = vec![100; 2*2048];
+        let mut v1 = vec![100; 2 * 2048];
         ntt_inverse(&params, v1.as_mut_slice());
         assert_eq!(v1[0], 100);
         assert_eq!(v1[2048], 100);
@@ -415,7 +416,7 @@ mod test {
         let mut v2 = v1.clone();
         ntt_forward(&params, v2.as_mut_slice());
         ntt_inverse(&params, v2.as_mut_slice());
-        for i in 0..params.crt_count*params.poly_len {
+        for i in 0..params.crt_count * params.poly_len {
             assert_eq!(v1[i], v2[i]);
         }
     }
@@ -425,4 +426,4 @@ mod test {
         assert_eq!(calc_index(&[2, 3, 4], &[10, 10, 100]), 2304);
         assert_eq!(calc_index(&[2, 3, 4], &[3, 5, 7]), 95);
     }
-}
+}

+ 1 - 1
spiral-rs/src/number_theory.rs

@@ -93,4 +93,4 @@ pub fn invert_uint_mod(value: u64, modulus: u64) -> Option<u64> {
     } else {
         return Some(gcd_tuple.1 as u64);
     }
-}
+}

+ 41 - 6
spiral-rs/src/params.rs

@@ -1,17 +1,52 @@
 use crate::{arith::*, ntt::*, number_theory::*};
 
 pub static Q2_VALUES: [u64; 37] = [
-    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12289, 12289, 61441, 65537, 65537, 520193, 786433, 786433, 3604481, 7340033, 16515073, 33292289, 67043329, 132120577, 268369921, 469762049, 1073479681, 2013265921, 4293918721, 8588886017, 17175674881, 34359214081, 68718428161
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    0,
+    12289,
+    12289,
+    61441,
+    65537,
+    65537,
+    520193,
+    786433,
+    786433,
+    3604481,
+    7340033,
+    16515073,
+    33292289,
+    67043329,
+    132120577,
+    268369921,
+    469762049,
+    1073479681,
+    2013265921,
+    4293918721,
+    8588886017,
+    17175674881,
+    34359214081,
+    68718428161,
 ];
 
-#[derive(Debug)]
-#[derive(PartialEq)]
-pub struct Params {    
+#[derive(Debug, PartialEq)]
+pub struct Params {
     pub poly_len: usize,
     pub poly_len_log2: usize,
-    pub ntt_tables: Vec<Vec<Vec<u64>>>,    
+    pub ntt_tables: Vec<Vec<Vec<u64>>>,
     pub scratch: Vec<u64>,
-    
+
     pub crt_count: usize,
     pub moduli: Vec<u64>,
     pub modulus: u64,

+ 21 - 14
spiral-rs/src/poly.rs

@@ -2,12 +2,12 @@ use core::num;
 #[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
 
-use std::ops::{Add, Mul, Neg};
-use std::cell::RefCell;
-use rand::Rng;
 use rand::distributions::Standard;
+use rand::Rng;
+use std::cell::RefCell;
+use std::ops::{Add, Mul, Neg};
 
-use crate::{arith::*, params::*, ntt::*, util::*, discrete_gaussian::*};
+use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*};
 
 const SCRATCH_SPACE: usize = 8192;
 thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
@@ -156,7 +156,12 @@ impl<'a> PolyMatrixRaw<'a> {
         for r in 0..self.rows {
             for c in 0..self.cols {
                 for z in 0..num_coeffs {
-                    write_arbitrary_bits(data.as_mut_slice(), self.get_poly(r,c)[z], bit_offs, modulus_bits);
+                    write_arbitrary_bits(
+                        data.as_mut_slice(),
+                        self.get_poly(r, c)[z],
+                        bit_offs,
+                        modulus_bits,
+                    );
                     bit_offs += modulus_bits;
                 }
                 // round bit_offs down to nearest byte boundary
@@ -287,16 +292,16 @@ pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u
     for c in 0..params.crt_count {
         for i in (0..params.poly_len).step_by(4) {
             unsafe {
-                let p_x = &a[c*params.poly_len + i] as *const u64;
-                let p_y = &b[c*params.poly_len + i] as *const u64;
-                let p_z = &mut res[c*params.poly_len + i] as *mut u64;
+                let p_x = &a[c * params.poly_len + i] as *const u64;
+                let p_y = &b[c * params.poly_len + i] as *const u64;
+                let p_z = &mut res[c * params.poly_len + i] as *mut u64;
                 let x = _mm256_loadu_si256(p_x as *const __m256i);
                 let y = _mm256_loadu_si256(p_y as *const __m256i);
                 let z = _mm256_loadu_si256(p_z as *const __m256i);
 
                 let product = _mm256_mul_epu32(x, y);
                 let out = _mm256_add_epi64(z, product);
-                
+
                 _mm256_storeu_si256(p_z as *mut __m256i, out);
             }
         }
@@ -306,7 +311,7 @@ pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u
 pub fn modular_reduce(params: &Params, res: &mut [u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[c*params.poly_len + i] %= params.moduli[c];
+            res[c * params.poly_len + i] %= params.moduli[c];
         }
     }
 }
@@ -320,7 +325,7 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     let params = res.params;
     for i in 0..a.rows {
         for j in 0..b.cols {
-            for z in 0..params.poly_len*params.crt_count {
+            for z in 0..params.poly_len * params.crt_count {
                 res.get_poly_mut(i, j)[z] = 0;
             }
             for k in 0..a.cols {
@@ -343,7 +348,7 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     let params = res.params;
     for i in 0..a.rows {
         for j in 0..b.cols {
-            for z in 0..params.poly_len*params.crt_count {
+            for z in 0..params.poly_len * params.crt_count {
                 res.get_poly_mut(i, j)[z] = 0;
             }
             let res_poly = res.get_poly_mut(i, j);
@@ -431,7 +436,10 @@ pub fn scalar_multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatri
     }
 }
 
-pub fn scalar_multiply_alloc<'a>(a: &PolyMatrixNTT<'a>, b: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
+pub fn scalar_multiply_alloc<'a>(
+    a: &PolyMatrixNTT<'a>,
+    b: &PolyMatrixNTT<'a>,
+) -> PolyMatrixNTT<'a> {
     let mut res = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
     scalar_multiply(&mut res, a, b);
     res
@@ -443,7 +451,6 @@ pub fn single_poly<'a>(params: &'a Params, val: u64) -> PolyMatrixRaw<'a> {
     res
 }
 
-
 pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
     let params = a.params;
     for r in 0..a.rows {

+ 29 - 29
spiral-rs/src/util.rs

@@ -1,5 +1,5 @@
 use crate::params::*;
-use serde_json::{Value};
+use serde_json::Value;
 
 pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
     let mut idx = 0usize;
@@ -13,8 +13,8 @@ pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
 
 pub fn get_test_params() -> Params {
     Params::init(
-        2048, 
-        &vec![268369921u64, 249561089u64], 
+        2048,
+        &vec![268369921u64, 249561089u64],
         6.4,
         2,
         256,
@@ -27,30 +27,30 @@ pub fn get_test_params() -> Params {
         9,
         6,
         1,
-        2048
+        2048,
     )
 }
 
 pub const fn get_empty_params() -> Params {
-    Params { 
-        poly_len: 0, 
-        poly_len_log2: 0, 
-        ntt_tables: Vec::new(), 
-        scratch: Vec::new(), 
-        crt_count: 0, 
-        moduli: Vec::new(), 
-        modulus: 0, 
-        modulus_log2: 0, 
-        noise_width: 0f64, 
-        n: 0, 
-        pt_modulus: 0, 
-        q2_bits: 0, 
-        t_conv: 0, 
-        t_exp_left: 0, 
-        t_exp_right: 0, 
-        t_gsw: 0, 
-        expand_queries: false, 
-        db_dim_1: 0, 
+    Params {
+        poly_len: 0,
+        poly_len_log2: 0,
+        ntt_tables: Vec::new(),
+        scratch: Vec::new(),
+        crt_count: 0,
+        moduli: Vec::new(),
+        modulus: 0,
+        modulus_log2: 0,
+        noise_width: 0f64,
+        n: 0,
+        pt_modulus: 0,
+        q2_bits: 0,
+        t_conv: 0,
+        t_exp_left: 0,
+        t_exp_right: 0,
+        t_gsw: 0,
+        expand_queries: false,
+        db_dim_1: 0,
         db_dim_2: 0,
         instances: 0,
         db_item_size: 0,
@@ -72,8 +72,8 @@ pub fn params_from_json(cfg: &str) -> Params {
     let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
     let do_expansion = v.get("kinda_direct_upload").is_none();
     Params::init(
-        2048, 
-        &vec![268369921u64, 249561089u64], 
+        2048,
+        &vec![268369921u64, 249561089u64],
         6.4,
         n,
         p,
@@ -147,8 +147,8 @@ mod test {
         let cfg = cfg.replace("'", "\"");
         let b = params_from_json(&cfg);
         let c = Params::init(
-            2048, 
-            &vec![268369921u64, 249561089u64], 
+            2048,
+            &vec![268369921u64, 249561089u64],
             6.4,
             2,
             256,
@@ -161,7 +161,7 @@ mod test {
             9,
             6,
             1,
-            2048
+            2048,
         );
         assert_eq!(b, c);
     }
@@ -185,4 +185,4 @@ mod test {
             bit_offs += num_bits;
         }
     }
-}
+}