Browse Source

Start on range statements

Currently, just a parser for statements that look like

(0..100).contains(x)

with a number of variants.
Ian Goldberg 4 months ago
parent
commit
65afb16310

+ 1 - 0
sigma_compiler_core/src/lib.rs

@@ -10,6 +10,7 @@ pub mod sigma {
 }
 mod codegen;
 mod pedersen;
+mod rangeproof;
 mod syntax;
 mod transform;
 

+ 24 - 12
sigma_compiler_core/src/pedersen.rs

@@ -407,7 +407,9 @@ pub enum PedersenExpr {
     Pedersen(Pedersen),
 }
 
-/// A struct that implements [`AExprFold`] in service of [`recognize`]
+/// A struct that implements [`AExprFold`] in service of
+/// [`recognize_pedersen`], [`recognize_linscalar`], and
+/// [`recognize_pubscalar`]
 struct RecognizeFold<'a> {
     /// The [`TaggedVarDict`] that maps variable names to their types
     vars: &'a TaggedVarDict,
@@ -786,21 +788,23 @@ pub fn recognize_linscalar(
 /// evaluates to a public `Scalar`.
 ///
 /// The returned [`bool`] is true if the expression evaluates to a
-/// vector
+/// vector.  The [`i128`] is the value of the expression if it is a
+/// constant.
 pub fn recognize_pubscalar(
     vars: &TaggedVarDict,
     vardict: &VarDict,
     expr: &Expr,
-) -> Option<bool> {
+) -> Option<(bool, Option<i128>)> {
     let mut fold = RecognizeFold {
         vars,
         randoms: &HashSet::new(),
     };
-    let Ok((AExprType::Scalar{is_vec, ..}, PedersenExpr::PubScalarExpr(_)))
-        = fold.fold(vardict, expr) else {
+    let Ok((AExprType::Scalar { is_vec, val, .. }, PedersenExpr::PubScalarExpr(_))) =
+        fold.fold(vardict, expr)
+    else {
         return None;
     };
-    Some(is_vec)
+    Some((is_vec, val))
 }
 
 #[cfg(test)]
@@ -1364,7 +1368,7 @@ mod test {
     fn recognize_pubscalar_tester(
         vars: (&[&str], &[&str]),
         e: Expr,
-        expected_out: Option<bool>,
+        expected_out: Option<(bool, Option<i128>)>,
     ) {
         let taggedvardict = taggedvardict_from_strs(vars);
         let vardict = taggedvardict_to_vardict(&taggedvardict);
@@ -1376,7 +1380,15 @@ mod test {
     fn recognize_pubscalar_test() {
         let vars = (
             [
-                "x", "y", "z", "pub a", "pub vec b", "pub c", "rand r", "rand s", "rand t",
+                "x",
+                "y",
+                "z",
+                "pub a",
+                "pub vec b",
+                "pub c",
+                "rand r",
+                "rand s",
+                "rand t",
             ]
             .as_slice(),
             ["C", "cind A", "cind B"].as_slice(),
@@ -1395,7 +1407,7 @@ mod test {
             parse_quote! {
                 3
             },
-            Some(false),
+            Some((false, Some(3))),
         );
 
         recognize_pubscalar_tester(
@@ -1403,7 +1415,7 @@ mod test {
             parse_quote! {
                 a
             },
-            Some(false),
+            Some((false, None)),
         );
 
         recognize_pubscalar_tester(
@@ -1411,7 +1423,7 @@ mod test {
             parse_quote! {
                 3*(a + 1)
             },
-            Some(false),
+            Some((false, None)),
         );
 
         recognize_pubscalar_tester(
@@ -1419,7 +1431,7 @@ mod test {
             parse_quote! {
                 3*(a + b)
             },
-            Some(true),
+            Some((true, None)),
         );
     }
 }

+ 360 - 0
sigma_compiler_core/src/rangeproof.rs

@@ -0,0 +1,360 @@
+//! A module to transform range statements about `Scalar`s into
+//! statements about linear combinations of `Point`s.
+//!
+//! A range statement looks like `(a..b).contains(x-8)`, where `a` and
+//! `b` are expressions involving only _public_ `Scalar`s and constants
+//! and `x-8` is a private `Scalar`, possibly offset by a public
+//! `Scalar` or constant.  At this time, none of the variables can be
+//! vector variables.
+//!
+//! As usual for Rust notation, the range `a..b` includes `a` but
+//! _excludes_ `b`.  You can also write `a..=b` to include both
+//! endpoints.  It is allowed for the range to "wrap around" 0, so
+//! that `L-50..100` is a valid range, and equivalent to `-50..100`,
+//! where `L` is the order of the group you are using.
+//!
+//! The size of the range (`b-a`) will be known at run time, but not
+//! necessarily at compile time.  The size must fit in an [`i128`].
+//! Note that the range (and its size) are public, but the value you
+//! are stating is in the range will be private.
+
+use super::codegen::CodeGen;
+use super::pedersen::{recognize_linscalar, recognize_pubscalar, LinScalar};
+use super::sigma::combiners::*;
+use super::sigma::types::VarDict;
+use super::syntax::taggedvardict_to_vardict;
+use super::transform::paren_if_needed;
+use super::TaggedVarDict;
+use syn::{parse_quote, Expr, Result};
+
+/// A struct representing a normalized parsed range statement.
+///
+/// Here, "normalized" means that the range is adjusted so that the
+/// lower bound is 0.  This is accomplished by subtracting the stated
+/// lower bound from both the upper bound and the expression that is
+/// being asserting that it is in the range.
+#[derive(Clone, Debug, PartialEq, Eq)]
+struct RangeStatement {
+    /// The upper bound of the range (exclusive).  This must evaluate to
+    /// a public Scalar.
+    upper: Expr,
+    /// The expression that is being asserted that it is in the range.
+    /// This must be a [`LinScalar`]
+    expr: LinScalar,
+}
+
+/// Subtract the Expr `lower` (with constant value `lowerval`, if
+/// present) from the Expr `expr` (with constant value `exprval`, if
+/// present).  Return the resulting expression, as well as its constant
+/// value, if there is one.  Do the subtraction numerically if possible,
+/// but otherwise symbolically.
+fn subtract_expr(
+    expr: Option<&Expr>,
+    exprval: Option<i128>,
+    lower: &Expr,
+    lowerval: Option<i128>,
+) -> (Expr, Option<i128>) {
+    // Note that if expr is None, then exprval is Some(0)
+    if let (Some(ev), Some(lv)) = (exprval, lowerval) {
+        if let Some(diffv) = ev.checked_sub(lv) {
+            // We can do the subtraction numerically
+            return (parse_quote! { #diffv }, Some(diffv));
+        }
+    }
+    let paren_lower = paren_if_needed(lower.clone());
+    // Return the difference symbolically
+    (
+        if let Some(e) = expr {
+            parse_quote! { #e - #paren_lower }
+        } else {
+            parse_quote! { -#paren_lower }
+        },
+        None,
+    )
+}
+
+/// Try to parse the given `Expr` as a range statement
+fn parse(vars: &TaggedVarDict, vardict: &VarDict, expr: &Expr) -> Option<RangeStatement> {
+    // The expression needs to be of the form
+    // (lower..upper).contains(expr)
+    // The "top level" must be the method call ".contains"
+    if let Expr::MethodCall(syn::ExprMethodCall {
+        receiver,
+        method,
+        turbofish: None,
+        args,
+        ..
+    }) = expr
+    {
+        if &method.to_string() != "contains" {
+            // Wasn't ".contains"
+            return None;
+        }
+        // Remove parens around the range, if present
+        let mut range_expr = receiver.as_ref();
+        if let Expr::Paren(syn::ExprParen {
+            expr: parened_expr, ..
+        }) = range_expr
+        {
+            range_expr = parened_expr;
+        }
+        // Parse the range
+        if let Expr::Range(syn::ExprRange {
+            start, limits, end, ..
+        }) = range_expr
+        {
+            // The endpoints of the range need to be non-vector public
+            // Scalar expressions
+            // The first as_ref() turns &Option<Box<Expr>> into
+            // Option<&Box<Expr>>.  The ? removes the Option, and the
+            // second as_ref() turns &Box<Expr> into &Expr.
+            let lower = start.as_ref()?.as_ref().clone();
+            let mut upper = end.as_ref()?.as_ref().clone();
+            let Some((false, lowerval)) = recognize_pubscalar(vars, vardict, &lower) else {
+                return None;
+            };
+            let Some((false, mut upperval)) = recognize_pubscalar(vars, vardict, &upper) else {
+                return None;
+            };
+            let inclusive_upper = matches!(limits, syn::RangeLimits::Closed(_));
+            // There needs to be exactly one argument of .contains()
+            if args.len() != 1 {
+                return None;
+            }
+            // The private expression needs to be a LinScalar
+            let priv_expr = args.first().unwrap();
+            let mut linscalar = recognize_linscalar(vars, vardict, priv_expr)?;
+            // It is.  See if the pub_scalar_expr in the LinScalar has a
+            // constant value
+            let linscalar_pubscalar_val = if let Some(ref pse) = linscalar.pub_scalar_expr {
+                let Some((false, pubscalar_val)) = recognize_pubscalar(vars, vardict, pse) else {
+                    return None;
+                };
+                pubscalar_val
+            } else {
+                Some(0)
+            };
+
+            // We have a valid range statement.  Normalize it by forcing
+            // the upper bound to be exclusive, and the lower bound to
+            // be 0.
+
+            // If the range was inclusive of the upper bound (e.g.,
+            // `0..=100`), add 1 to the upper bound to make it exclusive
+            // (e.g., `0..101`).
+            if inclusive_upper {
+                // Add 1 to the upper bound, numerically if possible,
+                // but otherwise symbolically
+                let mut added_numerically = false;
+                if let Some(uv) = upperval {
+                    if let Some(new_uv) = uv.checked_add(1) {
+                        upper = parse_quote! { #new_uv };
+                        upperval = Some(new_uv);
+                        added_numerically = true;
+                    }
+                }
+                if !added_numerically {
+                    upper = parse_quote! { #upper + 1 };
+                    upperval = None;
+                }
+            }
+
+            // If the lower bound is not 0, subtract it from both the
+            // upper bound and the pubscalar_expr in the LinScalar.  Do
+            // this numericaly if possibly, but otherwise symbolically.
+            if lowerval != Some(0) {
+                (upper, _) = subtract_expr(Some(&upper), upperval, &lower, lowerval);
+                let pubscalar_expr;
+                (pubscalar_expr, _) = subtract_expr(
+                    linscalar.pub_scalar_expr.as_ref(),
+                    linscalar_pubscalar_val,
+                    &lower,
+                    lowerval,
+                );
+                linscalar.pub_scalar_expr = Some(pubscalar_expr);
+            }
+
+            return Some(RangeStatement {
+                upper,
+                expr: linscalar,
+            });
+        }
+    }
+    None
+}
+
+/// Look for, and transform, range statements specified in the
+/// [`StatementTree`] into basic statements about linear combinations of
+/// `Point`s.
+pub fn transform(
+    codegen: &mut CodeGen,
+    st: &mut StatementTree,
+    vars: &mut TaggedVarDict,
+) -> Result<()> {
+    // Make the VarDict version of the variable dictionary
+    let vardict = taggedvardict_to_vardict(vars);
+
+    // Gather mutable references to all Exprs in the leaves of the
+    // StatementTree.  Note that this ignores the combiner structure in
+    // the StatementTree, but that's fine.
+    let mut leaves = st.leaves_mut();
+
+    // For each leaf expression, see if it looks like a range statement
+    for leafexpr in leaves.iter_mut() {
+        let is_range = parse(vars, &vardict, leafexpr);
+    }
+
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::super::syntax::taggedvardict_from_strs;
+    use super::*;
+
+    fn parse_tester(vars: (&[&str], &[&str]), expr: Expr, expect: Option<RangeStatement>) {
+        let taggedvardict = taggedvardict_from_strs(vars);
+        let vardict = taggedvardict_to_vardict(&taggedvardict);
+        let output = parse(&taggedvardict, &vardict, &expr);
+        assert_eq!(output, expect);
+    }
+
+    #[test]
+    fn parse_test() {
+        let vars = (
+            [
+                "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
+            ]
+            .as_slice(),
+            ["C", "cind A", "cind B"].as_slice(),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (0..100).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 100 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: None,
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (0..=100).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 101i128 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: None,
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (-12..100).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 112i128 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: Some(parse_quote! { 12i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (-12..(1<<20)).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 1048588i128 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: Some(parse_quote! { 12i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (12..(1<<20)).contains(x+7)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 1048564i128 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: Some(parse_quote! { -5i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (12..(1<<20)).contains(2*x+7)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 1048564i128 },
+                expr: LinScalar {
+                    coeff: 2,
+                    pub_scalar_expr: Some(parse_quote! { -5i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (-1..(((1<<126)-1)*2)).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { 170141183460469231731687303715884105727i128 },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: Some(parse_quote! { 1i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+
+        parse_tester(
+            vars,
+            parse_quote! {
+                (-2..(((1<<126)-1)*2)).contains(x)
+            },
+            Some(RangeStatement {
+                upper: parse_quote! { (((1<<126)-1)*2)-(-2) },
+                expr: LinScalar {
+                    coeff: 1,
+                    pub_scalar_expr: Some(parse_quote! { 2i128 }),
+                    id: parse_quote! {x},
+                    is_vec: false,
+                },
+            }),
+        );
+    }
+}