arith.rs 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. use crate::params::*;
  2. use std::mem;
  3. pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
  4. (((a as u128) * (b as u128)) % (modulus as u128)) as u64
  5. }
  6. pub const fn log2(a: u64) -> u64 {
  7. std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
  8. }
  9. pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  10. (a * b) % params.moduli[c]
  11. }
  12. pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
  13. (a * b + x) % params.moduli[c]
  14. }
  15. fn swap(a: u64, b: u64) -> (u64, u64) {
  16. (b, a)
  17. }
  18. pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
  19. if exponent == 0 {
  20. return 1;
  21. }
  22. if exponent == 1 {
  23. return operand;
  24. }
  25. let mut power = operand;
  26. let mut product;
  27. let mut intermediate = 1u64;
  28. loop {
  29. if (exponent % 2) == 1 {
  30. product = multiply_uint_mod(power, intermediate, modulus);
  31. mem::swap(&mut product, &mut intermediate);
  32. }
  33. exponent >>= 1;
  34. if exponent == 0 {
  35. break;
  36. }
  37. product = multiply_uint_mod(power, power, modulus);
  38. mem::swap(&mut product, &mut power);
  39. }
  40. intermediate
  41. }
  42. pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
  43. if bit_count == 0 {
  44. return 0;
  45. }
  46. let r = x.reverse_bits();
  47. r >> (mem::size_of::<u64>() * 8 - bit_count)
  48. }
  49. pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
  50. if operand & 1 == 1 {
  51. let res = operand.overflowing_add(modulus);
  52. if res.1 {
  53. return (res.0 >> 1) | (1u64 << 63);
  54. } else {
  55. return res.0 >> 1;
  56. }
  57. } else {
  58. return operand >> 1;
  59. }
  60. }
  61. #[cfg(test)]
  62. mod test {
  63. use super::*;
  64. #[test]
  65. fn div2_uint_mod_correct() {
  66. assert_eq!(div2_uint_mod(3, 7), 5);
  67. }
  68. }