arith.rs 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. use crate::params::*;
  2. use std::mem;
  3. pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
  4. (((a as u128) * (b as u128)) % (modulus as u128)) as u64
  5. }
  6. pub const fn log2(a: u64) -> u64 {
  7. std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
  8. }
  9. pub fn log2_ceil(a: u64) -> u64 {
  10. f64::ceil(f64::log2(a as f64)) as u64
  11. }
  12. pub fn log2_ceil_usize(a: usize) -> usize {
  13. f64::ceil(f64::log2(a as f64)) as usize
  14. }
  15. pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  16. (a * b) % params.moduli[c]
  17. }
  18. pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
  19. (a * b + x) % params.moduli[c]
  20. }
  21. pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  22. (a + b) % params.moduli[c]
  23. }
  24. pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
  25. params.moduli[c] - a
  26. }
  27. pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
  28. (x) % params.moduli[c]
  29. }
  30. pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
  31. if exponent == 0 {
  32. return 1;
  33. }
  34. if exponent == 1 {
  35. return operand;
  36. }
  37. let mut power = operand;
  38. let mut product;
  39. let mut intermediate = 1u64;
  40. loop {
  41. if (exponent % 2) == 1 {
  42. product = multiply_uint_mod(power, intermediate, modulus);
  43. mem::swap(&mut product, &mut intermediate);
  44. }
  45. exponent >>= 1;
  46. if exponent == 0 {
  47. break;
  48. }
  49. product = multiply_uint_mod(power, power, modulus);
  50. mem::swap(&mut product, &mut power);
  51. }
  52. intermediate
  53. }
  54. pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
  55. if bit_count == 0 {
  56. return 0;
  57. }
  58. let r = x.reverse_bits();
  59. r >> (mem::size_of::<u64>() * 8 - bit_count)
  60. }
  61. pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
  62. if operand & 1 == 1 {
  63. let res = operand.overflowing_add(modulus);
  64. if res.1 {
  65. return (res.0 >> 1) | (1u64 << 63);
  66. } else {
  67. return res.0 >> 1;
  68. }
  69. } else {
  70. return operand >> 1;
  71. }
  72. }
  73. pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
  74. assert!(from_modulus >= to_modulus);
  75. let from_modulus_i64 = from_modulus as i64;
  76. let to_modulus_i64 = to_modulus as i64;
  77. let mut a_val = val as i64;
  78. if val >= from_modulus / 2 {
  79. a_val -= from_modulus_i64;
  80. }
  81. a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64;
  82. a_val %= to_modulus_i64;
  83. a_val as u64
  84. }
  85. #[cfg(test)]
  86. mod test {
  87. use super::*;
  88. #[test]
  89. fn div2_uint_mod_correct() {
  90. assert_eq!(div2_uint_mod(3, 7), 5);
  91. }
  92. }