Browse Source

suppport direct upload

Samir Menon 2 years ago
parent
commit
496f4aac30
4 changed files with 162 additions and 32 deletions
  1. 10 9
      client/src/lib.rs
  2. 126 22
      spiral-rs/src/client.rs
  3. 20 1
      spiral-rs/src/main.rs
  4. 6 0
      spiral-rs/src/poly.rs

+ 10 - 9
client/src/lib.rs

@@ -50,15 +50,16 @@ pub fn initialize(json_params: Option<String>) -> WrappedClient {
     dg_seems_okay();
     // spiral_rs::ntt::test::ntt_correct();
     let cfg = r#"
-        {'n': 2,
-        'nu_1': 9,
-        'nu_2': 6,
-        'p': 256,
-        'q_prime_bits': 20,
-        's_e': 87.62938774292914,
-        't_GSW': 8,
-        't_conv': 4,
-        't_exp': 8,
+        {'kinda_direct_upload': 1,
+        'n': 5,
+        'nu_1': 11,
+        'nu_2': 3,
+        'p': 65536,
+        'q_prime_bits': 27,
+        's_e': 57.793748020122216,
+        't_GSW': 3,
+        't_conv': 56,
+        't_exp': 56,
         't_exp_right': 56}
     "#;
     let mut cfg = cfg.replace("'", "\"");

+ 126 - 22
spiral-rs/src/client.rs

@@ -1,4 +1,5 @@
 use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*, arith::*, util::*, number_theory::*};
+use std::iter::once;
 
 fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
     for i in 0..a.rows * a.cols * a.params.poly_len {
@@ -13,35 +14,54 @@ fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
 }
 
 pub struct PublicParameters<'a> {
-    v_packing: Vec<PolyMatrixNTT<'a>>,            // Ws
-    v_expansion_left: Vec<PolyMatrixNTT<'a>>,
-    v_expansion_right: Vec<PolyMatrixNTT<'a>>,
-    conversion: PolyMatrixNTT<'a>,              // V
+    v_packing: Vec<PolyMatrixNTT<'a>>,                  // Ws
+    v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
+    v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
+    v_conversion: Option<Vec<PolyMatrixNTT<'a>>>,              // V
 }
 
 impl<'a> PublicParameters<'a> {
     pub fn init(params: &'a Params) -> Self {
-        PublicParameters { 
-            v_packing: Vec::new(), 
-            v_expansion_left: Vec::new(), 
-            v_expansion_right: Vec::new(), 
-            conversion: PolyMatrixNTT::zero(params, 2, 2 * params.m_conv()) 
+        if params.expand_queries {
+            PublicParameters { 
+                v_packing: Vec::new(), 
+                v_expansion_left: Some(Vec::new()), 
+                v_expansion_right: Some(Vec::new()), 
+                v_conversion: Some(Vec::new())
+            }
+        } else {
+            PublicParameters { 
+                v_packing: Vec::new(), 
+                v_expansion_left: None,
+                v_expansion_right: None,
+                v_conversion: None,
+            }
         }
     }
 
-    pub fn to_raw(&self) -> Vec<Vec<PolyMatrixRaw>> {
+    fn from_ntt_alloc_vec(v: &Vec<PolyMatrixNTT<'a>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
+        Some(v.iter().map(from_ntt_alloc).collect())
+    }
+
+    fn from_ntt_alloc_opt_vec(v: &Option<Vec<PolyMatrixNTT<'a>>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
+        Some(v.as_ref()?.iter().map(from_ntt_alloc).collect())
+    }
+
+    pub fn to_raw(&self) -> Vec<Option<Vec<PolyMatrixRaw>>> {
         vec![
-            self.v_packing.iter().map(from_ntt_alloc).collect(),
-            self.v_expansion_left.iter().map(from_ntt_alloc).collect(),
-            self.v_expansion_right.iter().map(from_ntt_alloc).collect(),
-            vec![from_ntt_alloc(&self.conversion)]
+            Self::from_ntt_alloc_vec(&self.v_packing),
+            Self::from_ntt_alloc_opt_vec(&self.v_expansion_left),
+            Self::from_ntt_alloc_opt_vec(&self.v_expansion_right),
+            Self::from_ntt_alloc_opt_vec(&self.v_conversion),
         ]
     }
 
     pub fn serialize(&self) -> Vec<u8> {
         let mut data = Vec::new();
         for v in self.to_raw().iter() {
-            serialize_vec_polymatrix(&mut data, v);
+            if v.is_some() {
+                serialize_vec_polymatrix(&mut data, v.as_ref().unwrap());
+            }
         }
         data
     }
@@ -49,16 +69,31 @@ impl<'a> PublicParameters<'a> {
 
 pub struct Query<'a> {
     ct: Option<PolyMatrixRaw<'a>>,
-    // v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
+    v_buf: Option<Vec<u64>>,
+    v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
 }
 
 impl<'a> Query<'a> {
+    pub fn empty() -> Self {
+        Query { ct: None, v_ct: None, v_buf: None }
+    }
+
     pub fn serialize(&self) -> Vec<u8> {
         let mut data = Vec::new();
         if self.ct.is_some() {
             let ct = self.ct.as_ref().unwrap();
             serialize_polymatrix(&mut data, &ct);
         }
+        if self.v_buf.is_some() {
+            let v_buf = self.v_buf.as_ref().unwrap();
+            data.extend(v_buf.iter().map(|x| { x.to_ne_bytes() }).flatten());
+        }
+        if self.v_ct.is_some() {
+            let v_ct = self.v_ct.as_ref().unwrap();
+            for x in v_ct {
+                serialize_polymatrix(&mut data, x);
+            }
+        }
         data
     }
 }
@@ -217,14 +252,14 @@ impl<'a> Client<'a> {
         if params.expand_queries {
             // Params for expansion
             
-            pp.v_expansion_left = self.generate_expansion_params(self.g, params.t_exp_left);
-            pp.v_expansion_right = self.generate_expansion_params(self.stop_round + 1, params.t_exp_right);
+            pp.v_expansion_left = Some(self.generate_expansion_params(self.g, params.t_exp_left));
+            pp.v_expansion_right = Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
 
             // Params for converison
             let g_conv = build_gadget(params, 2, 2 * m_conv);
             let sk_reg_ntt = self.sk_reg.ntt();
             let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt;
-            pp.conversion = PolyMatrixNTT::zero(params, 2, 2 * m_conv);
+            pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(params, 2, 2 * m_conv))));
             for i in 0..2*m_conv {
                 let sigma;
                 if i % 2 == 0 {
@@ -235,13 +270,48 @@ impl<'a> Client<'a> {
                     sigma = &sk_reg_ntt * &single_poly(params, val).ntt();
                 }
                 let ct = self.encrypt_matrix_reg(&sigma);
-                pp.conversion.copy_into(&ct, 0, i);
+                pp.v_conversion.as_mut().unwrap()[0].copy_into(&ct, 0, i);
             }
         }
 
         pp
     }
 
+    // reindexes a vector of regev ciphertexts, to help server
+    fn reorient_reg_ciphertexts(&self, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
+        let params = self.params;
+        let poly_len = params.poly_len;
+        let crt_count = params.crt_count;
+
+        assert_eq!(crt_count, 2);
+        assert!(log2(params.moduli[0]) <= 32);
+
+        let num_reg_expanded = 1<<params.db_dim_1;
+        let ct_rows = v_reg[0].rows;
+        let ct_cols = v_reg[0].cols;
+
+        assert_eq!(ct_rows, 2);
+        assert_eq!(ct_cols, 1);
+
+        for j in 0..num_reg_expanded {
+            for r in 0..ct_rows {
+                for m in 0.. ct_cols {
+                    for z in 0..params.poly_len {
+                        let idx_a_in = r * (ct_cols*crt_count*poly_len) + m * (crt_count*poly_len);
+                        let idx_a_out =     z * (num_reg_expanded*ct_cols*ct_rows)
+                                                + j * (ct_cols*ct_rows)
+                                                + m * (ct_rows)
+                                                + r;
+                        let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
+                        let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
+
+                        out[idx_a_out] = val1 | (val2 << 32);
+                    }
+                }
+            }
+        }
+    }
+
     pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
         let params = self.params;
         let further_dims = params.db_dim_2;
@@ -250,7 +320,7 @@ impl<'a> Client<'a> {
         let scale_k = params.modulus / params.pt_modulus;
         let bits_per = get_bits_per(params, params.t_gsw);
 
-        let mut query = Query { ct: None };
+        let mut query = Query::empty();
         if params.expand_queries {
             // pack query into single ciphertext
             let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
@@ -273,7 +343,41 @@ impl<'a> Client<'a> {
 
             query.ct = Some(from_ntt_alloc(&self.encrypt_matrix_reg(&to_ntt_alloc(&sigma))));
         } else {
-            assert!(false);
+            let num_expanded = 1 << params.db_dim_1;
+            let mut sigma_v = Vec::<PolyMatrixNTT>::new();
+
+            // generate regev ciphertexts
+            let reg_cts_buf_words = num_expanded * 2 * params.poly_len;
+            let mut reg_cts_buf = vec![0u64; reg_cts_buf_words];
+            let mut reg_cts = Vec::<PolyMatrixNTT>::new();
+            for i in 0..num_expanded {
+                let value = ((i == idx_dim0) as u64) * scale_k;
+                let sigma = PolyMatrixRaw::single_value(&params, value);
+                reg_cts.push(self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)));
+            }
+            // reorient into server's preferred indexing
+            self.reorient_reg_ciphertexts(reg_cts_buf.as_mut_slice(), &reg_cts);
+
+            // generate GSW ciphertexts
+            for i in 0..further_dims {
+                let bit = ((idx_further as u64) & (1 << (i as u64))) >> (i as u64);
+                let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
+
+                for j in 0..params.t_gsw {
+                    let value = (1u64 << (bits_per * j)) * bit;
+                    let sigma = PolyMatrixRaw::single_value(&params, value);
+                    let sigma_ntt = to_ntt_alloc(&sigma);
+                    let ct = &self.encrypt_matrix_reg(&sigma_ntt);
+                    ct_gsw.copy_into(ct, 0, 2*j + 1);
+                    let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
+                    let ct = &self.encrypt_matrix_reg(&prod);
+                    ct_gsw.copy_into(ct, 0, 2*j);
+                }
+                sigma_v.push(ct_gsw);
+            }
+
+            query.v_buf = Some(reg_cts_buf);
+            query.v_ct = Some(sigma_v.iter().map(|x| { from_ntt_alloc(x) }).collect());
         }
         query
     }

