|
@@ -136,6 +136,8 @@ pub fn const_i128_tokens(val: i128) -> TokenStream {
|
|
|
/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
/// - the operation `<<` where both operands are expressions with no
|
|
|
/// variables
|
|
|
+/// - the function `sum` that takes a single vector argument and
|
|
|
+/// returns the sum of its elements
|
|
|
/// - parens
|
|
|
pub trait AExprFold<T> {
|
|
|
/// Called when an identifier found in the [`VarDict`] is
|
|
@@ -168,6 +170,12 @@ pub trait AExprFold<T> {
|
|
|
restype: AExprType,
|
|
|
) -> Result<T>;
|
|
|
|
|
|
+ /// Called when summing a vector of `Scalar`s
|
|
|
+ fn sum_scalars(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
+ /// Called when summing a vector of `Point`s
|
|
|
+ fn sum_points(&mut self, arg: (AExprType, T), restype: AExprType) -> Result<T>;
|
|
|
+
|
|
|
/// Called when subtracting two `Scalar`s
|
|
|
fn sub_scalars(
|
|
|
&mut self,
|
|
@@ -505,6 +513,53 @@ pub trait AExprFold<T> {
|
|
|
"invalid operation for arithmetic expression",
|
|
|
))
|
|
|
}
|
|
|
+ Expr::Call(syn::ExprCall { func, args, .. }) => {
|
|
|
+ let funcname = match func.as_ref() {
|
|
|
+ Expr::Path(syn::ExprPath { path, .. }) => {
|
|
|
+ path.get_ident().map(|id| id.to_string())
|
|
|
+ }
|
|
|
+ _ => None,
|
|
|
+ };
|
|
|
+ match funcname {
|
|
|
+ Some(ref s) if s == "sum" => {
|
|
|
+ if args.len() != 1 {
|
|
|
+ return Err(Error::new(
|
|
|
+ func.span(),
|
|
|
+ "sum must have exactly one argument",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ let (at, ae) = self.fold(vars, args.first().unwrap())?;
|
|
|
+ match at {
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_vec: true,
|
|
|
+ is_pub,
|
|
|
+ ..
|
|
|
+ } => {
|
|
|
+ let restype = AExprType::Scalar {
|
|
|
+ is_vec: false,
|
|
|
+ is_pub,
|
|
|
+ val: None,
|
|
|
+ };
|
|
|
+ let res = self.sum_scalars((at, ae), restype)?;
|
|
|
+ Ok((restype, res))
|
|
|
+ }
|
|
|
+ AExprType::Point {
|
|
|
+ is_vec: true,
|
|
|
+ is_pub,
|
|
|
+ } => {
|
|
|
+ let restype = AExprType::Point {
|
|
|
+ is_vec: false,
|
|
|
+ is_pub,
|
|
|
+ };
|
|
|
+ let res = self.sum_points((at, ae), restype)?;
|
|
|
+ Ok((restype, res))
|
|
|
+ }
|
|
|
+ _ => Err(Error::new(args.span(), "argument to sum must be a vector")),
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ => Err(Error::new(func.span(), "unknown function")),
|
|
|
+ }
|
|
|
+ }
|
|
|
_ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
|
|
|
}
|
|
|
}
|
|
@@ -555,6 +610,16 @@ impl AExprFold<()> for FoldNoOp {
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
+ /// Called when summing a vector of `Scalar`s
|
|
|
+ fn sum_scalars(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when summing a vector of `Point`s
|
|
|
+ fn sum_points(&mut self, _arg: (AExprType, ()), _restype: AExprType) -> Result<()> {
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+
|
|
|
/// Called when subtracting two `Scalar`s
|
|
|
fn sub_scalars(
|
|
|
&mut self,
|
|
@@ -607,6 +672,8 @@ impl AExprFold<()> for FoldNoOp {
|
|
|
/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
/// - the operation `<<` where both operands are expressions with no
|
|
|
/// variables
|
|
|
+/// - the function `sum` that takes a single vector argument and
|
|
|
+/// returns the sum of its elements
|
|
|
/// - parens
|
|
|
pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
|
|
|
let mut fold = FoldNoOp {};
|
|
@@ -769,6 +836,26 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
Ok(tokens_add_maybe_vec(le, l_is_vec, re, r_is_vec))
|
|
|
}
|
|
|
|
|
|
+ /// Called when summing a vector of `Scalar`s
|
|
|
+ fn sum_scalars(
|
|
|
+ &mut self,
|
|
|
+ arg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let arge = arg.1;
|
|
|
+ Ok(quote! { sigma_compiler::vecutils::sum_vec(&(#arge)) })
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Called when summing a vector of `Point`s
|
|
|
+ fn sum_points(
|
|
|
+ &mut self,
|
|
|
+ arg: (AExprType, TokenStream),
|
|
|
+ _restype: AExprType,
|
|
|
+ ) -> Result<TokenStream> {
|
|
|
+ let arge = arg.1;
|
|
|
+ Ok(quote! { sigma_compiler::vecutils::sum_vec(&(#arge)) })
|
|
|
+ }
|
|
|
+
|
|
|
/// Called when subtracting two `Scalar`s
|
|
|
fn sub_scalars(
|
|
|
&mut self,
|
|
@@ -893,6 +980,8 @@ impl<'a> AExprFold<TokenStream> for AExprTokenFold<'a> {
|
|
|
/// - the operations `*`, `+`, `-` (binary or unary)
|
|
|
/// - the operation `<<` where both operands are expressions with no
|
|
|
/// variables
|
|
|
+/// - the function `sum` that takes a single vector argument and
|
|
|
+/// returns the sum of its elements
|
|
|
/// - parens
|
|
|
pub fn expr_type_tokens(vars: &VarDict, expr: &Expr) -> Result<(AExprType, TokenStream)> {
|
|
|
let mut fold = AExprTokenFold {
|