|
|
@@ -13,7 +13,7 @@ use quote::quote;
|
|
|
use std::collections::HashMap;
|
|
|
use syn::parse::Result;
|
|
|
use syn::spanned::Spanned;
|
|
|
-use syn::{Error, Expr};
|
|
|
+use syn::{Error, Expr, Ident};
|
|
|
|
|
|
/// The possible types of an arithmetic expression over `Scalar`s and
|
|
|
/// `Point`s. Each expression has type either
|
|
|
@@ -123,25 +123,12 @@ fn const_i128_tokens(val: i128) -> TokenStream {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
-/// expression using the variables in the [`VarDict`], compute the
|
|
|
-/// [`AExprType`] of the expression.
|
|
|
+/// A trait for fold-style evaluations over arithmetic expressions.
|
|
|
///
|
|
|
-/// An arithmetic expression can consist of:
|
|
|
-/// - variables that are in the [`VarDict`]
|
|
|
-/// - integer constants
|
|
|
-/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
-/// - the operation `<<` where both operands are expressions with no
|
|
|
-/// variables
|
|
|
-/// - parens
|
|
|
-pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
- Ok(expr_type_tokens(vars, expr)?.0)
|
|
|
-}
|
|
|
-
|
|
|
-/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
-/// expression using the variables in the [`VarDict`], compute the
|
|
|
-/// [`AExprType`] of the expression and also a valid Rust
|
|
|
-/// [`TokenStream`] that evaluates the expression.
|
|
|
+/// The parameter `T` is the type you want to return.
|
|
|
+/// All functions take `(AExprType, T)` for each of the components of
|
|
|
+/// the arithmetic expression node, as well as the [`AExprType`] for the
|
|
|
+/// resulting value.
|
|
|
///
|
|
|
/// An arithmetic expression can consist of:
|
|
|
/// - variables that are in the [`VarDict`]
|
|
|
@@ -150,297 +137,606 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
/// - the operation `<<` where both operands are expressions with no
|
|
|
/// variables
|
|
|
/// - parens
|
|
|
-pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
|
|
|
- match expr {
|
|
|
- Expr::Lit(syn::ExprLit {
|
|
|
- lit: syn::Lit::Int(litint),
|
|
|
- ..
|
|
|
- }) => {
|
|
|
- let val = litint.base10_parse::<i128>().ok();
|
|
|
- if let Some(val_i128) = val {
|
|
|
+pub trait AExprFold<T> {
|
|
|
+ /// Called when an identifier found in the [`VarDict`] is
|
|
|
+ /// encountered in the [`Expr`]
|
|
|
+ fn ident(&mut self, id: &Ident, restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when the arithmetic expression evaluates to a constant
|
|
|
+ /// [`i128`] value.
|
|
|
+ fn const_i128(&mut self, restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called for unary negation
|
|
|
+ fn neg(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called for a parenthesized expression
|
|
|
+ fn paren(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when adding two `Scalar`s
|
|
|
+ fn add_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, T),
|
|
|
+ rarg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when adding two `Point`s
|
|
|
+ fn add_points(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, T),
|
|
|
+ rarg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when subtracting two `Scalar`s
|
|
|
+ fn sub_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, T),
|
|
|
+ rarg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when subtracting two `Point`s
|
|
|
+ fn sub_points(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, T),
|
|
|
+ rarg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when multiplying two `Scalar`s
|
|
|
+ fn mul_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, T),
|
|
|
+ rarg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
|
|
|
+ /// will always be passed as the first argument)
|
|
|
+ fn mul_scalar_point(
|
|
|
+ &mut self,
|
|
|
+ sarg: (AExprType, T),
|
|
|
+ parg: (AExprType, T),
|
|
|
+ restype: AExprType,
|
|
|
+ ) -> Result<T>;
|
|
|
+
|
|
|
+ /// Recursively process an arithmetic expression given by the
|
|
|
+ /// [`Expr`]
|
|
|
+ fn fold(&mut self, vars: &VarDict, expr: &Expr) -> Result<(AExprType, T)> {
|
|
|
+ match expr {
|
|
|
+ Expr::Lit(syn::ExprLit {
|
|
|
+ lit: syn::Lit::Int(litint),
|
|
|
+ ..
|
|
|
+ }) => {
|
|
|
+ let val = litint.base10_parse::<i128>().ok();
|
|
|
+ if val.is_some() {
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val,
|
|
|
+ };
|
|
|
+ let res = self.const_i128(restype)?;
|
|
|
+ Ok((restype, res))
|
|
|
+ } else {
|
|
|
+ Err(Error::new(expr.span(), "int literal does not fit in i128"))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Expr::Unary(syn::ExprUnary {
|
|
|
+ op: syn::UnOp::Neg(_),
|
|
|
+ expr,
|
|
|
+ ..
|
|
|
+ }) => match self.fold(vars, expr.as_ref()) {
|
|
|
Ok((
|
|
|
AExprType::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
- val,
|
|
|
+ val: Some(v),
|
|
|
},
|
|
|
- const_i128_tokens(val_i128),
|
|
|
- ))
|
|
|
- } else {
|
|
|
- Err(Error::new(expr.span(), "int literal does not fit in i128"))
|
|
|
- }
|
|
|
- }
|
|
|
- Expr::Unary(syn::ExprUnary {
|
|
|
- op: syn::UnOp::Neg(_),
|
|
|
- expr,
|
|
|
- ..
|
|
|
- }) => match expr_type_tokens(vars, expr.as_ref()) {
|
|
|
- Ok((
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: Some(v),
|
|
|
- },
|
|
|
- le,
|
|
|
- )) => {
|
|
|
- // If v happens to be i128::MIN, then -v isn't an i128.
|
|
|
- if let Some(negv) = v.checked_neg() {
|
|
|
- Ok((
|
|
|
- AExprType::Scalar {
|
|
|
+ le,
|
|
|
+ )) => {
|
|
|
+ // If v happens to be i128::MIN, then -v isn't an i128.
|
|
|
+ if let Some(negv) = v.checked_neg() {
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
val: Some(negv),
|
|
|
- },
|
|
|
- const_i128_tokens(negv),
|
|
|
- ))
|
|
|
- } else {
|
|
|
- Ok((
|
|
|
- AExprType::Scalar {
|
|
|
+ };
|
|
|
+ let res = self.const_i128(restype)?;
|
|
|
+ Ok((restype, res))
|
|
|
+ } else {
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
val: None,
|
|
|
- },
|
|
|
- quote! { -#le },
|
|
|
- ))
|
|
|
+ };
|
|
|
+ let res = self.neg(
|
|
|
+ (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(v),
|
|
|
+ },
|
|
|
+ le,
|
|
|
+ ),
|
|
|
+ restype,
|
|
|
+ )?;
|
|
|
+ Ok((restype, res))
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
- Ok((other, le)) => Ok((other, quote! { -#le })),
|
|
|
- Err(err) => Err(err),
|
|
|
- },
|
|
|
- Expr::Paren(syn::ExprParen { expr, .. }) => match expr_type_tokens(vars, expr.as_ref()) {
|
|
|
- Ok((aetype, ex)) => Ok((aetype, quote! { (#ex) })),
|
|
|
- Err(err) => Err(err),
|
|
|
- },
|
|
|
- Expr::Path(syn::ExprPath { path, .. }) => {
|
|
|
- if let Some(id) = path.get_ident() {
|
|
|
- if let Some(&vt) = vars.get(&id.to_string()) {
|
|
|
- return Ok((vt, quote! { #id }));
|
|
|
+ Ok((other, le)) => {
|
|
|
+ let res = self.neg((other, le), other)?;
|
|
|
+ Ok((other, res))
|
|
|
+ }
|
|
|
+ Err(err) => Err(err),
|
|
|
+ },
|
|
|
+ Expr::Paren(syn::ExprParen { expr, .. }) => match self.fold(vars, expr.as_ref()) {
|
|
|
+ Ok((aetype, ex)) => {
|
|
|
+ let res = self.paren((aetype, ex), aetype)?;
|
|
|
+ Ok((aetype, res))
|
|
|
}
|
|
|
+ Err(err) => Err(err),
|
|
|
+ },
|
|
|
+ Expr::Path(syn::ExprPath { path, .. }) => {
|
|
|
+ if let Some(id) = path.get_ident() {
|
|
|
+ if let Some(&vt) = vars.get(&id.to_string()) {
|
|
|
+ let res = self.ident(id, vt)?;
|
|
|
+ return Ok((vt, res));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Err(Error::new(expr.span(), "not a known variable"))
|
|
|
}
|
|
|
- Err(Error::new(expr.span(), "not a known variable"))
|
|
|
- }
|
|
|
- Expr::Binary(syn::ExprBinary {
|
|
|
- left, op, right, ..
|
|
|
- }) => {
|
|
|
- match op {
|
|
|
- syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
|
|
|
- let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
|
|
|
- let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
|
|
|
- let default_tokens = match op {
|
|
|
- syn::BinOp::Add(_) => quote! { #le + #re },
|
|
|
- syn::BinOp::Sub(_) => quote! { #le - #re },
|
|
|
- // The default match can't happen
|
|
|
- // because we're already inside a match
|
|
|
- // on op, but the compiler requires it
|
|
|
- // anyway
|
|
|
- _ => quote! {0},
|
|
|
- };
|
|
|
- // You can add or subtract two Scalars or two
|
|
|
- // Points, but not a Scalar and a Point. The result
|
|
|
- // is public if both arguments are public. The
|
|
|
- // result is a vector if either argument is a
|
|
|
- // vector.
|
|
|
- match (lt, rt) {
|
|
|
- (
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: lpub,
|
|
|
- is_vec: lvec,
|
|
|
- val: lval,
|
|
|
- },
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: rpub,
|
|
|
- is_vec: rvec,
|
|
|
- val: rval,
|
|
|
- },
|
|
|
- ) => {
|
|
|
- let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
- match op {
|
|
|
- syn::BinOp::Add(_) => lv.checked_add(rv),
|
|
|
- syn::BinOp::Sub(_) => lv.checked_sub(rv),
|
|
|
- // The default match can't
|
|
|
- // happen because we're already
|
|
|
- // inside a match on op, but the
|
|
|
- // compiler requires it anyway
|
|
|
- _ => None,
|
|
|
- }
|
|
|
- } else {
|
|
|
- None
|
|
|
- };
|
|
|
- return Ok((
|
|
|
+ Expr::Binary(syn::ExprBinary {
|
|
|
+ left, op, right, ..
|
|
|
+ }) => {
|
|
|
+ match op {
|
|
|
+ syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
|
|
|
+ let (lt, le) = self.fold(vars, left.as_ref())?;
|
|
|
+ let (rt, re) = self.fold(vars, right.as_ref())?;
|
|
|
+ let is_add = matches!(op, syn::BinOp::Add(_));
|
|
|
+ // You can add or subtract two Scalars or two
|
|
|
+ // Points, but not a Scalar and a Point. The result
|
|
|
+ // is public if both arguments are public. The
|
|
|
+ // result is a vector if either argument is a
|
|
|
+ // vector.
|
|
|
+ match (lt, rt) {
|
|
|
+ (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: lpub,
|
|
|
+ is_vec: lvec,
|
|
|
+ val: lval,
|
|
|
+ },
|
|
|
AExprType::Scalar {
|
|
|
+ is_pub: rpub,
|
|
|
+ is_vec: rvec,
|
|
|
+ val: rval,
|
|
|
+ },
|
|
|
+ ) => {
|
|
|
+ let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ match op {
|
|
|
+ syn::BinOp::Add(_) => lv.checked_add(rv),
|
|
|
+ syn::BinOp::Sub(_) => lv.checked_sub(rv),
|
|
|
+ // The default match can't
|
|
|
+ // happen because we're already
|
|
|
+ // inside a match on op, but the
|
|
|
+ // compiler requires it anyway
|
|
|
+ _ => None,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ };
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
val,
|
|
|
- },
|
|
|
- if let Some(v) = val {
|
|
|
- const_i128_tokens(v)
|
|
|
+ };
|
|
|
+ let res = if val.is_some() {
|
|
|
+ self.const_i128(restype)?
|
|
|
+ } else if is_add {
|
|
|
+ self.add_scalars((lt, le), (rt, re), restype)?
|
|
|
} else {
|
|
|
- default_tokens
|
|
|
+ self.sub_scalars((lt, le), (rt, re), restype)?
|
|
|
+ };
|
|
|
+ return Ok((restype, res));
|
|
|
+ }
|
|
|
+ (
|
|
|
+ AExprType::Point {
|
|
|
+ is_pub: lpub,
|
|
|
+ is_vec: lvec,
|
|
|
},
|
|
|
- ));
|
|
|
- }
|
|
|
- (
|
|
|
- AExprType::Point {
|
|
|
- is_pub: lpub,
|
|
|
- is_vec: lvec,
|
|
|
- },
|
|
|
- AExprType::Point {
|
|
|
- is_pub: rpub,
|
|
|
- is_vec: rvec,
|
|
|
- },
|
|
|
- ) => {
|
|
|
- return Ok((
|
|
|
AExprType::Point {
|
|
|
+ is_pub: rpub,
|
|
|
+ is_vec: rvec,
|
|
|
+ },
|
|
|
+ ) => {
|
|
|
+ let restype = AExprType::Point {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
- },
|
|
|
- default_tokens,
|
|
|
- ));
|
|
|
+ };
|
|
|
+ let res = if is_add {
|
|
|
+ self.add_points((lt, le), (rt, re), restype)?
|
|
|
+ } else {
|
|
|
+ self.sub_points((lt, le), (rt, re), restype)?
|
|
|
+ };
|
|
|
+ return Ok((restype, res));
|
|
|
+ }
|
|
|
+ _ => {}
|
|
|
}
|
|
|
- _ => {}
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "cannot add/subtract a Scalar and a Point",
|
|
|
+ ));
|
|
|
}
|
|
|
- return Err(Error::new(
|
|
|
- expr.span(),
|
|
|
- "cannot add/subtract a Scalar and a Point",
|
|
|
- ));
|
|
|
- }
|
|
|
- syn::BinOp::Mul(_) => {
|
|
|
- let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
|
|
|
- let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
|
|
|
- let default_tokens = quote! { #le * #re };
|
|
|
- // You can multiply two Scalars or a Scalar and a
|
|
|
- // Point, but not two Points. You can also not
|
|
|
- // multiply two private expressions. The result is
|
|
|
- // public if both arguments are public. The result
|
|
|
- // is a vector if either argument is a vector.
|
|
|
- match (lt, rt) {
|
|
|
- (
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: lpub,
|
|
|
- is_vec: lvec,
|
|
|
- val: lval,
|
|
|
- },
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: rpub,
|
|
|
- is_vec: rvec,
|
|
|
- val: rval,
|
|
|
- },
|
|
|
- ) => {
|
|
|
- if !lpub && !rpub {
|
|
|
- return Err(Error::new(
|
|
|
- expr.span(),
|
|
|
- "cannot multiply two private expressions",
|
|
|
- ));
|
|
|
- }
|
|
|
- let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
- lv.checked_mul(rv)
|
|
|
- } else {
|
|
|
- None
|
|
|
- };
|
|
|
- return Ok((
|
|
|
+ syn::BinOp::Mul(_) => {
|
|
|
+ let (lt, le) = self.fold(vars, left.as_ref())?;
|
|
|
+ let (rt, re) = self.fold(vars, right.as_ref())?;
|
|
|
+ // You can multiply two Scalars or a Scalar and a
|
|
|
+ // Point, but not two Points. You can also not
|
|
|
+ // multiply two private expressions. The result is
|
|
|
+ // public if both arguments are public. The result
|
|
|
+ // is a vector if either argument is a vector.
|
|
|
+ match (lt, rt) {
|
|
|
+ (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: lpub,
|
|
|
+ is_vec: lvec,
|
|
|
+ val: lval,
|
|
|
+ },
|
|
|
AExprType::Scalar {
|
|
|
+ is_pub: rpub,
|
|
|
+ is_vec: rvec,
|
|
|
+ val: rval,
|
|
|
+ },
|
|
|
+ ) => {
|
|
|
+ if !lpub && !rpub {
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "cannot multiply two private expressions",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ lv.checked_mul(rv)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ };
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
val,
|
|
|
- },
|
|
|
- if let Some(v) = val {
|
|
|
- const_i128_tokens(v)
|
|
|
+ };
|
|
|
+ let res = if val.is_some() {
|
|
|
+ self.const_i128(restype)?
|
|
|
} else {
|
|
|
- default_tokens
|
|
|
- },
|
|
|
- ));
|
|
|
- }
|
|
|
- (
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: lpub,
|
|
|
- is_vec: lvec,
|
|
|
- ..
|
|
|
- },
|
|
|
- AExprType::Point {
|
|
|
- is_pub: rpub,
|
|
|
- is_vec: rvec,
|
|
|
- },
|
|
|
- )
|
|
|
- | (
|
|
|
- AExprType::Point {
|
|
|
- is_pub: lpub,
|
|
|
- is_vec: lvec,
|
|
|
- },
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: rpub,
|
|
|
- is_vec: rvec,
|
|
|
- ..
|
|
|
- },
|
|
|
- ) => {
|
|
|
- if !lpub && !rpub {
|
|
|
- return Err(Error::new(
|
|
|
- expr.span(),
|
|
|
- "cannot multiply two private expressions",
|
|
|
- ));
|
|
|
+ self.mul_scalars((lt, le), (rt, re), restype)?
|
|
|
+ };
|
|
|
+ return Ok((restype, res));
|
|
|
}
|
|
|
- return Ok((
|
|
|
+ (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: lpub,
|
|
|
+ is_vec: lvec,
|
|
|
+ ..
|
|
|
+ },
|
|
|
AExprType::Point {
|
|
|
+ is_pub: rpub,
|
|
|
+ is_vec: rvec,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ | (
|
|
|
+ AExprType::Point {
|
|
|
+ is_pub: lpub,
|
|
|
+ is_vec: lvec,
|
|
|
+ },
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: rpub,
|
|
|
+ is_vec: rvec,
|
|
|
+ ..
|
|
|
+ },
|
|
|
+ ) => {
|
|
|
+ if !lpub && !rpub {
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "cannot multiply two private expressions",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ // Whichever order we were passed the
|
|
|
+ // Scalar and the Point, we pass to
|
|
|
+ // mul_scalar_point the Scalar first and
|
|
|
+ // the Point second.
|
|
|
+ let (ste, pte) = if matches!(lt, AExprType::Scalar { .. }) {
|
|
|
+ ((lt, le), (rt, re))
|
|
|
+ } else {
|
|
|
+ ((rt, re), (lt, le))
|
|
|
+ };
|
|
|
+ let restype = AExprType::Point {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
- },
|
|
|
- default_tokens,
|
|
|
- ));
|
|
|
+ };
|
|
|
+ let res = self.mul_scalar_point(ste, pte, restype)?;
|
|
|
+ return Ok((restype, res));
|
|
|
+ }
|
|
|
+ _ => {}
|
|
|
}
|
|
|
- _ => {}
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "cannot multiply a Point and a Point",
|
|
|
+ ));
|
|
|
}
|
|
|
- return Err(Error::new(
|
|
|
- expr.span(),
|
|
|
- "cannot multiply a Point and a Point",
|
|
|
- ));
|
|
|
- }
|
|
|
- syn::BinOp::Shl(_) => {
|
|
|
- let lt = expr_type(vars, left.as_ref())?;
|
|
|
- let rt = expr_type(vars, right.as_ref())?;
|
|
|
- // You can << only when both operands are constant
|
|
|
- // Scalar expressions
|
|
|
- if let (
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: Some(lv),
|
|
|
- },
|
|
|
- AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: Some(rv),
|
|
|
- },
|
|
|
- ) = (lt, rt)
|
|
|
- {
|
|
|
- let rvu32: Option<u32> = rv.try_into().ok();
|
|
|
- if let Some(shift_amt) = rvu32 {
|
|
|
- if let Some(v) = lv.checked_shl(shift_amt) {
|
|
|
- return Ok((
|
|
|
- AExprType::Scalar {
|
|
|
+ syn::BinOp::Shl(_) => {
|
|
|
+ let (lt, _) = self.fold(vars, left.as_ref())?;
|
|
|
+ let (rt, _) = self.fold(vars, right.as_ref())?;
|
|
|
+ // You can << only when both operands are constant
|
|
|
+ // Scalar expressions
|
|
|
+ if let (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(lv),
|
|
|
+ },
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(rv),
|
|
|
+ },
|
|
|
+ ) = (lt, rt)
|
|
|
+ {
|
|
|
+ let rvu32: Option<u32> = rv.try_into().ok();
|
|
|
+ if let Some(shift_amt) = rvu32 {
|
|
|
+ if let Some(v) = lv.checked_shl(shift_amt) {
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
val: Some(v),
|
|
|
- },
|
|
|
- const_i128_tokens(v),
|
|
|
- ));
|
|
|
+ };
|
|
|
+ let res = self.const_i128(restype)?;
|
|
|
+ return Ok((restype, res));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "can shift left only on constant i128 expressions",
|
|
|
+ ));
|
|
|
}
|
|
|
- return Err(Error::new(
|
|
|
- expr.span(),
|
|
|
- "can shift left only on constant i128 expressions",
|
|
|
- ));
|
|
|
+ _ => {}
|
|
|
}
|
|
|
- _ => {}
|
|
|
+ Err(Error::new(
|
|
|
+ op.span(),
|
|
|
+ "invalid operation for arithmetic expression",
|
|
|
+ ))
|
|
|
}
|
|
|
- Err(Error::new(
|
|
|
- op.span(),
|
|
|
- "invalid operation for arithmetic expression",
|
|
|
- ))
|
|
|
+ _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
|
|
|
}
|
|
|
- _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+struct FoldNoOp;
|
|
|
+
|
|
|
+impl AExprFold<()> for FoldNoOp {
|
|
|
+ /// Called when an identifier found in the [`VarDict`] is
|
|
|
+ /// encountered in the [`Expr`]
|
|
|
+ fn ident(&mut self, _id: &Ident, _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when the arithmetic expression evaluates to a constant
|
|
|
+ /// [`i128`] value.
|
|
|
+ fn const_i128(&mut self, _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called for unary negation
|
|
|
+ fn neg(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called for a parenthesized expression
|
|
|
+ fn paren(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when adding two `Scalar`s
|
|
|
+ fn add_scalars(
|
|
|
+ &mut self,
|
|
|
+ _larg: (AExprType, ()),
|
|
|
+ _rarg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when adding two `Point`s
|
|
|
+ fn add_points(
|
|
|
+ &mut self,
|
|
|
+ _larg: (AExprType, ()),
|
|
|
+ _rarg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when subtracting two `Scalar`s
|
|
|
+ fn sub_scalars(
|
|
|
+ &mut self,
|
|
|
+ _larg: (AExprType, ()),
|
|
|
+ _rarg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when subtracting two `Point`s
|
|
|
+ fn sub_points(
|
|
|
+ &mut self,
|
|
|
+ _larg: (AExprType, ()),
|
|
|
+ _rarg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when multiplying two `Scalar`s
|
|
|
+ fn mul_scalars(
|
|
|
+ &mut self,
|
|
|
+ _larg: (AExprType, ()),
|
|
|
+ _rarg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
|
|
|
+ /// will always be passed as the first argument)
|
|
|
+ fn mul_scalar_point(
|
|
|
+ &mut self,
|
|
|
+ _sarg: (AExprType, ()),
|
|
|
+ _parg: (AExprType, ()),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
+/// expression using the variables in the [`VarDict`], compute the
|
|
|
+/// [`AExprType`] of the expression.
|
|
|
+///
|
|
|
+/// An arithmetic expression can consist of:
|
|
|
+/// - variables that are in the [`VarDict`]
|
|
|
+/// - integer constants
|
|
|
+/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
+/// - the operation `<<` where both operands are expressions with no
|
|
|
+/// variables
|
|
|
+/// - parens
|
|
|
+pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
+ let mut fold = FoldNoOp {};
|
|
|
+ Ok(fold.fold(vars, expr)?.0)
|
|
|
+}
|
|
|
+
|
|
|
+pub struct AExprTokenFold;
|
|
|
+
|
|
|
+impl AExprFold<TokenStream> for AExprTokenFold {
|
|
|
+ /// Called when an identifier found in the [`VarDict`] is
|
|
|
+ /// encountered in the [`Expr`]
|
|
|
+ fn ident(&mut self, id: &Ident, _restype: AExprType) -> Result<TokenStream> {
|
|
|
+ Ok(quote! { #id })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when the arithmetic expression evaluates to a constant
|
|
|
+ /// [`i128`] value.
|
|
|
+ fn const_i128(&mut self, restype: AExprType) -> Result<TokenStream> {
|
|
|
+ let AExprType::Scalar { val: Some(v), .. } = restype else {
|
|
|
+ return Err(Error::new(
|
|
|
+ proc_macro2::Span::call_site(),
|
|
|
+ "BUG: it should not happen that const_i128 is called without a value",
|
|
|
+ ));
|
|
|
+ };
|
|
|
+ Ok(const_i128_tokens(v))
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called for unary negation
|
|
|
+ fn neg(&mut self, arg: (AExprType, TokenStream), _restype: AExprType) -> Result<TokenStream> {
|
|
|
+ let ae = arg.1;
|
|
|
+ Ok(quote! { -#ae })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called for a parenthesized expression
|
|
|
+ fn paren(&mut self, arg: (AExprType, TokenStream), _restype: AExprType) -> Result<TokenStream> {
|
|
|
+ let ae = arg.1;
|
|
|
+ Ok(quote! { (#ae) })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when adding two `Scalar`s
|
|
|
+ fn add_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, TokenStream),
|
|
|
+ rarg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let le = larg.1;
|
|
|
+ let re = rarg.1;
|
|
|
+ Ok(quote! { #le + #re })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when adding two `Point`s
|
|
|
+ fn add_points(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, TokenStream),
|
|
|
+ rarg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let le = larg.1;
|
|
|
+ let re = rarg.1;
|
|
|
+ Ok(quote! { #le + #re })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when subtracting two `Scalar`s
|
|
|
+ fn sub_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, TokenStream),
|
|
|
+ rarg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let le = larg.1;
|
|
|
+ let re = rarg.1;
|
|
|
+ Ok(quote! { #le - #re })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when subtracting two `Point`s
|
|
|
+ fn sub_points(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, TokenStream),
|
|
|
+ rarg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let le = larg.1;
|
|
|
+ let re = rarg.1;
|
|
|
+ Ok(quote! { #le - #re })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when multiplying two `Scalar`s
|
|
|
+ fn mul_scalars(
|
|
|
+ &mut self,
|
|
|
+ larg: (AExprType, TokenStream),
|
|
|
+ rarg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let le = larg.1;
|
|
|
+ let re = rarg.1;
|
|
|
+ Ok(quote! { #le * #re })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
|
|
|
+ /// will always be passed as the first argument)
|
|
|
+ fn mul_scalar_point(
|
|
|
+ &mut self,
|
|
|
+ sarg: (AExprType, TokenStream),
|
|
|
+ parg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let se = sarg.1;
|
|
|
+ let pe = parg.1;
|
|
|
+ Ok(quote! { #se * #pe })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
+/// expression using the variables in the [`VarDict`], compute the
|
|
|
+/// [`AExprType`] of the expression and also a valid Rust
|
|
|
+/// [`TokenStream`] that evaluates the expression.
|
|
|
+///
|
|
|
+/// An arithmetic expression can consist of:
|
|
|
+/// - variables that are in the [`VarDict`]
|
|
|
+/// - integer constants
|
|
|
+/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
+/// - the operation `<<` where both operands are expressions with no
|
|
|
+/// variables
|
|
|
+/// - parens
|
|
|
+pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
|
|
|
+ let mut fold = AExprTokenFold {};
|
|
|
+ fold.fold(vars, expr)
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
|
use super::*;
|
|
|
@@ -537,6 +833,12 @@ mod tests {
|
|
|
quote! {
|
|
|
(a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
|
|
|
);
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! {(a-(2-3))*(A+A*(3*4))},
|
|
|
+ quote! {
|
|
|
+ (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
|
|
|
+ );
|
|
|
|
|
|
// Tests that should fail
|
|
|
|