|
@@ -8,6 +8,8 @@
|
|
|
//! function for determining the type of arithmetic expressions
|
|
|
//! involving such variables.
|
|
|
|
|
|
+use proc_macro2::TokenStream;
|
|
|
+use quote::quote;
|
|
|
use std::collections::HashMap;
|
|
|
use syn::parse::Result;
|
|
|
use syn::spanned::Spanned;
|
|
@@ -109,6 +111,18 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
|
|
|
VarDict::from_iter(c)
|
|
|
}
|
|
|
|
|
|
+/// Given an [`i128`] value, output a [`TokenStream`] representing a
|
|
|
+/// valid Rust expression that evaluates to a `Scalar` having that
|
|
|
+/// value.
|
|
|
+fn const_i128_tokens(val: i128) -> TokenStream {
|
|
|
+ let uval = val.unsigned_abs();
|
|
|
+ if val >= 0 {
|
|
|
+ quote! { Scalar::from_u128(#uval) }
|
|
|
+ } else {
|
|
|
+ quote! { Scalar::from_u128(#uval).neg() }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
/// expression using the variables in the [`VarDict`], compute the
|
|
|
/// [`AExprType`] of the expression.
|
|
@@ -121,36 +135,86 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
|
|
|
/// variables
|
|
|
/// - parens
|
|
|
pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
+ Ok(expr_type_tokens(vars, expr)?.0)
|
|
|
+}
|
|
|
+
|
|
|
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
|
|
|
+/// expression using the variables in the [`VarDict`], compute the
|
|
|
+/// [`AExprType`] of the expression and also a valid Rust
|
|
|
+/// [`TokenStream`] that evaluates the expression.
|
|
|
+///
|
|
|
+/// An arithmetic expression can consist of:
|
|
|
+/// - variables that are in the [`VarDict`]
|
|
|
+/// - integer constants
|
|
|
+/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
+/// - the operation `<<` where both operands are expressions with no
|
|
|
+/// variables
|
|
|
+/// - parens
|
|
|
+pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
|
|
|
match expr {
|
|
|
Expr::Lit(syn::ExprLit {
|
|
|
lit: syn::Lit::Int(litint),
|
|
|
..
|
|
|
- }) => Ok(AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: litint.base10_parse::<i128>().ok(),
|
|
|
- }),
|
|
|
+ }) => {
|
|
|
+ let val = litint.base10_parse::<i128>().ok();
|
|
|
+ if let Some(val_i128) = val {
|
|
|
+ Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val,
|
|
|
+ },
|
|
|
+ const_i128_tokens(val_i128),
|
|
|
+ ))
|
|
|
+ } else {
|
|
|
+ Err(Error::new(expr.span(), "int literal does not fit in i128"))
|
|
|
+ }
|
|
|
+ }
|
|
|
Expr::Unary(syn::ExprUnary {
|
|
|
op: syn::UnOp::Neg(_),
|
|
|
expr,
|
|
|
..
|
|
|
- }) => match expr_type(vars, expr.as_ref()) {
|
|
|
- Ok(AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: Some(v),
|
|
|
- }) => Ok(AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: v.checked_neg(),
|
|
|
- }),
|
|
|
- other => other,
|
|
|
+ }) => match expr_type_tokens(vars, expr.as_ref()) {
|
|
|
+ Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(v),
|
|
|
+ },
|
|
|
+ le,
|
|
|
+ )) => {
|
|
|
+ // If v happens to be i128::MIN, then -v isn't an i128.
|
|
|
+ if let Some(negv) = v.checked_neg() {
|
|
|
+ Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(negv),
|
|
|
+ },
|
|
|
+ const_i128_tokens(negv),
|
|
|
+ ))
|
|
|
+ } else {
|
|
|
+ Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: None,
|
|
|
+ },
|
|
|
+ quote! { -#le },
|
|
|
+ ))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Ok((other, le)) => Ok((other, quote! { -#le })),
|
|
|
+ Err(err) => Err(err),
|
|
|
+ },
|
|
|
+ Expr::Paren(syn::ExprParen { expr, .. }) => match expr_type_tokens(vars, expr.as_ref()) {
|
|
|
+ Ok((aetype, ex)) => Ok((aetype, quote! { (#ex) })),
|
|
|
+ Err(err) => Err(err),
|
|
|
},
|
|
|
- Expr::Paren(syn::ExprParen { expr, .. }) => expr_type(vars, expr.as_ref()),
|
|
|
Expr::Path(syn::ExprPath { path, .. }) => {
|
|
|
if let Some(id) = path.get_ident() {
|
|
|
if let Some(&vt) = vars.get(&id.to_string()) {
|
|
|
- return Ok(vt);
|
|
|
+ return Ok((vt, quote! { #id }));
|
|
|
}
|
|
|
}
|
|
|
Err(Error::new(expr.span(), "not a known variable"))
|
|
@@ -160,8 +224,17 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
}) => {
|
|
|
match op {
|
|
|
syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
|
|
|
- let lt = expr_type(vars, left.as_ref())?;
|
|
|
- let rt = expr_type(vars, right.as_ref())?;
|
|
|
+ let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
|
|
|
+ let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
|
|
|
+ let default_tokens = match op {
|
|
|
+ syn::BinOp::Add(_) => quote! { #le + #re },
|
|
|
+ syn::BinOp::Sub(_) => quote! { #le - #re },
|
|
|
+ // The default match can't happen
|
|
|
+ // because we're already inside a match
|
|
|
+ // on op, but the compiler requires it
|
|
|
+ // anyway
|
|
|
+ _ => quote! {0},
|
|
|
+ };
|
|
|
// You can add or subtract two Scalars or two
|
|
|
// Points, but not a Scalar and a Point. The result
|
|
|
// is public if both arguments are public. The
|
|
@@ -180,23 +253,31 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
val: rval,
|
|
|
},
|
|
|
) => {
|
|
|
- return Ok(AExprType::Scalar {
|
|
|
- is_pub: lpub && rpub,
|
|
|
- is_vec: lvec || rvec,
|
|
|
- val: if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
- match op {
|
|
|
- syn::BinOp::Add(_) => lv.checked_add(rv),
|
|
|
- syn::BinOp::Sub(_) => lv.checked_sub(rv),
|
|
|
- // The default match can't
|
|
|
- // happen because we're already
|
|
|
- // inside a match on op, but the
|
|
|
- // compiler requires it anyway
|
|
|
- _ => None,
|
|
|
- }
|
|
|
+ let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ match op {
|
|
|
+ syn::BinOp::Add(_) => lv.checked_add(rv),
|
|
|
+ syn::BinOp::Sub(_) => lv.checked_sub(rv),
|
|
|
+ // The default match can't
|
|
|
+ // happen because we're already
|
|
|
+ // inside a match on op, but the
|
|
|
+ // compiler requires it anyway
|
|
|
+ _ => None,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ };
|
|
|
+ return Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: lpub && rpub,
|
|
|
+ is_vec: lvec || rvec,
|
|
|
+ val,
|
|
|
+ },
|
|
|
+ if let Some(v) = val {
|
|
|
+ const_i128_tokens(v)
|
|
|
} else {
|
|
|
- None
|
|
|
+ default_tokens
|
|
|
},
|
|
|
- });
|
|
|
+ ));
|
|
|
}
|
|
|
(
|
|
|
AExprType::Point {
|
|
@@ -208,10 +289,13 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
is_vec: rvec,
|
|
|
},
|
|
|
) => {
|
|
|
- return Ok(AExprType::Point {
|
|
|
- is_pub: lpub && rpub,
|
|
|
- is_vec: lvec || rvec,
|
|
|
- });
|
|
|
+ return Ok((
|
|
|
+ AExprType::Point {
|
|
|
+ is_pub: lpub && rpub,
|
|
|
+ is_vec: lvec || rvec,
|
|
|
+ },
|
|
|
+ default_tokens,
|
|
|
+ ));
|
|
|
}
|
|
|
_ => {}
|
|
|
}
|
|
@@ -221,8 +305,9 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
));
|
|
|
}
|
|
|
syn::BinOp::Mul(_) => {
|
|
|
- let lt = expr_type(vars, left.as_ref())?;
|
|
|
- let rt = expr_type(vars, right.as_ref())?;
|
|
|
+ let (lt, le) = expr_type_tokens(vars, left.as_ref())?;
|
|
|
+ let (rt, re) = expr_type_tokens(vars, right.as_ref())?;
|
|
|
+ let default_tokens = quote! { #le * #re };
|
|
|
// You can multiply two Scalars or a Scalar and a
|
|
|
// Point, but not two Points. You can also not
|
|
|
// multiply two private expressions. The result is
|
|
@@ -247,15 +332,23 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
"cannot multiply two private expressions",
|
|
|
));
|
|
|
}
|
|
|
- return Ok(AExprType::Scalar {
|
|
|
- is_pub: lpub && rpub,
|
|
|
- is_vec: lvec || rvec,
|
|
|
- val: if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
- lv.checked_mul(rv)
|
|
|
+ let val = if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ lv.checked_mul(rv)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ };
|
|
|
+ return Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: lpub && rpub,
|
|
|
+ is_vec: lvec || rvec,
|
|
|
+ val,
|
|
|
+ },
|
|
|
+ if let Some(v) = val {
|
|
|
+ const_i128_tokens(v)
|
|
|
} else {
|
|
|
- None
|
|
|
+ default_tokens
|
|
|
},
|
|
|
- });
|
|
|
+ ));
|
|
|
}
|
|
|
(
|
|
|
AExprType::Scalar {
|
|
@@ -285,10 +378,13 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
"cannot multiply two private expressions",
|
|
|
));
|
|
|
}
|
|
|
- return Ok(AExprType::Point {
|
|
|
- is_pub: lpub && rpub,
|
|
|
- is_vec: lvec || rvec,
|
|
|
- });
|
|
|
+ return Ok((
|
|
|
+ AExprType::Point {
|
|
|
+ is_pub: lpub && rpub,
|
|
|
+ is_vec: lvec || rvec,
|
|
|
+ },
|
|
|
+ default_tokens,
|
|
|
+ ));
|
|
|
}
|
|
|
_ => {}
|
|
|
}
|
|
@@ -317,17 +413,21 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
{
|
|
|
let rvu32: Option<u32> = rv.try_into().ok();
|
|
|
if let Some(shift_amt) = rvu32 {
|
|
|
- return Ok(AExprType::Scalar {
|
|
|
- is_pub: true,
|
|
|
- is_vec: false,
|
|
|
- val: lv.checked_shl(shift_amt),
|
|
|
- });
|
|
|
+ if let Some(v) = lv.checked_shl(shift_amt) {
|
|
|
+ return Ok((
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(v),
|
|
|
+ },
|
|
|
+ const_i128_tokens(v),
|
|
|
+ ));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
return Err(Error::new(
|
|
|
expr.span(),
|
|
|
- "can shift left only on constant
|
|
|
- expressions",
|
|
|
+ "can shift left only on constant i128 expressions",
|
|
|
));
|
|
|
}
|
|
|
_ => {}
|
|
@@ -361,6 +461,13 @@ mod tests {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ fn check_tokens(vars: &VarDict, expr: Expr, expect: TokenStream) {
|
|
|
+ assert_eq!(
|
|
|
+ expr_type_tokens(vars, &expr).unwrap().1.to_string(),
|
|
|
+ expect.to_string()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
fn check_fail(vars: &VarDict, expr: Expr) {
|
|
|
expr_type(vars, &expr).unwrap_err();
|
|
|
}
|
|
@@ -378,6 +485,58 @@ mod tests {
|
|
|
check(&vars, parse_quote! {A*3}, "pP");
|
|
|
check(&vars, parse_quote! {(a-1)*(A+A)}, "P");
|
|
|
check(&vars, parse_quote! {(v-1)*(A+A)}, "vP");
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { 0 },
|
|
|
+ quote! { Scalar::from_u128(0u128) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { 5 },
|
|
|
+ quote! { Scalar::from_u128(5u128) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { -77 },
|
|
|
+ quote! { Scalar::from_u128(77u128).neg() },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { 1<<20 },
|
|
|
+ quote! {
|
|
|
+ Scalar::from_u128(1048576u128) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { (3-2)<<(4*5) },
|
|
|
+ quote! {
|
|
|
+ Scalar::from_u128(1048576u128) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { 127<<120 },
|
|
|
+ quote! {
|
|
|
+ Scalar::from_u128(168811955464684315858783496655603761152u128) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { -(-170141183460469231731687303715884105727) },
|
|
|
+ quote! {
|
|
|
+ Scalar::from_u128(170141183460469231731687303715884105727u128) },
|
|
|
+ );
|
|
|
+ // -2^127 fits in an i128, but the negative of that does not
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! { -(-170141183460469231731687303715884105727-1) },
|
|
|
+ quote! {
|
|
|
+ -(Scalar::from_u128(170141183460469231731687303715884105728u128).neg()) },
|
|
|
+ );
|
|
|
+ check_tokens(
|
|
|
+ &vars,
|
|
|
+ parse_quote! {(a-(2-3))*(A+(3*4)*A)},
|
|
|
+ quote! {
|
|
|
+ (a-(Scalar::from_u128(1u128).neg()))*(A+(Scalar::from_u128(12u128))*A) },
|
|
|
+ );
|
|
|
|
|
|
// Tests that should fail
|
|
|
|