Browse Source

A function to evaluate the types of arithmetic expressions on Scalars and Points

Ian Goldberg 4 months ago
parent
commit
ed8a9d43ed

+ 3 - 2
sigma_compiler_core/src/lib.rs

@@ -9,10 +9,11 @@ use syn::{parse_quote, Expr, Ident, Token};
 /// module
 /// module
 mod sigma {
 mod sigma {
     pub mod combiners;
     pub mod combiners;
+    pub mod types;
 }
 }
 mod syntax;
 mod syntax;
 
 
-pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, VarDict};
+pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, TaggedVarDict};
 
 
 // Names and types of fields that might end up in a generated struct
 // Names and types of fields that might end up in a generated struct
 enum StructField {
 enum StructField {
@@ -41,7 +42,7 @@ impl StructFieldList {
     pub fn push_vecpoint(&mut self, s: &Ident) {
     pub fn push_vecpoint(&mut self, s: &Ident) {
         self.fields.push(StructField::VecPoint(s.clone()));
         self.fields.push(StructField::VecPoint(s.clone()));
     }
     }
-    pub fn push_vars(&mut self, vardict: &VarDict, is_pub: bool) {
+    pub fn push_vars(&mut self, vardict: &TaggedVarDict, is_pub: bool) {
         for (_, ti) in vardict.iter() {
         for (_, ti) in vardict.iter() {
             match ti {
             match ti {
                 TaggedIdent::Scalar(st) => {
                 TaggedIdent::Scalar(st) => {

+ 289 - 0
sigma_compiler_core/src/sigma/types.rs

@@ -0,0 +1,289 @@
+//! At the `sigma` level, each variable can be a private `Scalar`, a
+//! public `Scalar`, or a public `Point`, and each variable can be
+//! either a vector or not.  Arithmetic expressions of those variables
+//! can be of any of those types, and also private `Point`s (vector or
+//! not).  This module defines an enum [`AExprType`] for
+//! the possible types, as well as a dictionary type that maps
+//! [`String`]s (the name of the variable) to [`AExprType`], and a
+//! function for determining the type of arithmetic expressions
+//! involving such variables.
+
+use std::collections::HashMap;
+use syn::parse::Result;
+use syn::spanned::Spanned;
+use syn::{Error, Expr};
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum AExprType {
+    Scalar { is_pub: bool, is_vec: bool },
+    Point { is_pub: bool, is_vec: bool },
+}
+
+impl From<&str> for AExprType {
+    /// A convenience function for creating a [`AExprType`] from a
+    /// [`&str`].  Pass one of (or their short forms):
+    ///   - `"Scalar"` (`"S"`)
+    ///   - `"pub Scalar"` (`"pS"`)
+    ///   - `"vec Scalar"` (`"vS"`)
+    ///   - `"pub vec Scalar"` (`"pvS"`)
+    ///   - `"Point"` (`"P"`)
+    ///   - `"pub Point"` (`"pP"`)
+    ///   - `"vec Point"` (`"vP"`)
+    ///   - `"pub vec Point"` (`"pvP"`)
+    fn from(s: &str) -> Self {
+        match s {
+            "Scalar" | "S" => Self::Scalar {
+                is_pub: false,
+                is_vec: false,
+            },
+            "pub Scalar" | "pS" => Self::Scalar {
+                is_pub: true,
+                is_vec: false,
+            },
+            "vec Scalar" | "vS" => Self::Scalar {
+                is_pub: false,
+                is_vec: true,
+            },
+            "pub vec Scalar" | "pvS" => Self::Scalar {
+                is_pub: true,
+                is_vec: true,
+            },
+            "Point" | "P" => Self::Point {
+                is_pub: false,
+                is_vec: false,
+            },
+            "vec Point" | "vP" => Self::Point {
+                is_pub: false,
+                is_vec: true,
+            },
+            "pub Point" | "pP" => Self::Point {
+                is_pub: true,
+                is_vec: false,
+            },
+            "pub vec Point" | "pvP" => Self::Point {
+                is_pub: true,
+                is_vec: true,
+            },
+            _ => {
+                panic!("Illegal string passed to AExprType::from");
+            }
+        }
+    }
+}
+
+/// A dictionary of known variables (given by [`String`]s), mapping each
+/// to their [`AExprType`]
+pub type VarDict = HashMap<String, AExprType>;
+
+/// Pass a slice of pairs of strings.  The first element of each
+/// pair is the variable name; the second is the [`AExprType`], as
+/// listed in the [`AExprType::from`] function
+pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
+    let c = strs
+        .iter()
+        .map(|(k, v)| (String::from(*k), AExprType::from(*v)));
+    VarDict::from_iter(c)
+}
+
+/// Given a [`VarDict`] and an [`Expr`] representing an arithmetic
+/// expression using the variables in the [`VarDict`], compute the
+/// [`AExprType`] of the expression.  An arithmetic expression can consist
+/// of:
+///   - variables that are in the [`VarDict`]
+///   - integer constants
+///   - the operations `*`, `+`, `-` (binary or unary)
+///   - parens
+pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
+    match expr {
+        Expr::Lit(syn::ExprLit {
+            lit: syn::Lit::Int(_),
+            ..
+        }) => Ok(AExprType::Scalar {
+            is_pub: true,
+            is_vec: false,
+        }),
+        Expr::Unary(syn::ExprUnary {
+            op: syn::UnOp::Neg(_),
+            expr,
+            ..
+        }) => expr_type(vars, expr.as_ref()),
+        Expr::Paren(syn::ExprParen { expr, .. }) => expr_type(vars, expr.as_ref()),
+        Expr::Path(syn::ExprPath { path, .. }) => {
+            if let Some(id) = path.get_ident() {
+                if let Some(&vt) = vars.get(&id.to_string()) {
+                    return Ok(vt);
+                }
+            }
+            Err(Error::new(expr.span(), "not a known variable"))
+        }
+        Expr::Binary(syn::ExprBinary {
+            left, op, right, ..
+        }) => {
+            match op {
+                syn::BinOp::Add(_) | syn::BinOp::Sub(_) => {
+                    let lt = expr_type(vars, left.as_ref())?;
+                    let rt = expr_type(vars, right.as_ref())?;
+                    // You can add or subtract two Scalars or two
+                    // Points, but not a Scalar and a Point.  The result
+                    // is public if both arguments are public.  The
+                    // result is a vector if either argument is a
+                    // vector.
+                    match (lt, rt) {
+                        (
+                            AExprType::Scalar {
+                                is_pub: lpub,
+                                is_vec: lvec,
+                            },
+                            AExprType::Scalar {
+                                is_pub: rpub,
+                                is_vec: rvec,
+                            },
+                        ) => {
+                            return Ok(AExprType::Scalar {
+                                is_pub: lpub && rpub,
+                                is_vec: lvec || rvec,
+                            });
+                        }
+                        (
+                            AExprType::Point {
+                                is_pub: lpub,
+                                is_vec: lvec,
+                            },
+                            AExprType::Point {
+                                is_pub: rpub,
+                                is_vec: rvec,
+                            },
+                        ) => {
+                            return Ok(AExprType::Point {
+                                is_pub: lpub && rpub,
+                                is_vec: lvec || rvec,
+                            });
+                        }
+                        _ => {}
+                    }
+                    return Err(Error::new(
+                        expr.span(),
+                        "cannot add/subtract a Scalar and a Point",
+                    ));
+                }
+                syn::BinOp::Mul(_) => {
+                    let lt = expr_type(vars, left.as_ref())?;
+                    let rt = expr_type(vars, right.as_ref())?;
+                    // You can multiply two Scalars or a Scalar and a
+                    // Point, but not two Points.  You can also not
+                    // multiply two private expressions.  The result is
+                    // public if both arguments are public.  The result
+                    // is a vector if either argument is a vector.
+                    match (lt, rt) {
+                        (
+                            AExprType::Scalar {
+                                is_pub: lpub,
+                                is_vec: lvec,
+                            },
+                            AExprType::Scalar {
+                                is_pub: rpub,
+                                is_vec: rvec,
+                            },
+                        ) => {
+                            if !lpub && !rpub {
+                                return Err(Error::new(
+                                    expr.span(),
+                                    "cannot multiply two private expressions",
+                                ));
+                            }
+                            return Ok(AExprType::Scalar {
+                                is_pub: lpub && rpub,
+                                is_vec: lvec || rvec,
+                            });
+                        }
+                        (
+                            AExprType::Scalar {
+                                is_pub: lpub,
+                                is_vec: lvec,
+                            },
+                            AExprType::Point {
+                                is_pub: rpub,
+                                is_vec: rvec,
+                            },
+                        )
+                        | (
+                            AExprType::Point {
+                                is_pub: lpub,
+                                is_vec: lvec,
+                            },
+                            AExprType::Scalar {
+                                is_pub: rpub,
+                                is_vec: rvec,
+                            },
+                        ) => {
+                            if !lpub && !rpub {
+                                return Err(Error::new(
+                                    expr.span(),
+                                    "cannot multiply two private expressions",
+                                ));
+                            }
+                            return Ok(AExprType::Point {
+                                is_pub: lpub && rpub,
+                                is_vec: lvec || rvec,
+                            });
+                        }
+                        _ => {}
+                    }
+                    return Err(Error::new(
+                        expr.span(),
+                        "cannot multiply a Point and a Point",
+                    ));
+                }
+                _ => {}
+            }
+            Err(Error::new(
+                op.span(),
+                "invalid operation for arithmetic expression",
+            ))
+        }
+        _ => Err(Error::new(expr.span(), "not a valid arithmetic expression")),
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use syn::parse_quote;
+
+    fn check(vars: &VarDict, expr: Expr, expect: &str) {
+        assert_eq!(expr_type(vars, &expr).unwrap(), AExprType::from(expect));
+    }
+
+    fn check_fail(vars: &VarDict, expr: Expr) {
+        expr_type(vars, &expr).unwrap_err();
+    }
+
+    #[test]
+    fn test_expr_type() {
+        let vars: VarDict = vardict_from_strs(&[("a", "S"), ("A", "pP"), ("v", "vS")]);
+        check(&vars, parse_quote! {2}, "pS");
+        check(&vars, parse_quote! {-4}, "pS");
+        check(&vars, parse_quote! {(2)}, "pS");
+        check(&vars, parse_quote! {A}, "pP");
+        check(&vars, parse_quote! {a*A}, "P");
+        check(&vars, parse_quote! {A*3}, "pP");
+        check(&vars, parse_quote! {(a-1)*(A+A)}, "P");
+        check(&vars, parse_quote! {(v-1)*(A+A)}, "vP");
+
+        // Tests that should fail
+
+        // unknown variable
+        check_fail(&vars, parse_quote! {B});
+        // adding a Scalar to a Point
+        check_fail(&vars, parse_quote! {a+A});
+        // multiplying two Points
+        check_fail(&vars, parse_quote! {A*A});
+        // invalid operation
+        check_fail(&vars, parse_quote! {A/A});
+        // invalid expression
+        check_fail(&vars, parse_quote! {A.size});
+        // multiplying two private expressions (two ways)
+        check_fail(&vars, parse_quote! {a*a});
+        check_fail(&vars, parse_quote! {a*(a*A)});
+    }
+}

+ 32 - 6
sigma_compiler_core/src/syntax.rs

@@ -1,4 +1,5 @@
 use super::sigma::combiners::StatementTree;
 use super::sigma::combiners::StatementTree;
+use super::sigma::types::*;
 use quote::format_ident;
 use quote::format_ident;
 use std::collections::HashMap;
 use std::collections::HashMap;
 use syn::ext::IdentExt;
 use syn::ext::IdentExt;
@@ -87,10 +88,35 @@ pub enum TaggedIdent {
     Point(TaggedPoint),
     Point(TaggedPoint),
 }
 }
 
 
-/// A `VarDict` is a dictionary of the available variables, mapping
-/// the string version of `Ident`s to `TaggedIdent`, which includes
-/// their type (`Scalar` or `Point`)
-pub type VarDict = HashMap<String, TaggedIdent>;
+/// Convert a [`TaggedIdent`] to its underlying [`AExprType`]
+impl From<&TaggedIdent> for AExprType {
+    fn from(ti: &TaggedIdent) -> Self {
+        match ti {
+            TaggedIdent::Scalar(ts) => Self::Scalar {
+                is_pub: ts.is_pub,
+                is_vec: ts.is_vec,
+            },
+            TaggedIdent::Point(tp) => Self::Point {
+                is_pub: true,
+                is_vec: tp.is_vec,
+            },
+        }
+    }
+}
+
+/// A `TaggedVarDict` is a dictionary of the available variables,
+/// mapping the string version of `Ident`s to `TaggedIdent`, which
+/// includes their type (`Scalar` or `Point`)
+pub type TaggedVarDict = HashMap<String, TaggedIdent>;
+
+/// Convert a [`TaggedVarDict`] (a map from [`String`] to
+/// [`TaggedIdent`]) into the equivalent [`VarDict`] (a map from
+/// [`String`] to [`AExprType`])
+pub fn taggedvardict_to_vardict(vd: &TaggedVarDict) -> VarDict {
+    vd.iter()
+        .map(|(k, v)| (k.clone(), AExprType::from(v)))
+        .collect()
+}
 
 
 impl Parse for TaggedPoint {
 impl Parse for TaggedPoint {
     fn parse(input: ParseStream) -> Result<Self> {
     fn parse(input: ParseStream) -> Result<Self> {
@@ -129,7 +155,7 @@ impl Parse for TaggedPoint {
 pub struct SigmaCompSpec {
 pub struct SigmaCompSpec {
     pub proto_name: Ident,
     pub proto_name: Ident,
     pub group_name: Ident,
     pub group_name: Ident,
-    pub vars: VarDict,
+    pub vars: TaggedVarDict,
     pub statements: StatementTree,
     pub statements: StatementTree,
 }
 }
 
 
@@ -155,7 +181,7 @@ impl Parse for SigmaCompSpec {
         };
         };
         input.parse::<Token![,]>()?;
         input.parse::<Token![,]>()?;
 
 
-        let mut vars: VarDict = HashMap::new();
+        let mut vars: TaggedVarDict = HashMap::new();
 
 
         let scalars = paren_taggedidents::<TaggedScalar>(input)?;
         let scalars = paren_taggedidents::<TaggedScalar>(input)?;
         vars.extend(
         vars.extend(