123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- use crate::{arith::*, params::*, poly::*};
- use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
- use serde_json::Value;
- pub const CFG_20_256: &'static str = r#"
- {'n': 2,
- 'nu_1': 9,
- 'nu_2': 6,
- 'p': 256,
- 'q2_bits': 20,
- 's_e': 87.62938774292914,
- 't_gsw': 8,
- 't_conv': 4,
- 't_exp_left': 8,
- 't_exp_right': 56,
- 'instances': 1,
- 'db_item_size': 8192 }
- "#;
- pub const CFG_16_100000: &'static str = r#"
- {'n': 2,
- 'nu_1': 10,
- 'nu_2': 6,
- 'p': 512,
- 'q2_bits': 21,
- 's_e': 85.83255142749422,
- 't_gsw': 10,
- 't_conv': 4,
- 't_exp_left': 16,
- 't_exp_right': 56,
- 'instances': 11,
- 'db_item_size': 100000 }
- "#;
- pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
- let mut idx = 0usize;
- let mut prod = 1usize;
- for i in (0..indices.len()).rev() {
- idx += indices[i] * prod;
- prod *= lengths[i];
- }
- idx
- }
- pub fn get_test_params() -> Params {
- Params::init(
- 2048,
- &vec![268369921u64, 249561089u64],
- 6.4,
- 2,
- 256,
- 20,
- 4,
- 8,
- 56,
- 8,
- true,
- 9,
- 6,
- 1,
- 2048,
- )
- }
- pub fn get_short_keygen_params() -> Params {
- Params::init(
- 2048,
- &vec![268369921u64, 249561089u64],
- 6.4,
- 2,
- 256,
- 20,
- 4,
- 4,
- 4,
- 4,
- true,
- 9,
- 6,
- 1,
- 2048,
- )
- }
- pub fn get_expansion_testing_params() -> Params {
- let cfg = r#"
- {'n': 2,
- 'nu_1': 9,
- 'nu_2': 6,
- 'p': 256,
- 'q2_bits': 20,
- 't_gsw': 8,
- 't_conv': 4,
- 't_exp_left': 8,
- 't_exp_right': 56,
- 'instances': 1,
- 'db_item_size': 8192 }
- "#;
- params_from_json(&cfg.replace("'", "\""))
- }
- pub fn get_fast_expansion_testing_params() -> Params {
- let cfg = r#"
- {'n': 2,
- 'nu_1': 6,
- 'nu_2': 2,
- 'p': 256,
- 'q2_bits': 20,
- 't_gsw': 8,
- 't_conv': 4,
- 't_exp_left': 8,
- 't_exp_right': 8,
- 'instances': 1,
- 'db_item_size': 8192 }
- "#;
- params_from_json(&cfg.replace("'", "\""))
- }
- pub fn get_no_expansion_testing_params() -> Params {
- let cfg = r#"
- {'direct_upload': 1,
- 'n': 5,
- 'nu_1': 6,
- 'nu_2': 3,
- 'p': 65536,
- 'q2_bits': 27,
- 't_gsw': 3,
- 't_conv': 56,
- 't_exp_left': 56,
- 't_exp_right': 56}
- "#;
- params_from_json(&cfg.replace("'", "\""))
- }
- pub fn get_seed() -> [u8; 32] {
- thread_rng().gen::<[u8; 32]>()
- }
- pub fn get_seeded_rng() -> SmallRng {
- SmallRng::from_seed(get_seed())
- }
- pub fn get_static_seed() -> [u8; 32] {
- [
- 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
- 7, 8,
- ]
- }
- pub fn get_static_seeded_rng() -> SmallRng {
- SmallRng::from_seed(get_static_seed())
- }
- pub const fn get_empty_params() -> Params {
- Params {
- poly_len: 0,
- poly_len_log2: 0,
- ntt_tables: Vec::new(),
- scratch: Vec::new(),
- crt_count: 0,
- barrett_cr_0_modulus: 0,
- barrett_cr_1_modulus: 0,
- barrett_cr_0: [0u64; MAX_MODULI],
- barrett_cr_1: [0u64; MAX_MODULI],
- mod0_inv_mod1: 0,
- mod1_inv_mod0: 0,
- moduli: [0u64; MAX_MODULI],
- modulus: 0,
- modulus_log2: 0,
- noise_width: 0f64,
- n: 0,
- pt_modulus: 0,
- q2_bits: 0,
- t_conv: 0,
- t_exp_left: 0,
- t_exp_right: 0,
- t_gsw: 0,
- expand_queries: false,
- db_dim_1: 0,
- db_dim_2: 0,
- instances: 0,
- db_item_size: 0,
- }
- }
- pub fn params_from_json(cfg: &str) -> Params {
- let v: Value = serde_json::from_str(cfg).unwrap();
- let n = v["n"].as_u64().unwrap() as usize;
- let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
- let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
- let instances = v["instances"].as_u64().unwrap_or(1) as usize;
- let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
- let p = v["p"].as_u64().unwrap();
- let q2_bits = v["q2_bits"].as_u64().unwrap();
- let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
- let t_conv = v["t_conv"].as_u64().unwrap() as usize;
- let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
- let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
- let do_expansion = v.get("direct_upload").is_none();
- Params::init(
- 2048,
- &vec![268369921u64, 249561089u64],
- 6.4,
- n,
- p,
- q2_bits,
- t_conv,
- t_exp_left,
- t_exp_right,
- t_gsw,
- do_expansion,
- db_dim_1,
- db_dim_2,
- instances,
- db_item_size,
- )
- }
- pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
- let word_off = bit_offs / 64;
- let bit_off_within_word = bit_offs % 64;
- if (bit_off_within_word + num_bits) <= 64 {
- let idx = word_off * 8;
- let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
- (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
- } else {
- let idx = word_off * 8;
- let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
- ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
- }
- }
- pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
- let word_off = bit_offs / 64;
- let bit_off_within_word = bit_offs % 64;
- val = val & ((1u64 << num_bits) - 1);
- if (bit_off_within_word + num_bits) <= 64 {
- let idx = word_off * 8;
- let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
- cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
- cur_val |= val << bit_off_within_word;
- data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
- } else {
- let idx = word_off * 8;
- let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
- let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
- cur_val &= mask;
- cur_val |= (val as u128) << bit_off_within_word;
- data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
- }
- }
- pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
- let poly_len = params.poly_len;
- let crt_count = params.crt_count;
- assert_eq!(crt_count, 2);
- assert!(log2(params.moduli[0]) <= 32);
- let num_reg_expanded = 1 << params.db_dim_1;
- let ct_rows = v_reg[0].rows;
- let ct_cols = v_reg[0].cols;
- assert_eq!(ct_rows, 2);
- assert_eq!(ct_cols, 1);
- for j in 0..num_reg_expanded {
- for r in 0..ct_rows {
- for m in 0..ct_cols {
- for z in 0..params.poly_len {
- let idx_a_in =
- r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
- let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
- + j * (ct_cols * ct_rows)
- + m * (ct_rows)
- + r;
- let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
- let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
- out[idx_a_out] = val1 | (val2 << 32);
- }
- }
- }
- }
- }
- #[cfg(test)]
- mod test {
- use super::*;
- #[test]
- fn params_from_json_correct() {
- let cfg = r#"
- {'n': 2,
- 'nu_1': 9,
- 'nu_2': 6,
- 'p': 256,
- 'q2_bits': 20,
- 's_e': 87.62938774292914,
- 't_gsw': 8,
- 't_conv': 4,
- 't_exp_left': 8,
- 't_exp_right': 56,
- 'instances': 1,
- 'db_item_size': 2048 }
- "#;
- let cfg = cfg.replace("'", "\"");
- let b = params_from_json(&cfg);
- let c = Params::init(
- 2048,
- &vec![268369921u64, 249561089u64],
- 6.4,
- 2,
- 256,
- 20,
- 4,
- 8,
- 56,
- 8,
- true,
- 9,
- 6,
- 1,
- 2048,
- );
- assert_eq!(b, c);
- }
- #[test]
- fn test_read_write_arbitrary_bits() {
- let len = 4096;
- let num_bits = 9;
- let mut data = vec![0u8; len];
- let scaled_len = len * 8 / num_bits - 64;
- let mut bit_offs = 0;
- let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
- for i in 0..scaled_len {
- write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
- bit_offs += num_bits;
- }
- bit_offs = 0;
- for i in 0..scaled_len {
- let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
- assert_eq!(val, get_from(i));
- bit_offs += num_bits;
- }
- }
- }
|