+ 20 - 1
spiral-rs/src/main.rs

@@ -2,6 +2,7 @@ use serde_json::Value;
 use spiral_rs::util::*;
 use spiral_rs::client::*;
 use std::env;
+use std::time::Instant;
 
 fn send_api_req_text(path: &str, data: Vec<u8>) -> Option<String> {
     let client = reqwest::blocking::Client::builder()
@@ -29,7 +30,7 @@ fn send_api_req_vec(path: &str, data: Vec<u8>) -> Option<Vec<u8>> {
 }
 
 fn main() {
-    let cfg = r#"
+    let cfg_expand = r#"
         {'n': 2,
         'nu_1': 9,
         'nu_2': 6,
@@ -41,6 +42,20 @@ fn main() {
         't_exp': 8,
         't_exp_right': 56}
     "#;
+    let cfg_direct = r#"
+        {'kinda_direct_upload': 1,
+        'n': 5,
+        'nu_1': 11,
+        'nu_2': 3,
+        'p': 65536,
+        'q_prime_bits': 27,
+        's_e': 57.793748020122216,
+        't_GSW': 3,
+        't_conv': 56,
+        't_exp': 56,
+        't_exp_right': 56}
+    "#;
+    let cfg = cfg_direct;
     let cfg = cfg.replace("'", "\"");
     let params = params_from_json(&cfg);
 
@@ -63,7 +78,11 @@ fn main() {
     let id = resp_json["id"].as_str().unwrap();
     let mut full_query_buf = id.as_bytes().to_vec();
     full_query_buf.append(&mut query_buf);
+
+    let now = Instant::now();
     let query_resp = send_api_req_vec("/query", full_query_buf).unwrap();
+    let duration = now.elapsed().as_millis();
+    println!("duration of query processing is {} ms", duration);
     println!("query_resp len {}", query_resp.len());
 
     let _result = c.decode_response(query_resp.as_slice());

+ 6 - 0
spiral-rs/src/poly.rs

@@ -159,6 +159,12 @@ impl<'a> PolyMatrixRaw<'a> {
         }
         data
     }
+
+    pub fn single_value(params: &'a Params, value: u64) -> PolyMatrixRaw<'a> {
+        let mut out = Self::zero(params, 1, 1);
+        out.data[0] = value;
+        out
+    }
 }
 
 impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {