util.rs 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. use crate::{arith::*, params::*, poly::*};
  2. use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
  3. use serde_json::Value;
  4. pub const CFG_20_256: &'static str = r#"
  5. {'n': 2,
  6. 'nu_1': 9,
  7. 'nu_2': 6,
  8. 'p': 256,
  9. 'q2_bits': 20,
  10. 's_e': 87.62938774292914,
  11. 't_gsw': 8,
  12. 't_conv': 4,
  13. 't_exp_left': 8,
  14. 't_exp_right': 56,
  15. 'instances': 1,
  16. 'db_item_size': 8192 }
  17. "#;
  18. pub const CFG_16_100000: &'static str = r#"
  19. {'n': 2,
  20. 'nu_1': 10,
  21. 'nu_2': 6,
  22. 'p': 512,
  23. 'q2_bits': 21,
  24. 's_e': 85.83255142749422,
  25. 't_gsw': 10,
  26. 't_conv': 4,
  27. 't_exp_left': 16,
  28. 't_exp_right': 56,
  29. 'instances': 11,
  30. 'db_item_size': 100000 }
  31. "#;
  32. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  33. let mut idx = 0usize;
  34. let mut prod = 1usize;
  35. for i in (0..indices.len()).rev() {
  36. idx += indices[i] * prod;
  37. prod *= lengths[i];
  38. }
  39. idx
  40. }
  41. pub fn get_test_params() -> Params {
  42. Params::init(
  43. 2048,
  44. &vec![268369921u64, 249561089u64],
  45. 6.4,
  46. 2,
  47. 256,
  48. 20,
  49. 4,
  50. 8,
  51. 56,
  52. 8,
  53. true,
  54. 9,
  55. 6,
  56. 1,
  57. 2048,
  58. )
  59. }
  60. pub fn get_short_keygen_params() -> Params {
  61. Params::init(
  62. 2048,
  63. &vec![268369921u64, 249561089u64],
  64. 6.4,
  65. 2,
  66. 256,
  67. 20,
  68. 4,
  69. 4,
  70. 4,
  71. 4,
  72. true,
  73. 9,
  74. 6,
  75. 1,
  76. 2048,
  77. )
  78. }
  79. pub fn get_expansion_testing_params() -> Params {
  80. let cfg = r#"
  81. {'n': 2,
  82. 'nu_1': 9,
  83. 'nu_2': 6,
  84. 'p': 256,
  85. 'q2_bits': 20,
  86. 't_gsw': 8,
  87. 't_conv': 4,
  88. 't_exp_left': 8,
  89. 't_exp_right': 56,
  90. 'instances': 1,
  91. 'db_item_size': 8192 }
  92. "#;
  93. params_from_json(&cfg.replace("'", "\""))
  94. }
  95. pub fn get_fast_expansion_testing_params() -> Params {
  96. let cfg = r#"
  97. {'n': 2,
  98. 'nu_1': 6,
  99. 'nu_2': 2,
  100. 'p': 256,
  101. 'q2_bits': 20,
  102. 't_gsw': 8,
  103. 't_conv': 4,
  104. 't_exp_left': 8,
  105. 't_exp_right': 8,
  106. 'instances': 1,
  107. 'db_item_size': 8192 }
  108. "#;
  109. params_from_json(&cfg.replace("'", "\""))
  110. }
  111. pub fn get_no_expansion_testing_params() -> Params {
  112. let cfg = r#"
  113. {'direct_upload': 1,
  114. 'n': 5,
  115. 'nu_1': 6,
  116. 'nu_2': 3,
  117. 'p': 65536,
  118. 'q2_bits': 27,
  119. 't_gsw': 3,
  120. 't_conv': 56,
  121. 't_exp_left': 56,
  122. 't_exp_right': 56}
  123. "#;
  124. params_from_json(&cfg.replace("'", "\""))
  125. }
  126. pub fn get_seed() -> [u8; 32] {
  127. thread_rng().gen::<[u8; 32]>()
  128. }
  129. pub fn get_seeded_rng() -> SmallRng {
  130. SmallRng::from_seed(get_seed())
  131. }
  132. pub fn get_static_seed() -> [u8; 32] {
  133. [
  134. 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
  135. 7, 8,
  136. ]
  137. }
  138. pub fn get_static_seeded_rng() -> SmallRng {
  139. SmallRng::from_seed(get_static_seed())
  140. }
  141. pub const fn get_empty_params() -> Params {
  142. Params {
  143. poly_len: 0,
  144. poly_len_log2: 0,
  145. ntt_tables: Vec::new(),
  146. scratch: Vec::new(),
  147. crt_count: 0,
  148. barrett_cr_0_modulus: 0,
  149. barrett_cr_1_modulus: 0,
  150. barrett_cr_0: [0u64; MAX_MODULI],
  151. barrett_cr_1: [0u64; MAX_MODULI],
  152. mod0_inv_mod1: 0,
  153. mod1_inv_mod0: 0,
  154. moduli: [0u64; MAX_MODULI],
  155. modulus: 0,
  156. modulus_log2: 0,
  157. noise_width: 0f64,
  158. n: 0,
  159. pt_modulus: 0,
  160. q2_bits: 0,
  161. t_conv: 0,
  162. t_exp_left: 0,
  163. t_exp_right: 0,
  164. t_gsw: 0,
  165. expand_queries: false,
  166. db_dim_1: 0,
  167. db_dim_2: 0,
  168. instances: 0,
  169. db_item_size: 0,
  170. }
  171. }
  172. pub fn params_from_json(cfg: &str) -> Params {
  173. let v: Value = serde_json::from_str(cfg).unwrap();
  174. let n = v["n"].as_u64().unwrap() as usize;
  175. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  176. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  177. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  178. let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
  179. let p = v["p"].as_u64().unwrap();
  180. let q2_bits = v["q2_bits"].as_u64().unwrap();
  181. let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
  182. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  183. let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
  184. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  185. let do_expansion = v.get("direct_upload").is_none();
  186. Params::init(
  187. 2048,
  188. &vec![268369921u64, 249561089u64],
  189. 6.4,
  190. n,
  191. p,
  192. q2_bits,
  193. t_conv,
  194. t_exp_left,
  195. t_exp_right,
  196. t_gsw,
  197. do_expansion,
  198. db_dim_1,
  199. db_dim_2,
  200. instances,
  201. db_item_size,
  202. )
  203. }
  204. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  205. let word_off = bit_offs / 64;
  206. let bit_off_within_word = bit_offs % 64;
  207. if (bit_off_within_word + num_bits) <= 64 {
  208. let idx = word_off * 8;
  209. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  210. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  211. } else {
  212. let idx = word_off * 8;
  213. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  214. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  215. }
  216. }
  217. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  218. let word_off = bit_offs / 64;
  219. let bit_off_within_word = bit_offs % 64;
  220. val = val & ((1u64 << num_bits) - 1);
  221. if (bit_off_within_word + num_bits) <= 64 {
  222. let idx = word_off * 8;
  223. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  224. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  225. cur_val |= val << bit_off_within_word;
  226. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  227. } else {
  228. let idx = word_off * 8;
  229. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  230. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  231. cur_val &= mask;
  232. cur_val |= (val as u128) << bit_off_within_word;
  233. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  234. }
  235. }
  236. pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
  237. let poly_len = params.poly_len;
  238. let crt_count = params.crt_count;
  239. assert_eq!(crt_count, 2);
  240. assert!(log2(params.moduli[0]) <= 32);
  241. let num_reg_expanded = 1 << params.db_dim_1;
  242. let ct_rows = v_reg[0].rows;
  243. let ct_cols = v_reg[0].cols;
  244. assert_eq!(ct_rows, 2);
  245. assert_eq!(ct_cols, 1);
  246. for j in 0..num_reg_expanded {
  247. for r in 0..ct_rows {
  248. for m in 0..ct_cols {
  249. for z in 0..params.poly_len {
  250. let idx_a_in =
  251. r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
  252. let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
  253. + j * (ct_cols * ct_rows)
  254. + m * (ct_rows)
  255. + r;
  256. let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
  257. let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
  258. out[idx_a_out] = val1 | (val2 << 32);
  259. }
  260. }
  261. }
  262. }
  263. }
  264. #[cfg(test)]
  265. mod test {
  266. use super::*;
  267. #[test]
  268. fn params_from_json_correct() {
  269. let cfg = r#"
  270. {'n': 2,
  271. 'nu_1': 9,
  272. 'nu_2': 6,
  273. 'p': 256,
  274. 'q2_bits': 20,
  275. 's_e': 87.62938774292914,
  276. 't_gsw': 8,
  277. 't_conv': 4,
  278. 't_exp_left': 8,
  279. 't_exp_right': 56,
  280. 'instances': 1,
  281. 'db_item_size': 2048 }
  282. "#;
  283. let cfg = cfg.replace("'", "\"");
  284. let b = params_from_json(&cfg);
  285. let c = Params::init(
  286. 2048,
  287. &vec![268369921u64, 249561089u64],
  288. 6.4,
  289. 2,
  290. 256,
  291. 20,
  292. 4,
  293. 8,
  294. 56,
  295. 8,
  296. true,
  297. 9,
  298. 6,
  299. 1,
  300. 2048,
  301. );
  302. assert_eq!(b, c);
  303. }
  304. #[test]
  305. fn test_read_write_arbitrary_bits() {
  306. let len = 4096;
  307. let num_bits = 9;
  308. let mut data = vec![0u8; len];
  309. let scaled_len = len * 8 / num_bits - 64;
  310. let mut bit_offs = 0;
  311. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  312. for i in 0..scaled_len {
  313. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  314. bit_offs += num_bits;
  315. }
  316. bit_offs = 0;
  317. for i in 0..scaled_len {
  318. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  319. assert_eq!(val, get_from(i));
  320. bit_offs += num_bits;
  321. }
  322. }
  323. }