소스 검색

Abstract arithmetic expression parsing into an AExprFold trait

Ian Goldberg 8 달 전
부모
커밋
f0a54a8538
1개의 변경된 파일564개의 추가작업 그리고 262개의 파일을 삭제
  1. 564 262
      sigma_compiler_core/src/sigma/types.rs

+ 564 - 262
sigma_compiler_core/src/sigma/types.rs

@@ -13,7 +13,7 @@ use quote::quote;
 use std::collections::HashMap;
 use syn::parse::Result;
 use syn::spanned::Spanned;
-use syn::{Error, Expr};
+use syn::{Error, Expr, Ident};
 
 /// The possible types of an arithmetic expression over `Scalar`s and
 /// `Point`s.  Each expression has type either
@@ -123,25 +123,12 @@ fn const_i128_tokens(val: i128) -> TokenStream {
     }
 }
 
-/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
-/// expression using the variables in the [`VarDict`], compute the
-/// [`AExprType`] of the expression.
+/// A trait for fold-style evaluations over arithmetic expressions.
 ///
-/// An arithmetic expression can consist of:
-///   - variables that are in the [`VarDict`]
-///   - integer constants
-///   - the operations `*`, `+`, `-` (binary or unary)
-///   - the operation `<<` where both operands are expressions with no
-///     variables
-///   - parens
-pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
-    Ok(expr_type_tokens(vars, expr)?.0)
-}
-
-/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
-/// expression using the variables in the [`VarDict`], compute the
-/// [`AExprType`] of the expression and also a valid Rust
-/// [`TokenStream`] that evaluates the expression.
+/// The parameter `T` is the type you want to return.
+/// All functions take `(AExprType, T)` for each of the components of
+/// the arithmetic expression node, as well as the [`AExprType`] for the
+/// resulting value.
 ///
 /// An arithmetic expression can consist of:
 ///   - variables that are in the [`VarDict`]
@@ -150,297 +137,606 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
 ///   - the operation `<<` where both operands are expressions with no
 ///     variables
 ///   - parens
-pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
-    match expr {
-        Expr::Lit(syn::ExprLit {
-            lit: syn::Lit::Int(litint),
-            ..
-        }) => {
-            let val = litint.base10_parse::<i128>().ok();
-            if let Some(val_i128) = val {
+pub trait AExprFold<T> {
+    /// Called when an identifier found in the [`VarDict`] is
+    /// encountered in the [`Expr`]
+    fn ident(&mut self, id: &Ident, restype: AExprType) -> Result<T>;
+
+    /// Called when the arithmetic expression evaluates to a constant
+    /// [`i128`] value.
+    fn const_i128(&mut self, restype: AExprType) -> Result<T>;
+
+    /// Called for unary negation
+    fn neg(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
+
+    /// Called for a parenthesized expression
+    fn paren(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
+
+    /// Called when adding two `Scalar`s
+    fn add_scalars(
+        &mut self,
+        larg: (AExprType, T),
+        rarg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Called when adding two `Point`s
+    fn add_points(
+        &mut self,
+        larg: (AExprType, T),
+        rarg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Called when subtracting two `Scalar`s
+    fn sub_scalars(
+        &mut self,
+        larg: (AExprType, T),
+        rarg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Called when subtracting two `Point`s
+    fn sub_points(
+        &mut self,
+        larg: (AExprType, T),
+        rarg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Called when multiplying two `Scalar`s
+    fn mul_scalars(
+        &mut self,
+        larg: (AExprType, T),
+        rarg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
+    /// will always be passed as the first argument)
+    fn mul_scalar_point(
+        &mut self,
+        sarg: (AExprType, T),
+        parg: (AExprType, T),
+        restype: AExprType,
+    ) -> Result<T>;
+
+    /// Recursively process an arithmetic expression given by the
+    /// [`Expr`]
+    fn fold(&mut self, vars: &VarDict, expr: &Expr) -> Result<(AExprType, T)> {
+        match expr {
+            Expr::Lit(syn::ExprLit {
+                lit: syn::Lit::Int(litint),
+                ..
+            }) => {
+                let val = litint.base10_parse::<i128>().ok();
+                if val.is_some() {
+                    let restype = AExprType::Scalar {
+                        is_pub: true,
+                        is_vec: false,
+                        val,
+                    };
+                    let res = self.const_i128(restype)?;
+                    Ok((restype, res))
+                } else {
+                    Err(Error::new(expr.span(), "int literal does not fit in i128"))
+                }
+            }
+            Expr::Unary(syn::ExprUnary {
+                op: syn::UnOp::Neg(_),
+                expr,
+                ..
+            }) => match self.fold(vars, expr.as_ref()) {
                 Ok((
                     AExprType::Scalar {
                         is_pub: true,
                         is_vec: false,
-                        val,
+                        val: Some(v),
                     },
-                    const_i128_tokens(val_i128),
-                ))
-            } else {
-                Err(Error::new(expr.span(), "int literal does not fit in i128"))
-            }
-        }
-        Expr::Unary(syn::ExprUnary {
-            op: syn::UnOp::Neg(_),
-            expr,
-            ..
-        }) => match expr_type_tokens(vars, expr.as_ref()) {
-            Ok((
-                AExprType::Scalar {
-                    is_pub: true,
-                    is_vec: false,
-                    val: Some(v),
-                },
-                le,
-            )) => {
-                // If v happens to be i128::MIN, then -v isn't an i128.
-                if let Some(negv) = v.checked_neg() {
-                    Ok((
-                        AExprType::Scalar {
+                    le,
+                )) => {
+                    // If v happens to be i128::MIN, then -v isn't an i128.
+                    if let Some(negv) = v.checked_neg() {
+                        let restype = AExprType::Scalar {
                             is_pub: true,
                             is_vec: false,
                             val: Some(negv),
-                        },
-                        const_i128_tokens(negv),
-                    ))
-                } else {
-                    Ok((
-                        AExprType::Scalar {
+                        };
+                        let res = self.const_i128(restype)?;
+                        Ok((restype, res))
+                    } else {
+                        let restype = AExprType::Scalar {
                             is_pub: true,
                             is_vec: false,
                             val: None,
-                        },
-                        quote! { -#le },
-                    ))
+                        };
+                        let res = self.neg(
+                            (
+                                AExprType::Scalar {
+                                    is_pub: true,
+                                    is_vec: false,
+                                    val: Some(v),
+                                },
+                                le,
+                            ),
+                            restype,
+                        )?;
+                        Ok((restype, res))
+                    }
                 }
-            }
-            Ok((other, le)) => Ok((other, quote! { -#le })),
-            Err(err) => Err(err),
-        },
-        Expr::Paren(syn::ExprParen { expr, .. }) => match expr_type_tokens(vars, expr.as_ref()) {
-            Ok((aetype, ex)) => Ok((aetype, quote! { (#ex) })),
-            Err(err) => Err(err),
-        },
-        Expr::Path(syn::ExprPath { path, .. }) => {
-            if let Some(id) = path.get_ident() {
-                if let Some(&vt) = vars.get(&id.to_string()) {
-                    return Ok((vt, quote! { #id }));
+                Ok((other, le)) => {
+                    let res = self.neg((other, le), other)?;
+                    Ok((other, res))
+                }
+                Err(err) => Err(err),
+            },
+            Expr::Paren(syn::ExprParen { expr, .. }) => match self.fold(vars, expr.as_ref()) {
+                Ok((aetype, ex)) => {
+                    let res = self.paren((aetype, ex), aetype)?;
+                    Ok((aetype, res))
                 }
+                Err(err) => Err(err),
+            },
+            Expr::Path(syn::ExprPath { path, .. }) => {
+                if let Some(id) = path.get_ident() {
+                    if let Some(&vt) = vars.get(&id.to_string()) {
+                        let res = self.ident(id, vt)?;
+                        return Ok((vt, res));
+                    }
+                }
+                Err(Error::new(expr.span(), "not a known variable"))
             }
-            Err(Error::new(expr.span(), "not a known variable"))
-        }
-        Expr::Binary(syn::ExprBinary {
-            left, op, right, ..
-        }) => {
-            match op {
-                syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
-                    let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
-                    let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
-                    let default_tokens = match op {
-                        syn::BinOp::Add(_) => quote! { #le + #re },
-                        syn::BinOp::Sub(_) => quote! { #le - #re },
-                        // The default match can't happen
-                        // because we're already inside a match
-                        // on op, but the compiler requires it
-                        // anyway
-                        _ => quote! {0},
-                    };
-                    // You can add or subtract two Scalars or two
-                    // Points, but not a Scalar and a Point.  The result
-                    // is public if both arguments are public.  The
-                    // result is a vector if either argument is a
-                    // vector.
-                    match (lt, rt) {
-                        (
-                            AExprType::Scalar {
-                                is_pub: lpub,
-                                is_vec: lvec,
-                                val: lval,
-                            },
-                            AExprType::Scalar {
-                                is_pub: rpub,
-                                is_vec: rvec,
-                                val: rval,
-                            },
-                        ) => {
-                            let val = if let (Some(lv), Some(rv)) = (lval, rval) {
-                                match op {
-                                    syn::BinOp::Add(_) => lv.checked_add(rv),
-                                    syn::BinOp::Sub(_) => lv.checked_sub(rv),
-                                    // The default match can't
-                                    // happen because we're already
-                                    // inside a match on op, but the
-                                    // compiler requires it anyway
-                                    _ => None,
-                                }
-                            } else {
-                                None
-                            };
-                            return Ok((
+            Expr::Binary(syn::ExprBinary {
+                left, op, right, ..
+            }) => {
+                match op {
+                    syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
+                        let (lt, le) = self.fold(vars, left.as_ref())?;
+                        let (rt, re) = self.fold(vars, right.as_ref())?;
+                        let is_add = matches!(op, syn::BinOp::Add(_));
+                        // You can add or subtract two Scalars or two
+                        // Points, but not a Scalar and a Point.  The result
+                        // is public if both arguments are public.  The
+                        // result is a vector if either argument is a
+                        // vector.
+                        match (lt, rt) {
+                            (
+                                AExprType::Scalar {
+                                    is_pub: lpub,
+                                    is_vec: lvec,
+                                    val: lval,
+                                },
                                 AExprType::Scalar {
+                                    is_pub: rpub,
+                                    is_vec: rvec,
+                                    val: rval,
+                                },
+                            ) => {
+                                let val = if let (Some(lv), Some(rv)) = (lval, rval) {
+                                    match op {
+                                        syn::BinOp::Add(_) => lv.checked_add(rv),
+                                        syn::BinOp::Sub(_) => lv.checked_sub(rv),
+                                        // The default match can't
+                                        // happen because we're already
+                                        // inside a match on op, but the
+                                        // compiler requires it anyway
+                                        _ => None,
+                                    }
+                                } else {
+                                    None
+                                };
+                                let restype = AExprType::Scalar {
                                     is_pub: lpub && rpub,
                                     is_vec: lvec || rvec,
                                     val,
-                                },
-                                if let Some(v) = val {
-                                    const_i128_tokens(v)
+                                };
+                                let res = if val.is_some() {
+                                    self.const_i128(restype)?
+                                } else if is_add {
+                                    self.add_scalars((lt, le), (rt, re), restype)?
                                 } else {
-                                    default_tokens
+                                    self.sub_scalars((lt, le), (rt, re), restype)?
+                                };
+                                return Ok((restype, res));
+                            }
+                            (
+                                AExprType::Point {
+                                    is_pub: lpub,
+                                    is_vec: lvec,
                                 },
-                            ));
-                        }
-                        (
-                            AExprType::Point {
-                                is_pub: lpub,
-                                is_vec: lvec,
-                            },
-                            AExprType::Point {
-                                is_pub: rpub,
-                                is_vec: rvec,
-                            },
-                        ) => {
-                            return Ok((
                                 AExprType::Point {
+                                    is_pub: rpub,
+                                    is_vec: rvec,
+                                },
+                            ) => {
+                                let restype = AExprType::Point {
                                     is_pub: lpub && rpub,
                                     is_vec: lvec || rvec,
-                                },
-                                default_tokens,
-                            ));
+                                };
+                                let res = if is_add {
+                                    self.add_points((lt, le), (rt, re), restype)?
+                                } else {
+                                    self.sub_points((lt, le), (rt, re), restype)?
+                                };
+                                return Ok((restype, res));
+                            }
+                            _ => {}
                         }
-                        _ => {}
+                        return Err(Error::new(
+                            expr.span(),
+                            "cannot add/subtract a Scalar and a Point",
+                        ));
                     }
-                    return Err(Error::new(
-                        expr.span(),
-                        "cannot add/subtract a Scalar and a Point",
-                    ));
-                }
-                syn::BinOp::Mul(_) => {
-                    let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
-                    let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
-                    let default_tokens = quote! { #le * #re };
-                    // You can multiply two Scalars or a Scalar and a
-                    // Point, but not two Points.  You can also not
-                    // multiply two private expressions.  The result is
-                    // public if both arguments are public.  The result
-                    // is a vector if either argument is a vector.
-                    match (lt, rt) {
-                        (
-                            AExprType::Scalar {
-                                is_pub: lpub,
-                                is_vec: lvec,
-                                val: lval,
-                            },
-                            AExprType::Scalar {
-                                is_pub: rpub,
-                                is_vec: rvec,
-                                val: rval,
-                            },
-                        ) => {
-                            if !lpub && !rpub {
-                                return Err(Error::new(
-                                    expr.span(),
-                                    "cannot multiply two private expressions",
-                                ));
-                            }
-                            let val = if let (Some(lv), Some(rv)) = (lval, rval) {
-                                lv.checked_mul(rv)
-                            } else {
-                                None
-                            };
-                            return Ok((
+                    syn::BinOp::Mul(_) => {
+                        let (lt, le) = self.fold(vars, left.as_ref())?;
+                        let (rt, re) = self.fold(vars, right.as_ref())?;
+                        // You can multiply two Scalars or a Scalar and a
+                        // Point, but not two Points.  You can also not
+                        // multiply two private expressions.  The result is
+                        // public if both arguments are public.  The result
+                        // is a vector if either argument is a vector.
+                        match (lt, rt) {
+                            (
+                                AExprType::Scalar {
+                                    is_pub: lpub,
+                                    is_vec: lvec,
+                                    val: lval,
+                                },
                                 AExprType::Scalar {
+                                    is_pub: rpub,
+                                    is_vec: rvec,
+                                    val: rval,
+                                },
+                            ) => {
+                                if !lpub && !rpub {
+                                    return Err(Error::new(
+                                        expr.span(),
+                                        "cannot multiply two private expressions",
+                                    ));
+                                }
+                                let val = if let (Some(lv), Some(rv)) = (lval, rval) {
+                                    lv.checked_mul(rv)
+                                } else {
+                                    None
+                                };
+                                let restype = AExprType::Scalar {
                                     is_pub: lpub && rpub,
                                     is_vec: lvec || rvec,
                                     val,
-                                },
-                                if let Some(v) = val {
-                                    const_i128_tokens(v)
+                                };
+                                let res = if val.is_some() {
+                                    self.const_i128(restype)?
                                 } else {
-                                    default_tokens
-                                },
-                            ));
-                        }
-                        (
-                            AExprType::Scalar {
-                                is_pub: lpub,
-                                is_vec: lvec,
-                                ..
-                            },
-                            AExprType::Point {
-                                is_pub: rpub,
-                                is_vec: rvec,
-                            },
-                        )
-                        | (
-                            AExprType::Point {
-                                is_pub: lpub,
-                                is_vec: lvec,
-                            },
-                            AExprType::Scalar {
-                                is_pub: rpub,
-                                is_vec: rvec,
-                                ..
-                            },
-                        ) => {
-                            if !lpub && !rpub {
-                                return Err(Error::new(
-                                    expr.span(),
-                                    "cannot multiply two private expressions",
-                                ));
+                                    self.mul_scalars((lt, le), (rt, re), restype)?
+                                };
+                                return Ok((restype, res));
                             }
-                            return Ok((
+                            (
+                                AExprType::Scalar {
+                                    is_pub: lpub,
+                                    is_vec: lvec,
+                                    ..
+                                },
                                 AExprType::Point {
+                                    is_pub: rpub,
+                                    is_vec: rvec,
+                                },
+                            )
+                            | (
+                                AExprType::Point {
+                                    is_pub: lpub,
+                                    is_vec: lvec,
+                                },
+                                AExprType::Scalar {
+                                    is_pub: rpub,
+                                    is_vec: rvec,
+                                    ..
+                                },
+                            ) => {
+                                if !lpub && !rpub {
+                                    return Err(Error::new(
+                                        expr.span(),
+                                        "cannot multiply two private expressions",
+                                    ));
+                                }
+                                // Whichever order we were passed the
+                                // Scalar and the Point, we pass to
+                                // mul_scalar_point the Scalar first and
+                                // the Point second.
+                                let (ste, pte) = if matches!(lt, AExprType::Scalar { .. }) {
+                                    ((lt, le), (rt, re))
+                                } else {
+                                    ((rt, re), (lt, le))
+                                };
+                                let restype = AExprType::Point {
                                     is_pub: lpub && rpub,
                                     is_vec: lvec || rvec,
-                                },
-                                default_tokens,
-                            ));
+                                };
+                                let res = self.mul_scalar_point(ste, pte, restype)?;
+                                return Ok((restype, res));
+                            }
+                            _ => {}
                         }
-                        _ => {}
+                        return Err(Error::new(
+                            expr.span(),
+                            "cannot multiply a Point and a Point",
+                        ));
                     }
-                    return Err(Error::new(
-                        expr.span(),
-                        "cannot multiply a Point and a Point",
-                    ));
-                }
-                syn::BinOp::Shl(_) => {
-                    let lt = expr_type(vars, left.as_ref())?;
-                    let rt = expr_type(vars, right.as_ref())?;
-                    // You can << only when both operands are constant
-                    // Scalar expressions
-                    if let (
-                        AExprType::Scalar {
-                            is_pub: true,
-                            is_vec: false,
-                            val: Some(lv),
-                        },
-                        AExprType::Scalar {
-                            is_pub: true,
-                            is_vec: false,
-                            val: Some(rv),
-                        },
-                    ) = (lt, rt)
-                    {
-                        let rvu32: Option<u32> = rv.try_into().ok();
-                        if let Some(shift_amt) = rvu32 {
-                            if let Some(v) = lv.checked_shl(shift_amt) {
-                                return Ok((
-                                    AExprType::Scalar {
+                    syn::BinOp::Shl(_) => {
+                        let (lt, _) = self.fold(vars, left.as_ref())?;
+                        let (rt, _) = self.fold(vars, right.as_ref())?;
+                        // You can << only when both operands are constant
+                        // Scalar expressions
+                        if let (
+                            AExprType::Scalar {
+                                is_pub: true,
+                                is_vec: false,
+                                val: Some(lv),
+                            },
+                            AExprType::Scalar {
+                                is_pub: true,
+                                is_vec: false,
+                                val: Some(rv),
+                            },
+                        ) = (lt, rt)
+                        {
+                            let rvu32: Option<u32> = rv.try_into().ok();
+                            if let Some(shift_amt) = rvu32 {
+                                if let Some(v) = lv.checked_shl(shift_amt) {
+                                    let restype = AExprType::Scalar {
                                         is_pub: true,
                                         is_vec: false,
                                         val: Some(v),
-                                    },
-                                    const_i128_tokens(v),
-                                ));
+                                    };
+                                    let res = self.const_i128(restype)?;
+                                    return Ok((restype, res));
+                                }
                             }
                         }
+                        return Err(Error::new(
+                            expr.span(),
+                            "can shift left only on constant i128 expressions",
+                        ));
                     }
-                    return Err(Error::new(
-                        expr.span(),
-                        "can shift left only on constant i128 expressions",
-                    ));
+                    _ => {}
                 }
-                _ => {}
+                Err(Error::new(
+                    op.span(),
+                    "invalid operation for arithmetic expression",
+                ))
             }
-            Err(Error::new(
-                op.span(),
-                "invalid operation for arithmetic expression",
-            ))
+            _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
         }
-        _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
     }
 }
 
+struct FoldNoOp;
+
+impl AExprFold<()> for FoldNoOp {
+    /// Called when an identifier found in the [`VarDict`] is
+    /// encountered in the [`Expr`]
+    fn ident(&mut self, _id: &Ident, _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when the arithmetic expression evaluates to a constant
+    /// [`i128`] value.
+    fn const_i128(&mut self, _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called for unary negation
+    fn neg(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called for a parenthesized expression
+    fn paren(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when adding two `Scalar`s
+    fn add_scalars(
+        &mut self,
+        _larg: (AExprType, ()),
+        _rarg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when adding two `Point`s
+    fn add_points(
+        &mut self,
+        _larg: (AExprType, ()),
+        _rarg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when subtracting two `Scalar`s
+    fn sub_scalars(
+        &mut self,
+        _larg: (AExprType, ()),
+        _rarg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when subtracting two `Point`s
+    fn sub_points(
+        &mut self,
+        _larg: (AExprType, ()),
+        _rarg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when multiplying two `Scalar`s
+    fn mul_scalars(
+        &mut self,
+        _larg: (AExprType, ()),
+        _rarg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
+    /// will always be passed as the first argument)
+    fn mul_scalar_point(
+        &mut self,
+        _sarg: (AExprType, ()),
+        _parg: (AExprType, ()),
+        _restype: AExprType,
+    ) -> Result<()> {
+        Ok(())
+    }
+}
+
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
+/// expression using the variables in the [`VarDict`], compute the
+/// [`AExprType`] of the expression.
+///
+/// An arithmetic expression can consist of:
+///   - variables that are in the [`VarDict`]
+///   - integer constants
+///   - the operations `*`, `+`, `-` (binary or unary)
+///   - the operation `<<` where both operands are expressions with no
+///     variables
+///   - parens
+pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
+    let mut fold = FoldNoOp {};
+    Ok(fold.fold(vars, expr)?.0)
+}
+
+pub struct AExprTokenFold;
+
+impl AExprFold<TokenStream> for AExprTokenFold {
+    /// Called when an identifier found in the [`VarDict`] is
+    /// encountered in the [`Expr`]
+    fn ident(&mut self, id: &Ident, _restype: AExprType) -> Result<TokenStream> {
+        Ok(quote! { #id })
+    }
+
+    /// Called when the arithmetic expression evaluates to a constant
+    /// [`i128`] value.
+    fn const_i128(&mut self, restype: AExprType) -> Result<TokenStream> {
+        let AExprType::Scalar { val: Some(v), .. } = restype else {
+            return Err(Error::new(
+                proc_macro2::Span::call_site(),
+                "BUG: it should not happen that const_i128 is called without a value",
+            ));
+        };
+        Ok(const_i128_tokens(v))
+    }
+
+    /// Called for unary negation
+    fn neg(&mut self, arg: (AExprType, TokenStream), _restype: AExprType) -> Result<TokenStream> {
+        let ae = arg.1;
+        Ok(quote! { -#ae })
+    }
+
+    /// Called for a parenthesized expression
+    fn paren(&mut self, arg: (AExprType, TokenStream), _restype: AExprType) -> Result<TokenStream> {
+        let ae = arg.1;
+        Ok(quote! { (#ae) })
+    }
+
+    /// Called when adding two `Scalar`s
+    fn add_scalars(
+        &mut self,
+        larg: (AExprType, TokenStream),
+        rarg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let le = larg.1;
+        let re = rarg.1;
+        Ok(quote! { #le + #re })
+    }
+
+    /// Called when adding two `Point`s
+    fn add_points(
+        &mut self,
+        larg: (AExprType, TokenStream),
+        rarg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let le = larg.1;
+        let re = rarg.1;
+        Ok(quote! { #le + #re })
+    }
+
+    /// Called when subtracting two `Scalar`s
+    fn sub_scalars(
+        &mut self,
+        larg: (AExprType, TokenStream),
+        rarg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let le = larg.1;
+        let re = rarg.1;
+        Ok(quote! { #le - #re })
+    }
+
+    /// Called when subtracting two `Point`s
+    fn sub_points(
+        &mut self,
+        larg: (AExprType, TokenStream),
+        rarg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let le = larg.1;
+        let re = rarg.1;
+        Ok(quote! { #le - #re })
+    }
+
+    /// Called when multiplying two `Scalar`s
+    fn mul_scalars(
+        &mut self,
+        larg: (AExprType, TokenStream),
+        rarg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let le = larg.1;
+        let re = rarg.1;
+        Ok(quote! { #le * #re })
+    }
+
+    /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
+    /// will always be passed as the first argument)
+    fn mul_scalar_point(
+        &mut self,
+        sarg: (AExprType, TokenStream),
+        parg: (AExprType, TokenStream),
+        _restype: AExprType,
+    ) -> Result<TokenStream> {
+        let se = sarg.1;
+        let pe = parg.1;
+        Ok(quote! { #se * #pe })
+    }
+}
+
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
+/// expression using the variables in the [`VarDict`], compute the
+/// [`AExprType`] of the expression and also a valid Rust
+/// [`TokenStream`] that evaluates the expression.
+///
+/// An arithmetic expression can consist of:
+///   - variables that are in the [`VarDict`]
+///   - integer constants
+///   - the operations `*`, `+`, `-` (binary or unary)
+///   - the operation `<<` where both operands are expressions with no
+///     variables
+///   - parens
+pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
+    let mut fold = AExprTokenFold {};
+    fold.fold(vars, expr)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -537,6 +833,12 @@ mod tests {
             quote! {
             (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
         );
+        check_tokens(
+            &vars,
+            parse_quote! {(a-(2-3))*(A+A*(3*4))},
+            quote! {
+            (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
+        );
 
         // Tests that should fail