Samir Menon 2 years ago
parent
commit
d0d1792532
4 changed files with 312 additions and 12 deletions
  1. 4 4
      spiral-rs/src/client.rs
  2. 48 0
      spiral-rs/src/poly.rs
  3. 249 5
      spiral-rs/src/server.rs
  4. 11 3
      spiral-rs/src/util.rs

+ 4 - 4
spiral-rs/src/client.rs

@@ -73,9 +73,9 @@ impl<'a> PublicParameters<'a> {
 }
 
 pub struct Query<'a> {
-    ct: Option<PolyMatrixRaw<'a>>,
-    v_buf: Option<Vec<u64>>,
-    v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
+    pub ct: Option<PolyMatrixRaw<'a>>,
+    pub v_buf: Option<Vec<u64>>,
+    pub v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
 }
 
 impl<'a> Query<'a> {
@@ -498,7 +498,7 @@ mod test {
     #[test]
     fn keygen_is_correct() {
         let params = get_params();
-        let mut seeded_rng = get_seeded_rng();
+        let mut seeded_rng = get_static_seeded_rng();
         let mut client = Client::init(&params, &mut seeded_rng);
 
         let public_params = client.generate_keys();

+ 48 - 0
spiral-rs/src/poly.rs

@@ -140,6 +140,19 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
     }
 }
 
+impl<'a> Clone for PolyMatrixRaw<'a> {
+    fn clone(&self) -> Self {
+        let mut data_clone = AlignedMemory64::new(self.data.len());
+        data_clone.as_mut_slice().copy_from_slice(self.data.as_slice());
+        PolyMatrixRaw {
+            params: self.params,
+            rows: self.rows,
+            cols: self.cols,
+            data: data_clone,
+        }
+    }
+}
+
 impl<'a> PolyMatrixRaw<'a> {
     pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
         let num_coeffs = rows * cols * params.poly_len;
@@ -182,6 +195,17 @@ impl<'a> PolyMatrixRaw<'a> {
         }
     }
 
+    pub fn apply_func<F: Fn(u64) -> u64>(&mut self, func: F) {
+        for r in 0..self.rows {
+            for c in 0..self.cols {
+                let pol_mut = self.get_poly_mut(r, c);
+                for el in pol_mut {
+                    *el = func(*el);
+                }
+            }
+        }
+    }
+
     pub fn to_vec(&self, modulus_bits: usize, num_coeffs: usize) -> Vec<u8> {
         let sz_bits = self.rows * self.cols * num_coeffs * modulus_bits;
         let sz_bytes = f64::ceil((sz_bits as f64) / 8f64) as usize + 32;
@@ -288,6 +312,19 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
     }
 }
 
+impl<'a> Clone for PolyMatrixNTT<'a> {
+    fn clone(&self) -> Self {
+        let mut data_clone = AlignedMemory64::new(self.data.len());
+        data_clone.as_mut_slice().copy_from_slice(self.data.as_slice());
+        PolyMatrixNTT {
+            params: self.params,
+            rows: self.rows,
+            cols: self.cols,
+            data: data_clone,
+        }
+    }
+}
+
 impl<'a> PolyMatrixNTT<'a> {
     pub fn raw(&self) -> PolyMatrixRaw<'a> {
         from_ntt_alloc(&self)
@@ -457,6 +494,17 @@ pub fn add_into(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT) {
     }
 }
 
