Browse Source

Add a recognizer for LinScalar expressions

Ian Goldberg 4 months ago
parent
commit
676304ef47
1 changed files with 79 additions and 3 deletions
  1. 79 3
      sigma_compiler_core/src/pedersen.rs

+ 79 - 3
sigma_compiler_core/src/pedersen.rs

@@ -735,7 +735,7 @@ impl<'a> AExprFold<PedersenExpr> for RecognizeFold<'a> {
 
 /// Parse the right-hand side of the = in an [`Expr`] to see if we
 /// recognize it as a Pedersen commitment
-pub fn recognize(
+pub fn recognize_pedersen(
     vars: &TaggedVarDict,
     randoms: &HashSet<String>,
     vardict: &VarDict,
@@ -762,6 +762,26 @@ pub fn recognize(
     Some(pedersen)
 }
 
+/// Parse an [`Expr`] to see if we recognize it as a [`LinScalar`]
+pub fn recognize_linscalar(
+    vars: &TaggedVarDict,
+    vardict: &VarDict,
+    expr: &Expr,
+) -> Option<LinScalar> {
+    let mut fold = RecognizeFold {
+        vars,
+        randoms: &HashSet::new(),
+    };
+    let Ok((aetype, PedersenExpr::LinScalar(linscalar))) = fold.fold(vardict, expr) else {
+        return None;
+    };
+    // A 0 coefficient is not allowed
+    if linscalar.coeff == 0 {
+        return None;
+    }
+    Some(linscalar)
+}
+
 #[cfg(test)]
 mod test {
     use super::*;
@@ -1164,12 +1184,23 @@ mod test {
         for r in randoms {
             randoms_hash.insert(r.to_string());
         }
-        let output = recognize(&taggedvardict, &randoms_hash, &vardict, &e);
+        let output = recognize_pedersen(&taggedvardict, &randoms_hash, &vardict, &e);
+        assert_eq!(output, expected_out);
+    }
+
+    fn recognize_linscalar_tester(
+        vars: (&[&str], &[&str]),
+        e: Expr,
+        expected_out: Option<LinScalar>,
+    ) {
+        let taggedvardict = taggedvardict_from_strs(vars);
+        let vardict = taggedvardict_to_vardict(&taggedvardict);
+        let output = recognize_linscalar(&taggedvardict, &vardict, &e);
         assert_eq!(output, expected_out);
     }
 
     #[test]
-    fn recognize_test() {
+    fn recognize_pedersen_test() {
         let vars = (
             [
                 "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
@@ -1263,4 +1294,49 @@ mod test {
             }),
         );
     }
+
+    #[test]
+    fn recognize_linscalar_test() {
+        let vars = (
+            [
+                "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
+            ]
+            .as_slice(),
+            ["C", "cind A", "cind B"].as_slice(),
+        );
+
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                3*(x + a + 1)
+            },
+            Some(LinScalar {
+                coeff: 3,
+                pub_scalar_expr: Some(parse_quote! { ( a + 1i128 ) * 3i128 }),
+                id: parse_quote! {x},
+                is_vec: false,
+            }),
+        );
+
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                3*(x + a + 1) - x*4
+            },
+            Some(LinScalar {
+                coeff: -1,
+                pub_scalar_expr: Some(parse_quote! { ( a + 1i128 ) * 3i128 }),
+                id: parse_quote! {x},
+                is_vec: false,
+            }),
+        );
+
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                3*(x + a + 1) - x*4 + x
+            },
+            None,
+        );
+    }
 }