util.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. use crate::{arith::*, client::Seed, params::*, poly::*};
  2. use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
  3. use rand_chacha::ChaCha20Rng;
  4. use serde_json::Value;
  5. use std::fs;
  6. pub const CFG_20_256: &'static str = r#"
  7. {'n': 2,
  8. 'nu_1': 9,
  9. 'nu_2': 6,
  10. 'p': 256,
  11. 'q2_bits': 20,
  12. 's_e': 87.62938774292914,
  13. 't_gsw': 8,
  14. 't_conv': 4,
  15. 't_exp_left': 8,
  16. 't_exp_right': 56,
  17. 'instances': 1,
  18. 'db_item_size': 8192 }
  19. "#;
  20. pub const CFG_16_100000: &'static str = r#"
  21. {'n': 2,
  22. 'nu_1': 10,
  23. 'nu_2': 6,
  24. 'p': 512,
  25. 'q2_bits': 21,
  26. 's_e': 85.83255142749422,
  27. 't_gsw': 10,
  28. 't_conv': 4,
  29. 't_exp_left': 16,
  30. 't_exp_right': 56,
  31. 'instances': 11,
  32. 'db_item_size': 100000 }
  33. "#;
  34. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  35. let mut idx = 0usize;
  36. let mut prod = 1usize;
  37. for i in (0..indices.len()).rev() {
  38. idx += indices[i] * prod;
  39. prod *= lengths[i];
  40. }
  41. idx
  42. }
  43. pub fn get_test_params() -> Params {
  44. Params::init(
  45. 2048,
  46. &vec![268369921u64, 249561089u64],
  47. 6.4,
  48. 2,
  49. 256,
  50. 20,
  51. 4,
  52. 8,
  53. 56,
  54. 8,
  55. true,
  56. 9,
  57. 6,
  58. 1,
  59. 2048,
  60. )
  61. }
  62. pub fn get_short_keygen_params() -> Params {
  63. Params::init(
  64. 2048,
  65. &vec![268369921u64, 249561089u64],
  66. 6.4,
  67. 2,
  68. 256,
  69. 20,
  70. 4,
  71. 4,
  72. 4,
  73. 4,
  74. true,
  75. 9,
  76. 6,
  77. 1,
  78. 2048,
  79. )
  80. }
  81. pub fn get_expansion_testing_params() -> Params {
  82. let cfg = r#"
  83. {'n': 2,
  84. 'nu_1': 9,
  85. 'nu_2': 6,
  86. 'p': 256,
  87. 'q2_bits': 20,
  88. 't_gsw': 8,
  89. 't_conv': 4,
  90. 't_exp_left': 8,
  91. 't_exp_right': 56,
  92. 'instances': 1,
  93. 'db_item_size': 8192 }
  94. "#;
  95. params_from_json(&cfg.replace("'", "\""))
  96. }
  97. pub fn get_fast_expansion_testing_params() -> Params {
  98. let cfg = r#"
  99. {'n': 2,
  100. 'nu_1': 6,
  101. 'nu_2': 2,
  102. 'p': 256,
  103. 'q2_bits': 20,
  104. 't_gsw': 8,
  105. 't_conv': 4,
  106. 't_exp_left': 8,
  107. 't_exp_right': 8,
  108. 'instances': 1,
  109. 'db_item_size': 8192 }
  110. "#;
  111. params_from_json(&cfg.replace("'", "\""))
  112. }
  113. pub fn get_no_expansion_testing_params() -> Params {
  114. let cfg = r#"
  115. {'direct_upload': 1,
  116. 'n': 5,
  117. 'nu_1': 6,
  118. 'nu_2': 3,
  119. 'p': 65536,
  120. 'q2_bits': 27,
  121. 't_gsw': 3,
  122. 't_conv': 56,
  123. 't_exp_left': 56,
  124. 't_exp_right': 56}
  125. "#;
  126. params_from_json(&cfg.replace("'", "\""))
  127. }
  128. pub fn get_seed() -> u64 {
  129. thread_rng().gen::<u64>()
  130. }
  131. pub fn get_seeded_rng() -> SmallRng {
  132. SmallRng::seed_from_u64(get_seed())
  133. }
  134. pub fn get_chacha_seed() -> Seed {
  135. thread_rng().gen::<[u8; 32]>()
  136. }
  137. pub fn get_chacha_rng() -> ChaCha20Rng {
  138. ChaCha20Rng::from_seed(get_chacha_seed())
  139. }
  140. pub fn get_static_seed() -> u64 {
  141. 0x123456789
  142. }
  143. pub fn get_chacha_static_seed() -> Seed {
  144. [
  145. 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x0, 0x1,
  146. 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
  147. ]
  148. }
  149. pub fn get_static_seeded_rng() -> SmallRng {
  150. SmallRng::seed_from_u64(get_static_seed())
  151. }
  152. pub const fn get_empty_params() -> Params {
  153. Params {
  154. poly_len: 0,
  155. poly_len_log2: 0,
  156. ntt_tables: Vec::new(),
  157. scratch: Vec::new(),
  158. crt_count: 0,
  159. barrett_cr_0_modulus: 0,
  160. barrett_cr_1_modulus: 0,
  161. barrett_cr_0: [0u64; MAX_MODULI],
  162. barrett_cr_1: [0u64; MAX_MODULI],
  163. mod0_inv_mod1: 0,
  164. mod1_inv_mod0: 0,
  165. moduli: [0u64; MAX_MODULI],
  166. modulus: 0,
  167. modulus_log2: 0,
  168. noise_width: 0f64,
  169. n: 0,
  170. pt_modulus: 0,
  171. q2_bits: 0,
  172. t_conv: 0,
  173. t_exp_left: 0,
  174. t_exp_right: 0,
  175. t_gsw: 0,
  176. expand_queries: false,
  177. db_dim_1: 0,
  178. db_dim_2: 0,
  179. instances: 0,
  180. db_item_size: 0,
  181. }
  182. }
  183. pub fn params_from_json(cfg: &str) -> Params {
  184. let v: Value = serde_json::from_str(cfg).unwrap();
  185. params_from_json_obj(&v)
  186. }
  187. pub fn params_from_json_obj(v: &Value) -> Params {
  188. let n = v["n"].as_u64().unwrap() as usize;
  189. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  190. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  191. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  192. let p = v["p"].as_u64().unwrap();
  193. let q2_bits = u64::max(v["q2_bits"].as_u64().unwrap(), MIN_Q2_BITS);
  194. let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
  195. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  196. let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
  197. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  198. let do_expansion = v.get("direct_upload").is_none();
  199. let mut db_item_size = v["db_item_size"].as_u64().unwrap_or(0) as usize;
  200. if db_item_size == 0 {
  201. db_item_size = instances * n * n;
  202. db_item_size = db_item_size * 2048 * log2_ceil(p) as usize / 8;
  203. }
  204. Params::init(
  205. 2048,
  206. &vec![268369921u64, 249561089u64],
  207. 6.4,
  208. n,
  209. p,
  210. q2_bits,
  211. t_conv,
  212. t_exp_left,
  213. t_exp_right,
  214. t_gsw,
  215. do_expansion,
  216. db_dim_1,
  217. db_dim_2,
  218. instances,
  219. db_item_size,
  220. )
  221. }
  222. static ALL_PARAMS_STORE_FNAME: &str = "../params_store.json";
  223. pub fn get_params_from_store(target_num_log2: usize, item_size: usize) -> Params {
  224. let params_store_str = fs::read_to_string(ALL_PARAMS_STORE_FNAME).unwrap();
  225. let v: Value = serde_json::from_str(&params_store_str).unwrap();
  226. let nearest_target_num = target_num_log2;
  227. let nearest_item_size = 1 << usize::max(log2_ceil_usize(item_size), 8);
  228. println!(
  229. "Starting with parameters for 2^{} x {} bytes...",
  230. nearest_target_num, nearest_item_size
  231. );
  232. let target = v
  233. .as_array()
  234. .unwrap()
  235. .iter()
  236. .map(|x| x.as_object().unwrap())
  237. .filter(|x| x.get("target_num").unwrap().as_u64().unwrap() == (nearest_target_num as u64))
  238. .filter(|x| x.get("item_size").unwrap().as_u64().unwrap() == (nearest_item_size as u64))
  239. .map(|x| x.get("params").unwrap())
  240. .next()
  241. .unwrap();
  242. params_from_json_obj(target)
  243. }
  244. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  245. let word_off = bit_offs / 64;
  246. let bit_off_within_word = bit_offs % 64;
  247. if (bit_off_within_word + num_bits) <= 64 {
  248. let idx = word_off * 8;
  249. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  250. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  251. } else {
  252. let idx = word_off * 8;
  253. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  254. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  255. }
  256. }
  257. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  258. let word_off = bit_offs / 64;
  259. let bit_off_within_word = bit_offs % 64;
  260. val = val & ((1u64 << num_bits) - 1);
  261. if (bit_off_within_word + num_bits) <= 64 {
  262. let idx = word_off * 8;
  263. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  264. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  265. cur_val |= val << bit_off_within_word;
  266. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  267. } else {
  268. let idx = word_off * 8;
  269. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  270. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  271. cur_val &= mask;
  272. cur_val |= (val as u128) << bit_off_within_word;
  273. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  274. }
  275. }
  276. pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
  277. let poly_len = params.poly_len;
  278. let crt_count = params.crt_count;
  279. assert_eq!(crt_count, 2);
  280. assert!(log2(params.moduli[0]) <= 32);
  281. let num_reg_expanded = 1 << params.db_dim_1;
  282. let ct_rows = v_reg[0].rows;
  283. let ct_cols = v_reg[0].cols;
  284. assert_eq!(ct_rows, 2);
  285. assert_eq!(ct_cols, 1);
  286. for j in 0..num_reg_expanded {
  287. for r in 0..ct_rows {
  288. for m in 0..ct_cols {
  289. for z in 0..params.poly_len {
  290. let idx_a_in =
  291. r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
  292. let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
  293. + j * (ct_cols * ct_rows)
  294. + m * (ct_rows)
  295. + r;
  296. let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
  297. let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
  298. out[idx_a_out] = val1 | (val2 << 32);
  299. }
  300. }
  301. }
  302. }
  303. }
  304. #[cfg(test)]
  305. mod test {
  306. use super::*;
  307. #[test]
  308. fn params_from_json_correct() {
  309. let cfg = r#"
  310. {'n': 2,
  311. 'nu_1': 9,
  312. 'nu_2': 6,
  313. 'p': 256,
  314. 'q2_bits': 20,
  315. 's_e': 87.62938774292914,
  316. 't_gsw': 8,
  317. 't_conv': 4,
  318. 't_exp_left': 8,
  319. 't_exp_right': 56,
  320. 'instances': 1,
  321. 'db_item_size': 2048 }
  322. "#;
  323. let cfg = cfg.replace("'", "\"");
  324. let b = params_from_json(&cfg);
  325. let c = Params::init(
  326. 2048,
  327. &vec![268369921u64, 249561089u64],
  328. 6.4,
  329. 2,
  330. 256,
  331. 20,
  332. 4,
  333. 8,
  334. 56,
  335. 8,
  336. true,
  337. 9,
  338. 6,
  339. 1,
  340. 2048,
  341. );
  342. assert_eq!(b, c);
  343. }
  344. #[test]
  345. fn test_read_write_arbitrary_bits() {
  346. let len = 4096;
  347. let num_bits = 9;
  348. let mut data = vec![0u8; len];
  349. let scaled_len = len * 8 / num_bits - 64;
  350. let mut bit_offs = 0;
  351. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  352. for i in 0..scaled_len {
  353. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  354. bit_offs += num_bits;
  355. }
  356. bit_offs = 0;
  357. for i in 0..scaled_len {
  358. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  359. assert_eq!(val, get_from(i));
  360. bit_offs += num_bits;
  361. }
  362. }
  363. }