|
@@ -1,4 +1,6 @@
|
|
|
-use crate::{poly::*, params::*, discrete_gaussian::*, gadget::*, arith::*, util::*, number_theory::*};
|
|
|
+use crate::{
|
|
|
+ arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
|
|
|
+};
|
|
|
use std::iter::once;
|
|
|
|
|
|
fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
|
|
@@ -14,24 +16,24 @@ fn serialize_vec_polymatrix(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
|
|
|
}
|
|
|
|
|
|
pub struct PublicParameters<'a> {
|
|
|
- v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
|
|
|
+ v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
|
|
|
v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
|
|
|
v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
|
|
|
- v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
|
|
|
+ v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
|
|
|
}
|
|
|
|
|
|
impl<'a> PublicParameters<'a> {
|
|
|
pub fn init(params: &'a Params) -> Self {
|
|
|
if params.expand_queries {
|
|
|
- PublicParameters {
|
|
|
- v_packing: Vec::new(),
|
|
|
- v_expansion_left: Some(Vec::new()),
|
|
|
- v_expansion_right: Some(Vec::new()),
|
|
|
- v_conversion: Some(Vec::new())
|
|
|
+ PublicParameters {
|
|
|
+ v_packing: Vec::new(),
|
|
|
+ v_expansion_left: Some(Vec::new()),
|
|
|
+ v_expansion_right: Some(Vec::new()),
|
|
|
+ v_conversion: Some(Vec::new()),
|
|
|
}
|
|
|
} else {
|
|
|
- PublicParameters {
|
|
|
- v_packing: Vec::new(),
|
|
|
+ PublicParameters {
|
|
|
+ v_packing: Vec::new(),
|
|
|
v_expansion_left: None,
|
|
|
v_expansion_right: None,
|
|
|
v_conversion: None,
|
|
@@ -43,7 +45,9 @@ impl<'a> PublicParameters<'a> {
|
|
|
Some(v.iter().map(from_ntt_alloc).collect())
|
|
|
}
|
|
|
|
|
|
- fn from_ntt_alloc_opt_vec(v: &Option<Vec<PolyMatrixNTT<'a>>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
|
|
|
+ fn from_ntt_alloc_opt_vec(
|
|
|
+ v: &Option<Vec<PolyMatrixNTT<'a>>>,
|
|
|
+ ) -> Option<Vec<PolyMatrixRaw<'a>>> {
|
|
|
Some(v.as_ref()?.iter().map(from_ntt_alloc).collect())
|
|
|
}
|
|
|
|
|
@@ -75,7 +79,11 @@ pub struct Query<'a> {
|
|
|
|
|
|
impl<'a> Query<'a> {
|
|
|
pub fn empty() -> Self {
|
|
|
- Query { ct: None, v_ct: None, v_buf: None }
|
|
|
+ Query {
|
|
|
+ ct: None,
|
|
|
+ v_ct: None,
|
|
|
+ v_buf: None,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
pub fn serialize(&self) -> Vec<u8> {
|
|
@@ -86,7 +94,7 @@ impl<'a> Query<'a> {
|
|
|
}
|
|
|
if self.v_buf.is_some() {
|
|
|
let v_buf = self.v_buf.as_ref().unwrap();
|
|
|
- data.extend(v_buf.iter().map(|x| { x.to_ne_bytes() }).flatten());
|
|
|
+ data.extend(v_buf.iter().map(|x| x.to_ne_bytes()).flatten());
|
|
|
}
|
|
|
if self.v_ct.is_some() {
|
|
|
let v_ct = self.v_ct.as_ref().unwrap();
|
|
@@ -109,7 +117,7 @@ pub struct Client<'a> {
|
|
|
stop_round: usize,
|
|
|
}
|
|
|
|
|
|
-fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
|
|
|
+fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
|
|
|
assert_eq!(p.cols, 1);
|
|
|
let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
|
|
|
r.copy_into(p, 0, 0);
|
|
@@ -119,8 +127,8 @@ fn matrix_with_identity<'a> (p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
|
|
|
|
|
|
fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
|
|
|
Params::init(
|
|
|
- params.poly_len,
|
|
|
- moduli,
|
|
|
+ params.poly_len,
|
|
|
+ moduli,
|
|
|
params.noise_width,
|
|
|
params.n,
|
|
|
params.pt_modulus,
|
|
@@ -206,14 +214,18 @@ impl<'a> Client<'a> {
|
|
|
let res = &(p.ntt()) + &(ag.pad_top(1));
|
|
|
res
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
fn encrypt_matrix_reg(&mut self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
|
|
|
let m = a.cols;
|
|
|
let p = self.get_fresh_reg_public_key(m);
|
|
|
&p + &a.pad_top(1)
|
|
|
}
|
|
|
|
|
|
- fn generate_expansion_params(&mut self, num_exp: usize, m_exp: usize) -> Vec<PolyMatrixNTT<'a>> {
|
|
|
+ fn generate_expansion_params(
|
|
|
+ &mut self,
|
|
|
+ num_exp: usize,
|
|
|
+ m_exp: usize,
|
|
|
+ ) -> Vec<PolyMatrixNTT<'a>> {
|
|
|
let params = self.params;
|
|
|
let g_exp = build_gadget(params, 1, m_exp);
|
|
|
let g_exp_ntt = g_exp.ntt();
|
|
@@ -239,7 +251,7 @@ impl<'a> Client<'a> {
|
|
|
let m_conv = params.m_conv();
|
|
|
|
|
|
let mut pp = PublicParameters::init(params);
|
|
|
-
|
|
|
+
|
|
|
// Params for packing
|
|
|
let gadget_conv = build_gadget(params, 1, m_conv);
|
|
|
let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
|
|
@@ -253,16 +265,21 @@ impl<'a> Client<'a> {
|
|
|
|
|
|
if params.expand_queries {
|
|
|
// Params for expansion
|
|
|
-
|
|
|
+
|
|
|
pp.v_expansion_left = Some(self.generate_expansion_params(self.g, params.t_exp_left));
|
|
|
- pp.v_expansion_right = Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
|
|
|
+ pp.v_expansion_right =
|
|
|
+ Some(self.generate_expansion_params(self.stop_round + 1, params.t_exp_right));
|
|
|
|
|
|
// Params for converison
|
|
|
let g_conv = build_gadget(params, 2, 2 * m_conv);
|
|
|
let sk_reg_ntt = self.sk_reg.ntt();
|
|
|
let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt;
|
|
|
- pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(params, 2, 2 * m_conv))));
|
|
|
- for i in 0..2*m_conv {
|
|
|
+ pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(
|
|
|
+ params,
|
|
|
+ 2,
|
|
|
+ 2 * m_conv,
|
|
|
+ ))));
|
|
|
+ for i in 0..2 * m_conv {
|
|
|
let sigma;
|
|
|
if i % 2 == 0 {
|
|
|
let val = g_conv.get_poly(0, i)[0];
|
|
@@ -288,7 +305,7 @@ impl<'a> Client<'a> {
|
|
|
assert_eq!(crt_count, 2);
|
|
|
assert!(log2(params.moduli[0]) <= 32);
|
|
|
|
|
|
- let num_reg_expanded = 1<<params.db_dim_1;
|
|
|
+ let num_reg_expanded = 1 << params.db_dim_1;
|
|
|
let ct_rows = v_reg[0].rows;
|
|
|
let ct_cols = v_reg[0].cols;
|
|
|
|
|
@@ -297,13 +314,14 @@ impl<'a> Client<'a> {
|
|
|
|
|
|
for j in 0..num_reg_expanded {
|
|
|
for r in 0..ct_rows {
|
|
|
- for m in 0.. ct_cols {
|
|
|
+ for m in 0..ct_cols {
|
|
|
for z in 0..params.poly_len {
|
|
|
- let idx_a_in = r * (ct_cols*crt_count*poly_len) + m * (crt_count*poly_len);
|
|
|
- let idx_a_out = z * (num_reg_expanded*ct_cols*ct_rows)
|
|
|
- + j * (ct_cols*ct_rows)
|
|
|
- + m * (ct_rows)
|
|
|
- + r;
|
|
|
+ let idx_a_in =
|
|
|
+ r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
|
|
|
+ let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
|
|
|
+ + j * (ct_cols * ct_rows)
|
|
|
+ + m * (ct_rows)
|
|
|
+ + r;
|
|
|
let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
|
|
|
let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
|
|
|
|
|
@@ -317,8 +335,8 @@ impl<'a> Client<'a> {
|
|
|
pub fn generate_query(&mut self, idx_target: usize) -> Query<'a> {
|
|
|
let params = self.params;
|
|
|
let further_dims = params.db_dim_2;
|
|
|
- let idx_dim0= idx_target / (1 << further_dims);
|
|
|
- let idx_further = idx_target % (1 << further_dims);
|
|
|
+ let idx_dim0 = idx_target / (1 << further_dims);
|
|
|
+ let idx_further = idx_target % (1 << further_dims);
|
|
|
let scale_k = params.modulus / params.pt_modulus;
|
|
|
let bits_per = get_bits_per(params, params.t_gsw);
|
|
|
|
|
@@ -326,24 +344,28 @@ impl<'a> Client<'a> {
|
|
|
if params.expand_queries {
|
|
|
// pack query into single ciphertext
|
|
|
let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
|
|
|
- sigma.data[2*idx_dim0] = scale_k;
|
|
|
+ sigma.data[2 * idx_dim0] = scale_k;
|
|
|
for i in 0..further_dims as u64 {
|
|
|
let bit: u64 = ((idx_further as u64) & (1 << i)) >> i;
|
|
|
for j in 0..params.t_gsw {
|
|
|
let val = (1u64 << (bits_per * j)) * bit;
|
|
|
let idx = (i as usize) * params.t_gsw + (j as usize);
|
|
|
- sigma.data[2*idx + 1] = val;
|
|
|
+ sigma.data[2 * idx + 1] = val;
|
|
|
}
|
|
|
}
|
|
|
let inv_2_g_first = invert_uint_mod(1 << self.g, params.modulus).unwrap();
|
|
|
- let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round+1), params.modulus).unwrap();
|
|
|
+ let inv_2_g_rest = invert_uint_mod(1 << (self.stop_round + 1), params.modulus).unwrap();
|
|
|
|
|
|
- for i in 0..params.poly_len/2 {
|
|
|
- sigma.data[2*i] = multiply_uint_mod(sigma.data[2*i], inv_2_g_first, params.modulus);
|
|
|
- sigma.data[2*i+1] = multiply_uint_mod(sigma.data[2*i+1], inv_2_g_rest, params.modulus);
|
|
|
+ for i in 0..params.poly_len / 2 {
|
|
|
+ sigma.data[2 * i] =
|
|
|
+ multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
|
|
|
+ sigma.data[2 * i + 1] =
|
|
|
+ multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
|
|
|
}
|
|
|
|
|
|
- query.ct = Some(from_ntt_alloc(&self.encrypt_matrix_reg(&to_ntt_alloc(&sigma))));
|
|
|
+ query.ct = Some(from_ntt_alloc(
|
|
|
+ &self.encrypt_matrix_reg(&to_ntt_alloc(&sigma)),
|
|
|
+ ));
|
|
|
} else {
|
|
|
let num_expanded = 1 << params.db_dim_1;
|
|
|
let mut sigma_v = Vec::<PolyMatrixNTT>::new();
|
|
@@ -370,20 +392,20 @@ impl<'a> Client<'a> {
|
|
|
let sigma = PolyMatrixRaw::single_value(¶ms, value);
|
|
|
let sigma_ntt = to_ntt_alloc(&sigma);
|
|
|
let ct = &self.encrypt_matrix_reg(&sigma_ntt);
|
|
|
- ct_gsw.copy_into(ct, 0, 2*j + 1);
|
|
|
+ ct_gsw.copy_into(ct, 0, 2 * j + 1);
|
|
|
let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
|
|
|
let ct = &self.encrypt_matrix_reg(&prod);
|
|
|
- ct_gsw.copy_into(ct, 0, 2*j);
|
|
|
+ ct_gsw.copy_into(ct, 0, 2 * j);
|
|
|
}
|
|
|
sigma_v.push(ct_gsw);
|
|
|
}
|
|
|
|
|
|
query.v_buf = Some(reg_cts_buf);
|
|
|
- query.v_ct = Some(sigma_v.iter().map(|x| { from_ntt_alloc(x) }).collect());
|
|
|
+ query.v_ct = Some(sigma_v.iter().map(|x| from_ntt_alloc(x)).collect());
|
|
|
}
|
|
|
query
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
pub fn decode_response(&self, data: &[u8]) -> Vec<u8> {
|
|
|
/*
|
|
|
0. NTT over q2 the secret key
|
|
@@ -438,11 +460,11 @@ impl<'a> Client<'a> {
|
|
|
let p_i128 = p as i128;
|
|
|
for i in 0..params.n * params.n * params.poly_len {
|
|
|
let mut val_first = sk_prod.data[i] as i64;
|
|
|
- if val_first >= q2_i64/2 {
|
|
|
+ if val_first >= q2_i64 / 2 {
|
|
|
val_first -= q2_i64;
|
|
|
}
|
|
|
let mut val_rest = rest_rows.data[i] as i64;
|
|
|
- if val_rest >= q1_i64/2 {
|
|
|
+ if val_rest >= q1_i64 / 2 {
|
|
|
val_rest -= q1_i64;
|
|
|
}
|
|
|
|
|
@@ -453,8 +475,8 @@ impl<'a> Client<'a> {
|
|
|
|
|
|
// divide r by q2, rounding
|
|
|
let sign: i64 = if r >= 0 { 1 } else { -1 };
|
|
|
- let mut res = ((r + sign*(denom/2)) as i128) / (denom as i128);
|
|
|
- res = (res + (denom as i128/p_i128)*(p_i128) + 2*(p_i128)) % (p_i128);
|
|
|
+ let mut res = ((r + sign * (denom / 2)) as i128) / (denom as i128);
|
|
|
+ res = (res + (denom as i128 / p_i128) * (p_i128) + 2 * (p_i128)) % (p_i128);
|
|
|
let idx = instance * params.n * params.n * params.poly_len + i;
|
|
|
result.data[idx] = res as u64;
|
|
|
}
|
|
@@ -468,4 +490,4 @@ impl<'a> Client<'a> {
|
|
|
let modp_words_per_chunk = f64::ceil((bytes_per_chunk * 8) as f64 / logp as f64) as usize;
|
|
|
result.to_vec(p_bits as usize, modp_words_per_chunk)
|
|
|
}
|
|
|
-}
|
|
|
+}
|