arith.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. use crate::params::*;
  2. use std::mem;
  3. use std::slice;
  4. pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
  5. (((a as u128) * (b as u128)) % (modulus as u128)) as u64
  6. }
  7. pub const fn log2(a: u64) -> u64 {
  8. std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
  9. }
  10. pub fn log2_ceil(a: u64) -> u64 {
  11. f64::ceil(f64::log2(a as f64)) as u64
  12. }
  13. pub fn log2_ceil_usize(a: usize) -> usize {
  14. f64::ceil(f64::log2(a as f64)) as usize
  15. }
  16. pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  17. barrett_coeff_u64(params, a * b, c)
  18. }
  19. pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
  20. barrett_coeff_u64(params, a * b + x, c)
  21. }
  22. pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
  23. barrett_coeff_u64(params, a + b, c)
  24. }
  25. pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
  26. params.moduli[c] - a
  27. }
  28. pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
  29. barrett_coeff_u64(params, x, c)
  30. }
  31. pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
  32. if exponent == 0 {
  33. return 1;
  34. }
  35. if exponent == 1 {
  36. return operand;
  37. }
  38. let mut power = operand;
  39. let mut product;
  40. let mut intermediate = 1u64;
  41. loop {
  42. if (exponent % 2) == 1 {
  43. product = multiply_uint_mod(power, intermediate, modulus);
  44. mem::swap(&mut product, &mut intermediate);
  45. }
  46. exponent >>= 1;
  47. if exponent == 0 {
  48. break;
  49. }
  50. product = multiply_uint_mod(power, power, modulus);
  51. mem::swap(&mut product, &mut power);
  52. }
  53. intermediate
  54. }
  55. pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
  56. if bit_count == 0 {
  57. return 0;
  58. }
  59. let r = x.reverse_bits();
  60. r >> (mem::size_of::<u64>() * 8 - bit_count)
  61. }
  62. pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
  63. if operand & 1 == 1 {
  64. let res = operand.overflowing_add(modulus);
  65. if res.1 {
  66. return (res.0 >> 1) | (1u64 << 63);
  67. } else {
  68. return res.0 >> 1;
  69. }
  70. } else {
  71. return operand >> 1;
  72. }
  73. }
  74. pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
  75. assert!(from_modulus >= to_modulus);
  76. let from_modulus_i64 = from_modulus as i64;
  77. let to_modulus_i64 = to_modulus as i64;
  78. let mut a_val = val as i64;
  79. if val >= from_modulus / 2 {
  80. a_val -= from_modulus_i64;
  81. }
  82. a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64;
  83. a_val %= to_modulus_i64;
  84. a_val as u64
  85. }
  86. pub fn get_barrett_crs(modulus: u64) -> (u64, u64) {
  87. let numerator = [0, 0, 1];
  88. let (_, quotient) = divide_uint192_inplace(numerator, modulus);
  89. (quotient[0], quotient[1])
  90. }
  91. pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) {
  92. let mut cr0 = [0u64; MAX_MODULI];
  93. let mut cr1 = [0u64; MAX_MODULI];
  94. for i in 0..moduli.len() {
  95. (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]);
  96. }
  97. (cr0, cr1)
  98. }
  99. pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
  100. let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
  101. // Barrett subtraction
  102. let res = input - tmp * modulus;
  103. // One more subtraction is enough
  104. if res >= modulus {
  105. res - modulus
  106. } else {
  107. res
  108. }
  109. }
  110. pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
  111. barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
  112. }
  113. fn split(x: u128) -> (u64, u64) {
  114. let lo = x & ((1u128 << 64) - 1);
  115. let hi = x >> 64;
  116. (lo as u64, hi as u64)
  117. }
  118. fn mul_u128(a: u64, b: u64) -> (u64, u64) {
  119. let prod = (a as u128) * (b as u128);
  120. split(prod)
  121. }
  122. fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 {
  123. match op1.checked_add(op2) {
  124. Some(x) => {
  125. *out = x;
  126. 0
  127. }
  128. None => 1,
  129. }
  130. }
  131. fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 {
  132. let (zx, zy) = split(val);
  133. let mut tmp1 = 0;
  134. let mut tmp3;
  135. let mut carry;
  136. let (_, prody) = mul_u128(zx, cr0);
  137. carry = prody;
  138. let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1);
  139. tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1);
  140. (tmp2x, tmp2y) = mul_u128(zy, cr0);
  141. carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1);
  142. tmp1 = zy * cr1 + tmp3 + carry;
  143. tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus));
  144. tmp3
  145. // uint64_t zx = val & (((__uint128_t)1 << 64) - 1);
  146. // uint64_t zy = val >> 64;
  147. // uint64_t tmp1, tmp3, carry;
  148. // ulonglong2_h prod = umul64wide(zx, const_ratio_0);
  149. // carry = prod.y;
  150. // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1);
  151. // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1);
  152. // tmp2 = umul64wide(zy, const_ratio_0);
  153. // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1);
  154. // tmp1 = zy * const_ratio_1 + tmp3 + carry;
  155. // tmp3 = zx - tmp1 * modulus;
  156. // return tmp3;
  157. }
  158. fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 {
  159. let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus);
  160. reduced_val -= (modulus) * ((reduced_val >= modulus) as u64);
  161. reduced_val
  162. }
  163. pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 {
  164. let modulus = params.modulus;
  165. let cr0 = params.barrett_cr_0_modulus;
  166. let cr1 = params.barrett_cr_1_modulus;
  167. barrett_reduction_u128_raw(modulus, cr0, cr1, val)
  168. }
  169. // Following code is ported from SEAL (github.com/microsoft/SEAL)
  170. pub fn get_significant_bit_count(val: &[u64]) -> usize {
  171. for i in (0..val.len()).rev() {
  172. for j in (0..64).rev() {
  173. if (val[i] & (1u64 << j)) != 0 {
  174. return i * 64 + j + 1;
  175. }
  176. }
  177. }
  178. 0
  179. }
  180. fn divide_round_up(num: usize, denom: usize) -> usize {
  181. (num + (denom - 1)) / denom
  182. }
  183. const BITS_PER_U64: usize = u64::BITS as usize;
  184. fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
  185. let mut result = [0u64; 3];
  186. if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
  187. result[2] = operand[0];
  188. result[1] = 0;
  189. result[0] = 0;
  190. } else if (shift_amount & BITS_PER_U64) != 0 {
  191. result[2] = operand[1];
  192. result[1] = operand[0];
  193. result[0] = 0;
  194. } else {
  195. result[2] = operand[2];
  196. result[1] = operand[1];
  197. result[0] = operand[0];
  198. }
  199. let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
  200. if bit_shift_amount != 0 {
  201. let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
  202. result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount);
  203. result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount);
  204. result[0] = result[0] << bit_shift_amount;
  205. }
  206. result
  207. }
  208. fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
  209. let mut result = [0u64; 3];
  210. if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
  211. result[0] = operand[2];
  212. result[1] = 0;
  213. result[2] = 0;
  214. } else if (shift_amount & BITS_PER_U64) != 0 {
  215. result[0] = operand[1];
  216. result[1] = operand[2];
  217. result[2] = 0;
  218. } else {
  219. result[2] = operand[2];
  220. result[1] = operand[1];
  221. result[0] = operand[0];
  222. }
  223. let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
  224. if bit_shift_amount != 0 {
  225. let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
  226. result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount);
  227. result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount);
  228. result[2] = result[2] >> bit_shift_amount;
  229. }
  230. result
  231. }
  232. fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
  233. *result = operand1.wrapping_add(operand2);
  234. (*result < operand1) as u8
  235. }
  236. fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 {
  237. let operand1 = operand1.wrapping_add(operand2);
  238. *result = operand1.wrapping_add(carry as u64);
  239. ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8
  240. }
  241. fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
  242. *result = operand1.wrapping_sub(operand2);
  243. (operand2 > operand1) as u8
  244. }
  245. fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 {
  246. let diff = operand1.wrapping_sub(operand2);
  247. *result = diff.wrapping_sub((borrow != 0) as u64);
  248. ((diff > operand1) || (diff < (borrow as u64))) as u8
  249. }
  250. pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
  251. let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]);
  252. for i in 0..uint64_count - 1 {
  253. let mut temp_result = 0u64;
  254. borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result);
  255. result[1 + i] = temp_result;
  256. }
  257. borrow
  258. }
  259. pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
  260. let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]);
  261. for i in 0..uint64_count - 1 {
  262. let mut temp_result = 0u64;
  263. carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result);
  264. result[1 + i] = temp_result;
  265. }
  266. carry
  267. }
  268. pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) {
  269. let mut numerator_bits = get_significant_bit_count(&numerator);
  270. let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator));
  271. let mut quotient = [0u64; 3];
  272. if numerator_bits < denominator_bits {
  273. return (numerator, quotient);
  274. }
  275. let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64);
  276. if uint64_count == 1 {
  277. quotient[0] = numerator[0] / denominator;
  278. numerator[0] -= quotient[0] * denominator;
  279. return (numerator, quotient);
  280. }
  281. let mut shifted_denominator = [0u64; 3];
  282. shifted_denominator[0] = denominator;
  283. let mut difference = [0u64; 3];
  284. let denominator_shift = numerator_bits - denominator_bits;
  285. let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift);
  286. denominator_bits += denominator_shift;
  287. let mut remaining_shifts = denominator_shift;
  288. while numerator_bits == denominator_bits {
  289. if (sub_uint(
  290. &numerator,
  291. &shifted_denominator,
  292. uint64_count,
  293. &mut difference,
  294. )) != 0
  295. {
  296. if remaining_shifts == 0 {
  297. break;
  298. }
  299. add_uint(
  300. &difference.clone(),
  301. &numerator,
  302. uint64_count,
  303. &mut difference,
  304. );
  305. quotient = left_shift_uint192(quotient, 1);
  306. remaining_shifts -= 1;
  307. }
  308. quotient[0] |= 1;
  309. numerator_bits = get_significant_bit_count(&difference);
  310. let mut numerator_shift = denominator_bits - numerator_bits;
  311. if numerator_shift > remaining_shifts {
  312. numerator_shift = remaining_shifts;
  313. }
  314. if numerator_bits > 0 {
  315. numerator = left_shift_uint192(difference, numerator_shift);
  316. numerator_bits += numerator_shift;
  317. } else {
  318. for w in 0..uint64_count {
  319. numerator[w] = 0;
  320. }
  321. }
  322. quotient = left_shift_uint192(quotient, numerator_shift);
  323. remaining_shifts -= numerator_shift;
  324. }
  325. if numerator_bits > 0 {
  326. numerator = right_shift_uint192(numerator, denominator_shift);
  327. }
  328. (numerator, quotient)
  329. }
  330. pub fn recenter_mod(val: u64, small_modulus: u64, large_modulus: u64) -> u64 {
  331. assert!(val < small_modulus);
  332. let mut val_i64 = val as i64;
  333. let small_modulus_i64 = small_modulus as i64;
  334. let large_modulus_i64 = large_modulus as i64;
  335. if val_i64 > small_modulus_i64 / 2 {
  336. val_i64 -= small_modulus_i64;
  337. }
  338. if val_i64 < 0 {
  339. val_i64 += large_modulus_i64;
  340. }
  341. val_i64 as u64
  342. }
  343. pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 {
  344. let inp_mod_i64 = inp_mod as i64;
  345. let out_mod_i128 = out_mod as i128;
  346. let mut inp_val = (a % inp_mod) as i64;
  347. if inp_val >= (inp_mod_i64 / 2) {
  348. inp_val -= inp_mod_i64;
  349. }
  350. let sign: i64 = if inp_val >= 0 { 1 } else { -1 };
  351. let val = (inp_val as i128) * (out_mod as i128);
  352. let mut result = (val + (sign*(inp_mod_i64/2)) as i128) / (inp_mod as i128);
  353. result = (result + ((inp_mod/out_mod)*out_mod) as i128 + (2*out_mod_i128)) % out_mod_i128;
  354. assert!(result >= 0);
  355. ((result + out_mod_i128) % out_mod_i128) as u64
  356. }
  357. #[cfg(test)]
  358. mod test {
  359. use super::*;
  360. use crate::util::get_seeded_rng;
  361. use rand::Rng;
  362. fn combine(lo: u64, hi: u64) -> u128 {
  363. (lo as u128) & ((hi as u128) << 64)
  364. }
  365. #[test]
  366. fn div2_uint_mod_correct() {
  367. assert_eq!(div2_uint_mod(3, 7), 5);
  368. }
  369. #[test]
  370. fn divide_uint192_inplace_correct() {
  371. assert_eq!(
  372. divide_uint192_inplace([35, 0, 0], 7),
  373. ([0, 0, 0], [5, 0, 0])
  374. );
  375. assert_eq!(
  376. divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000),
  377. (
  378. [0x10, 0, 0],
  379. [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F]
  380. )
  381. );
  382. }
  383. #[test]
  384. fn get_barrett_crs_correct() {
  385. assert_eq!(
  386. get_barrett_crs(268369921u64),
  387. (16144578669088582089u64, 68736257792u64)
  388. );
  389. assert_eq!(
  390. get_barrett_crs(249561089u64),
  391. (10966983149909726427u64, 73916747789u64)
  392. );
  393. assert_eq!(
  394. get_barrett_crs(66974689739603969u64),
  395. (7906011006380390721u64, 275u64)
  396. );
  397. }
  398. #[test]
  399. fn barrett_reduction_u128_raw_correct() {
  400. let modulus = 66974689739603969u64;
  401. let modulus_u128 = modulus as u128;
  402. let exec = |val| {
  403. barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val)
  404. };
  405. assert_eq!(exec(modulus_u128), 0);
  406. assert_eq!(exec(modulus_u128 + 1), 1);
  407. assert_eq!(exec(modulus_u128 * 7 + 5), 5);
  408. let mut rng = get_seeded_rng();
  409. for _ in 0..100 {
  410. let val = combine(rng.gen(), rng.gen());
  411. assert_eq!(exec(val), (val % modulus_u128) as u64);
  412. }
  413. }
  414. #[test]
  415. fn barrett_raw_u64_correct() {
  416. let modulus = 66974689739603969u64;
  417. let cr1 = 275u64;
  418. let mut rng = get_seeded_rng();
  419. for _ in 0..100 {
  420. let val = rng.gen();
  421. assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus);
  422. }
  423. }
  424. }