+pub fn add_into_at(res: &mut PolyMatrixNTT, a: &PolyMatrixNTT, t_row: usize, t_col: usize) {
+    let params = res.params;
+    for i in 0..a.rows {
+        for j in 0..a.cols {
+            let res_poly = res.get_poly_mut(t_row + i, t_col + j);
+            let pol2 = a.get_poly(i, j);
+            add_poly_into(params, res_poly, pol2);
+        }
+    }
+}
+
 pub fn invert(res: &mut PolyMatrixRaw, a: &PolyMatrixRaw) {
     assert!(res.rows == a.rows);
     assert!(res.cols == a.cols);

+ 249 - 5
spiral-rs/src/server.rs

@@ -5,6 +5,9 @@ use std::arch::x86_64::*;
 use crate::aligned_memory::*;
 
 use crate::arith::*;
+use crate::aligned_memory::*;
+use crate::client::PublicParameters;
+use crate::client::Query;
 use crate::gadget::*;
 use crate::params::*;
 use crate::poly::*;
@@ -13,7 +16,7 @@ use crate::util::*;
 pub fn coefficient_expansion(
     v: &mut Vec<PolyMatrixNTT>,
     g: usize,
-    stopround: usize,
+    stop_round: usize,
     params: &Params,
     v_w_left: &Vec<PolyMatrixNTT>,
     v_w_right: &Vec<PolyMatrixNTT>,
@@ -41,8 +44,8 @@ pub fn coefficient_expansion(
         let neg1 = &v_neg1[r];
 
         for i in 0..num_out {
-            if stopround > 0 && i % 2 == 1 && r > stopround
-                || (r == stopround && i / 2 >= max_bits_to_gen_right)
+            if stop_round > 0 && i % 2 == 1 && r > stop_round
+                || (r == stop_round && i / 2 >= max_bits_to_gen_right)
             {
                 continue;
             }
@@ -312,6 +315,218 @@ pub fn fold_ciphertexts(
     }
 }
 
+pub fn pack<'a>(
+    params: &'a Params,
+    v_ct: &Vec<PolyMatrixRaw>,
+    v_w: &Vec<PolyMatrixNTT>
+) -> PolyMatrixNTT<'a> {
+    assert!(v_ct.len() >= params.n * params.n);
+    assert!(v_w.len() == params.n);
+    assert!(v_ct[0].rows == 2);
+    assert!(v_ct[0].cols == 1);
+    assert!(v_w[0].rows == (params.n + 1));
+    assert!(v_w[0].cols == params.t_conv);
+
+    let mut result = PolyMatrixNTT::zero(params, params.n + 1, params.n);
+
+    let mut ginv = PolyMatrixRaw::zero(params, params.t_conv, 1);
+    let mut ginv_nttd = PolyMatrixNTT::zero(params, params.t_conv, 1);
+    let mut prod = PolyMatrixNTT::zero(params, params.n + 1, 1);
+    let mut ct_1 = PolyMatrixRaw::zero(params, 1, 1);
+    let mut ct_2 = PolyMatrixRaw::zero(params, 1, 1);
+    let mut ct_2_ntt = PolyMatrixNTT::zero(params, 1, 1);
+
+    for c in 0..params.n {
+        let mut v_int = PolyMatrixNTT::zero(&params, params.n + 1, 1);
+        for r in 0..params.n {
+            let w = &v_w[r];
+            let ct = &v_ct[r * params.n + c];
+            ct_1.copy_into(ct, 0, 0);
+            ct_2.copy_into(ct, 1, 0);
+            to_ntt(&mut ct_2_ntt, &ct_2);
+            gadget_invert(&mut ginv, &ct_1);
+            to_ntt(&mut ginv_nttd, &ginv);
+            multiply(&mut prod, &w, &ginv_nttd);
+            add_into_at(&mut v_int, &ct_2_ntt, 1 + r, 0);
+            add_into(&mut v_int, &prod);
+        }
+        result.copy_into(&v_int, 0, c);
+    }
+
+    result
+}
+
+pub fn encode(
+    params: &Params,
+    v_packed_ct: &Vec<PolyMatrixRaw>
+) -> Vec<u8> {
+    let q1 = 4 * params.pt_modulus;
+    let q1_bits = log2_ceil(q1) as usize;
+    let q2 = Q2_VALUES[params.q2_bits as usize];
+    let q2_bits = params.q2_bits as usize;
+
+    let num_bits = params.instances * 
+        (
+            (q2_bits * params.n * params.poly_len) + 
+            (q1_bits * params.n * params.n * params.poly_len)
+        );
+    let round_to = 64;
+    let num_bytes_rounded_up = ((num_bits + round_to - 1) / round_to) * round_to / 8;
+
+    let mut result = vec![0u8; num_bytes_rounded_up];
+    let mut bit_offs = 0;
+    for instance in 0..params.instances {
+        let packed_ct = &v_packed_ct[instance];
+        
+        let mut first_row = packed_ct.submatrix(0, 0, 1, packed_ct.cols);
+        let mut rest_rows = packed_ct.submatrix(1, 0, packed_ct.rows - 1, packed_ct.cols);
+        first_row.apply_func(|x| { rescale(x, params.modulus, q2) });
+        rest_rows.apply_func(|x| { rescale(x, params.modulus, q1) });
+
+        let data = result.as_mut_slice();
+        for i in 0..params.n * params.poly_len {
+            write_arbitrary_bits(data, first_row.data[i], bit_offs, q2_bits);
+            bit_offs += q2_bits;
+        }
+        for i in 0..params.n * params.n * params.poly_len {
+            write_arbitrary_bits(data, rest_rows.data[i], bit_offs, q1_bits);
+            bit_offs += q1_bits;
+        }
+    }
+    result
+}
+
+pub fn expand_query<'a>(
+    params: &'a Params, 
+    public_params: &PublicParameters<'a>, 
+    query: &Query<'a>,
+) -> (AlignedMemory64, Vec<PolyMatrixNTT<'a>>, Vec<PolyMatrixNTT<'a>>) {
+    let dim0 = 1 << params.db_dim_1;
+    let further_dims = params.db_dim_2;
+
+    let mut v_reg_reoriented;
+    let mut v_folding;
+    let mut v_folding_neg;
+
+    let num_bits_to_gen = params.t_gsw * further_dims + dim0;
+    let g = log2_ceil_usize(num_bits_to_gen);
+    let right_expanded = params.t_gsw * further_dims;
+    let stop_round = log2_ceil_usize(right_expanded);
+
+    let mut v = Vec::new();
+    for _ in 0..(1 << g) {
+        v.push(PolyMatrixNTT::zero(params, 2, 1));
+    }
+    v[0].copy_into(&query.ct.as_ref().unwrap().ntt(), 0, 0);
+
+    let v_conversion = &public_params.v_conversion.as_ref().unwrap()[0];
+    let v_w_left = public_params.v_expansion_left.as_ref().unwrap();
+    let v_w_right = public_params.v_expansion_right.as_ref().unwrap();
+    let v_neg1 = params.get_v_neg1();
+
+    coefficient_expansion(
+        &mut v,
+        g,
+        stop_round,
+        params,
+        &v_w_left,
+        &v_w_right,
+        &v_neg1,
+        params.t_gsw * params.db_dim_2,
+    );
+
+    let mut v_reg_inp = Vec::with_capacity(dim0);
+    for i in 0..dim0 {
+        v_reg_inp.push(v[2 * i].clone());
+    }
+    let mut v_gsw_inp = Vec::with_capacity(right_expanded);
+    for i in 0..right_expanded {
+        v_gsw_inp.push(v[2 * i + 1].clone());
+    }
+
+    let v_reg_sz = dim0 * 2 * params.poly_len;
+    v_reg_reoriented = AlignedMemory64::new(v_reg_sz);
+    reorient_reg_ciphertexts(params, v_reg_reoriented.as_mut_slice(), &v_reg_inp);
+
+    v_folding = Vec::new();
+    for _ in 0..params.db_dim_2 {
+        v_folding.push(PolyMatrixNTT::zero(params, 2, 2 * params.t_gsw));
+    }
+
+    regev_to_gsw(&mut v_folding, &v_gsw_inp, &v_conversion, params, 1, 0);
+
+    let gadget_ntt = build_gadget(&params, 2, 2 * params.t_gsw).ntt();
+    v_folding_neg = Vec::new();
+    let mut ct_gsw_inv = PolyMatrixRaw::zero(&params, 2, 2 * params.t_gsw);
+    for i in 0..params.db_dim_2 {
+        invert(&mut ct_gsw_inv, &v_folding[i].raw());
+        let mut ct_gsw_neg = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
+        add(&mut ct_gsw_neg, &gadget_ntt, &ct_gsw_inv.ntt());
+        v_folding_neg.push(ct_gsw_neg);
+    }
+
+    (v_reg_reoriented, v_folding, v_folding_neg)
+}
+
+#[cfg(target_feature = "avx2")]
+pub fn process_query(
+    params: &Params, 
+    public_params: &PublicParameters, 
+    query: &Query,
+    db: &[u64],
+) -> Vec<u8> {
+    let dim0 = 1 << params.db_dim_1;
+    let num_per = 1 << params.db_dim_2;
+    let further_dims = params.db_dim_2;
+    let db_slice_sz = dim0 * num_per * params.poly_len;
+
+    
+
+    let v_packing = public_params.v_packing.as_ref();
+
+    if params.expand_queries {
+        
+    }
+
+    let mut intermediate = Vec::with_capacity(num_per);
+    let mut intermediate_raw = Vec::with_capacity(num_per);
+    for _ in 0..dim0 {
+        intermediate.push(PolyMatrixNTT::zero(params, 2, 1));
+        intermediate_raw.push(PolyMatrixRaw::zero(params, 2, 1));
+    }
+
+    let mut v_ct = Vec::new();
+    for trial in 0..(params.n * params.n) {
+        let cur_db = &db[(db_slice_sz * trial)..(db_slice_sz * trial + db_slice_sz)];
+
+        multiply_reg_by_database(&mut intermediate, db, v_reg_reoriented.as_slice(), params, dim0, num_per);
+
+        for i in 0..intermediate.len() {
+            from_ntt(&mut intermediate_raw[i], &intermediate[i]);
+        }
+
+        fold_ciphertexts(
+            params,
+            &mut intermediate_raw,
+            &v_folding,
+            &v_folding_neg
+        );
+
+        v_ct.push(intermediate_raw[0]);
+    }
+
+    let packed_ct = pack(
+        params,
+        &v_ct,
+        &v_packing,
+    );
+
+    let mut v_packed_ct = Vec::new();
+    v_packed_ct.push(packed_ct.raw());
+
+    encode(params, &v_packed_ct)
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -351,12 +566,12 @@ mod test {
         client: &mut Client<'a, StdRng>,
     ) -> u64 {
         let dec = client.decrypt_matrix_reg(ct).raw();
-        let idx = (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
+        let idx = 2 * (params.t_gsw - 1) * params.poly_len + params.poly_len; // this offset should encode a large value
         let mut val = dec.data[idx] as i64;
         if val >= (params.modulus / 2) as i64 {
             val -= params.modulus as i64;
         }
-        if val < 100 {
+        if i64::abs(val) < (1i64 << 10) {
             0
         } else {
             1
@@ -559,4 +774,33 @@ mod test {
         // decrypt
         assert_eq!(dec_reg(&params, &v_reg_raw[0].ntt(), &mut client, scale_k), 1);
     }
+
+    #[test]
+    fn full_protocol_is_correct() {
+        let params = get_params();
+        let mut seeded_rng = get_seeded_rng();
+
+        let dim0 = 1 << params.db_dim_1;
+        let num_per = 1 << params.db_dim_2;
+        let scale_k = params.modulus / params.pt_modulus;
+
+        let target_idx = seeded_rng.gen::<usize>() % (dim0 * num_per);
+        let target_idx_dim0 = target_idx / num_per;
+        let target_idx_num_per = target_idx % num_per;
+
+        let mut client = Client::init(&params, &mut seeded_rng);
+        let public_parameters = client.generate_keys();
+        let query = client.generate_query(target_idx);
+
+        let (corr_item, db) = generate_random_db_and_get_item(&params, target_idx);
+
+        let mut v_reg = Vec::new();
+        for i in 0..dim0 {
+            let val = if i == target_idx_dim0 { scale_k } else { 0 };
+            let sigma = PolyMatrixRaw::single_value(&params, val).ntt();
+            v_reg.push(client.encrypt_matrix_reg(&sigma));
+        }
+
+        
+    }
 }

+ 11 - 3
spiral-rs/src/util.rs

@@ -1,5 +1,5 @@
 use crate::{arith::*, params::*, poly::*};
-use rand::{prelude::StdRng, SeedableRng};
+use rand::{prelude::StdRng, SeedableRng, thread_rng, Rng};
 use serde_json::Value;
 
 pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
@@ -73,14 +73,22 @@ pub fn get_expansion_testing_params() -> Params {
 }
 
 pub fn get_seed() -> [u8; 32] {
+    thread_rng().gen::<[u8; 32]>()
+}
+
+pub fn get_seeded_rng() -> StdRng {
+    StdRng::from_seed(get_seed())
+}
+
+pub fn get_static_seed() -> [u8; 32] {
     [
         1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
         7, 8,
     ]
 }
 
-pub fn get_seeded_rng() -> StdRng {
-    StdRng::from_seed(get_seed())
+pub fn get_static_seeded_rng() -> StdRng {
+    StdRng::from_seed(get_static_seed())
 }
 
 pub const fn get_empty_params() -> Params {