Sfoglia il codice sorgente

Convert arithmetic expressions of integer constants and (Scalar or Point) vars into a valid Rust expression

Doesn't work for vec variables yet.
Ian Goldberg 3 mesi fa
parent
commit
673157e12c

+ 1 - 1
sigma_compiler_core/src/codegen.rs

@@ -337,7 +337,6 @@ impl CodeGen {
         // Output the generated module for this protocol
         let dump_use = if cfg!(feature = "dump") {
             quote! {
-                use ff::PrimeField;
                 use group::GroupEncoding;
             }
         } else {
@@ -346,6 +345,7 @@ impl CodeGen {
         quote! {
             #[allow(non_snake_case)]
             pub mod #proto_name {
+                use ff::PrimeField;
                 #dump_use
 
                 #group_types

+ 218 - 59
sigma_compiler_core/src/sigma/types.rs

@@ -8,6 +8,8 @@
 //! function for determining the type of arithmetic expressions
 //! involving such variables.
 
+use proc_macro2::TokenStream;
+use quote::quote;
 use std::collections::HashMap;
 use syn::parse::Result;
 use syn::spanned::Spanned;
@@ -109,6 +111,18 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
     VarDict::from_iter(c)
 }
 
+/// Given an [`i128`] value, output a [`TokenStream`] representing a
+/// valid Rust expression that evaluates to a `Scalar` having that
+/// value.
+fn const_i128_tokens(val: i128) -> TokenStream {
+    let uval = val.unsigned_abs();
+    if val >= 0 {
+        quote! { Scalar::from_u128(#uval) }
+    } else {
+        quote! { Scalar::from_u128(#uval).neg() }
+    }
+}
+
 /// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
 /// expression using the variables in the [`VarDict`], compute the
 /// [`AExprType`] of the expression.
@@ -121,36 +135,86 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
 ///     variables
 ///   - parens
 pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
+    Ok(expr_type_tokens(vars, expr)?.0)
+}
+
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
+/// expression using the variables in the [`VarDict`], compute the
+/// [`AExprType`] of the expression and also a valid Rust
+/// [`TokenStream`] that evaluates the expression.
+///
+/// An arithmetic expression can consist of:
+///   - variables that are in the [`VarDict`]
+///   - integer constants
+///   - the operations `*`, `+`, `-` (binary or unary)
+///   - the operation `<<` where both operands are expressions with no
+///     variables
+///   - parens
+pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
     match expr {
         Expr::Lit(syn::ExprLit {
             lit: syn::Lit::Int(litint),
             ..
-        }) => Ok(AExprType::Scalar {
-            is_pub: true,
-            is_vec: false,
-            val: litint.base10_parse::<i128>().ok(),
-        }),
+        }) => {
+            let val = litint.base10_parse::<i128>().ok();
+            if let Some(val_i128) = val {
+                Ok((
+                    AExprType::Scalar {
+                        is_pub: true,
+                        is_vec: false,
+                        val,
+                    },
+                    const_i128_tokens(val_i128),
+                ))
+            } else {
+                Err(Error::new(expr.span(), "int literal does not fit in i128"))
+            }
+        }
         Expr::Unary(syn::ExprUnary {
             op: syn::UnOp::Neg(_),
             expr,
             ..
-        }) => match expr_type(vars, expr.as_ref()) {
-            Ok(AExprType::Scalar {
-                is_pub: true,
-                is_vec: false,
-                val: Some(v),
-            }) => Ok(AExprType::Scalar {
-                is_pub: true,
-                is_vec: false,
-                val: v.checked_neg(),
-            }),
-            other => other,
+        }) => match expr_type_tokens(vars, expr.as_ref()) {
+            Ok((
+                AExprType::Scalar {
+                    is_pub: true,
+                    is_vec: false,
+                    val: Some(v),
+                },
+                le,
+            )) => {
+                // If v happens to be i128::MIN, then -v isn't an i128.
+                if let Some(negv) = v.checked_neg() {
+                    Ok((
+                        AExprType::Scalar {
+                            is_pub: true,
+                            is_vec: false,
+                            val: Some(negv),
+                        },
+                        const_i128_tokens(negv),
+                    ))
+                } else {
+                    Ok((
+                        AExprType::Scalar {
+                            is_pub: true,
+                            is_vec: false,
+                            val: None,
+                        },
+                        quote! { -#le },
+                    ))
+                }
+            }
+            Ok((other, le)) => Ok((other, quote! { -#le })),
+            Err(err) => Err(err),
+        },
+        Expr::Paren(syn::ExprParen { expr, .. }) => match expr_type_tokens(vars, expr.as_ref()) {
+            Ok((aetype, ex)) => Ok((aetype, quote! { (#ex) })),
+            Err(err) => Err(err),
         },
-        Expr::Paren(syn::ExprParen { expr, .. }) => expr_type(vars, expr.as_ref()),
         Expr::Path(syn::ExprPath { path, .. }) => {
             if let Some(id) = path.get_ident() {
                 if let Some(&vt) = vars.get(&id.to_string()) {
-                    return Ok(vt);
+                    return Ok((vt, quote! { #id }));
                 }
             }
             Err(Error::new(expr.span(), "not a known variable"))
@@ -160,8 +224,17 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
         }) => {
             match op {
                 syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
-                    let lt = expr_type(vars, left.as_ref())?;
-                    let rt = expr_type(vars, right.as_ref())?;
+                    let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
+                    let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
+                    let default_tokens = match op {
+                        syn::BinOp::Add(_) => quote! { #le + #re },
+                        syn::BinOp::Sub(_) => quote! { #le - #re },
+                        // The default match can't happen
+                        // because we're already inside a match
+                        // on op, but the compiler requires it
+                        // anyway
+                        _ => quote! {0},
+                    };
                     // You can add or subtract two Scalars or two
                     // Points, but not a Scalar and a Point.  The result
                     // is public if both arguments are public.  The
@@ -180,23 +253,31 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                                 val: rval,
                             },
                         ) => {
-                            return Ok(AExprType::Scalar {
-                                is_pub: lpub && rpub,
-                                is_vec: lvec || rvec,
-                                val: if let (Some(lv), Some(rv)) = (lval, rval) {
-                                    match op {
-                                        syn::BinOp::Add(_) => lv.checked_add(rv),
-                                        syn::BinOp::Sub(_) => lv.checked_sub(rv),
-                                        // The default match can't
-                                        // happen because we're already
-                                        // inside a match on op, but the
-                                        // compiler requires it anyway
-                                        _ => None,
-                                    }
+                            let val = if let (Some(lv), Some(rv)) = (lval, rval) {
+                                match op {
+                                    syn::BinOp::Add(_) => lv.checked_add(rv),
+                                    syn::BinOp::Sub(_) => lv.checked_sub(rv),
+                                    // The default match can't
+                                    // happen because we're already
+                                    // inside a match on op, but the
+                                    // compiler requires it anyway
+                                    _ => None,
+                                }
+                            } else {
+                                None
+                            };
+                            return Ok((
+                                AExprType::Scalar {
+                                    is_pub: lpub && rpub,
+                                    is_vec: lvec || rvec,
+                                    val,
+                                },
+                                if let Some(v) = val {
+                                    const_i128_tokens(v)
                                 } else {
-                                    None
+                                    default_tokens
                                 },
-                            });
+                            ));
                         }
                         (
                             AExprType::Point {
@@ -208,10 +289,13 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                                 is_vec: rvec,
                             },
                         ) => {
-                            return Ok(AExprType::Point {
-                                is_pub: lpub && rpub,
-                                is_vec: lvec || rvec,
-                            });
+                            return Ok((
+                                AExprType::Point {
+                                    is_pub: lpub && rpub,
+                                    is_vec: lvec || rvec,
+                                },
+                                default_tokens,
+                            ));
                         }
                         _ => {}
                     }
@@ -221,8 +305,9 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                     ));
                 }
                 syn::BinOp::Mul(_) => {
-                    let lt = expr_type(vars, left.as_ref())?;
-                    let rt = expr_type(vars, right.as_ref())?;
+                    let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
+                    let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
+                    let default_tokens = quote! { #le * #re };
                     // You can multiply two Scalars or a Scalar and a
                     // Point, but not two Points.  You can also not
                     // multiply two private expressions.  The result is
@@ -247,15 +332,23 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                                     "cannot multiply two private expressions",
                                 ));
                             }
-                            return Ok(AExprType::Scalar {
-                                is_pub: lpub && rpub,
-                                is_vec: lvec || rvec,
-                                val: if let (Some(lv), Some(rv)) = (lval, rval) {
-                                    lv.checked_mul(rv)
+                            let val = if let (Some(lv), Some(rv)) = (lval, rval) {
+                                lv.checked_mul(rv)
+                            } else {
+                                None
+                            };
+                            return Ok((
+                                AExprType::Scalar {
+                                    is_pub: lpub && rpub,
+                                    is_vec: lvec || rvec,
+                                    val,
+                                },
+                                if let Some(v) = val {
+                                    const_i128_tokens(v)
                                 } else {
-                                    None
+                                    default_tokens
                                 },
-                            });
+                            ));
                         }
                         (
                             AExprType::Scalar {
@@ -285,10 +378,13 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                                     "cannot multiply two private expressions",
                                 ));
                             }
-                            return Ok(AExprType::Point {
-                                is_pub: lpub && rpub,
-                                is_vec: lvec || rvec,
-                            });
+                            return Ok((
+                                AExprType::Point {
+                                    is_pub: lpub && rpub,
+                                    is_vec: lvec || rvec,
+                                },
+                                default_tokens,
+                            ));
                         }
                         _ => {}
                     }
@@ -317,17 +413,21 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                     {
                         let rvu32: Option<u32> = rv.try_into().ok();
                         if let Some(shift_amt) = rvu32 {
-                            return Ok(AExprType::Scalar {
-                                is_pub: true,
-                                is_vec: false,
-                                val: lv.checked_shl(shift_amt),
-                            });
+                            if let Some(v) = lv.checked_shl(shift_amt) {
+                                return Ok((
+                                    AExprType::Scalar {
+                                        is_pub: true,
+                                        is_vec: false,
+                                        val: Some(v),
+                                    },
+                                    const_i128_tokens(v),
+                                ));
+                            }
                         }
                     }
                     return Err(Error::new(
                         expr.span(),
-                        "can shift left only on constant
-                            expressions",
+                        "can shift left only on constant i128 expressions",
                     ));
                 }
                 _ => {}
@@ -361,6 +461,13 @@ mod tests {
         );
     }
 
+    fn check_tokens(vars: &VarDict, expr: Expr, expect: TokenStream) {
+        assert_eq!(
+            expr_type_tokens(vars, &expr).unwrap().1.to_string(),
+            expect.to_string()
+        );
+    }
+
     fn check_fail(vars: &VarDict, expr: Expr) {
         expr_type(vars, &expr).unwrap_err();
     }
@@ -378,6 +485,58 @@ mod tests {
         check(&vars, parse_quote! {A*3}, "pP");
         check(&vars, parse_quote! {(a-1)*(A+A)}, "P");
         check(&vars, parse_quote! {(v-1)*(A+A)}, "vP");
+        check_tokens(
+            &vars,
+            parse_quote! { 0 },
+            quote! { Scalar::from_u128(0u128) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { 5 },
+            quote! { Scalar::from_u128(5u128) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { -77 },
+            quote! { Scalar::from_u128(77u128).neg() },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { 1<<20 },
+            quote! {
+            Scalar::from_u128(1048576u128) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { (3-2)<<(4*5) },
+            quote! {
+            Scalar::from_u128(1048576u128) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { 127<<120 },
+            quote! {
+            Scalar::from_u128(168811955464684315858783496655603761152u128) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! { -(-170141183460469231731687303715884105727) },
+            quote! {
+            Scalar::from_u128(170141183460469231731687303715884105727u128) },
+        );
+        // -2^127 fits in an i128, but the negative of that does not
+        check_tokens(
+            &vars,
+            parse_quote! { -(-170141183460469231731687303715884105727-1) },
+            quote! {
+            -(Scalar::from_u128(170141183460469231731687303715884105728u128).neg()) },
+        );
+        check_tokens(
+            &vars,
+            parse_quote! {(a-(2-3))*(A+(3*4)*A)},
+            quote! {
+            (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
+        );
 
         // Tests that should fail
 

+ 14 - 8
sigma_compiler_core/src/transform.rs

@@ -5,6 +5,7 @@
 
 use super::codegen::CodeGen;
 use super::sigma::combiners::*;
+use super::sigma::types::expr_type_tokens;
 use super::syntax::taggedvardict_to_vardict;
 use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
 use quote::quote;
@@ -219,6 +220,9 @@ pub fn apply_substitutions(
     st: &mut StatementTree,
     vars: &mut TaggedVarDict,
 ) -> Result<()> {
+    // Construct the VarDict corresponding to vars
+    let vardict = taggedvardict_to_vardict(vars);
+
     // Gather mutable references to all Exprs in the leaves of the
     // StatementTree.  Note that this ignores the combiner structure in
     // the StatementTree, but that's fine.
@@ -251,15 +255,17 @@ pub fn apply_substitutions(
             std::mem::swap(&mut expr, *leafexpr);
             // This "if let" is guaranteed to succeed
             if let Expr::Assign(syn::ExprAssign { right, .. }) = expr {
-                let used_priv_scalars = priv_scalar_set(&right, vars);
-                if !subs_vars.insert(id.to_string()) {
-                    return Err(Error::new(id.span(), "variable substituted multiple times"));
+                if let Ok((_, right_tokens)) = expr_type_tokens(&vardict, &right) {
+                    let used_priv_scalars = priv_scalar_set(&right, vars);
+                    if !subs_vars.insert(id.to_string()) {
+                        return Err(Error::new(id.span(), "variable substituted multiple times"));
+                    }
+                    codegen.prove_append(quote! {
+                        assert!(#id == #right_tokens);
+                    });
+                    let right = paren_if_needed(*right);
+                    subs.push_back((id, right, used_priv_scalars));
                 }
-                let right = paren_if_needed(*right);
-                codegen.prove_append(quote! {
-                    assert!(#id == #right);
-                });
-                subs.push_back((id, right, used_priv_scalars));
             }
         }
     }