|
@@ -1,5 +1,6 @@
|
|
|
use crate::params::*;
|
|
|
use std::mem;
|
|
|
+use std::slice;
|
|
|
|
|
|
pub fn multiply_uint_mod(a: u64, b: u64, modulus: u64) -> u64 {
|
|
|
(((a as u128) * (b as u128)) % (modulus as u128)) as u64
|
|
@@ -18,15 +19,15 @@ pub fn log2_ceil_usize(a: usize) -> usize {
|
|
|
}
|
|
|
|
|
|
pub fn multiply_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
|
|
|
- (a * b) % params.moduli[c]
|
|
|
+ barrett_coeff_u64(params, a * b, c)
|
|
|
}
|
|
|
|
|
|
pub fn multiply_add_modular(params: &Params, a: u64, b: u64, x: u64, c: usize) -> u64 {
|
|
|
- (a * b + x) % params.moduli[c]
|
|
|
+ barrett_coeff_u64(params, a * b + x, c)
|
|
|
}
|
|
|
|
|
|
pub fn add_modular(params: &Params, a: u64, b: u64, c: usize) -> u64 {
|
|
|
- (a + b) % params.moduli[c]
|
|
|
+ barrett_coeff_u64(params, a + b, c)
|
|
|
}
|
|
|
|
|
|
pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
|
|
@@ -34,7 +35,7 @@ pub fn invert_modular(params: &Params, a: u64, c: usize) -> u64 {
|
|
|
}
|
|
|
|
|
|
pub fn modular_reduce(params: &Params, x: u64, c: usize) -> u64 {
|
|
|
- (x) % params.moduli[c]
|
|
|
+ barrett_coeff_u64(params, x, c)
|
|
|
}
|
|
|
|
|
|
pub fn exponentiate_uint_mod(operand: u64, mut exponent: u64, modulus: u64) -> u64 {
|
|
@@ -102,12 +103,384 @@ pub fn recenter(val: u64, from_modulus: u64, to_modulus: u64) -> u64 {
|
|
|
a_val as u64
|
|
|
}
|
|
|
|
|
|
+pub fn get_barrett_crs(modulus: u64) -> (u64, u64) {
|
|
|
+ let numerator = [0, 0, 1];
|
|
|
+ let (_, quotient) = divide_uint192_inplace(numerator, modulus);
|
|
|
+
|
|
|
+ (quotient[0], quotient[1])
|
|
|
+}
|
|
|
+
|
|
|
+pub fn get_barrett(moduli: &[u64]) -> ([u64; MAX_MODULI], [u64; MAX_MODULI]) {
|
|
|
+ let mut cr0 = [0u64; MAX_MODULI];
|
|
|
+ let mut cr1 = [0u64; MAX_MODULI];
|
|
|
+ for i in 0..moduli.len() {
|
|
|
+ (cr0[i], cr1[i]) = get_barrett_crs(moduli[i]);
|
|
|
+ }
|
|
|
+ (cr0, cr1)
|
|
|
+}
|
|
|
+
|
|
|
+pub fn barrett_raw_u64(input: u64, const_ratio_1: u64, modulus: u64) -> u64 {
|
|
|
+ let tmp = (((input as u128) * (const_ratio_1 as u128)) >> 64) as u64;
|
|
|
+
|
|
|
+ // Barrett subtraction
|
|
|
+ let res = input - tmp * modulus;
|
|
|
+
|
|
|
+ // One more subtraction is enough
|
|
|
+ if res >= modulus {
|
|
|
+ res - modulus
|
|
|
+ } else {
|
|
|
+ res
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub fn barrett_coeff_u64(params: &Params, val: u64, n: usize) -> u64 {
|
|
|
+ barrett_raw_u64(val, params.barrett_cr_1[n], params.moduli[n])
|
|
|
+}
|
|
|
+
|
|
|
+fn split(x: u128) -> (u64, u64) {
|
|
|
+ let lo = x & ((1u128 << 64) - 1);
|
|
|
+ let hi = x >> 64;
|
|
|
+ (lo as u64, hi as u64)
|
|
|
+}
|
|
|
+
|
|
|
+fn mul_u128(a: u64, b: u64) -> (u64, u64) {
|
|
|
+ let prod = (a as u128) * (b as u128);
|
|
|
+ split(prod)
|
|
|
+}
|
|
|
+
|
|
|
+fn add_u64(op1: u64, op2: u64, out: &mut u64) -> u64 {
|
|
|
+ match op1.checked_add(op2) {
|
|
|
+ Some(x) => {
|
|
|
+ *out = x;
|
|
|
+ 0
|
|
|
+ }
|
|
|
+ None => 1,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+fn barrett_raw_u128(val: u128, cr0: u64, cr1: u64, modulus: u64) -> u64 {
|
|
|
+ let (zx, zy) = split(val);
|
|
|
+
|
|
|
+ let mut tmp1 = 0;
|
|
|
+ let mut tmp3;
|
|
|
+ let mut carry;
|
|
|
+ let (_, prody) = mul_u128(zx, cr0);
|
|
|
+ carry = prody;
|
|
|
+ let (mut tmp2x, mut tmp2y) = mul_u128(zx, cr1);
|
|
|
+ tmp3 = tmp2y + add_u64(tmp2x, carry, &mut tmp1);
|
|
|
+ (tmp2x, tmp2y) = mul_u128(zy, cr0);
|
|
|
+ carry = tmp2y + add_u64(tmp1, tmp2x, &mut tmp1);
|
|
|
+ tmp1 = zy * cr1 + tmp3 + carry;
|
|
|
+ tmp3 = zx.wrapping_sub(tmp1.wrapping_mul(modulus));
|
|
|
+
|
|
|
+ tmp3
|
|
|
+
|
|
|
+ // uint64_t zx = val & (((__uint128_t)1 << 64) - 1);
|
|
|
+ // uint64_t zy = val >> 64;
|
|
|
+
|
|
|
+ // uint64_t tmp1, tmp3, carry;
|
|
|
+ // ulonglong2_h prod = umul64wide(zx, const_ratio_0);
|
|
|
+ // carry = prod.y;
|
|
|
+ // ulonglong2_h tmp2 = umul64wide(zx, const_ratio_1);
|
|
|
+ // tmp3 = tmp2.y + cpu_add_u64(tmp2.x, carry, &tmp1);
|
|
|
+ // tmp2 = umul64wide(zy, const_ratio_0);
|
|
|
+ // carry = tmp2.y + cpu_add_u64(tmp1, tmp2.x, &tmp1);
|
|
|
+ // tmp1 = zy * const_ratio_1 + tmp3 + carry;
|
|
|
+ // tmp3 = zx - tmp1 * modulus;
|
|
|
+
|
|
|
+ // return tmp3;
|
|
|
+}
|
|
|
+
|
|
|
+fn barrett_reduction_u128_raw(modulus: u64, cr0: u64, cr1: u64, val: u128) -> u64 {
|
|
|
+ let mut reduced_val = barrett_raw_u128(val, cr0, cr1, modulus);
|
|
|
+ reduced_val -= (modulus) * ((reduced_val >= modulus) as u64);
|
|
|
+ reduced_val
|
|
|
+}
|
|
|
+
|
|
|
+pub fn barrett_reduction_u128(params: &Params, val: u128) -> u64 {
|
|
|
+ let modulus = params.modulus;
|
|
|
+ let cr0 = params.barrett_cr_0_modulus;
|
|
|
+ let cr1 = params.barrett_cr_1_modulus;
|
|
|
+ barrett_reduction_u128_raw(modulus, cr0, cr1, val)
|
|
|
+}
|
|
|
+
|
|
|
+// Following code is ported from SEAL (github.com/microsoft/SEAL)
|
|
|
+
|
|
|
+pub fn get_significant_bit_count(val: &[u64]) -> usize {
|
|
|
+ for i in (0..val.len()).rev() {
|
|
|
+ for j in (0..64).rev() {
|
|
|
+ if (val[i] & (1u64 << j)) != 0 {
|
|
|
+ return i * 64 + j + 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ 0
|
|
|
+}
|
|
|
+
|
|
|
+fn divide_round_up(num: usize, denom: usize) -> usize {
|
|
|
+ (num + (denom - 1)) / denom
|
|
|
+}
|
|
|
+
|
|
|
+const BITS_PER_U64: usize = u64::BITS as usize;
|
|
|
+
|
|
|
+fn left_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
|
|
|
+ let mut result = [0u64; 3];
|
|
|
+ if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
|
|
|
+ result[2] = operand[0];
|
|
|
+ result[1] = 0;
|
|
|
+ result[0] = 0;
|
|
|
+ } else if (shift_amount & BITS_PER_U64) != 0 {
|
|
|
+ result[2] = operand[1];
|
|
|
+ result[1] = operand[0];
|
|
|
+ result[0] = 0;
|
|
|
+ } else {
|
|
|
+ result[2] = operand[2];
|
|
|
+ result[1] = operand[1];
|
|
|
+ result[0] = operand[0];
|
|
|
+ }
|
|
|
+
|
|
|
+ let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
|
|
|
+
|
|
|
+ if bit_shift_amount != 0 {
|
|
|
+ let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
|
|
|
+
|
|
|
+ result[2] = (result[2] << bit_shift_amount) | (result[1] >> neg_bit_shift_amount);
|
|
|
+ result[1] = (result[1] << bit_shift_amount) | (result[0] >> neg_bit_shift_amount);
|
|
|
+ result[0] = result[0] << bit_shift_amount;
|
|
|
+ }
|
|
|
+
|
|
|
+ result
|
|
|
+}
|
|
|
+
|
|
|
+fn right_shift_uint192(operand: [u64; 3], shift_amount: usize) -> [u64; 3] {
|
|
|
+ let mut result = [0u64; 3];
|
|
|
+
|
|
|
+ if (shift_amount & (BITS_PER_U64 << 1)) != 0 {
|
|
|
+ result[0] = operand[2];
|
|
|
+ result[1] = 0;
|
|
|
+ result[2] = 0;
|
|
|
+ } else if (shift_amount & BITS_PER_U64) != 0 {
|
|
|
+ result[0] = operand[1];
|
|
|
+ result[1] = operand[2];
|
|
|
+ result[2] = 0;
|
|
|
+ } else {
|
|
|
+ result[2] = operand[2];
|
|
|
+ result[1] = operand[1];
|
|
|
+ result[0] = operand[0];
|
|
|
+ }
|
|
|
+
|
|
|
+ let bit_shift_amount = shift_amount & (BITS_PER_U64 - 1);
|
|
|
+
|
|
|
+ if bit_shift_amount != 0 {
|
|
|
+ let neg_bit_shift_amount = BITS_PER_U64 - bit_shift_amount;
|
|
|
+
|
|
|
+ result[0] = (result[0] >> bit_shift_amount) | (result[1] << neg_bit_shift_amount);
|
|
|
+ result[1] = (result[1] >> bit_shift_amount) | (result[2] << neg_bit_shift_amount);
|
|
|
+ result[2] = result[2] >> bit_shift_amount;
|
|
|
+ }
|
|
|
+
|
|
|
+ result
|
|
|
+}
|
|
|
+
|
|
|
+fn add_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
|
|
|
+ *result = operand1.wrapping_add(operand2);
|
|
|
+ (*result < operand1) as u8
|
|
|
+}
|
|
|
+
|
|
|
+fn add_uint64_carry(operand1: u64, operand2: u64, carry: u8, result: &mut u64) -> u8 {
|
|
|
+ let operand1 = operand1.wrapping_add(operand2);
|
|
|
+ *result = operand1.wrapping_add(carry as u64);
|
|
|
+ ((operand1 < operand2) || (!operand1 < (carry as u64))) as u8
|
|
|
+}
|
|
|
+
|
|
|
+fn sub_uint64(operand1: u64, operand2: u64, result: &mut u64) -> u8 {
|
|
|
+ *result = operand1.wrapping_sub(operand2);
|
|
|
+ (operand2 > operand1) as u8
|
|
|
+}
|
|
|
+
|
|
|
+fn sub_uint64_borrow(operand1: u64, operand2: u64, borrow: u8, result: &mut u64) -> u8 {
|
|
|
+ let diff = operand1.wrapping_sub(operand2);
|
|
|
+ *result = diff.wrapping_sub((borrow != 0) as u64);
|
|
|
+ ((diff > operand1) || (diff < (borrow as u64))) as u8
|
|
|
+}
|
|
|
+
|
|
|
+pub fn sub_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
|
|
|
+ let mut borrow = sub_uint64(operand1[0], operand2[0], &mut result[0]);
|
|
|
+
|
|
|
+ for i in 0..uint64_count - 1 {
|
|
|
+ let mut temp_result = 0u64;
|
|
|
+ borrow = sub_uint64_borrow(operand1[1 + i], operand2[1 + i], borrow, &mut temp_result);
|
|
|
+ result[1 + i] = temp_result;
|
|
|
+ }
|
|
|
+
|
|
|
+ borrow
|
|
|
+}
|
|
|
+
|
|
|
+pub fn add_uint(operand1: &[u64], operand2: &[u64], uint64_count: usize, result: &mut [u64]) -> u8 {
|
|
|
+ let mut carry = add_uint64(operand1[0], operand2[0], &mut result[0]);
|
|
|
+
|
|
|
+ for i in 0..uint64_count - 1 {
|
|
|
+ let mut temp_result = 0u64;
|
|
|
+ carry = add_uint64_carry(operand1[1 + i], operand2[1 + i], carry, &mut temp_result);
|
|
|
+ result[1 + i] = temp_result;
|
|
|
+ }
|
|
|
+
|
|
|
+ carry
|
|
|
+}
|
|
|
+
|
|
|
+pub fn divide_uint192_inplace(mut numerator: [u64; 3], denominator: u64) -> ([u64; 3], [u64; 3]) {
|
|
|
+ let mut numerator_bits = get_significant_bit_count(&numerator);
|
|
|
+ let mut denominator_bits = get_significant_bit_count(slice::from_ref(&denominator));
|
|
|
+
|
|
|
+ let mut quotient = [0u64; 3];
|
|
|
+
|
|
|
+ if numerator_bits < denominator_bits {
|
|
|
+ return (numerator, quotient);
|
|
|
+ }
|
|
|
+
|
|
|
+ let uint64_count = divide_round_up(numerator_bits, BITS_PER_U64);
|
|
|
+
|
|
|
+ if uint64_count == 1 {
|
|
|
+ quotient[0] = numerator[0] / denominator;
|
|
|
+ numerator[0] -= quotient[0] * denominator;
|
|
|
+ return (numerator, quotient);
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut shifted_denominator = [0u64; 3];
|
|
|
+ shifted_denominator[0] = denominator;
|
|
|
+
|
|
|
+ let mut difference = [0u64; 3];
|
|
|
+
|
|
|
+ let denominator_shift = numerator_bits - denominator_bits;
|
|
|
+
|
|
|
+ let shifted_denominator = left_shift_uint192(shifted_denominator, denominator_shift);
|
|
|
+ denominator_bits += denominator_shift;
|
|
|
+
|
|
|
+ let mut remaining_shifts = denominator_shift;
|
|
|
+ while numerator_bits == denominator_bits {
|
|
|
+ if (sub_uint(
|
|
|
+ &numerator,
|
|
|
+ &shifted_denominator,
|
|
|
+ uint64_count,
|
|
|
+ &mut difference,
|
|
|
+ )) != 0
|
|
|
+ {
|
|
|
+ if remaining_shifts == 0 {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ add_uint(
|
|
|
+ &difference.clone(),
|
|
|
+ &numerator,
|
|
|
+ uint64_count,
|
|
|
+ &mut difference,
|
|
|
+ );
|
|
|
+
|
|
|
+ quotient = left_shift_uint192(quotient, 1);
|
|
|
+ remaining_shifts -= 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ quotient[0] |= 1;
|
|
|
+
|
|
|
+ numerator_bits = get_significant_bit_count(&difference);
|
|
|
+ let mut numerator_shift = denominator_bits - numerator_bits;
|
|
|
+ if numerator_shift > remaining_shifts {
|
|
|
+ numerator_shift = remaining_shifts;
|
|
|
+ }
|
|
|
+
|
|
|
+ if numerator_bits > 0 {
|
|
|
+ numerator = left_shift_uint192(difference, numerator_shift);
|
|
|
+ numerator_bits += numerator_shift;
|
|
|
+ } else {
|
|
|
+ for w in 0..uint64_count {
|
|
|
+ numerator[w] = 0;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ quotient = left_shift_uint192(quotient, numerator_shift);
|
|
|
+ remaining_shifts -= numerator_shift;
|
|
|
+ }
|
|
|
+
|
|
|
+ if numerator_bits > 0 {
|
|
|
+ numerator = right_shift_uint192(numerator, denominator_shift);
|
|
|
+ }
|
|
|
+
|
|
|
+ (numerator, quotient)
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
use super::*;
|
|
|
+ use crate::util::get_seeded_rng;
|
|
|
+ use rand::Rng;
|
|
|
+
|
|
|
+ fn combine(lo: u64, hi: u64) -> u128 {
|
|
|
+ (lo as u128) & ((hi as u128) << 64)
|
|
|
+ }
|
|
|
|
|
|
#[test]
|
|
|
fn div2_uint_mod_correct() {
|
|
|
assert_eq!(div2_uint_mod(3, 7), 5);
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn divide_uint192_inplace_correct() {
|
|
|
+ assert_eq!(
|
|
|
+ divide_uint192_inplace([35, 0, 0], 7),
|
|
|
+ ([0, 0, 0], [5, 0, 0])
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ divide_uint192_inplace([0x10101010, 0x2B2B2B2B, 0xF1F1F1F1], 0x1000),
|
|
|
+ (
|
|
|
+ [0x10, 0, 0],
|
|
|
+ [0xB2B0000000010101, 0x1F1000000002B2B2, 0xF1F1F]
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn get_barrett_crs_correct() {
|
|
|
+ assert_eq!(
|
|
|
+ get_barrett_crs(268369921u64),
|
|
|
+ (16144578669088582089u64, 68736257792u64)
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_barrett_crs(249561089u64),
|
|
|
+ (10966983149909726427u64, 73916747789u64)
|
|
|
+ );
|
|
|
+ assert_eq!(
|
|
|
+ get_barrett_crs(66974689739603969u64),
|
|
|
+ (7906011006380390721u64, 275u64)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn barrett_reduction_u128_raw_correct() {
|
|
|
+ let modulus = 66974689739603969u64;
|
|
|
+ let modulus_u128 = modulus as u128;
|
|
|
+ let exec = |val| {
|
|
|
+ barrett_reduction_u128_raw(66974689739603969u64, 7906011006380390721u64, 275u64, val)
|
|
|
+ };
|
|
|
+ assert_eq!(exec(modulus_u128), 0);
|
|
|
+ assert_eq!(exec(modulus_u128 + 1), 1);
|
|
|
+ assert_eq!(exec(modulus_u128 * 7 + 5), 5);
|
|
|
+
|
|
|
+ let mut rng = get_seeded_rng();
|
|
|
+ for _ in 0..100 {
|
|
|
+ let val = combine(rng.gen(), rng.gen());
|
|
|
+ assert_eq!(exec(val), (val % modulus_u128) as u64);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn barrett_raw_u64_correct() {
|
|
|
+ let modulus = 66974689739603969u64;
|
|
|
+ let cr1 = 275u64;
|
|
|
+
|
|
|
+ let mut rng = get_seeded_rng();
|
|
|
+ for _ in 0..100 {
|
|
|
+ let val = rng.gen();
|
|
|
+ assert_eq!(barrett_raw_u64(val, cr1, modulus), val % modulus);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|