util.rs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. use crate::params::*;
  2. use rand::{prelude::StdRng, SeedableRng};
  3. use serde_json::Value;
  4. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  5. let mut idx = 0usize;
  6. let mut prod = 1usize;
  7. for i in (0..indices.len()).rev() {
  8. idx += indices[i] * prod;
  9. prod *= lengths[i];
  10. }
  11. idx
  12. }
  13. pub fn get_test_params() -> Params {
  14. Params::init(
  15. 2048,
  16. &vec![268369921u64, 249561089u64],
  17. 6.4,
  18. 2,
  19. 256,
  20. 20,
  21. 4,
  22. 8,
  23. 56,
  24. 8,
  25. true,
  26. 9,
  27. 6,
  28. 1,
  29. 2048,
  30. )
  31. }
  32. pub fn get_short_keygen_params() -> Params {
  33. Params::init(
  34. 2048,
  35. &vec![268369921u64, 249561089u64],
  36. 6.4,
  37. 2,
  38. 256,
  39. 20,
  40. 4,
  41. 4,
  42. 4,
  43. 4,
  44. true,
  45. 9,
  46. 6,
  47. 1,
  48. 2048,
  49. )
  50. }
  51. pub fn get_seed() -> [u8; 32] {
  52. [
  53. 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
  54. 7, 8,
  55. ]
  56. }
  57. pub fn get_seeded_rng() -> StdRng {
  58. StdRng::from_seed(get_seed())
  59. }
  60. pub const fn get_empty_params() -> Params {
  61. Params {
  62. poly_len: 0,
  63. poly_len_log2: 0,
  64. ntt_tables: Vec::new(),
  65. scratch: Vec::new(),
  66. crt_count: 0,
  67. moduli: Vec::new(),
  68. modulus: 0,
  69. modulus_log2: 0,
  70. noise_width: 0f64,
  71. n: 0,
  72. pt_modulus: 0,
  73. q2_bits: 0,
  74. t_conv: 0,
  75. t_exp_left: 0,
  76. t_exp_right: 0,
  77. t_gsw: 0,
  78. expand_queries: false,
  79. db_dim_1: 0,
  80. db_dim_2: 0,
  81. instances: 0,
  82. db_item_size: 0,
  83. }
  84. }
  85. pub fn params_from_json(cfg: &str) -> Params {
  86. let v: Value = serde_json::from_str(cfg).unwrap();
  87. let n = v["n"].as_u64().unwrap() as usize;
  88. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  89. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  90. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  91. let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
  92. let p = v["p"].as_u64().unwrap();
  93. let q2_bits = v["q_prime_bits"].as_u64().unwrap();
  94. let t_gsw = v["t_GSW"].as_u64().unwrap() as usize;
  95. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  96. let t_exp_left = v["t_exp"].as_u64().unwrap() as usize;
  97. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  98. let do_expansion = v.get("kinda_direct_upload").is_none();
  99. Params::init(
  100. 2048,
  101. &vec![268369921u64, 249561089u64],
  102. 6.4,
  103. n,
  104. p,
  105. q2_bits,
  106. t_conv,
  107. t_exp_left,
  108. t_exp_right,
  109. t_gsw,
  110. do_expansion,
  111. db_dim_1,
  112. db_dim_2,
  113. instances,
  114. db_item_size,
  115. )
  116. }
  117. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  118. let word_off = bit_offs / 64;
  119. let bit_off_within_word = bit_offs % 64;
  120. if (bit_off_within_word + num_bits) <= 64 {
  121. let idx = word_off * 8;
  122. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  123. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  124. } else {
  125. let idx = word_off * 8;
  126. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  127. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  128. }
  129. }
  130. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  131. let word_off = bit_offs / 64;
  132. let bit_off_within_word = bit_offs % 64;
  133. val = val & ((1u64 << num_bits) - 1);
  134. if (bit_off_within_word + num_bits) <= 64 {
  135. let idx = word_off * 8;
  136. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  137. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  138. cur_val |= val << bit_off_within_word;
  139. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  140. } else {
  141. let idx = word_off * 8;
  142. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  143. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  144. cur_val &= mask;
  145. cur_val |= (val as u128) << bit_off_within_word;
  146. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  147. }
  148. }
  149. #[cfg(test)]
  150. mod test {
  151. use super::*;
  152. #[test]
  153. fn params_from_json_correct() {
  154. let cfg = r#"
  155. {'n': 2,
  156. 'nu_1': 9,
  157. 'nu_2': 6,
  158. 'p': 256,
  159. 'q_prime_bits': 20,
  160. 's_e': 87.62938774292914,
  161. 't_GSW': 8,
  162. 't_conv': 4,
  163. 't_exp': 8,
  164. 't_exp_right': 56,
  165. 'instances': 1,
  166. 'db_item_size': 2048 }
  167. "#;
  168. let cfg = cfg.replace("'", "\"");
  169. let b = params_from_json(&cfg);
  170. let c = Params::init(
  171. 2048,
  172. &vec![268369921u64, 249561089u64],
  173. 6.4,
  174. 2,
  175. 256,
  176. 20,
  177. 4,
  178. 8,
  179. 56,
  180. 8,
  181. true,
  182. 9,
  183. 6,
  184. 1,
  185. 2048,
  186. );
  187. assert_eq!(b, c);
  188. }
  189. #[test]
  190. fn test_read_write_arbitrary_bits() {
  191. let len = 4096;
  192. let num_bits = 9;
  193. let mut data = vec![0u8; len];
  194. let scaled_len = len * 8 / num_bits - 64;
  195. let mut bit_offs = 0;
  196. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  197. for i in 0..scaled_len {
  198. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  199. bit_offs += num_bits;
  200. }
  201. bit_offs = 0;
  202. for i in 0..scaled_len {
  203. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  204. assert_eq!(val, get_from(i));
  205. bit_offs += num_bits;
  206. }
  207. }
  208. }