Samir Menon 2 years ago
parent
commit
4c06d05cbb
4 changed files with 57 additions and 30 deletions
  1. 2 2
      .cargo/config.toml
  2. 38 22
      src/client.rs
  3. 1 0
      src/ntt.rs
  4. 16 6
      src/poly.rs

+ 2 - 2
.cargo/config.toml

@@ -1,3 +1,3 @@
 [build]
-target = "x86_64-unknown-linux-gnu"
-rustflags = ["-C", "target-feature=+avx2"]
+# target = "x86_64-unknown-linux-gnu"
+# rustflags = ["-C", "target-feature=+avx2"]

+ 38 - 22
src/client.rs

@@ -6,7 +6,7 @@ pub struct PublicParameters<'a> {
     v_packing: Vec<PolyMatrixNTT<'a>>,            // Ws
     v_expansion_left: Vec<PolyMatrixNTT<'a>>,
     v_expansion_right: Vec<PolyMatrixNTT<'a>>,
-    v_conversion: PolyMatrixNTT<'a>,              // V
+    conversion: PolyMatrixNTT<'a>,              // V
 }
 
 impl<'a> PublicParameters<'a> {
@@ -15,11 +15,16 @@ impl<'a> PublicParameters<'a> {
             v_packing: Vec::new(), 
             v_expansion_left: Vec::new(), 
             v_expansion_right: Vec::new(), 
-            v_conversion: PolyMatrixNTT::zero(params, 2, 2 * params.m_conv()) 
+            conversion: PolyMatrixNTT::zero(params, 2, 2 * params.m_conv()) 
         }
     }
 }
 
+pub struct Query<'a> {
+    ct: PolyMatrixNTT<'a>,
+    v_ct: Vec<PolyMatrixNTT<'a>>,
+}
+
 pub struct Client<'a> {
     params: &'a Params,
     sk_gsw: PolyMatrixRaw<'a>,
@@ -143,30 +148,41 @@ impl<'a> Client<'a> {
             pp.v_packing.push(w);
         }
 
-        // Params for expansion
-        let further_dims = 1usize << params.db_dim_2;
-        let num_expanded = 1usize << params.db_dim_1;
-        let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
-        let g = log2(num_bits_to_gen as u64) as usize;
-        let stop_round = log2((params.t_gsw * further_dims) as u64) as usize;
-        pp.v_expansion_left = self.generate_expansion_params(g, params.t_exp_left);
-        pp.v_expansion_right = self.generate_expansion_params(stop_round + 1, params.t_exp_right);
-
-        // Params for converison
-        let g_conv = build_gadget(params, 2, 2 * m_conv);
-        let sk_reg_squared_ntt = &self.sk_reg.ntt() * &self.sk_reg.ntt();
-        pp.v_conversion = PolyMatrixNTT::zero(params, 2, 2 * m_conv);
-        for i in 0..2*m_conv {
-            if i % 2 == 0 {
-                let val = g_conv.get_poly(0, i)[0];
-                let sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
-                let ct = self.encrypt_matrix_reg(sigma);
-                pp.v_conversion.copy_into(&ct, 0, i);
+        if params.expand_queries {
+            // Params for expansion
+            let further_dims = 1usize << params.db_dim_2;
+            let num_expanded = 1usize << params.db_dim_1;
+            let num_bits_to_gen = params.t_gsw * further_dims + num_expanded;
+            let g = log2(num_bits_to_gen as u64) as usize;
+            let stop_round = log2((params.t_gsw * further_dims) as u64) as usize;
+            pp.v_expansion_left = self.generate_expansion_params(g, params.t_exp_left);
+            pp.v_expansion_right = self.generate_expansion_params(stop_round + 1, params.t_exp_right);
+
+            // Params for converison
+            let g_conv = build_gadget(params, 2, 2 * m_conv);
+            let sk_reg_squared_ntt = &self.sk_reg.ntt() * &self.sk_reg.ntt();
+            pp.conversion = PolyMatrixNTT::zero(params, 2, 2 * m_conv);
+            for i in 0..2*m_conv {
+                if i % 2 == 0 {
+                    let val = g_conv.get_poly(0, i)[0];
+                    let sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
+                    let ct = self.encrypt_matrix_reg(sigma);
+                    pp.conversion.copy_into(&ct, 0, i);
+                }
             }
         }
 
         pp
     }
-    // fn generate_query(&self) -> Query<'a, Params>;
+
+    // fn generate_query(&self) -> Query<'a> {
+    //     let params = self.params;
+    //     let mut query = Query { ct: PolyMatrixNTT::zero(params, 1, 1), v_ct: Vec::new() }
+    //     if params.expand_queries {
+            
+    //     } else {
+
+    //     }
+    // }
     
 }

