123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 |
- use crate::params::*;
- use std::mem;
- use std::slice;
- pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
- (((a as u128) * (b as u128)) % (modulus as u128)) as u64
- }
- pub const fn log2(a: u64) -> u64 {
- std::mem::size_of::<u64>() as u64 * 8 - a.leading_zeros() as u64 - 1
- }
- pub fn log2_ceil(a: u64) -> u64 {
- f64::ceil(f64::log2(a as f64)) as u64
- }
- pub fn log2_ceil_usize(a: usize) -> usize {
- f64::ceil(f64::log2(a as f64)) as usize
- }
- pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
- barrett_coeff_u64(params, a * b, c)
- }
- pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
- barrett_coeff_u64(params, a * b + x, c)
- }
- pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
- barrett_coeff_u64(params, a + b, c)
- }
- pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
- params.moduli[c] - a
- }
- pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
- barrett_coeff_u64(params, x, c)
- }
- pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
- if exponent == 0 {
- return 1;
- }
- if exponent == 1 {
- return operand;
- }
- let mut power = operand;
- let mut product;
- let mut intermediate = 1u64;
- loop {
- if (exponent % 2) == 1 {
- product = multiply_uint_mod(power, intermediate, modulus);
- mem::swap(&mut product, &mut intermediate);
- }
- exponent >>= 1;
- if exponent == 0 {
- break;
- }
- product = multiply_uint_mod(power, power, modulus);
- mem::swap(&mut product, &mut power);
- }
- intermediate
- }
- pub fn reverse_bits(x: u64, bit_count: usize) -> u64 {
- if bit_count == 0 {
- return 0;
- }
- let r = x.reverse_bits();
- r >> (mem::size_of::<u64>() * 8 - bit_count)
- }
- pub fn div2_uint_mod(operand: u64, modulus: u64) -> u64 {
- if operand & 1 == 1 {
- let res = operand.overflowing_add(modulus);
- if res.1 {
- return (res.0 >> 1) | (1u64 << 63);
- } else {
- return res.0 >> 1;
- }
- } else {
- return operand >> 1;
- }
- }
- pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
- assert!(from_modulus >= to_modulus);
- let from_modulus_i64 = from_modulus as i64;
- let to_modulus_i64 = to_modulus as i64;
- let mut a_val = val as i64;
- if val >= from_modulus / 2 {
- a_val -= from_modulus_i64;
- }
- a_val = a_val + (from_modulus_i64 / to_modulus_i64) * to_modulus_i64 + 2 * to_modulus_i64;
- a_val %= to_modulus_i64;
- a_val as u64
- }
- pub fn get_barrett_crs(modulus: u64) -> (u64, u64) {
- let numerator = [0, 0, 1];
- let (_, quotient) = divide_uint192_inplace(numerator, modulus);
- (quotient[0], quotient[1])
- }
- pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) {
- let mut cr0 = [0u64; MAX_MODULI];
- let mut cr1 = [0u64; MAX_MODULI];
- for i in 0..moduli.len() {
- (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]);
- }
- (cr0, cr1)
- }
- pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
- let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
- // Barrett subtraction
- let res = input - tmp * modulus;
- // One more subtraction is enough
- if res >= modulus {
- res - modulus
- } else {
- res
- }
- }
- pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
- barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
- }
- fn split(x: u128) -> (u64, u64) {
- let lo = x & ((1u128 << 64) - 1);
- let hi = x >> 64;
- (lo as u64, hi as u64)
- }
- fn mul_u128(a: u64, b: u64) -> (u64, u64) {
- let prod = (a as u128) * (b as u128);
- split(prod)
- }
- fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 {
- match op1.checked_add(op2) {
- Some(x) => {
- *out = x;
- 0
- }
- None => 1,
- }
- }
- fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 {
- let (zx, zy) = split(val);
- let mut tmp1 = 0;
- let mut tmp3;
- let mut carry;
- let (_, prody) = mul_u128(zx, cr0);
- carry = prody;
- let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1);
- tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1);
- (tmp2x, tmp2y) = mul_u128(zy, cr0);
- carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1);
- tmp1 = zy * cr1 + tmp3 + carry;
- tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus));
- tmp3
- // uint64_t zx = val & (((__uint128_t)1 << 64) - 1);
- // uint64_t zy = val >> 64;
- // uint64_t tmp1, tmp3, carry;
- // ulonglong2_h prod = umul64wide(zx, const_ratio_0);
- // carry = prod.y;
- // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1);
- // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1);
- // tmp2 = umul64wide(zy, const_ratio_0);
- // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1);
- // tmp1 = zy * const_ratio_1 + tmp3 + carry;
- // tmp3 = zx - tmp1 * modulus;
- // return tmp3;
- }
- fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 {
- let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus);
- reduced_val -= (modulus) * ((reduced_val >= modulus) as u64);
- reduced_val
- }
- pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 {
- let modulus = params.modulus;
- let cr0 = params.barrett_cr_0_modulus;
- let cr1 = params.barrett_cr_1_modulus;
- barrett_reduction_u128_raw(modulus, cr0, cr1, val)
- }
- // Following code is ported from SEAL (github.com/microsoft/SEAL)
- pub fn get_significant_bit_count(val: &[u64]) -> usize {
- for i in (0..val.len()).rev() {
- for j in (0..64).rev() {
- if (val[i] & (1u64 << j)) != 0 {
- return i * 64 + j + 1;
- }
- }
- }
- 0
- }
- fn divide_round_up(num: usize, denom: usize) -> usize {
- (num + (denom - 1)) / denom
- }
- const BITS_PER_U64: usize = u64::BITS as usize;
- fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
- let mut result = [0u64; 3];
- if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
- result[2] = operand[0];
- result[1] = 0;
- result[0] = 0;
- } else if (shift_amount & BITS_PER_U64) != 0 {
- result[2] = operand[1];
- result[1] = operand[0];
- result[0] = 0;
- } else {
- result[2] = operand[2];
- result[1] = operand[1];
- result[0] = operand[0];
- }
- let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
- if bit_shift_amount != 0 {
- let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
- result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount);
- result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount);
- result[0] = result[0] << bit_shift_amount;
- }
- result
- }
- fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
- let mut result = [0u64; 3];
- if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
- result[0] = operand[2];
- result[1] = 0;
- result[2] = 0;
- } else if (shift_amount & BITS_PER_U64) != 0 {
- result[0] = operand[1];
- result[1] = operand[2];
- result[2] = 0;
- } else {
- result[2] = operand[2];
- result[1] = operand[1];
- result[0] = operand[0];
- }
- let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
- if bit_shift_amount != 0 {
- let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
- result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount);
- result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount);
- result[2] = result[2] >> bit_shift_amount;
- }
- result
- }
- fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
- *result = operand1.wrapping_add(operand2);
- (*result < operand1) as u8
- }
- fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 {
- let operand1 = operand1.wrapping_add(operand2);
- *result = operand1.wrapping_add(carry as u64);
- ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8
- }
- fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
- *result = operand1.wrapping_sub(operand2);
- (operand2 > operand1) as u8
- }
- fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 {
- let diff = operand1.wrapping_sub(operand2);
- *result = diff.wrapping_sub((borrow != 0) as u64);
- ((diff > operand1) || (diff < (borrow as u64))) as u8
- }
- pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
- let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]);
- for i in 0..uint64_count - 1 {
- let mut temp_result = 0u64;
- borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result);
- result[1 + i] = temp_result;
- }
- borrow
- }
- pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
- let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]);
- for i in 0..uint64_count - 1 {
- let mut temp_result = 0u64;
- carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result);
- result[1 + i] = temp_result;
- }
- carry
- }
- pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) {
- let mut numerator_bits = get_significant_bit_count(&numerator);
- let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator));
- let mut quotient = [0u64; 3];
- if numerator_bits < denominator_bits {
- return (numerator, quotient);
- }
- let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64);
- if uint64_count == 1 {
- quotient[0] = numerator[0] / denominator;
- numerator[0] -= quotient[0] * denominator;
- return (numerator, quotient);
- }
- let mut shifted_denominator = [0u64; 3];
- shifted_denominator[0] = denominator;
- let mut difference = [0u64; 3];
- let denominator_shift = numerator_bits - denominator_bits;
- let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift);
- denominator_bits += denominator_shift;
- let mut remaining_shifts = denominator_shift;
- while numerator_bits == denominator_bits {
- if (sub_uint(
- &numerator,
- &shifted_denominator,
- uint64_count,
- &mut difference,
- )) != 0
- {
- if remaining_shifts == 0 {
- break;
- }
- add_uint(
- &difference.clone(),
- &numerator,
- uint64_count,
- &mut difference,
- );
- quotient = left_shift_uint192(quotient, 1);
- remaining_shifts -= 1;
- }
- quotient[0] |= 1;
- numerator_bits = get_significant_bit_count(&difference);
- let mut numerator_shift = denominator_bits - numerator_bits;
- if numerator_shift > remaining_shifts {
- numerator_shift = remaining_shifts;
- }
- if numerator_bits > 0 {
- numerator = left_shift_uint192(difference, numerator_shift);
- numerator_bits += numerator_shift;
- } else {
- for w in 0..uint64_count {
- numerator[w] = 0;
- }
- }
- quotient = left_shift_uint192(quotient, numerator_shift);
- remaining_shifts -= numerator_shift;
- }
- if numerator_bits > 0 {
- numerator = right_shift_uint192(numerator, denominator_shift);
- }
- (numerator, quotient)
- }
- pub fn recenter_mod(val: u64, small_modulus: u64, large_modulus: u64) -> u64 {
- assert!(val < small_modulus);
- let mut val_i64 = val as i64;
- let small_modulus_i64 = small_modulus as i64;
- let large_modulus_i64 = large_modulus as i64;
- if val_i64 > small_modulus_i64 / 2 {
- val_i64 -= small_modulus_i64;
- }
- if val_i64 < 0 {
- val_i64 += large_modulus_i64;
- }
- val_i64 as u64
- }
- pub fn rescale(a: u64, inp_mod: u64, out_mod: u64) -> u64 {
- let inp_mod_i64 = inp_mod as i64;
- let out_mod_i128 = out_mod as i128;
- let mut inp_val = (a % inp_mod) as i64;
- if inp_val >= (inp_mod_i64 / 2) {
- inp_val -= inp_mod_i64;
- }
- let sign: i64 = if inp_val >= 0 { 1 } else { -1 };
- let val = (inp_val as i128) * (out_mod as i128);
- let mut result = (val + (sign * (inp_mod_i64 / 2)) as i128) / (inp_mod as i128);
- result = (result + ((inp_mod / out_mod) * out_mod) as i128 + (2 * out_mod_i128)) % out_mod_i128;
- assert!(result >= 0);
- ((result + out_mod_i128) % out_mod_i128) as u64
- }
- #[cfg(test)]
- mod test {
- use super::*;
- use crate::util::get_seeded_rng;
- use rand::Rng;
- fn combine(lo: u64, hi: u64) -> u128 {
- (lo as u128) & ((hi as u128) << 64)
- }
- #[test]
- fn div2_uint_mod_correct() {
- assert_eq!(div2_uint_mod(3, 7), 5);
- }
- #[test]
- fn divide_uint192_inplace_correct() {
- assert_eq!(
- divide_uint192_inplace([35, 0, 0], 7),
- ([0, 0, 0], [5, 0, 0])
- );
- assert_eq!(
- divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000),
- (
- [0x10, 0, 0],
- [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F]
- )
- );
- }
- #[test]
- fn get_barrett_crs_correct() {
- assert_eq!(
- get_barrett_crs(268369921u64),
- (16144578669088582089u64, 68736257792u64)
- );
- assert_eq!(
- get_barrett_crs(249561089u64),
- (10966983149909726427u64, 73916747789u64)
- );
- assert_eq!(
- get_barrett_crs(66974689739603969u64),
- (7906011006380390721u64, 275u64)
- );
- }
- #[test]
- fn barrett_reduction_u128_raw_correct() {
- let modulus = 66974689739603969u64;
- let modulus_u128 = modulus as u128;
- let exec = |val| {
- barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val)
- };
- assert_eq!(exec(modulus_u128), 0);
- assert_eq!(exec(modulus_u128 + 1), 1);
- assert_eq!(exec(modulus_u128 * 7 + 5), 5);
- let mut rng = get_seeded_rng();
- for _ in 0..100 {
- let val = combine(rng.gen(), rng.gen());
- assert_eq!(exec(val), (val % modulus_u128) as u64);
- }
- }
- #[test]
- fn barrett_raw_u64_correct() {
- let modulus = 66974689739603969u64;
- let cr1 = 275u64;
- let mut rng = get_seeded_rng();
- for _ in 0..100 {
- let val = rng.gen();
- assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus);
- }
- }
- }
|