Browse Source

Compute all the variables needed for range proofs

Ian Goldberg 4 months ago
parent
commit
46e240d6d4
2 changed files with 235 additions and 53 deletions
  1. 15 7
      sigma_compiler_core/src/codegen.rs
  2. 220 46
      sigma_compiler_core/src/rangeproof.rs

+ 15 - 7
sigma_compiler_core/src/codegen.rs

@@ -165,6 +165,11 @@ impl CodeGen {
         id
     }
 
+    /// Create a new identifier, using the unique prefix
+    pub fn gen_ident(&self, base: &Ident) -> Ident {
+        format_ident!("{}{}", self.unique_prefix, base)
+    }
+
     /// Append some code to the generated `prove` function
     pub fn prove_append(&mut self, code: TokenStream) {
         let prove_code = &self.prove_code;
@@ -194,16 +199,17 @@ impl CodeGen {
     }
 
     /// Append some code to both the generated `prove` and `verify`
-    /// functions
-    pub fn prove_verify_append(&mut self, code: TokenStream) {
+    /// functions, the latter to be run before the `sent_params` are
+    /// deserialized
+    pub fn prove_verify_pre_params_append(&mut self, code: TokenStream) {
         let prove_code = &self.prove_code;
         self.prove_code = quote! {
             #prove_code
             #code
         };
-        let verify_code = &self.verify_code;
-        self.verify_code = quote! {
-            #verify_code
+        let verify_pre_params_code = &self.verify_pre_params_code;
+        self.verify_pre_params_code = quote! {
+            #verify_pre_params_code
             #code
         };
     }
@@ -404,7 +410,7 @@ impl CodeGen {
 
                 let chunks = self.sent_params.fields.iter().map(|sf| match sf {
                     StructField::Point(id) => quote! {
-                        let #id = sigma_rs::serialization::deserialize_elements(
+                        let #id: Point = sigma_rs::serialization::deserialize_elements(
                                 &proof[#offset_var..],
                                 1,
                             ).ok_or(SigmaError::VerificationFailure)?[0];
@@ -455,8 +461,10 @@ impl CodeGen {
         quote! {
             #[allow(non_snake_case)]
             pub mod #proto_name {
+                use sigma_compiler::group::Group;
+                use sigma_compiler::group::ff::{Field, PrimeField};
+                use sigma_compiler::group::ff::derive::subtle::ConditionallySelectable;
                 use sigma_compiler::rand::{CryptoRng, RngCore};
-                use sigma_compiler::group::ff::PrimeField;
                 use sigma_compiler::sigma_rs::errors::Error as SigmaError;
                 #dump_use
 

+ 220 - 46
sigma_compiler_core/src/rangeproof.rs

@@ -47,7 +47,7 @@ struct RangeStatement {
     upper: Expr,
     /// The expression that is being asserted that it is in the range.
     /// This must be a [`LinScalar`]
-    expr: LinScalar,
+    linscalar: LinScalar,
 }
 
 /// Subtract the Expr `lower` (with constant value `lowerval`, if
@@ -181,10 +181,7 @@ fn parse(vars: &TaggedVarDict, vardict: &VarDict, expr: &Expr) -> Option<RangeSt
                 linscalar.pub_scalar_expr = Some(pubscalar_expr);
             }
 
-            return Some(RangeStatement {
-                upper,
-                expr: linscalar,
-            });
+            return Some(RangeStatement { upper, linscalar });
         }
     }
     None
@@ -216,7 +213,7 @@ fn convert_commitment(
     if ped_assign_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(ped_assign_linscalar.coeff);
         generated_code = quote! {
-            #coeff_tokens.inverse().unwrap() * #generated_code
+            <Scalar as Field>::invert(&#coeff_tokens).unwrap() * #generated_code
         };
     }
     // Now multiply by the coeff in new_linscalar, if present
@@ -259,7 +256,7 @@ fn convert_randomness(
     if ped_assign_linscalar.coeff != 1 {
         let coeff_tokens = const_i128_tokens(ped_assign_linscalar.coeff);
         generated_code = quote! {
-            #coeff_tokens.inverse().unwrap() * #generated_code
+            <Scalar as Field>::invert(&#coeff_tokens).unwrap() * #generated_code
         };
     }
     // Now multiply by the coeff in new_linscalar, if present
@@ -347,7 +344,7 @@ pub fn transform(
 
         // We'll need a Pedersen commitment to the variable in the range
         // statement.  See if there already is one.
-        let range_id = &range_stmt.expr.id;
+        let range_id = &range_stmt.linscalar.id;
         let ped_assign = if let Some(ped_assign) = pedersens.get(range_id) {
             ped_assign.clone()
         } else {
@@ -367,14 +364,14 @@ pub fn transform(
             let commitment_var = codegen.gen_point(
                 vars,
                 &format_ident!("range{}_{}_genC", range_stmt_index, range_id),
-                false,
-                true,
+                false, // is_vec
+                true,  // send_to_verifier
             );
             let rand_var = codegen.gen_scalar(
                 vars,
                 &format_ident!("range{}_{}_genr", range_stmt_index, range_id),
-                true,
-                false,
+                true,  // is_rand
+                false, // is_vec
             );
 
             // Update vardict and randoms with the new vars
@@ -407,35 +404,212 @@ pub fn transform(
         // code for just the prover that converts the randomness.
 
         // Make a new runtime variable to hold the converted commitment
-        let commitment_var = codegen.gen_point(
-            vars,
-            &format_ident!("range{}_{}_C", range_stmt_index, range_id),
-            false,
-            false,
-        );
-        let rand_var = codegen.gen_scalar(
-            vars,
-            &format_ident!("range{}_{}_r", range_stmt_index, range_id),
-            true,
-            false,
-        );
+        let commitment_var =
+            codegen.gen_ident(&format_ident!("range{}_{}_C", range_stmt_index, range_id));
+        let rand_var =
+            codegen.gen_ident(&format_ident!("range{}_{}_r", range_stmt_index, range_id));
 
         // Update vardict and randoms with the new vars
         vardict = taggedvardict_to_vardict(vars);
         randoms.insert(rand_var.to_string());
 
-        codegen.prove_verify_append(convert_commitment(
+        codegen.verify_append(convert_commitment(
             &commitment_var,
             &ped_assign,
-            &range_stmt.expr,
+            &range_stmt.linscalar,
             &vardict,
         )?);
         codegen.prove_append(convert_randomness(
             &rand_var,
             &ped_assign,
-            &range_stmt.expr,
+            &range_stmt.linscalar,
             &vardict,
         )?);
+
+        // Have both the prover and verifier compute the upper bound of
+        // the range, and generate the bitrep_scalar vector based on
+        // that upper bound.  The key to the range proof is that this
+        // bitrep_scalar vector has the property that you can write a
+        // Scalar x as a sum of (different) elements of this vector if
+        // and only if 0 <= x < upper.  The prover and verifier both
+        // know this vector (it depends only on upper, which is public).
+        // Then the prover will generate private bits that indicate
+        // which elements of the vector add up to x, and output
+        // commitments to those bits, along with proofs that each of
+        // those commitments indeed commits to a bit (0 or 1).  The
+        // verifier will check that the linear combination of the
+        // commitments to those bits with the elements of the
+        // bitrep_scalar vector yields the known commitment to x.
+        //
+        // As a small optimization, the commitment to the first bit
+        // (which always has a bitrep_scalar entry of 1) is not actually
+        // sent; instead of the verifier checking that the linear
+        // combination of the commitments equals the known commitment to
+        // x, it _computes_ the missing commitment to the first bit as
+        // the difference between the known commitment to x and the
+        // linear combination of the remaining commitments.  The prover
+        // still needs to prove that the value committed in that
+        // computed commitment is a bit, but does not need to send the
+        // commitment itself, saving a small bit of communication.
+
+        let upper_var = codegen.gen_ident(&format_ident!(
+            "range{}_{}_upper",
+            range_stmt_index,
+            range_id
+        ));
+        let upper_code = expr_type_tokens(&vardict, &range_stmt.upper)?.1;
+        let bitrep_scalars_var = codegen.gen_ident(&format_ident!(
+            "range{}_{}_bitrep_scalars",
+            range_stmt_index,
+            range_id
+        ));
+        let nbits_var = codegen.gen_ident(&format_ident!(
+            "range{}_{}_nbits",
+            range_stmt_index,
+            range_id
+        ));
+
+        codegen.prove_verify_pre_params_append(quote! {
+            let #upper_var = #upper_code;
+            let #bitrep_scalars_var =
+                sigma_compiler::rangeutils::bitrep_scalars_vartime(#upper_var)?;
+            if #bitrep_scalars_var.is_empty() {
+                // The upper bound was either less than 2, or more than
+                // i128::MAX
+                return Err(SigmaError::VerificationFailure);
+            }
+            let #nbits_var = #bitrep_scalars_var.len();
+        });
+
+        // The prover will compute the bit representation (which
+        // elements of the bitrep_scalars vector add up to x).  This
+        // should be done (in the prover code at runtime) in constant
+        // time.
+        let x_var = codegen.gen_ident(&format_ident!("range{}_{}_var", range_stmt_index, range_id));
+        let bitrep_var = codegen.gen_ident(&format_ident!(
+            "range{}_{}_bitrep",
+            range_stmt_index,
+            range_id
+        ));
+        let x_code = expr_type_tokens(&vardict, &range_stmt.linscalar.to_expr())?.1;
+        codegen.prove_append(quote! {
+            let #x_var = #x_code;
+            let #bitrep_var =
+                sigma_compiler::rangeutils::compute_bitrep(#x_var, &#bitrep_scalars_var);
+        });
+
+        // As mentioned above, we treat the first bit specially.  Make a
+        // vector of commitments to the rest of the bits to send to the
+        // verifier, and also a vector of the committed bits and a
+        // vector of randomnesses for the commitments, both for the
+        // witness, again not putting those for the first bit into the
+        // vectors.  Do make separate witness elements for the committed
+        // first bit and the randomness for it.
+        let bitcomm_var = codegen.gen_point(
+            vars,
+            &format_ident!("range{}_{}_bitC", range_stmt_index, range_id),
+            true, // is_vec
+            true, // send_to_verifier
+        );
+        let bits_var = codegen.gen_scalar(
+            vars,
+            &format_ident!("range{}_{}_bit", range_stmt_index, range_id),
+            false, // is_rand
+            true,  // is_vec
+        );
+        let bitrand_var = codegen.gen_scalar(
+            vars,
+            &format_ident!("range{}_{}_bitrand", range_stmt_index, range_id),
+            true, // is_rand
+            true, // is_vec
+        );
+        let firstbitcomm_var = codegen.gen_point(
+            vars,
+            &format_ident!("range{}_{}_firstbitC", range_stmt_index, range_id),
+            false, // is_vec
+            false, // send_to_verifier
+        );
+        let firstbit_var = codegen.gen_scalar(
+            vars,
+            &format_ident!("range{}_{}_firstbit", range_stmt_index, range_id),
+            false, // is_rand
+            false, // is_vec
+        );
+        let firstbitrand_var = codegen.gen_scalar(
+            vars,
+            &format_ident!("range{}_{}_firstbitrand", range_stmt_index, range_id),
+            true,  // is_rand
+            false, // is_vec
+        );
+
+        // Update vardict and randoms with the new vars
+        vardict = taggedvardict_to_vardict(vars);
+        randoms.insert(bitrand_var.to_string());
+        randoms.insert(firstbitrand_var.to_string());
+
+        // The generators used in the Pedersen commitment
+        let commit_generator = &ped_assign.pedersen.var_term.id;
+        let rand_generator = &ped_assign.pedersen.rand_term.id;
+
+        codegen.verify_pre_params_append(quote! {
+            let mut #bitcomm_var = Vec::<Point>::new();
+            #bitcomm_var.resize(#nbits_var - 1, Point::default());
+        });
+        // The prover code
+        codegen.prove_append(quote! {
+            // Map the bit representation to a vector of Scalar(0) and
+            // Scalar(1), but skip the first bit, as described above.
+            let #bits_var: Vec<Scalar> =
+                #bitrep_var
+                    .iter()
+                    .skip(1)
+                    .map(|b| Scalar::conditional_select(
+                        &Scalar::ZERO,
+                        &Scalar::ONE,
+                        *b,
+                    ))
+                    .collect();
+            // Choose randomizers for the commitments randomly
+            let #bitrand_var: Vec<Scalar> =
+                (0..(#nbits_var-1))
+                    .map(|_| Scalar::random(rng))
+                    .collect();
+            // Compute the commitments
+            let #bitcomm_var: Vec<Point> =
+                (0..(#nbits_var-1))
+                    .map(|i| #bits_var[i] * #commit_generator +
+                        #bitrand_var[i] * #rand_generator)
+                    .collect();
+            // The same as above, for for the first bit
+            let #firstbit_var =
+                Scalar::conditional_select(
+                    &Scalar::ZERO,
+                    &Scalar::ONE,
+                    #bitrep_var[0],
+                );
+            // Compute the randomness that would be needed in the first
+            // bit commitment so that the linear combination of all the
+            // bit commitments (with the scalars in bitrep_scalars) adds
+            // up to commitment_var.
+            let mut #firstbitrand_var = #rand_var;
+            for i in 0..(#nbits_var-1) {
+                #firstbitrand_var -=
+                    #bitrand_var[i] * #bitrep_scalars_var[i+1];
+            }
+            // Compute the first bit commitment
+            let #firstbitcomm_var =
+                #firstbit_var * #commit_generator +
+                #firstbitrand_var * #rand_generator;
+        });
+
+        // The verifier also needs to compute the first commitment
+        codegen.verify_append(quote! {
+            let mut #firstbitcomm_var = #commitment_var;
+            for i in 0..(#nbits_var-1) {
+                #firstbitcomm_var -=
+                    #bitcomm_var[i] * #bitrep_scalars_var[i+1];
+            }
+        });
     }
 
     Ok(())
@@ -471,7 +645,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 100 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: None,
                     id: parse_quote! {x},
@@ -487,7 +661,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 101i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: None,
                     id: parse_quote! {x},
@@ -503,7 +677,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 112i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: Some(parse_quote! { 12i128 }),
                     id: parse_quote! {x},
@@ -519,7 +693,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 1048588i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: Some(parse_quote! { 12i128 }),
                     id: parse_quote! {x},
@@ -535,7 +709,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 1048564i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: Some(parse_quote! { -5i128 }),
                     id: parse_quote! {x},
@@ -551,7 +725,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 1048564i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 2,
                     pub_scalar_expr: Some(parse_quote! { -5i128 }),
                     id: parse_quote! {x},
@@ -567,7 +741,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { 170141183460469231731687303715884105727i128 },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: Some(parse_quote! { 1i128 }),
                     id: parse_quote! {x},
@@ -583,7 +757,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { (((1<<126)-1)*2)-(-2) },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 1,
                     pub_scalar_expr: Some(parse_quote! { 2i128 }),
                     id: parse_quote! {x},
@@ -599,7 +773,7 @@ mod tests {
             },
             Some(RangeStatement {
                 upper: parse_quote! { b+c*c+7-(a*b) },
-                expr: LinScalar {
+                linscalar: LinScalar {
                     coeff: 3,
                     pub_scalar_expr: Some(parse_quote! { c*(a+b+2i128)-(a*b) }),
                     id: parse_quote! {x},
@@ -702,10 +876,10 @@ mod tests {
             parse_quote! { C = 3*x*A + r*B },
             parse_quote! { 2 * x + 12 + a },
             quote! { let out = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).inverse().unwrap() * C +
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * C +
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).inverse().unwrap() * r; },
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128)).unwrap() * r; },
         );
 
         convert_commitment_randomness_tester(
@@ -714,10 +888,10 @@ mod tests {
             parse_quote! { C = -3*x*A + r*B },
             parse_quote! { 2 * x + 12 + a },
             quote! { let out = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() * C +
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * C +
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() * r; },
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r; },
         );
 
         convert_commitment_randomness_tester(
@@ -726,11 +900,11 @@ mod tests {
             parse_quote! { C = (-3*x+4+b)*A + r*B },
             parse_quote! { 2 * x + 12 + a },
             quote! { let out = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() *
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
             (C - (Scalar::from_u128(4u128) + b) * A) +
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() * r; },
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() * r; },
         );
 
         convert_commitment_randomness_tester(
@@ -739,11 +913,11 @@ mod tests {
             parse_quote! { C = (-3*x+4+b)*A + 2*r*B },
             parse_quote! { 2 * x + 12 + a },
             quote! { let out = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() *
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
             (C - (Scalar::from_u128(4u128) + b) * A) +
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() *
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
             (Scalar::from_u128(2u128) * r); },
         );
 
@@ -753,11 +927,11 @@ mod tests {
             parse_quote! { C = (-3*x+4+b)*A + (2*r+c-3)*B },
             parse_quote! { 2 * x + 12 + a },
             quote! { let out = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() *
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
             (C - (Scalar::from_u128(4u128) + b) * A) +
             (Scalar::from_u128(12u128) + a) * A; },
             quote! { let out_rand = Scalar::from_u128(2u128) *
-            Scalar::from_u128(3u128).neg().inverse().unwrap() *
+            <Scalar as Field>::invert(&Scalar::from_u128(3u128).neg()).unwrap() *
             (Scalar::from_u128(2u128) * r +
             (c + (Scalar::from_u128(3u128).neg()))); },
         );