arith.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. use crate::params::*;
  2. use std::mem;
  3. use std::slice;
  4. pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
  5. (((a as u128) * (b as u128)) % (modulus as u128)) as u64
  6. }
  7. pub const fn log2(a: u64) -> u64 {
  8. std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
  9. }
  10. pub fn log2_ceil(a: u64) -> u64 {
  11. f64::ceil(f64::log2(a as f64)) as u64
  12. }
  13. pub fn log2_ceil_usize(a: usize) -> usize {
  14. f64::ceil(f64::log2(a as f64)) as usize
  15. }
  16. pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  17. barrett_coeff_u64(params, a * b, c)
  18. }
  19. pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
  20. barrett_coeff_u64(params, a * b + x, c)
  21. }
  22. pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  23. barrett_coeff_u64(params, a + b, c)
  24. }
  25. pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
  26. params.moduli[c] - a
  27. }
  28. pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
  29. barrett_coeff_u64(params, x, c)
  30. }
  31. pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
  32. if exponent == 0 {
  33. return 1;
  34. }
  35. if exponent == 1 {
  36. return operand;
  37. }
  38. let mut power = operand;
  39. let mut product;
  40. let mut intermediate = 1u64;
  41. loop {
  42. if (exponent % 2) == 1 {
  43. product = multiply_uint_mod(power, intermediate, modulus);
  44. mem::swap(&mut product, &mut intermediate);
  45. }
  46. exponent >>= 1;
  47. if exponent == 0 {
  48. break;
  49. }
  50. product = multiply_uint_mod(power, power, modulus);
  51. mem::swap(&mut product, &mut power);
  52. }
  53. intermediate
  54. }
  55. pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
  56. if bit_count == 0 {
  57. return 0;
  58. }
  59. let r = x.reverse_bits();
  60. r >> (mem::size_of::<u64>() * 8 - bit_count)
  61. }
  62. pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
  63. if operand & 1 == 1 {
  64. let res = operand.overflowing_add(modulus);
  65. if res.1 {
  66. return (res.0 >> 1) | (1u64 << 63);
  67. } else {
  68. return res.0 >> 1;
  69. }
  70. } else {
  71. return operand >> 1;
  72. }
  73. }
  74. pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
  75. assert!(from_modulus >= to_modulus);
  76. let from_modulus_i64 = from_modulus as i64;
  77. let to_modulus_i64 = to_modulus as i64;
  78. let mut a_val = val as i64;
  79. if val >= from_modulus / 2 {
  80. a_val -= from_modulus_i64;
  81. }
  82. a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64;
  83. a_val %= to_modulus_i64;
  84. a_val as u64
  85. }
  86. pub fn get_barrett_crs(modulus: u64) -> (u64, u64) {
  87. let numerator = [0, 0, 1];
  88. let (_, quotient) = divide_uint192_inplace(numerator, modulus);
  89. (quotient[0], quotient[1])
  90. }
  91. pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) {
  92. let mut cr0 = [0u64; MAX_MODULI];
  93. let mut cr1 = [0u64; MAX_MODULI];
  94. for i in 0..moduli.len() {
  95. (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]);
  96. }
  97. (cr0, cr1)
  98. }
  99. pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
  100. let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
  101. // Barrett subtraction
  102. let res = input - tmp * modulus;
  103. // One more subtraction is enough
  104. if res >= modulus {
  105. res - modulus
  106. } else {
  107. res
  108. }
  109. }
  110. pub fn barrett_u64(params: &Params, val: u64) -> u64 {
  111. barrett_raw_u64(val, params.barrett_cr_1_modulus, params.modulus)
  112. }
  113. pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
  114. barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
  115. }
  116. fn split(x: u128) -> (u64, u64) {
  117. let lo = x & ((1u128 << 64) - 1);
  118. let hi = x >> 64;
  119. (lo as u64, hi as u64)
  120. }
  121. fn mul_u128(a: u64, b: u64) -> (u64, u64) {
  122. let prod = (a as u128) * (b as u128);
  123. split(prod)
  124. }
  125. fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 {
  126. match op1.checked_add(op2) {
  127. Some(x) => {
  128. *out = x;
  129. 0
  130. }
  131. None => 1,
  132. }
  133. }
  134. fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 {
  135. let (zx, zy) = split(val);
  136. let mut tmp1 = 0;
  137. let mut tmp3;
  138. let mut carry;
  139. let (_, prody) = mul_u128(zx, cr0);
  140. carry = prody;
  141. let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1);
  142. tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1);
  143. (tmp2x, tmp2y) = mul_u128(zy, cr0);
  144. carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1);
  145. tmp1 = zy * cr1 + tmp3 + carry;
  146. tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus));
  147. tmp3
  148. // uint64_t zx = val & (((__uint128_t)1 << 64) - 1);
  149. // uint64_t zy = val >> 64;
  150. // uint64_t tmp1, tmp3, carry;
  151. // ulonglong2_h prod = umul64wide(zx, const_ratio_0);
  152. // carry = prod.y;
  153. // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1);
  154. // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1);
  155. // tmp2 = umul64wide(zy, const_ratio_0);
  156. // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1);
  157. // tmp1 = zy * const_ratio_1 + tmp3 + carry;
  158. // tmp3 = zx - tmp1 * modulus;
  159. // return tmp3;
  160. }
  161. fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 {
  162. let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus);
  163. reduced_val -= (modulus) * ((reduced_val >= modulus) as u64);
  164. reduced_val
  165. }
  166. pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 {
  167. let modulus = params.modulus;
  168. let cr0 = params.barrett_cr_0_modulus;
  169. let cr1 = params.barrett_cr_1_modulus;
  170. barrett_reduction_u128_raw(modulus, cr0, cr1, val)
  171. }
  172. // Following code is ported from SEAL (github.com/microsoft/SEAL)
  173. pub fn get_significant_bit_count(val: &[u64]) -> usize {
  174. for i in (0..val.len()).rev() {
  175. for j in (0..64).rev() {
  176. if (val[i] & (1u64 << j)) != 0 {
  177. return i * 64 + j + 1;
  178. }
  179. }
  180. }
  181. 0
  182. }
  183. fn divide_round_up(num: usize, denom: usize) -> usize {
  184. (num + (denom - 1)) / denom
  185. }
  186. const BITS_PER_U64: usize = u64::BITS as usize;
  187. fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
  188. let mut result = [0u64; 3];
  189. if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
  190. result[2] = operand[0];
  191. result[1] = 0;
  192. result[0] = 0;
  193. } else if (shift_amount & BITS_PER_U64) != 0 {
  194. result[2] = operand[1];
  195. result[1] = operand[0];
  196. result[0] = 0;
  197. } else {
  198. result[2] = operand[2];
  199. result[1] = operand[1];
  200. result[0] = operand[0];
  201. }
  202. let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
  203. if bit_shift_amount != 0 {
  204. let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
  205. result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount);
  206. result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount);
  207. result[0] = result[0] << bit_shift_amount;
  208. }
  209. result
  210. }
  211. fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
  212. let mut result = [0u64; 3];
  213. if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
  214. result[0] = operand[2];
  215. result[1] = 0;
  216. result[2] = 0;
  217. } else if (shift_amount & BITS_PER_U64) != 0 {
  218. result[0] = operand[1];
  219. result[1] = operand[2];
  220. result[2] = 0;
  221. } else {
  222. result[2] = operand[2];
  223. result[1] = operand[1];
  224. result[0] = operand[0];
  225. }
  226. let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
  227. if bit_shift_amount != 0 {
  228. let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
  229. result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount);
  230. result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount);
  231. result[2] = result[2] >> bit_shift_amount;
  232. }
  233. result
  234. }
  235. fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
  236. *result = operand1.wrapping_add(operand2);
  237. (*result < operand1) as u8
  238. }
  239. fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 {
  240. let operand1 = operand1.wrapping_add(operand2);
  241. *result = operand1.wrapping_add(carry as u64);
  242. ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8
  243. }
  244. fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
  245. *result = operand1.wrapping_sub(operand2);
  246. (operand2 > operand1) as u8
  247. }
  248. fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 {
  249. let diff = operand1.wrapping_sub(operand2);
  250. *result = diff.wrapping_sub((borrow != 0) as u64);
  251. ((diff > operand1) || (diff < (borrow as u64))) as u8
  252. }
  253. pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
  254. let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]);
  255. for i in 0..uint64_count - 1 {
  256. let mut temp_result = 0u64;
  257. borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result);
  258. result[1 + i] = temp_result;
  259. }
  260. borrow
  261. }
  262. pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
  263. let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]);
  264. for i in 0..uint64_count - 1 {
  265. let mut temp_result = 0u64;
  266. carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result);
  267. result[1 + i] = temp_result;
  268. }
  269. carry
  270. }
  271. pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) {
  272. let mut numerator_bits = get_significant_bit_count(&numerator);
  273. let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator));
  274. let mut quotient = [0u64; 3];
  275. if numerator_bits < denominator_bits {
  276. return (numerator, quotient);
  277. }
  278. let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64);
  279. if uint64_count == 1 {
  280. quotient[0] = numerator[0] / denominator;
  281. numerator[0] -= quotient[0] * denominator;
  282. return (numerator, quotient);
  283. }
  284. let mut shifted_denominator = [0u64; 3];
  285. shifted_denominator[0] = denominator;
  286. let mut difference = [0u64; 3];
  287. let denominator_shift = numerator_bits - denominator_bits;
  288. let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift);
  289. denominator_bits += denominator_shift;
  290. let mut remaining_shifts = denominator_shift;
  291. while numerator_bits == denominator_bits {
  292. if (sub_uint(
  293. &numerator,
  294. &shifted_denominator,
  295. uint64_count,
  296. &mut difference,
  297. )) != 0
  298. {
  299. if remaining_shifts == 0 {
  300. break;
  301. }
  302. add_uint(
  303. &difference.clone(),
  304. &numerator,
  305. uint64_count,
  306. &mut difference,
  307. );
  308. quotient = left_shift_uint192(quotient, 1);
  309. remaining_shifts -= 1;
  310. }
  311. quotient[0] |= 1;
  312. numerator_bits = get_significant_bit_count(&difference);
  313. let mut numerator_shift = denominator_bits - numerator_bits;
  314. if numerator_shift > remaining_shifts {
  315. numerator_shift = remaining_shifts;
  316. }
  317. if numerator_bits > 0 {
  318. numerator = left_shift_uint192(difference, numerator_shift);
  319. numerator_bits += numerator_shift;
  320. } else {
  321. for w in 0..uint64_count {
  322. numerator[w] = 0;
  323. }
  324. }
  325. quotient = left_shift_uint192(quotient, numerator_shift);
  326. remaining_shifts -= numerator_shift;
  327. }
  328. if numerator_bits > 0 {
  329. numerator = right_shift_uint192(numerator, denominator_shift);
  330. }
  331. (numerator, quotient)
  332. }
  333. pub fn recenter_mod(val: u64, small_modulus: u64, large_modulus: u64) -> u64 {
  334. assert!(val < small_modulus);
  335. let mut val_i64 = val as i64;
  336. let small_modulus_i64 = small_modulus as i64;
  337. let large_modulus_i64 = large_modulus as i64;
  338. if val_i64 > small_modulus_i64 / 2 {
  339. val_i64 -= small_modulus_i64;
  340. }
  341. if val_i64 < 0 {
  342. val_i64 += large_modulus_i64;
  343. }
  344. val_i64 as u64
  345. }
  346. pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 {
  347. let inp_mod_i64 = inp_mod as i64;
  348. let out_mod_i128 = out_mod as i128;
  349. let mut inp_val = (a % inp_mod) as i64;
  350. if inp_val >= (inp_mod_i64 / 2) {
  351. inp_val -= inp_mod_i64;
  352. }
  353. let sign: i64 = if inp_val >= 0 { 1 } else { -1 };
  354. let val = (inp_val as i128) * (out_mod as i128);
  355. let mut result = (val + (sign * (inp_mod_i64 / 2)) as i128) / (inp_mod as i128);
  356. result = (result + ((inp_mod / out_mod) * out_mod) as i128 + (2 * out_mod_i128)) % out_mod_i128;
  357. assert!(result >= 0);
  358. ((result + out_mod_i128) % out_mod_i128) as u64
  359. }
  360. #[cfg(test)]
  361. mod test {
  362. use super::*;
  363. use crate::util::get_seeded_rng;
  364. use rand::Rng;
  365. fn combine(lo: u64, hi: u64) -> u128 {
  366. (lo as u128) & ((hi as u128) << 64)
  367. }
  368. #[test]
  369. fn div2_uint_mod_correct() {
  370. assert_eq!(div2_uint_mod(3, 7), 5);
  371. }
  372. #[test]
  373. fn divide_uint192_inplace_correct() {
  374. assert_eq!(
  375. divide_uint192_inplace([35, 0, 0], 7),
  376. ([0, 0, 0], [5, 0, 0])
  377. );
  378. assert_eq!(
  379. divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000),
  380. (
  381. [0x10, 0, 0],
  382. [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F]
  383. )
  384. );
  385. }
  386. #[test]
  387. fn get_barrett_crs_correct() {
  388. assert_eq!(
  389. get_barrett_crs(268369921u64),
  390. (16144578669088582089u64, 68736257792u64)
  391. );
  392. assert_eq!(
  393. get_barrett_crs(249561089u64),
  394. (10966983149909726427u64, 73916747789u64)
  395. );
  396. assert_eq!(
  397. get_barrett_crs(66974689739603969u64),
  398. (7906011006380390721u64, 275u64)
  399. );
  400. }
  401. #[test]
  402. fn barrett_reduction_u128_raw_correct() {
  403. let modulus = 66974689739603969u64;
  404. let modulus_u128 = modulus as u128;
  405. let exec = |val| {
  406. barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val)
  407. };
  408. assert_eq!(exec(modulus_u128), 0);
  409. assert_eq!(exec(modulus_u128 + 1), 1);
  410. assert_eq!(exec(modulus_u128 * 7 + 5), 5);
  411. let mut rng = get_seeded_rng();
  412. for _ in 0..100 {
  413. let val = combine(rng.gen(), rng.gen());
  414. assert_eq!(exec(val), (val % modulus_u128) as u64);
  415. }
  416. }
  417. #[test]
  418. fn barrett_raw_u64_correct() {
  419. let modulus = 66974689739603969u64;
  420. let cr1 = 275u64;
  421. let mut rng = get_seeded_rng();
  422. for _ in 0..100 {
  423. let val = rng.gen();
  424. assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus);
  425. }
  426. }
  427. }