|
@@ -21,10 +21,20 @@ use syn::{Error, Expr};
|
|
|
/// Note that while an individual variable cannot be a private `Point`,
|
|
|
/// it is common to construct an arithmetic expression of that type, for
|
|
|
/// example by multiplying a private `Scalar` by a public `Point`.
|
|
|
+/// In addition, an [`AExprType`] that represents a constant `Scalar`
|
|
|
+/// value (that fits in an [`i128`]) will have that constant value in
|
|
|
+/// `val`.
|
|
|
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
|
|
pub enum AExprType {
|
|
|
- Scalar { is_pub: bool, is_vec: bool },
|
|
|
- Point { is_pub: bool, is_vec: bool },
|
|
|
+ Scalar {
|
|
|
+ is_pub: bool,
|
|
|
+ is_vec: bool,
|
|
|
+ val: Option<i128>,
|
|
|
+ },
|
|
|
+ Point {
|
|
|
+ is_pub: bool,
|
|
|
+ is_vec: bool,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
impl From<&str> for AExprType {
|
|
@@ -43,18 +53,22 @@ impl From<&str> for AExprType {
|
|
|
"Scalar" | "S" => Self::Scalar {
|
|
|
is_pub: false,
|
|
|
is_vec: false,
|
|
|
+ val: None,
|
|
|
},
|
|
|
"pub Scalar" | "pS" => Self::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
+ val: None,
|
|
|
},
|
|
|
"vec Scalar" | "vS" => Self::Scalar {
|
|
|
is_pub: false,
|
|
|
is_vec: true,
|
|
|
+ val: None,
|
|
|
},
|
|
|
"pub vec Scalar" | "pvS" => Self::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: true,
|
|
|
+ val: None,
|
|
|
},
|
|
|
"Point" | "P" => Self::Point {
|
|
|
is_pub: false,
|
|
@@ -103,21 +117,35 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
|
|
|
/// - variables that are in the [`VarDict`]
|
|
|
/// - integer constants
|
|
|
/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
+/// - the operation `<<` where both operands are expressions with no
|
|
|
+/// variables
|
|
|
/// - parens
|
|
|
pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
match expr {
|
|
|
Expr::Lit(syn::ExprLit {
|
|
|
- lit: syn::Lit::Int(_),
|
|
|
+ lit: syn::Lit::Int(litint),
|
|
|
..
|
|
|
}) => Ok(AExprType::Scalar {
|
|
|
is_pub: true,
|
|
|
is_vec: false,
|
|
|
+ val: litint.base10_parse::<i128>().ok(),
|
|
|
}),
|
|
|
Expr::Unary(syn::ExprUnary {
|
|
|
op: syn::UnOp::Neg(_),
|
|
|
expr,
|
|
|
..
|
|
|
- }) => expr_type(vars, expr.as_ref()),
|
|
|
+ }) => match expr_type(vars, expr.as_ref()) {
|
|
|
+ Ok(AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(v),
|
|
|
+ }) => Ok(AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: v.checked_neg(),
|
|
|
+ }),
|
|
|
+ other => other,
|
|
|
+ },
|
|
|
Expr::Paren(syn::ExprParen { expr, .. }) => expr_type(vars, expr.as_ref()),
|
|
|
Expr::Path(syn::ExprPath { path, .. }) => {
|
|
|
if let Some(id) = path.get_ident() {
|
|
@@ -144,15 +172,30 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
AExprType::Scalar {
|
|
|
is_pub: lpub,
|
|
|
is_vec: lvec,
|
|
|
+ val: lval,
|
|
|
},
|
|
|
AExprType::Scalar {
|
|
|
is_pub: rpub,
|
|
|
is_vec: rvec,
|
|
|
+ val: rval,
|
|
|
},
|
|
|
) => {
|
|
|
return Ok(AExprType::Scalar {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
+ val: if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ match op {
|
|
|
+ syn::BinOp::Add(_) => lv.checked_add(rv),
|
|
|
+ syn::BinOp::Sub(_) => lv.checked_sub(rv),
|
|
|
+ // The default match can't
|
|
|
+ // happen because we're already
|
|
|
+ // inside a match on op, but the
|
|
|
+ // compiler requires it anyway
|
|
|
+ _ => None,
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ },
|
|
|
});
|
|
|
}
|
|
|
(
|
|
@@ -190,10 +233,12 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
AExprType::Scalar {
|
|
|
is_pub: lpub,
|
|
|
is_vec: lvec,
|
|
|
+ val: lval,
|
|
|
},
|
|
|
AExprType::Scalar {
|
|
|
is_pub: rpub,
|
|
|
is_vec: rvec,
|
|
|
+ val: rval,
|
|
|
},
|
|
|
) => {
|
|
|
if !lpub && !rpub {
|
|
@@ -205,12 +250,18 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
return Ok(AExprType::Scalar {
|
|
|
is_pub: lpub && rpub,
|
|
|
is_vec: lvec || rvec,
|
|
|
+ val: if let (Some(lv), Some(rv)) = (lval, rval) {
|
|
|
+ lv.checked_mul(rv)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ },
|
|
|
});
|
|
|
}
|
|
|
(
|
|
|
AExprType::Scalar {
|
|
|
is_pub: lpub,
|
|
|
is_vec: lvec,
|
|
|
+ ..
|
|
|
},
|
|
|
AExprType::Point {
|
|
|
is_pub: rpub,
|
|
@@ -225,6 +276,7 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
AExprType::Scalar {
|
|
|
is_pub: rpub,
|
|
|
is_vec: rvec,
|
|
|
+ ..
|
|
|
},
|
|
|
) => {
|
|
|
if !lpub && !rpub {
|
|
@@ -245,6 +297,39 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
"cannot multiply a Point and a Point",
|
|
|
));
|
|
|
}
|
|
|
+ syn::BinOp::Shl(_) => {
|
|
|
+ let lt = expr_type(vars, left.as_ref())?;
|
|
|
+ let rt = expr_type(vars, right.as_ref())?;
|
|
|
+ // You can << only when both operands are constant
|
|
|
+ // Scalar expressions
|
|
|
+ if let (
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(lv),
|
|
|
+ },
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(rv),
|
|
|
+ },
|
|
|
+ ) = (lt, rt)
|
|
|
+ {
|
|
|
+ let rvu32: Option<u32> = rv.try_into().ok();
|
|
|
+ if let Some(shift_amt) = rvu32 {
|
|
|
+ return Ok(AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: lv.checked_shl(shift_amt),
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return Err(Error::new(
|
|
|
+ expr.span(),
|
|
|
+ "can shift left only on constant
|
|
|
+ expressions",
|
|
|
+ ));
|
|
|
+ }
|
|
|
_ => {}
|
|
|
}
|
|
|
Err(Error::new(
|
|
@@ -265,6 +350,17 @@ mod tests {
|
|
|
assert_eq!(expr_type(vars, &expr).unwrap(), AExprType::from(expect));
|
|
|
}
|
|
|
|
|
|
+ fn check_const(vars: &VarDict, expr: Expr, expect: i128) {
|
|
|
+ assert_eq!(
|
|
|
+ expr_type(vars, &expr).unwrap(),
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_pub: true,
|
|
|
+ is_vec: false,
|
|
|
+ val: Some(expect),
|
|
|
+ }
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
fn check_fail(vars: &VarDict, expr: Expr) {
|
|
|
expr_type(vars, &expr).unwrap_err();
|
|
|
}
|
|
@@ -272,9 +368,11 @@ mod tests {
|
|
|
#[test]
|
|
|
fn expr_type_test() {
|
|
|
let vars: VarDict = vardict_from_strs(&[("a", "S"), ("A", "pP"), ("v", "vS")]);
|
|
|
- check(&vars, parse_quote! {2}, "pS");
|
|
|
- check(&vars, parse_quote! {-4}, "pS");
|
|
|
- check(&vars, parse_quote! {(2)}, "pS");
|
|
|
+ check_const(&vars, parse_quote! {2}, 2);
|
|
|
+ check_const(&vars, parse_quote! {-4}, -4);
|
|
|
+ check_const(&vars, parse_quote! {(2)}, 2);
|
|
|
+ check_const(&vars, parse_quote! {1<<20}, 1048576);
|
|
|
+ check_const(&vars, parse_quote! {(3-2)<<(4*5)}, 1048576);
|
|
|
check(&vars, parse_quote! {A}, "pP");
|
|
|
check(&vars, parse_quote! {a*A}, "P");
|
|
|
check(&vars, parse_quote! {A*3}, "pP");
|
|
@@ -296,5 +394,8 @@ mod tests {
|
|
|
// multiplying two private expressions (two ways)
|
|
|
check_fail(&vars, parse_quote! {a*a});
|
|
|
check_fail(&vars, parse_quote! {a*(a*A)});
|
|
|
+ // Shifting non-constant expressions
|
|
|
+ check_fail(&vars, parse_quote! {a<<2});
|
|
|
+ check_fail(&vars, parse_quote! {1<<a});
|
|
|
}
|
|
|
}
|