Browse Source

Functions to compute a bit representation for a range proof

Ian Goldberg 4 months ago
parent
commit
110a9eacb1
1 changed files with 130 additions and 9 deletions
  1. 130 9
      src/rangeutils.rs

+ 130 - 9
src/rangeutils.rs

@@ -2,18 +2,19 @@
 //! processing of range proofs.
 
 use group::ff::PrimeField;
+use sigma_rs::errors::Error;
 use subtle::Choice;
 
-/// Convert a [`Scalar`] to an [`i128`], assuming it fits and is
-/// nonnegative.  Also output the number of bits of the [`Scalar`].
-/// This version assumes that `s` is public, and so does not need to run
-/// in constant time.
+/// Convert a [`Scalar`] to an [`u128`], assuming it fits in an [`i128`]
+/// and is nonnegative.  Also output the number of bits of the
+/// [`Scalar`].  This version assumes that `s` is public, and so does
+/// not need to run in constant time.
 ///
 /// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
-pub fn bit_decomp_vartime<S: PrimeField>(mut s: S) -> Option<(i128, u32)> {
-    let mut val = 0i128;
+pub fn bit_decomp_vartime<S: PrimeField>(mut s: S) -> Option<(u128, u32)> {
+    let mut val = 0u128;
     let mut bitnum = 0u32;
-    let mut bitval = 1i128; // Invariant: bitval = 2^bitnum
+    let mut bitval = 1u128; // Invariant: bitval = 2^bitnum
     while bitnum < 127 && !s.is_zero_vartime() {
         if s.is_odd().into() {
             val += bitval;
@@ -48,6 +49,90 @@ pub fn bit_decomp<S: PrimeField>(mut s: S, nbits: u32) -> Vec<Choice> {
     bits
 }
 
+/// Given a [`Scalar`] `upper` (strictly greater than 1), make a vector
+/// of [`Scalar`]s with the property that a [`Scalar`] `x` can be
+/// written as a sum of zero or more (distinct) elements of this vector
+/// if and only if `0 <= x < upper`.
+///
+/// The strategy is to write x as a sequence of `nbits` bits, with one
+/// twist: the low bits represent 2^0, 2^1, 2^2, etc., as usual.  But
+/// the highest bit represents `upper-2^{nbits-1}` instead of the usual
+/// `2^{nbits-1}`.  `nbits` will be the largest value for which
+/// `2^{nbits-1}` is strictly less than `upper`.  For example, if
+/// `upper` is 100, the bits represent 1, 2, 4, 8, 16, 32, 36.  A number
+/// x can be represented as a sum of 0 or more elements of this sequence
+/// if and only if `0 <= x < upper`.
+///
+/// It is assumed that `upper` is public, and so this function is not
+/// constant time.
+///
+/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
+pub fn bitrep_scalars_vartime<S: PrimeField>(upper: S) -> Result<Vec<S>, Error> {
+    // Get the `u128` value of `upper`, and its number of bits `nbits`
+    let (upper_val, mut nbits) = bit_decomp_vartime(upper).ok_or(Error::VerificationFailure)?;
+
+    // Ensure `nbits` is at least 2.
+    if nbits < 2 {
+        return Err(Error::VerificationFailure);
+    }
+
+    // If upper is exactly a power of 2, use one fewer bit
+    if upper_val == 1u128 << (nbits - 1) {
+        nbits -= 1;
+    }
+
+    // Make the vector of Scalars containing the represented value of
+    // the bits
+    Ok((0..nbits)
+        .map(|i| {
+            if i < nbits - 1 {
+                S::from_u128(1u128 << i)
+            } else {
+                // Compute the represented value of the highest bit
+                S::from_u128(upper_val - (1u128 << (nbits - 1)))
+            }
+        })
+        .collect())
+}
+
+/// Given a vector of [`Scalar`]s as output by
+/// [`bitrep_scalars_vartime`] and a private [`Scalar`] `x`, output a
+/// vector of [`Choice`] (of the same length as the given
+/// `bitrep_scalars` vector) such that `x` is the sum of the chosen
+/// elements of `bitrep_scalars`.  This function should be constant time
+/// in the value of `x`.  If `x` is not less than the `upper` used by
+/// [`bitrep_scalars_vartime`] to generate `bitrep_scalars`, then `x`
+/// will not (and indeed cannot) equal the sum of the chosen elements of
+/// `bitrep_scalars`.
+///
+/// [`Scalar`]: https://docs.rs/group/0.13.0/group/trait.Group.html#associatedtype.Scalar
+pub fn compute_bitrep<S: PrimeField>(mut x: S, bitrep_scalars: &[S]) -> Vec<Choice> {
+    // We know the length of bitrep_scalars is at most 127.
+    let nbits: u32 = bitrep_scalars.len().try_into().unwrap();
+
+    // Decompose `x` as a normal `nbit`-bit vector.  This only looks at
+    // the low `nbits` bits of `x`, so the resulting bit vector forces
+    // `x < 2^{nbits}`.
+    let x_raw_bits = bit_decomp(x, nbits);
+    let high_bit = x_raw_bits[(nbits as usize) - 1];
+
+    // Conditionally subtract the last represented value in the
+    // vector, depending on whether the high bit of x is set.  That is,
+    // if `x < 2^{nbits-1}`, then we don't subtract from x.  If `x >=
+    // 2^{nbits-1}`, then we will subtract `upper - 2^{nbits-1}` from
+    // `x`.  In either case, the remaining value is non-negative, and
+    // strictly less than 2^{nbits-1}.
+    x -= S::conditional_select(&S::ZERO, &bitrep_scalars[(nbits as usize) - 1], high_bit);
+
+    // Now get the `nbits-1` bits of the result in the usual way
+    let mut x_bits = bit_decomp(x, nbits - 1);
+
+    // and tack on the high bit
+    x_bits.push(high_bit);
+
+    x_bits
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -81,11 +166,11 @@ mod tests {
         assert_eq!(bit_decomp_vartime(Scalar::from(1u32).neg()), None);
         assert_eq!(
             bit_decomp_vartime(Scalar::from((1u128 << 127) - 2)),
-            Some((i128::MAX - 1, 127))
+            Some(((i128::MAX - 1) as u128, 127))
         );
         assert_eq!(
             bit_decomp_vartime(Scalar::from((1u128 << 127) - 1)),
-            Some((i128::MAX, 127))
+            Some((i128::MAX as u128, 127))
         );
         assert_eq!(bit_decomp_vartime(Scalar::from(1u128 << 127)), None);
 
@@ -118,4 +203,40 @@ mod tests {
         "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
         );
     }
+
+    // Obliviously test whether x is in 0..upper (that is, 0 <= x <
+    // upper) using bit decomposition.  `upper` is considered public,
+    // but `x` is private.  `upper` must be at least 2.
+    fn bitrep_tester(upper: Scalar, x: Scalar, expected: bool) -> Result<(), Error> {
+        let rep_scalars = bitrep_scalars_vartime(upper)?;
+        let bitrep = compute_bitrep(x, &rep_scalars);
+
+        let nbits = bitrep.len();
+        assert!(nbits == rep_scalars.len());
+        let mut x_out = Scalar::ZERO;
+        for i in 0..nbits {
+            x_out += Scalar::conditional_select(&Scalar::ZERO, &rep_scalars[i], bitrep[i]);
+        }
+
+        if (x == x_out) != expected {
+            return Err(Error::VerificationFailure);
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitrep_test() {
+        bitrep_tester(Scalar::from(0u32), Scalar::from(0u32), false).unwrap_err();
+        bitrep_tester(Scalar::from(1u32), Scalar::from(0u32), true).unwrap_err();
+        bitrep_tester(Scalar::from(2u32), Scalar::from(1u32), true).unwrap();
+        bitrep_tester(Scalar::from(3u32), Scalar::from(1u32), true).unwrap();
+        bitrep_tester(Scalar::from(100u32), Scalar::from(99u32), true).unwrap();
+        bitrep_tester(Scalar::from(127u32), Scalar::from(126u32), true).unwrap();
+        bitrep_tester(Scalar::from(128u32), Scalar::from(127u32), true).unwrap();
+        bitrep_tester(Scalar::from(128u32), Scalar::from(128u32), false).unwrap();
+        bitrep_tester(Scalar::from(129u32), Scalar::from(128u32), true).unwrap();
+        bitrep_tester(Scalar::from(129u32), Scalar::from(0u32), true).unwrap();
+        bitrep_tester(Scalar::from(129u32), Scalar::from(129u32), false).unwrap();
+    }
 }