Parcourir la source

Add a recognizer for expressions that evaluate to public Scalars

Ian Goldberg il y a 6 mois
Parent
commit
22dd8f33fd
1 fichiers modifiés avec 95 ajouts et 12 suppressions
  1. 95 12
      sigma_compiler_core/src/pedersen.rs

+ 95 - 12
sigma_compiler_core/src/pedersen.rs

@@ -772,7 +772,7 @@ pub fn recognize_linscalar(
         vars,
         randoms: &HashSet::new(),
     };
-    let Ok((aetype, PedersenExpr::LinScalar(linscalar))) = fold.fold(vardict, expr) else {
+    let Ok((_, PedersenExpr::LinScalar(linscalar))) = fold.fold(vardict, expr) else {
         return None;
     };
     // A 0 coefficient is not allowed
@@ -782,6 +782,27 @@ pub fn recognize_linscalar(
     Some(linscalar)
 }
 
+/// Parse an [`Expr`] to see if we recognize it as an expression that
+/// evaluates to a public `Scalar`.
+///
+/// The returned [`bool`] is true if the expression evaluates to a
+/// vector
+pub fn recognize_pubscalar(
+    vars: &TaggedVarDict,
+    vardict: &VarDict,
+    expr: &Expr,
+) -> Option<bool> {
+    let mut fold = RecognizeFold {
+        vars,
+        randoms: &HashSet::new(),
+    };
+    let Ok((AExprType::Scalar{is_vec, ..}, PedersenExpr::PubScalarExpr(_)))
+        = fold.fold(vardict, expr) else {
+        return None;
+    };
+    Some(is_vec)
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -1188,17 +1209,6 @@ mod test {
         assert_eq!(output, expected_out);
     }
 
-    fn recognize_linscalar_tester(
-        vars: (&[&str], &[&str]),
-        e: Expr,
-        expected_out: Option<LinScalar>,
-    ) {
-        let taggedvardict = taggedvardict_from_strs(vars);
-        let vardict = taggedvardict_to_vardict(&taggedvardict);
-        let output = recognize_linscalar(&taggedvardict, &vardict, &e);
-        assert_eq!(output, expected_out);
-    }
-
     #[test]
     fn recognize_pedersen_test() {
         let vars = (
@@ -1295,6 +1305,17 @@ mod test {
         );
     }
 
+    fn recognize_linscalar_tester(
+        vars: (&[&str], &[&str]),
+        e: Expr,
+        expected_out: Option<LinScalar>,
+    ) {
+        let taggedvardict = taggedvardict_from_strs(vars);
+        let vardict = taggedvardict_to_vardict(&taggedvardict);
+        let output = recognize_linscalar(&taggedvardict, &vardict, &e);
+        assert_eq!(output, expected_out);
+    }
+
     #[test]
     fn recognize_linscalar_test() {
         let vars = (
@@ -1339,4 +1360,66 @@ mod test {
             None,
         );
     }
+
+    fn recognize_pubscalar_tester(
+        vars: (&[&str], &[&str]),
+        e: Expr,
+        expected_out: Option<bool>,
+    ) {
+        let taggedvardict = taggedvardict_from_strs(vars);
+        let vardict = taggedvardict_to_vardict(&taggedvardict);
+        let output = recognize_pubscalar(&taggedvardict, &vardict, &e);
+        assert_eq!(output, expected_out);
+    }
+
+    #[test]
+    fn recognize_pubscalar_test() {
+        let vars = (
+            [
+                "x", "y", "z", "pub a", "pub vec b", "pub c", "rand r", "rand s", "rand t",
+            ]
+            .as_slice(),
+            ["C", "cind A", "cind B"].as_slice(),
+        );
+
+        recognize_pubscalar_tester(
+            vars,
+            parse_quote! {
+                3*(x + a + 1)
+            },
+            None,
+        );
+
+        recognize_pubscalar_tester(
+            vars,
+            parse_quote! {
+                3
+            },
+            Some(false),
+        );
+
+        recognize_pubscalar_tester(
+            vars,
+            parse_quote! {
+                a
+            },
+            Some(false),
+        );
+
+        recognize_pubscalar_tester(
+            vars,
+            parse_quote! {
+                3*(a + 1)
+            },
+            Some(false),
+        );
+
+        recognize_pubscalar_tester(
+            vars,
+            parse_quote! {
+                3*(a + b)
+            },
+            Some(true),
+        );
+    }
 }