Browse Source

Allow << on constant expressions in statements

It will be useful to be able to specify things like (1<<20) as the size
of a range.
Ian Goldberg 3 months ago
parent
commit
f501f2c7f3
2 changed files with 109 additions and 7 deletions
  1. 108 7
      sigma_compiler_core/src/sigma/types.rs
  2. 1 0
      sigma_compiler_core/src/syntax.rs

+ 108 - 7
sigma_compiler_core/src/sigma/types.rs

@@ -21,10 +21,20 @@ use syn::{Error, Expr};
 /// Note that while an individual variable cannot be a private `Point`,
 /// it is common to construct an arithmetic expression of that type, for
 /// example by multiplying a private `Scalar` by a public `Point`.
+/// In addition, an [`AExprType`] that represents a constant `Scalar`
+/// value (that fits in an [`i128`]) will have that constant value in
+/// `val`.
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
 pub enum AExprType {
-    Scalar { is_pub: bool, is_vec: bool },
-    Point { is_pub: bool, is_vec: bool },
+    Scalar {
+        is_pub: bool,
+        is_vec: bool,
+        val: Option<i128>,
+    },
+    Point {
+        is_pub: bool,
+        is_vec: bool,
+    },
 }
 
 impl From<&str> for AExprType {
@@ -43,18 +53,22 @@ impl From<&str> for AExprType {
             "Scalar" | "S" => Self::Scalar {
                 is_pub: false,
                 is_vec: false,
+                val: None,
             },
             "pub Scalar" | "pS" => Self::Scalar {
                 is_pub: true,
                 is_vec: false,
+                val: None,
             },
             "vec Scalar" | "vS" => Self::Scalar {
                 is_pub: false,
                 is_vec: true,
+                val: None,
             },
             "pub vec Scalar" | "pvS" => Self::Scalar {
                 is_pub: true,
                 is_vec: true,
+                val: None,
             },
             "Point" | "P" => Self::Point {
                 is_pub: false,
@@ -103,21 +117,35 @@ pub fn vardict_from_strs(strs: &[(&str, &str)]) -> VarDict {
 ///   - variables that are in the [`VarDict`]
 ///   - integer constants
 ///   - the operations `*`, `+`, `-` (binary or unary)
+///   - the operation `<<` where both operands are expressions with no
+///     variables
 ///   - parens
 pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
     match expr {
         Expr::Lit(syn::ExprLit {
-            lit: syn::Lit::Int(_),
+            lit: syn::Lit::Int(litint),
             ..
         }) => Ok(AExprType::Scalar {
             is_pub: true,
             is_vec: false,
+            val: litint.base10_parse::<i128>().ok(),
         }),
         Expr::Unary(syn::ExprUnary {
             op: syn::UnOp::Neg(_),
             expr,
             ..
-        }) => expr_type(vars, expr.as_ref()),
+        }) => match expr_type(vars, expr.as_ref()) {
+            Ok(AExprType::Scalar {
+                is_pub: true,
+                is_vec: false,
+                val: Some(v),
+            }) => Ok(AExprType::Scalar {
+                is_pub: true,
+                is_vec: false,
+                val: v.checked_neg(),
+            }),
+            other => other,
+        },
         Expr::Paren(syn::ExprParen { expr, .. }) => expr_type(vars, expr.as_ref()),
         Expr::Path(syn::ExprPath { path, .. }) => {
             if let Some(id) = path.get_ident() {
@@ -144,15 +172,30 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                             AExprType::Scalar {
                                 is_pub: lpub,
                                 is_vec: lvec,
+                                val: lval,
                             },
                             AExprType::Scalar {
                                 is_pub: rpub,
                                 is_vec: rvec,
+                                val: rval,
                             },
                         ) => {
                             return Ok(AExprType::Scalar {
                                 is_pub: lpub && rpub,
                                 is_vec: lvec || rvec,
+                                val: if let (Some(lv), Some(rv)) = (lval, rval) {
+                                    match op {
+                                        syn::BinOp::Add(_) => lv.checked_add(rv),
+                                        syn::BinOp::Sub(_) => lv.checked_sub(rv),
+                                        // The default match can't
+                                        // happen because we're already
+                                        // inside a match on op, but the
+                                        // compiler requires it anyway
+                                        _ => None,
+                                    }
+                                } else {
+                                    None
+                                },
                             });
                         }
                         (
@@ -190,10 +233,12 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                             AExprType::Scalar {
                                 is_pub: lpub,
                                 is_vec: lvec,
+                                val: lval,
                             },
                             AExprType::Scalar {
                                 is_pub: rpub,
                                 is_vec: rvec,
+                                val: rval,
                             },
                         ) => {
                             if !lpub && !rpub {
@@ -205,12 +250,18 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                             return Ok(AExprType::Scalar {
                                 is_pub: lpub && rpub,
                                 is_vec: lvec || rvec,
+                                val: if let (Some(lv), Some(rv)) = (lval, rval) {
+                                    lv.checked_mul(rv)
+                                } else {
+                                    None
+                                },
                             });
                         }
                         (
                             AExprType::Scalar {
                                 is_pub: lpub,
                                 is_vec: lvec,
+                                ..
                             },
                             AExprType::Point {
                                 is_pub: rpub,
@@ -225,6 +276,7 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                             AExprType::Scalar {
                                 is_pub: rpub,
                                 is_vec: rvec,
+                                ..
                             },
                         ) => {
                             if !lpub && !rpub {
@@ -245,6 +297,39 @@ pub fn expr_type(vars: &VarDict, expr: &Expr) -> Result<AExprType> {
                         "cannot multiply a Point and a Point",
                     ));
                 }
+                syn::BinOp::Shl(_) => {
+                    let lt = expr_type(vars, left.as_ref())?;
+                    let rt = expr_type(vars, right.as_ref())?;
+                    // You can << only when both operands are constant
+                    // Scalar expressions
+                    if let (
+                        AExprType::Scalar {
+                            is_pub: true,
+                            is_vec: false,
+                            val: Some(lv),
+                        },
+                        AExprType::Scalar {
+                            is_pub: true,
+                            is_vec: false,
+                            val: Some(rv),
+                        },
+                    ) = (lt, rt)
+                    {
+                        let rvu32: Option<u32> = rv.try_into().ok();
+                        if let Some(shift_amt) = rvu32 {
+                            return Ok(AExprType::Scalar {
+                                is_pub: true,
+                                is_vec: false,
+                                val: lv.checked_shl(shift_amt),
+                            });
+                        }
+                    }
+                    return Err(Error::new(
+                        expr.span(),
+                        "can shift left only on constant
+                            expressions",
+                    ));
+                }
                 _ => {}
             }
             Err(Error::new(
@@ -265,6 +350,17 @@ mod tests {
         assert_eq!(expr_type(vars, &expr).unwrap(), AExprType::from(expect));
     }
 
+    fn check_const(vars: &VarDict, expr: Expr, expect: i128) {
+        assert_eq!(
+            expr_type(vars, &expr).unwrap(),
+            AExprType::Scalar {
+                is_pub: true,
+                is_vec: false,
+                val: Some(expect),
+            }
+        );
+    }
+
     fn check_fail(vars: &VarDict, expr: Expr) {
         expr_type(vars, &expr).unwrap_err();
     }
@@ -272,9 +368,11 @@ mod tests {
     #[test]
     fn expr_type_test() {
         let vars: VarDict = vardict_from_strs(&[("a", "S"), ("A", "pP"), ("v", "vS")]);
-        check(&vars, parse_quote! {2}, "pS");
-        check(&vars, parse_quote! {-4}, "pS");
-        check(&vars, parse_quote! {(2)}, "pS");
+        check_const(&vars, parse_quote! {2}, 2);
+        check_const(&vars, parse_quote! {-4}, -4);
+        check_const(&vars, parse_quote! {(2)}, 2);
+        check_const(&vars, parse_quote! {1<<20}, 1048576);
+        check_const(&vars, parse_quote! {(3-2)<<(4*5)}, 1048576);
         check(&vars, parse_quote! {A}, "pP");
         check(&vars, parse_quote! {a*A}, "P");
         check(&vars, parse_quote! {A*3}, "pP");
@@ -296,5 +394,8 @@ mod tests {
         // multiplying two private expressions (two ways)
         check_fail(&vars, parse_quote! {a*a});
         check_fail(&vars, parse_quote! {a*(a*A)});
+        // Shifting non-constant expressions
+        check_fail(&vars, parse_quote! {a<<2});
+        check_fail(&vars, parse_quote! {1<<a});
     }
 }

+ 1 - 0
sigma_compiler_core/src/syntax.rs

@@ -132,6 +132,7 @@ impl From<&TaggedIdent> for AExprType {
             TaggedIdent::Scalar(ts) => Self::Scalar {
                 is_pub: ts.is_pub,
                 is_vec: ts.is_vec,
+                val: None,
             },
             TaggedIdent::Point(tp) => Self::Point {
                 is_pub: true,