util.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. use crate::{arith::*, params::*, poly::*};
  2. use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
  3. use serde_json::Value;
  4. use std::fs;
  5. pub const CFG_20_256: &'static str = r#"
  6. {'n': 2,
  7. 'nu_1': 9,
  8. 'nu_2': 6,
  9. 'p': 256,
  10. 'q2_bits': 20,
  11. 's_e': 87.62938774292914,
  12. 't_gsw': 8,
  13. 't_conv': 4,
  14. 't_exp_left': 8,
  15. 't_exp_right': 56,
  16. 'instances': 1,
  17. 'db_item_size': 8192 }
  18. "#;
  19. pub const CFG_16_100000: &'static str = r#"
  20. {'n': 2,
  21. 'nu_1': 10,
  22. 'nu_2': 6,
  23. 'p': 512,
  24. 'q2_bits': 21,
  25. 's_e': 85.83255142749422,
  26. 't_gsw': 10,
  27. 't_conv': 4,
  28. 't_exp_left': 16,
  29. 't_exp_right': 56,
  30. 'instances': 11,
  31. 'db_item_size': 100000 }
  32. "#;
  33. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  34. let mut idx = 0usize;
  35. let mut prod = 1usize;
  36. for i in (0..indices.len()).rev() {
  37. idx += indices[i] * prod;
  38. prod *= lengths[i];
  39. }
  40. idx
  41. }
  42. pub fn get_test_params() -> Params {
  43. Params::init(
  44. 2048,
  45. &vec![268369921u64, 249561089u64],
  46. 6.4,
  47. 2,
  48. 256,
  49. 20,
  50. 4,
  51. 8,
  52. 56,
  53. 8,
  54. true,
  55. 9,
  56. 6,
  57. 1,
  58. 2048,
  59. )
  60. }
  61. pub fn get_short_keygen_params() -> Params {
  62. Params::init(
  63. 2048,
  64. &vec![268369921u64, 249561089u64],
  65. 6.4,
  66. 2,
  67. 256,
  68. 20,
  69. 4,
  70. 4,
  71. 4,
  72. 4,
  73. true,
  74. 9,
  75. 6,
  76. 1,
  77. 2048,
  78. )
  79. }
  80. pub fn get_expansion_testing_params() -> Params {
  81. let cfg = r#"
  82. {'n': 2,
  83. 'nu_1': 9,
  84. 'nu_2': 6,
  85. 'p': 256,
  86. 'q2_bits': 20,
  87. 't_gsw': 8,
  88. 't_conv': 4,
  89. 't_exp_left': 8,
  90. 't_exp_right': 56,
  91. 'instances': 1,
  92. 'db_item_size': 8192 }
  93. "#;
  94. params_from_json(&cfg.replace("'", "\""))
  95. }
  96. pub fn get_fast_expansion_testing_params() -> Params {
  97. let cfg = r#"
  98. {'n': 2,
  99. 'nu_1': 6,
  100. 'nu_2': 2,
  101. 'p': 256,
  102. 'q2_bits': 20,
  103. 't_gsw': 8,
  104. 't_conv': 4,
  105. 't_exp_left': 8,
  106. 't_exp_right': 8,
  107. 'instances': 1,
  108. 'db_item_size': 8192 }
  109. "#;
  110. params_from_json(&cfg.replace("'", "\""))
  111. }
  112. pub fn get_no_expansion_testing_params() -> Params {
  113. let cfg = r#"
  114. {'direct_upload': 1,
  115. 'n': 5,
  116. 'nu_1': 6,
  117. 'nu_2': 3,
  118. 'p': 65536,
  119. 'q2_bits': 27,
  120. 't_gsw': 3,
  121. 't_conv': 56,
  122. 't_exp_left': 56,
  123. 't_exp_right': 56}
  124. "#;
  125. params_from_json(&cfg.replace("'", "\""))
  126. }
  127. pub fn get_seed() -> u64 {
  128. thread_rng().gen::<u64>()
  129. }
  130. pub fn get_seeded_rng() -> SmallRng {
  131. SmallRng::seed_from_u64(get_seed())
  132. }
  133. pub fn get_static_seed() -> u64 {
  134. 0x123456789
  135. }
  136. pub fn get_static_seeded_rng() -> SmallRng {
  137. SmallRng::seed_from_u64(get_static_seed())
  138. }
  139. pub const fn get_empty_params() -> Params {
  140. Params {
  141. poly_len: 0,
  142. poly_len_log2: 0,
  143. ntt_tables: Vec::new(),
  144. scratch: Vec::new(),
  145. crt_count: 0,
  146. barrett_cr_0_modulus: 0,
  147. barrett_cr_1_modulus: 0,
  148. barrett_cr_0: [0u64; MAX_MODULI],
  149. barrett_cr_1: [0u64; MAX_MODULI],
  150. mod0_inv_mod1: 0,
  151. mod1_inv_mod0: 0,
  152. moduli: [0u64; MAX_MODULI],
  153. modulus: 0,
  154. modulus_log2: 0,
  155. noise_width: 0f64,
  156. n: 0,
  157. pt_modulus: 0,
  158. q2_bits: 0,
  159. t_conv: 0,
  160. t_exp_left: 0,
  161. t_exp_right: 0,
  162. t_gsw: 0,
  163. expand_queries: false,
  164. db_dim_1: 0,
  165. db_dim_2: 0,
  166. instances: 0,
  167. db_item_size: 0,
  168. }
  169. }
  170. pub fn params_from_json(cfg: &str) -> Params {
  171. let v: Value = serde_json::from_str(cfg).unwrap();
  172. params_from_json_obj(&v)
  173. }
  174. pub fn params_from_json_obj(v: &Value) -> Params {
  175. let n = v["n"].as_u64().unwrap() as usize;
  176. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  177. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  178. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  179. let p = v["p"].as_u64().unwrap();
  180. let q2_bits = u64::max(v["q2_bits"].as_u64().unwrap(), MIN_Q2_BITS);
  181. let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
  182. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  183. let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
  184. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  185. let do_expansion = v.get("direct_upload").is_none();
  186. let mut db_item_size = v["db_item_size"].as_u64().unwrap_or(0) as usize;
  187. if db_item_size == 0 {
  188. db_item_size = instances * n * n;
  189. db_item_size = db_item_size * 2048 * log2_ceil(p) as usize / 8;
  190. }
  191. Params::init(
  192. 2048,
  193. &vec![268369921u64, 249561089u64],
  194. 6.4,
  195. n,
  196. p,
  197. q2_bits,
  198. t_conv,
  199. t_exp_left,
  200. t_exp_right,
  201. t_gsw,
  202. do_expansion,
  203. db_dim_1,
  204. db_dim_2,
  205. instances,
  206. db_item_size,
  207. )
  208. }
  209. static ALL_PARAMS_STORE_FNAME: &str = "../params_store.json";
  210. pub fn get_params_from_store(target_num_log2: usize, item_size: usize) -> Params {
  211. let params_store_str = fs::read_to_string(ALL_PARAMS_STORE_FNAME).unwrap();
  212. let v: Value = serde_json::from_str(&params_store_str).unwrap();
  213. let nearest_target_num = target_num_log2;
  214. let nearest_item_size = 1 << usize::max(log2_ceil_usize(item_size), 8);
  215. println!("{} x {}", nearest_target_num, nearest_item_size);
  216. let target = v.as_array().unwrap().iter()
  217. .map(|x| x.as_object().unwrap() )
  218. .filter(|x| x.get("target_num").unwrap().as_u64().unwrap() == (nearest_target_num as u64))
  219. .filter(|x| x.get("item_size").unwrap().as_u64().unwrap() == (nearest_item_size as u64))
  220. .map(|x| x.get("params").unwrap())
  221. .next().unwrap();
  222. params_from_json_obj(target)
  223. }
  224. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  225. let word_off = bit_offs / 64;
  226. let bit_off_within_word = bit_offs % 64;
  227. if (bit_off_within_word + num_bits) <= 64 {
  228. let idx = word_off * 8;
  229. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  230. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  231. } else {
  232. let idx = word_off * 8;
  233. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  234. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  235. }
  236. }
  237. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  238. let word_off = bit_offs / 64;
  239. let bit_off_within_word = bit_offs % 64;
  240. val = val & ((1u64 << num_bits) - 1);
  241. if (bit_off_within_word + num_bits) <= 64 {
  242. let idx = word_off * 8;
  243. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  244. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  245. cur_val |= val << bit_off_within_word;
  246. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  247. } else {
  248. let idx = word_off * 8;
  249. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  250. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  251. cur_val &= mask;
  252. cur_val |= (val as u128) << bit_off_within_word;
  253. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  254. }
  255. }
  256. pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
  257. let poly_len = params.poly_len;
  258. let crt_count = params.crt_count;
  259. assert_eq!(crt_count, 2);
  260. assert!(log2(params.moduli[0]) <= 32);
  261. let num_reg_expanded = 1 << params.db_dim_1;
  262. let ct_rows = v_reg[0].rows;
  263. let ct_cols = v_reg[0].cols;
  264. assert_eq!(ct_rows, 2);
  265. assert_eq!(ct_cols, 1);
  266. for j in 0..num_reg_expanded {
  267. for r in 0..ct_rows {
  268. for m in 0..ct_cols {
  269. for z in 0..params.poly_len {
  270. let idx_a_in =
  271. r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
  272. let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
  273. + j * (ct_cols * ct_rows)
  274. + m * (ct_rows)
  275. + r;
  276. let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
  277. let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
  278. out[idx_a_out] = val1 | (val2 << 32);
  279. }
  280. }
  281. }
  282. }
  283. }
  284. #[cfg(test)]
  285. mod test {
  286. use super::*;
  287. #[test]
  288. fn params_from_json_correct() {
  289. let cfg = r#"
  290. {'n': 2,
  291. 'nu_1': 9,
  292. 'nu_2': 6,
  293. 'p': 256,
  294. 'q2_bits': 20,
  295. 's_e': 87.62938774292914,
  296. 't_gsw': 8,
  297. 't_conv': 4,
  298. 't_exp_left': 8,
  299. 't_exp_right': 56,
  300. 'instances': 1,
  301. 'db_item_size': 2048 }
  302. "#;
  303. let cfg = cfg.replace("'", "\"");
  304. let b = params_from_json(&cfg);
  305. let c = Params::init(
  306. 2048,
  307. &vec![268369921u64, 249561089u64],
  308. 6.4,
  309. 2,
  310. 256,
  311. 20,
  312. 4,
  313. 8,
  314. 56,
  315. 8,
  316. true,
  317. 9,
  318. 6,
  319. 1,
  320. 2048,
  321. );
  322. assert_eq!(b, c);
  323. }
  324. #[test]
  325. fn test_read_write_arbitrary_bits() {
  326. let len = 4096;
  327. let num_bits = 9;
  328. let mut data = vec![0u8; len];
  329. let scaled_len = len * 8 / num_bits - 64;
  330. let mut bit_offs = 0;
  331. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  332. for i in 0..scaled_len {
  333. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  334. bit_offs += num_bits;
  335. }
  336. bit_offs = 0;
  337. for i in 0..scaled_len {
  338. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  339. assert_eq!(val, get_from(i));
  340. bit_offs += num_bits;
  341. }
  342. }
  343. }