|
|
@@ -13,7 +13,7 @@ use quote::quote;
|
|
|
use std::collections::HashMap;
|
|
|
use syn::parse::Result;
|
|
|
use syn::spanned::Spanned;
|
|
|
-use syn::{Error, Expr, Ident};
|
|
|
+use syn::{parse_quote, Error, Expr, Ident};
|
|
|
|
|
|
/// The possible types of an arithmetic expression over `Scalar`s and
|
|
|
/// `Point`s. Each expression has type either
|
|
|
@@ -613,6 +613,79 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
Ok(fold.fold(vars, expr)?.0)
|
|
|
}
|
|
|
|
|
|
+/// Add parentheses around a [`TokenStream`] (which represents an
|
|
|
+/// [arithmetic expression]) if needed.
|
|
|
+///
|
|
|
+/// The parentheses are needed if the [`TokenStream`] would parse as
|
|
|
+/// multiple tokens. For example, `a+b` turns into `(a+b)`, but `c`
|
|
|
+/// remains `c` and `(a+b)` remains `(a+b)`.
|
|
|
+///
|
|
|
+/// [arithmetic expression]: expr_type
|
|
|
+pub fn tokens_paren_if_needed(tok: TokenStream) -> TokenStream {
|
|
|
+ let expr: Expr = parse_quote! { #tok };
|
|
|
+ match expr {
|
|
|
+ Expr::Unary(_) | Expr::Binary(_) => quote! { (#tok) },
|
|
|
+ _ => tok,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
|
|
|
+/// those arguments are vectors or not, construct a [`TokenStream`]
|
|
|
+/// representing their sum.
|
|
|
+pub fn tokens_add_maybe_vec(
|
|
|
+ left: TokenStream,
|
|
|
+ left_is_vec: bool,
|
|
|
+ right: TokenStream,
|
|
|
+ right_is_vec: bool,
|
|
|
+) -> TokenStream {
|
|
|
+ let lp = tokens_paren_if_needed(left);
|
|
|
+ let rp = tokens_paren_if_needed(right);
|
|
|
+ match (left_is_vec, right_is_vec) {
|
|
|
+ (true, true) => quote! { add_vecs(&#lp, &#rp) },
|
|
|
+ (false, true) => quote! { add_nv_vec(&#lp, &#rp) },
|
|
|
+ (true, false) => quote! { add_vec_nv(&#lp, &#rp) },
|
|
|
+ (false, false) => quote! { #lp + #rp },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
|
|
|
+/// those arguments are vectors or not, construct a [`TokenStream`]
|
|
|
+/// representing their difference.
|
|
|
+pub fn tokens_sub_maybe_vec(
|
|
|
+ left: TokenStream,
|
|
|
+ left_is_vec: bool,
|
|
|
+ right: TokenStream,
|
|
|
+ right_is_vec: bool,
|
|
|
+) -> TokenStream {
|
|
|
+ let lp = tokens_paren_if_needed(left);
|
|
|
+ let rp = tokens_paren_if_needed(right);
|
|
|
+ match (left_is_vec, right_is_vec) {
|
|
|
+ (true, true) => quote! { sub_vecs(&#lp, &#rp) },
|
|
|
+ (false, true) => quote! { sub_nv_vec(&#lp, &#rp) },
|
|
|
+ (true, false) => quote! { sub_vec_nv(&#lp, &#rp) },
|
|
|
+ (false, false) => quote! { #lp - #rp },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
|
|
|
+/// those arguments are vectors or not, construct a [`TokenStream`]
|
|
|
+/// representing their product.
|
|
|
+pub fn tokens_mul_maybe_vec(
|
|
|
+ left: TokenStream,
|
|
|
+ left_is_vec: bool,
|
|
|
+ right: TokenStream,
|
|
|
+ right_is_vec: bool,
|
|
|
+) -> TokenStream {
|
|
|
+ let lp = tokens_paren_if_needed(left);
|
|
|
+ let rp = tokens_paren_if_needed(right);
|
|
|
+ match (left_is_vec, right_is_vec) {
|
|
|
+ (true, true) => quote! { mul_vecs(&#lp, &#rp) },
|
|
|
+ (false, true) => quote! { mul_nv_vec(&#lp, &#rp) },
|
|
|
+ (true, false) => quote! { mul_vec_nv(&#lp, &#rp) },
|
|
|
+ (false, false) => quote! { #lp * #rp },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
pub struct AExprTokenFold<'a> {
|
|
|
ident_closure: &'a mut dyn FnMut(&Ident, AExprType) -> Result<TokenStream>,
|
|
|
}
|
|
|
@@ -669,12 +742,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
else {
|
|
|
panic!("Should not happen: non-Scalar passed to add_scalars");
|
|
|
};
|
|
|
- match (l_is_vec, r_is_vec) {
|
|
|
- (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le + #re }),
|
|
|
- }
|
|
|
+ Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
|
|
|
}
|
|
|
|
|
|
/// Called when adding two `Point`s
|
|
|
@@ -698,12 +766,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
else {
|
|
|
panic!("Should not happen: non-Point passed to add_points");
|
|
|
};
|
|
|
- match (l_is_vec, r_is_vec) {
|
|
|
- (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le + #re }),
|
|
|
- }
|
|
|
+ Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
|
|
|
}
|
|
|
|
|
|
/// Called when subtracting two `Scalar`s
|
|
|
@@ -727,12 +790,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
else {
|
|
|
panic!("Should not happen: non-Scalar passed to sub_scalars");
|
|
|
};
|
|
|
- match (l_is_vec, r_is_vec) {
|
|
|
- (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le + (-#re) }),
|
|
|
- }
|
|
|
+ Ok(tokens_sub_maybe_vec(le, l_is_vec, re, r_is_vec))
|
|
|
}
|
|
|
|
|
|
/// Called when subtracting two `Point`s
|
|
|
@@ -756,12 +814,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
else {
|
|
|
panic!("Should not happen: non-Point passed to sub_points");
|
|
|
};
|
|
|
- match (l_is_vec, r_is_vec) {
|
|
|
- (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le - #re }),
|
|
|
- }
|
|
|
+ Ok(tokens_sub_maybe_vec(le, l_is_vec, re, r_is_vec))
|
|
|
}
|
|
|
|
|
|
/// Called when multiplying two `Scalar`s
|
|
|
@@ -794,12 +847,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
} else {
|
|
|
(l_is_vec, larg.1, r_is_vec, rarg.1)
|
|
|
};
|
|
|
- match (lv, rv) {
|
|
|
- (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le * #re }),
|
|
|
- }
|
|
|
+ Ok(tokens_mul_maybe_vec(le, lv, re, rv))
|
|
|
}
|
|
|
|
|
|
/// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
|
|
|
@@ -830,12 +878,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
} else {
|
|
|
(s_is_vec, sarg.1, p_is_vec, parg.1)
|
|
|
};
|
|
|
- match (lv, rv) {
|
|
|
- (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
|
|
|
- (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
|
|
|
- (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
|
|
|
- (false, false) => Ok(quote! { #le * #re }),
|
|
|
- }
|
|
|
+ Ok(tokens_mul_maybe_vec(le, lv, re, rv))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -963,13 +1006,13 @@ mod tests {
|
|
|
&vars,
|
|
|
parse_quote! {(a-(2-3))*(A+(3*4)*A)},
|
|
|
quote! {
|
|
|
- (a+(-(Scalar::from_u128(1u128).neg())))*(A+A*(Scalar::from_u128(12u128))) },
|
|
|
+ (a-(Scalar::from_u128(1u128).neg()))*(A+(A*(Scalar::from_u128(12u128)))) },
|
|
|
);
|
|
|
check_tokens(
|
|
|
&vars,
|
|
|
parse_quote! {(a-(2-3))*(A+A*(3*4))},
|
|
|
quote! {
|
|
|
- (a+(-(Scalar::from_u128(1u128).neg())))*(A+A*(Scalar::from_u128(12u128))) },
|
|
|
+ (a-(Scalar::from_u128(1u128).neg()))*(A+(A*(Scalar::from_u128(12u128)))) },
|
|
|
);
|
|
|
|
|
|
// Tests that should fail
|