Browse Source

Refactor some code to generate TokenStreams of arithmetic possibly involving vectors

Ian Goldberg 3 months ago
parent
commit
f9832a0cb8
1 changed files with 82 additions and 39 deletions
  1. 82 39
      sigma_compiler_core/src/sigma/types.rs

+ 82 - 39
sigma_compiler_core/src/sigma/types.rs

@@ -13,7 +13,7 @@ use quote::quote;
 use std::collections::HashMap;
 use syn::parse::Result;
 use syn::spanned::Spanned;
-use syn::{Error, Expr, Ident};
+use syn::{parse_quote, Error, Expr, Ident};
 
 /// The possible types of an arithmetic expression over `Scalar`s and
 /// `Point`s.  Each expression has type either
@@ -613,6 +613,79 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
     Ok(fold.fold(vars, expr)?.0)
 }
 
+/// Add parentheses around a [`TokenStream`] (which represents an
+/// [arithmetic expression]) if needed.
+///
+/// The parentheses are needed if the [`TokenStream`] would parse as
+/// multiple tokens.  For example, `a+b` turns into `(a+b)`, but `c`
+/// remains `c` and `(a+b)` remains `(a+b)`.
+///
+/// [arithmetic expression]: expr_type
+pub fn tokens_paren_if_needed(tok: TokenStream) -> TokenStream {
+    let expr: Expr = parse_quote! { #tok };
+    match expr {
+        Expr::Unary(_) | Expr::Binary(_) => quote! { (#tok) },
+        _ => tok,
+    }
+}
+
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
+/// those arguments are vectors or not, construct a [`TokenStream`]
+/// representing their sum.
+pub fn tokens_add_maybe_vec(
+    left: TokenStream,
+    left_is_vec: bool,
+    right: TokenStream,
+    right_is_vec: bool,
+) -> TokenStream {
+    let lp = tokens_paren_if_needed(left);
+    let rp = tokens_paren_if_needed(right);
+    match (left_is_vec, right_is_vec) {
+        (true, true) => quote! { add_vecs(&#lp, &#rp) },
+        (false, true) => quote! { add_nv_vec(&#lp, &#rp) },
+        (true, false) => quote! { add_vec_nv(&#lp, &#rp) },
+        (false, false) => quote! { #lp + #rp },
+    }
+}
+
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
+/// those arguments are vectors or not, construct a [`TokenStream`]
+/// representing their difference.
+pub fn tokens_sub_maybe_vec(
+    left: TokenStream,
+    left_is_vec: bool,
+    right: TokenStream,
+    right_is_vec: bool,
+) -> TokenStream {
+    let lp = tokens_paren_if_needed(left);
+    let rp = tokens_paren_if_needed(right);
+    match (left_is_vec, right_is_vec) {
+        (true, true) => quote! { sub_vecs(&#lp, &#rp) },
+        (false, true) => quote! { sub_nv_vec(&#lp, &#rp) },
+        (true, false) => quote! { sub_vec_nv(&#lp, &#rp) },
+        (false, false) => quote! { #lp - #rp },
+    }
+}
+
+/// Given [`TokenStream`]s of two arguments, and [`bool`]s saying whether
+/// those arguments are vectors or not, construct a [`TokenStream`]
+/// representing their product.
+pub fn tokens_mul_maybe_vec(
+    left: TokenStream,
+    left_is_vec: bool,
+    right: TokenStream,
+    right_is_vec: bool,
+) -> TokenStream {
+    let lp = tokens_paren_if_needed(left);
+    let rp = tokens_paren_if_needed(right);
+    match (left_is_vec, right_is_vec) {
+        (true, true) => quote! { mul_vecs(&#lp, &#rp) },
+        (false, true) => quote! { mul_nv_vec(&#lp, &#rp) },
+        (true, false) => quote! { mul_vec_nv(&#lp, &#rp) },
+        (false, false) => quote! { #lp * #rp },
+    }
+}
+
 pub struct AExprTokenFold<'a> {
     ident_closure: &'a mut dyn FnMut(&Ident, AExprType) -> Result<TokenStream>,
 }
@@ -669,12 +742,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         else {
             panic!("Should not happen: non-Scalar passed to add_scalars");
         };
-        match (l_is_vec, r_is_vec) {
-            (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le + #re }),
-        }
+        Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
     }
 
     /// Called when adding two `Point`s
@@ -698,12 +766,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         else {
             panic!("Should not happen: non-Point passed to add_points");
         };
-        match (l_is_vec, r_is_vec) {
-            (true, true) => Ok(quote! { add_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { add_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { add_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le + #re }),
-        }
+        Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
     }
 
     /// Called when subtracting two `Scalar`s
@@ -727,12 +790,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         else {
             panic!("Should not happen: non-Scalar passed to sub_scalars");
         };
-        match (l_is_vec, r_is_vec) {
-            (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le + (-#re) }),
-        }
+        Ok(tokens_sub_maybe_vec(le, l_is_vec, re, r_is_vec))
     }
 
     /// Called when subtracting two `Point`s
@@ -756,12 +814,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         else {
             panic!("Should not happen: non-Point passed to sub_points");
         };
-        match (l_is_vec, r_is_vec) {
-            (true, true) => Ok(quote! { sub_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { sub_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { sub_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le - #re }),
-        }
+        Ok(tokens_sub_maybe_vec(le, l_is_vec, re, r_is_vec))
     }
 
     /// Called when multiplying two `Scalar`s
@@ -794,12 +847,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         } else {
             (l_is_vec, larg.1, r_is_vec, rarg.1)
         };
-        match (lv, rv) {
-            (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le * #re }),
-        }
+        Ok(tokens_mul_maybe_vec(le, lv, re, rv))
     }
 
     /// Called when multiplying a `Scalar` and a `Point` (the `Scalar`
@@ -830,12 +878,7 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
         } else {
             (s_is_vec, sarg.1, p_is_vec, parg.1)
         };
-        match (lv, rv) {
-            (true, true) => Ok(quote! { mul_vecs(&(#le), &(#re)) }),
-            (false, true) => Ok(quote! { mul_nv_vec(&(#le), &(#re)) }),
-            (true, false) => Ok(quote! { mul_vec_nv(&(#le), &(#re)) }),
-            (false, false) => Ok(quote! { #le * #re }),
-        }
+        Ok(tokens_mul_maybe_vec(le, lv, re, rv))
     }
 }
 
@@ -963,13 +1006,13 @@ mod tests {
             &vars,
             parse_quote! {(a-(2-3))*(A+(3*4)*A)},
             quote! {
-            (a+(-(Scalar::from_u128(1u128).neg())))*(A+A*(Scalar::from_u128(12u128))) },
+            (a-(Scalar::from_u128(1u128).neg()))*(A+(A*(Scalar::from_u128(12u128)))) },
         );
         check_tokens(
             &vars,
             parse_quote! {(a-(2-3))*(A+A*(3*4))},
             quote! {
-            (a+(-(Scalar::from_u128(1u128).neg())))*(A+A*(Scalar::from_u128(12u128))) },
+            (a-(Scalar::from_u128(1u128).neg()))*(A+(A*(Scalar::from_u128(12u128)))) },
         );
 
         // Tests that should fail