+ 1 - 0
src/ntt.rs

@@ -1,3 +1,4 @@
+#[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
 
 use crate::{

+ 16 - 6
src/poly.rs

@@ -1,4 +1,6 @@
+#[cfg(target_feature = "avx2")]
 use std::arch::x86_64::*;
+
 use std::ops::{Add, Mul, Neg};
 use std::cell::RefCell;
 use rand::Rng;
@@ -210,7 +212,8 @@ impl<'a> PolyMatrixNTT<'a> {
 pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[i] = multiply_modular(params, a[i], b[i], c);
+            let idx = c * params.poly_len + i;
+            res[idx] = multiply_modular(params, a[idx], b[idx], c);
         }
     }
 }
@@ -218,7 +221,8 @@ pub fn multiply_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
 pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[i] = multiply_add_modular(params, a[i], b[i], res[i], c);
+            let idx = c * params.poly_len + i;
+            res[idx] = multiply_add_modular(params, a[idx], b[idx], res[idx], c);
         }
     }
 }
@@ -226,7 +230,8 @@ pub fn multiply_add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64])
 pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[i] = add_modular(params, a[i], b[i], c);
+            let idx = c * params.poly_len + i;
+            res[idx] = add_modular(params, a[idx], b[idx], c);
         }
     }
 }
@@ -234,7 +239,8 @@ pub fn add_poly(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
 pub fn invert_poly(params: &Params, res: &mut [u64], a: &[u64]) {
     for c in 0..params.crt_count {
         for i in 0..params.poly_len {
-            res[i] = invert_modular(params, a[i], c);
+            let idx = c * params.poly_len + i;
+            res[idx] = invert_modular(params, a[idx], c);
         }
     }
 }
@@ -253,6 +259,7 @@ pub fn automorph_poly(params: &Params, res: &mut [u64], a: &[u64], t: usize) {
     }
 }
 
+#[cfg(target_feature = "avx2")]
 pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u64]) {
     for c in 0..params.crt_count {
         for i in (0..params.poly_len).step_by(4) {
@@ -283,11 +290,14 @@ pub fn modular_reduce(params: &Params, res: &mut [u64]) {
 
 #[cfg(not(target_feature = "avx2"))]
 pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
+    assert!(res.rows == a.rows);
+    assert!(res.cols == b.cols);
     assert!(a.cols == b.rows);
 
+    let params = res.params;
     for i in 0..a.rows {
         for j in 0..b.cols {
-            for z in 0..res.params.poly_len {
+            for z in 0..params.poly_len*params.crt_count {
                 res.get_poly_mut(i, j)[z] = 0;
             }
             for k in 0..a.cols {
@@ -310,7 +320,7 @@ pub fn multiply(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, b: &PolyMatrixNTT) {
     let params = res.params;
     for i in 0..a.rows {
         for j in 0..b.cols {
-            for z in 0..res.params.poly_len {
+            for z in 0..params.poly_len*params.crt_count {
                 res.get_poly_mut(i, j)[z] = 0;
             }
             let res_poly = res.get_poly_mut(i, j);