util.rs 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. use crate::{arith::*, params::*, poly::*};
  2. use rand::{prelude::StdRng, SeedableRng, thread_rng, Rng};
  3. use serde_json::Value;
  4. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  5. let mut idx = 0usize;
  6. let mut prod = 1usize;
  7. for i in (0..indices.len()).rev() {
  8. idx += indices[i] * prod;
  9. prod *= lengths[i];
  10. }
  11. idx
  12. }
  13. pub fn get_test_params() -> Params {
  14. Params::init(
  15. 2048,
  16. &vec![268369921u64, 249561089u64],
  17. 6.4,
  18. 2,
  19. 256,
  20. 20,
  21. 4,
  22. 8,
  23. 56,
  24. 8,
  25. true,
  26. 9,
  27. 6,
  28. 1,
  29. 2048,
  30. )
  31. }
  32. pub fn get_short_keygen_params() -> Params {
  33. Params::init(
  34. 2048,
  35. &vec![268369921u64, 249561089u64],
  36. 6.4,
  37. 2,
  38. 256,
  39. 20,
  40. 4,
  41. 4,
  42. 4,
  43. 4,
  44. true,
  45. 9,
  46. 6,
  47. 1,
  48. 2048,
  49. )
  50. }
  51. pub fn get_expansion_testing_params() -> Params {
  52. let cfg = r#"
  53. {'n': 2,
  54. 'nu_1': 9,
  55. 'nu_2': 6,
  56. 'p': 256,
  57. 'q_prime_bits': 20,
  58. 's_e': 87.62938774292914,
  59. 't_GSW': 8,
  60. 't_conv': 4,
  61. 't_exp': 8,
  62. 't_exp_right': 56,
  63. 'instances': 1,
  64. 'db_item_size': 8192 }
  65. "#;
  66. let cfg = cfg.replace("'", "\"");
  67. let b = params_from_json(&cfg);
  68. b
  69. }
  70. pub fn get_seed() -> [u8; 32] {
  71. thread_rng().gen::<[u8; 32]>()
  72. }
  73. pub fn get_seeded_rng() -> StdRng {
  74. StdRng::from_seed(get_seed())
  75. }
  76. pub fn get_static_seed() -> [u8; 32] {
  77. [
  78. 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
  79. 7, 8,
  80. ]
  81. }
  82. pub fn get_static_seeded_rng() -> StdRng {
  83. StdRng::from_seed(get_static_seed())
  84. }
  85. pub const fn get_empty_params() -> Params {
  86. Params {
  87. poly_len: 0,
  88. poly_len_log2: 0,
  89. ntt_tables: Vec::new(),
  90. scratch: Vec::new(),
  91. crt_count: 0,
  92. barrett_cr_0_modulus: 0,
  93. barrett_cr_1_modulus: 0,
  94. barrett_cr_0: [0u64; MAX_MODULI],
  95. barrett_cr_1: [0u64; MAX_MODULI],
  96. mod0_inv_mod1: 0,
  97. mod1_inv_mod0: 0,
  98. moduli: [0u64; MAX_MODULI],
  99. modulus: 0,
  100. modulus_log2: 0,
  101. noise_width: 0f64,
  102. n: 0,
  103. pt_modulus: 0,
  104. q2_bits: 0,
  105. t_conv: 0,
  106. t_exp_left: 0,
  107. t_exp_right: 0,
  108. t_gsw: 0,
  109. expand_queries: false,
  110. db_dim_1: 0,
  111. db_dim_2: 0,
  112. instances: 0,
  113. db_item_size: 0,
  114. }
  115. }
  116. pub fn params_from_json(cfg: &str) -> Params {
  117. let v: Value = serde_json::from_str(cfg).unwrap();
  118. let n = v["n"].as_u64().unwrap() as usize;
  119. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  120. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  121. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  122. let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
  123. let p = v["p"].as_u64().unwrap();
  124. let q2_bits = v["q_prime_bits"].as_u64().unwrap();
  125. let t_gsw = v["t_GSW"].as_u64().unwrap() as usize;
  126. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  127. let t_exp_left = v["t_exp"].as_u64().unwrap() as usize;
  128. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  129. let do_expansion = v.get("kinda_direct_upload").is_none();
  130. Params::init(
  131. 2048,
  132. &vec![268369921u64, 249561089u64],
  133. 6.4,
  134. n,
  135. p,
  136. q2_bits,
  137. t_conv,
  138. t_exp_left,
  139. t_exp_right,
  140. t_gsw,
  141. do_expansion,
  142. db_dim_1,
  143. db_dim_2,
  144. instances,
  145. db_item_size,
  146. )
  147. }
  148. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  149. let word_off = bit_offs / 64;
  150. let bit_off_within_word = bit_offs % 64;
  151. if (bit_off_within_word + num_bits) <= 64 {
  152. let idx = word_off * 8;
  153. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  154. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  155. } else {
  156. let idx = word_off * 8;
  157. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  158. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  159. }
  160. }
  161. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  162. let word_off = bit_offs / 64;
  163. let bit_off_within_word = bit_offs % 64;
  164. val = val & ((1u64 << num_bits) - 1);
  165. if (bit_off_within_word + num_bits) <= 64 {
  166. let idx = word_off * 8;
  167. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  168. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  169. cur_val |= val << bit_off_within_word;
  170. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  171. } else {
  172. let idx = word_off * 8;
  173. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  174. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  175. cur_val &= mask;
  176. cur_val |= (val as u128) << bit_off_within_word;
  177. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  178. }
  179. }
  180. pub fn reorient_reg_ciphertexts(params: &Params, out: &mut [u64], v_reg: &Vec<PolyMatrixNTT>) {
  181. let poly_len = params.poly_len;
  182. let crt_count = params.crt_count;
  183. assert_eq!(crt_count, 2);
  184. assert!(log2(params.moduli[0]) <= 32);
  185. let num_reg_expanded = 1 << params.db_dim_1;
  186. let ct_rows = v_reg[0].rows;
  187. let ct_cols = v_reg[0].cols;
  188. assert_eq!(ct_rows, 2);
  189. assert_eq!(ct_cols, 1);
  190. for j in 0..num_reg_expanded {
  191. for r in 0..ct_rows {
  192. for m in 0..ct_cols {
  193. for z in 0..params.poly_len {
  194. let idx_a_in =
  195. r * (ct_cols * crt_count * poly_len) + m * (crt_count * poly_len);
  196. let idx_a_out = z * (num_reg_expanded * ct_cols * ct_rows)
  197. + j * (ct_cols * ct_rows)
  198. + m * (ct_rows)
  199. + r;
  200. let val1 = v_reg[j].data[idx_a_in + z] % params.moduli[0];
  201. let val2 = v_reg[j].data[idx_a_in + params.poly_len + z] % params.moduli[1];
  202. out[idx_a_out] = val1 | (val2 << 32);
  203. }
  204. }
  205. }
  206. }
  207. }
  208. #[cfg(test)]
  209. mod test {
  210. use super::*;
  211. #[test]
  212. fn params_from_json_correct() {
  213. let cfg = r#"
  214. {'n': 2,
  215. 'nu_1': 9,
  216. 'nu_2': 6,
  217. 'p': 256,
  218. 'q_prime_bits': 20,
  219. 's_e': 87.62938774292914,
  220. 't_GSW': 8,
  221. 't_conv': 4,
  222. 't_exp': 8,
  223. 't_exp_right': 56,
  224. 'instances': 1,
  225. 'db_item_size': 2048 }
  226. "#;
  227. let cfg = cfg.replace("'", "\"");
  228. let b = params_from_json(&cfg);
  229. let c = Params::init(
  230. 2048,
  231. &vec![268369921u64, 249561089u64],
  232. 6.4,
  233. 2,
  234. 256,
  235. 20,
  236. 4,
  237. 8,
  238. 56,
  239. 8,
  240. true,
  241. 9,
  242. 6,
  243. 1,
  244. 2048,
  245. );
  246. assert_eq!(b, c);
  247. }
  248. #[test]
  249. fn test_read_write_arbitrary_bits() {
  250. let len = 4096;
  251. let num_bits = 9;
  252. let mut data = vec![0u8; len];
  253. let scaled_len = len * 8 / num_bits - 64;
  254. let mut bit_offs = 0;
  255. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  256. for i in 0..scaled_len {
  257. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  258. bit_offs += num_bits;
  259. }
  260. bit_offs = 0;
  261. for i in 0..scaled_len {
  262. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  263. assert_eq!(val, get_from(i));
  264. bit_offs += num_bits;
  265. }
  266. }
  267. }