util.rs 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. use crate::{arith::*, params::*, poly::*};
  2. use rand::{prelude::SmallRng, thread_rng, Rng, SeedableRng};
  3. use serde_json::Value;
  4. pub const CFG_20_256: &'static str = r#"
  5. {'n': 2,
  6. 'nu_1': 9,
  7. 'nu_2': 6,
  8. 'p': 256,
  9. 'q2_bits': 20,
  10. 's_e': 87.62938774292914,
  11. 't_gsw': 8,
  12. 't_conv': 4,
  13. 't_exp_left': 8,
  14. 't_exp_right': 56,
  15. 'instances': 1,
  16. 'db_item_size': 8192 }
  17. "#;
  18. pub const CFG_16_100000: &'static str = r#"
  19. {'n': 2,
  20. 'nu_1': 10,
  21. 'nu_2': 6,
  22. 'p': 512,
  23. 'q2_bits': 21,
  24. 's_e': 85.83255142749422,
  25. 't_gsw': 10,
  26. 't_conv': 4,
  27. 't_exp_left': 16,
  28. 't_exp_right': 56,
  29. 'instances': 11,
  30. 'db_item_size': 100000 }
  31. "#;
  32. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  33. let mut idx = 0usize;
  34. let mut prod = 1usize;
  35. for i in (0..indices.len()).rev() {
  36. idx += indices[i] * prod;
  37. prod *= lengths[i];
  38. }
  39. idx
  40. }
  41. pub fn get_test_params() -> Params {
  42. Params::init(
  43. 2048,
  44. &vec![268369921u64, 249561089u64],
  45. 6.4,
  46. 2,
  47. 256,
  48. 20,
  49. 4,
  50. 8,
  51. 56,
  52. 8,
  53. true,
  54. 9,
  55. 6,
  56. 1,
  57. 2048,
  58. )
  59. }
  60. pub fn get_short_keygen_params() -> Params {
  61. Params::init(
  62. 2048,
  63. &vec![268369921u64, 249561089u64],
  64. 6.4,
  65. 2,
  66. 256,
  67. 20,
  68. 4,
  69. 4,
  70. 4,
  71. 4,
  72. true,
  73. 9,
  74. 6,
  75. 1,
  76. 2048,
  77. )
  78. }
  79. pub fn get_expansion_testing_params() -> Params {
  80. let cfg = r#"
  81. {'n': 2,
  82. 'nu_1': 9,
  83. 'nu_2': 6,
  84. 'p': 256,
  85. 'q2_bits': 20,
  86. 't_gsw': 8,
  87. 't_conv': 4,
  88. 't_exp_left': 8,
  89. 't_exp_right': 56,
  90. 'instances': 1,
  91. 'db_item_size': 8192 }
  92. "#;
  93. params_from_json(&cfg.replace("'", "\""))
  94. }
  95. pub fn get_fast_expansion_testing_params() -> Params {
  96. let cfg = r#"
  97. {'n': 2,
  98. 'nu_1': 6,
  99. 'nu_2': 2,
  100. 'p': 256,
  101. 'q2_bits': 20,
  102. 't_gsw': 8,
  103. 't_conv': 4,
  104. 't_exp_left': 8,
  105. 't_exp_right': 8,
  106. 'instances': 1,
  107. 'db_item_size': 8192 }
  108. "#;
  109. params_from_json(&cfg.replace("'", "\""))
  110. }
  111. pub fn get_no_expansion_testing_params() -> Params {
  112. let cfg = r#"
  113. {'direct_upload': 1,
  114. 'n': 5,
  115. 'nu_1': 6,
  116. 'nu_2': 3,
  117. 'p': 65536,
  118. 'q2_bits': 27,
  119. 't_gsw': 3,
  120. 't_conv': 56,
  121. 't_exp_left': 56,
  122. 't_exp_right': 56}
  123. "#;
  124. params_from_json(&cfg.replace("'", "\""))
  125. }
  126. pub fn get_seed() -> u64 {
  127. thread_rng().gen::<u64>()
  128. }
  129. pub fn get_seeded_rng() -> SmallRng {
  130. SmallRng::seed_from_u64(get_seed())
  131. }
  132. pub fn get_static_seed() -> u64 {
  133. 0x123456789
  134. }
  135. pub fn get_static_seeded_rng() -> SmallRng {
  136. SmallRng::seed_from_u64(get_static_seed())
  137. }
  138. pub const fn get_empty_params() -> Params {
  139. Params {
  140. poly_len: 0,
  141. poly_len_log2: 0,
  142. ntt_tables: Vec::new(),
  143. scratch: Vec::new(),
  144. crt_count: 0,
  145. barrett_cr_0_modulus: 0,
  146. barrett_cr_1_modulus: 0,
  147. barrett_cr_0: [0u64; MAX_MODULI],
  148. barrett_cr_1: [0u64; MAX_MODULI],
  149. mod0_inv_mod1: 0,
  150. mod1_inv_mod0: 0,
  151. moduli: [0u64; MAX_MODULI],
  152. modulus: 0,
  153. modulus_log2: 0,
  154. noise_width: 0f64,
  155. n: 0,
  156. pt_modulus: 0,
  157. q2_bits: 0,
  158. t_conv: 0,
  159. t_exp_left: 0,
  160. t_exp_right: 0,
  161. t_gsw: 0,
  162. expand_queries: false,
  163. db_dim_1: 0,
  164. db_dim_2: 0,
  165. instances: 0,
  166. db_item_size: 0,
  167. }
  168. }
  169. pub fn params_from_json(cfg: &str) -> Params {
  170. let v: Value = serde_json::from_str(cfg).unwrap();
  171. let n = v["n"].as_u64().unwrap() as usize;
  172. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  173. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  174. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  175. let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
  176. let p = v["p"].as_u64().unwrap();
  177. let q2_bits = v["q2_bits"].as_u64().unwrap();
  178. let t_gsw = v["t_gsw"].as_u64().unwrap() as usize;
  179. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  180. let t_exp_left = v["t_exp_left"].as_u64().unwrap() as usize;
  181. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  182. let do_expansion = v.get("direct_upload").is_none();
  183. Params::init(
  184. 2048,
  185. &vec![268369921u64, 249561089u64],
  186. 6.4,
  187. n,
  188. p,
  189. q2_bits,
  190. t_conv,
  191. t_exp_left,
  192. t_exp_right,
  193. t_gsw,
  194. do_expansion,
  195. db_dim_1,
  196. db_dim_2,
  197. instances,
  198. db_item_size,
  199. )
  200. }
  201. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  202. let word_off = bit_offs / 64;
  203. let bit_off_within_word = bit_offs % 64;
  204. if (bit_off_within_word + num_bits) <= 64 {
  205. let idx = word_off * 8;
  206. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  207. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  208. } else {
  209. let idx = word_off * 8;
  210. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  211. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  212. }
  213. }
  214. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  215. let word_off = bit_offs / 64;
  216. let bit_off_within_word = bit_offs % 64;
  217. val = val & ((1u64 << num_bits) - 1);
  218. if (bit_off_within_word + num_bits) <= 64 {
  219. let idx = word_off * 8;
  220. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  221. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  222. cur_val |= val << bit_off_within_word;
  223. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  224. } else {
  225. let idx = word_off * 8;
  226. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  227. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  228. cur_val &= mask;
  229. cur_val |= (val as u128) << bit_off_within_word;
  230. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  231. }
  232. }
  233. pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
  234. let poly_len = params.poly_len;
  235. let crt_count = params.crt_count;
  236. assert_eq!(crt_count, 2);
  237. assert!(log2(params.moduli[0]) <= 32);
  238. let num_reg_expanded = 1 << params.db_dim_1;
  239. let ct_rows = v_reg[0].rows;
  240. let ct_cols = v_reg[0].cols;
  241. assert_eq!(ct_rows, 2);
  242. assert_eq!(ct_cols, 1);
  243. for j in 0..num_reg_expanded {
  244. for r in 0..ct_rows {
  245. for m in 0..ct_cols {
  246. for z in 0..params.poly_len {
  247. let idx_a_in =
  248. r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
  249. let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
  250. + j * (ct_cols * ct_rows)
  251. + m * (ct_rows)
  252. + r;
  253. let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
  254. let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
  255. out[idx_a_out] = val1 | (val2 << 32);
  256. }
  257. }
  258. }
  259. }
  260. }
  261. #[cfg(test)]
  262. mod test {
  263. use super::*;
  264. #[test]
  265. fn params_from_json_correct() {
  266. let cfg = r#"
  267. {'n': 2,
  268. 'nu_1': 9,
  269. 'nu_2': 6,
  270. 'p': 256,
  271. 'q2_bits': 20,
  272. 's_e': 87.62938774292914,
  273. 't_gsw': 8,
  274. 't_conv': 4,
  275. 't_exp_left': 8,
  276. 't_exp_right': 56,
  277. 'instances': 1,
  278. 'db_item_size': 2048 }
  279. "#;
  280. let cfg = cfg.replace("'", "\"");
  281. let b = params_from_json(&cfg);
  282. let c = Params::init(
  283. 2048,
  284. &vec![268369921u64, 249561089u64],
  285. 6.4,
  286. 2,
  287. 256,
  288. 20,
  289. 4,
  290. 8,
  291. 56,
  292. 8,
  293. true,
  294. 9,
  295. 6,
  296. 1,
  297. 2048,
  298. );
  299. assert_eq!(b, c);
  300. }
  301. #[test]
  302. fn test_read_write_arbitrary_bits() {
  303. let len = 4096;
  304. let num_bits = 9;
  305. let mut data = vec![0u8; len];
  306. let scaled_len = len * 8 / num_bits - 64;
  307. let mut bit_offs = 0;
  308. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  309. for i in 0..scaled_len {
  310. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  311. bit_offs += num_bits;
  312. }
  313. bit_offs = 0;
  314. for i in 0..scaled_len {
  315. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  316. assert_eq!(val, get_from(i));
  317. bit_offs += num_bits;
  318. }
  319. }
  320. }