util.rs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. use crate::params::*;
  2. use serde_json::{Value};
  3. pub fn calc_index(indices: &[usize], lengths: &[usize]) -> usize {
  4. let mut idx = 0usize;
  5. let mut prod = 1usize;
  6. for i in (0..indices.len()).rev() {
  7. idx += indices[i] * prod;
  8. prod *= lengths[i];
  9. }
  10. idx
  11. }
  12. pub fn get_test_params() -> Params {
  13. Params::init(
  14. 2048,
  15. &vec![268369921u64, 249561089u64],
  16. 6.4,
  17. 2,
  18. 256,
  19. 20,
  20. 4,
  21. 8,
  22. 56,
  23. 8,
  24. true,
  25. 9,
  26. 6,
  27. 1,
  28. 2048
  29. )
  30. }
  31. pub const fn get_empty_params() -> Params {
  32. Params {
  33. poly_len: 0,
  34. poly_len_log2: 0,
  35. ntt_tables: Vec::new(),
  36. scratch: Vec::new(),
  37. crt_count: 0,
  38. moduli: Vec::new(),
  39. modulus: 0,
  40. modulus_log2: 0,
  41. noise_width: 0f64,
  42. n: 0,
  43. pt_modulus: 0,
  44. q2_bits: 0,
  45. t_conv: 0,
  46. t_exp_left: 0,
  47. t_exp_right: 0,
  48. t_gsw: 0,
  49. expand_queries: false,
  50. db_dim_1: 0,
  51. db_dim_2: 0,
  52. instances: 0,
  53. db_item_size: 0,
  54. }
  55. }
  56. pub fn params_from_json(cfg: &str) -> Params {
  57. let v: Value = serde_json::from_str(cfg).unwrap();
  58. let n = v["n"].as_u64().unwrap() as usize;
  59. let db_dim_1 = v["nu_1"].as_u64().unwrap() as usize;
  60. let db_dim_2 = v["nu_2"].as_u64().unwrap() as usize;
  61. let instances = v["instances"].as_u64().unwrap_or(1) as usize;
  62. let db_item_size = v["db_item_size"].as_u64().unwrap_or(1) as usize;
  63. let p = v["p"].as_u64().unwrap();
  64. let q2_bits = v["q_prime_bits"].as_u64().unwrap();
  65. let t_gsw = v["t_GSW"].as_u64().unwrap() as usize;
  66. let t_conv = v["t_conv"].as_u64().unwrap() as usize;
  67. let t_exp_left = v["t_exp"].as_u64().unwrap() as usize;
  68. let t_exp_right = v["t_exp_right"].as_u64().unwrap() as usize;
  69. let do_expansion = v.get("kinda_direct_upload").is_none();
  70. Params::init(
  71. 2048,
  72. &vec![268369921u64, 249561089u64],
  73. 6.4,
  74. n,
  75. p,
  76. q2_bits,
  77. t_conv,
  78. t_exp_left,
  79. t_exp_right,
  80. t_gsw,
  81. do_expansion,
  82. db_dim_1,
  83. db_dim_2,
  84. instances,
  85. db_item_size,
  86. )
  87. }
  88. pub fn read_arbitrary_bits(data: &[u8], bit_offs: usize, num_bits: usize) -> u64 {
  89. let word_off = bit_offs / 64;
  90. let bit_off_within_word = bit_offs % 64;
  91. if (bit_off_within_word + num_bits) <= 64 {
  92. let idx = word_off * 8;
  93. let val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  94. (val >> bit_off_within_word) & ((1u64 << num_bits) - 1)
  95. } else {
  96. let idx = word_off * 8;
  97. let val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  98. ((val >> bit_off_within_word) & ((1u128 << num_bits) - 1)) as u64
  99. }
  100. }
  101. pub fn write_arbitrary_bits(data: &mut [u8], mut val: u64, bit_offs: usize, num_bits: usize) {
  102. let word_off = bit_offs / 64;
  103. let bit_off_within_word = bit_offs % 64;
  104. val = val & ((1u64 << num_bits) - 1);
  105. if (bit_off_within_word + num_bits) <= 64 {
  106. let idx = word_off * 8;
  107. let mut cur_val = u64::from_ne_bytes(data[idx..idx + 8].try_into().unwrap());
  108. cur_val &= !(((1u64 << num_bits) - 1) << bit_off_within_word);
  109. cur_val |= val << bit_off_within_word;
  110. data[idx..idx + 8].copy_from_slice(&u64::to_ne_bytes(cur_val));
  111. } else {
  112. let idx = word_off * 8;
  113. let mut cur_val = u128::from_ne_bytes(data[idx..idx + 16].try_into().unwrap());
  114. let mask = !(((1u128 << num_bits) - 1) << bit_off_within_word);
  115. cur_val &= mask;
  116. cur_val |= (val as u128) << bit_off_within_word;
  117. data[idx..idx + 16].copy_from_slice(&u128::to_ne_bytes(cur_val));
  118. }
  119. }
  120. #[cfg(test)]
  121. mod test {
  122. use super::*;
  123. #[test]
  124. fn params_from_json_correct() {
  125. let cfg = r#"
  126. {'n': 2,
  127. 'nu_1': 9,
  128. 'nu_2': 6,
  129. 'p': 256,
  130. 'q_prime_bits': 20,
  131. 's_e': 87.62938774292914,
  132. 't_GSW': 8,
  133. 't_conv': 4,
  134. 't_exp': 8,
  135. 't_exp_right': 56,
  136. 'instances': 1,
  137. 'db_item_size': 2048 }
  138. "#;
  139. let cfg = cfg.replace("'", "\"");
  140. let b = params_from_json(&cfg);
  141. let c = Params::init(
  142. 2048,
  143. &vec![268369921u64, 249561089u64],
  144. 6.4,
  145. 2,
  146. 256,
  147. 20,
  148. 4,
  149. 8,
  150. 56,
  151. 8,
  152. true,
  153. 9,
  154. 6,
  155. 1,
  156. 2048
  157. );
  158. assert_eq!(b, c);
  159. }
  160. #[test]
  161. fn test_read_write_arbitrary_bits() {
  162. let len = 4096;
  163. let num_bits = 9;
  164. let mut data = vec![0u8; len];
  165. let scaled_len = len * 8 / num_bits - 64;
  166. let mut bit_offs = 0;
  167. let get_from = |i: usize| -> u64 { ((i * 7 + 13) % (1 << num_bits)) as u64 };
  168. for i in 0..scaled_len {
  169. write_arbitrary_bits(data.as_mut_slice(), get_from(i), bit_offs, num_bits);
  170. bit_offs += num_bits;
  171. }
  172. bit_offs = 0;
  173. for i in 0..scaled_len {
  174. let val = read_arbitrary_bits(data.as_slice(), bit_offs, num_bits);
  175. assert_eq!(val, get_from(i));
  176. bit_offs += num_bits;
  177. }
  178. }
  179. }