Browse Source

Implement and test LinScalar::to_expr()

Ian Goldberg 4 months ago
parent
commit
ed1a2ab347
1 changed files with 68 additions and 0 deletions
  1. 68 0
      sigma_compiler_core/src/pedersen.rs

+ 68 - 0
sigma_compiler_core/src/pedersen.rs

@@ -167,6 +167,25 @@ impl LinScalar {
             ..self
         })
     }
+
+    /// Output a [`LinScalar`] as an [`Expr`]
+    pub fn to_expr(&self) -> Expr {
+        let coeff = self.coeff;
+        let id = &self.id;
+        // If there's a non-1 coefficient, multiply it by the id
+        let coeff_var_term: Expr = if coeff == 1 {
+            parse_quote! { #id }
+        } else {
+            parse_quote! { #coeff * #id }
+        };
+        // If there's a pub_scalar_expr, add it to the result
+        if let Some(ref pse) = self.pub_scalar_expr {
+            let ppse = paren_if_needed(pse.clone());
+            parse_quote! { #coeff_var_term + #ppse }
+        } else {
+            coeff_var_term
+        }
+    }
 }
 
 /// A representation of `b*A` where `b` is a public `Scalar` [arithmetic
@@ -1354,11 +1373,15 @@ mod test {
         vars: (&[&str], &[&str]),
         e: Expr,
         expected_out: Option<LinScalar>,
+        expected_expr: Option<Expr>,
     ) {
         let taggedvardict = taggedvardict_from_strs(vars);
         let vardict = taggedvardict_to_vardict(&taggedvardict);
         let output = recognize_linscalar(&taggedvardict, &vardict, &e);
         assert_eq!(output, expected_out);
+        if output.is_some() {
+            assert_eq!(output.unwrap().to_expr(), expected_expr.unwrap());
+        }
     }
 
     #[test]
@@ -1371,6 +1394,48 @@ mod test {
             ["C", "cind A", "cind B"].as_slice(),
         );
 
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                x
+            },
+            Some(LinScalar {
+                coeff: 1,
+                pub_scalar_expr: None,
+                id: parse_quote! {x},
+                is_vec: false,
+            }),
+            Some(parse_quote! { x }),
+        );
+
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                x * 7 - x * 3
+            },
+            Some(LinScalar {
+                coeff: 4,
+                pub_scalar_expr: None,
+                id: parse_quote! {x},
+                is_vec: false,
+            }),
+            Some(parse_quote! { 4i128 * x }),
+        );
+
+        recognize_linscalar_tester(
+            vars,
+            parse_quote! {
+                x - (a + 12)
+            },
+            Some(LinScalar {
+                coeff: 1,
+                pub_scalar_expr: Some(parse_quote! {-(a + 12i128)}),
+                id: parse_quote! {x},
+                is_vec: false,
+            }),
+            Some(parse_quote! { x + (-(a + 12i128))}),
+        );
+
         recognize_linscalar_tester(
             vars,
             parse_quote! {
@@ -1382,6 +1447,7 @@ mod test {
                 id: parse_quote! {x},
                 is_vec: false,
             }),
+            Some(parse_quote! { 3i128 * x + ((a + 1i128) * 3i128) }),
         );
 
         recognize_linscalar_tester(
@@ -1395,6 +1461,7 @@ mod test {
                 id: parse_quote! {x},
                 is_vec: false,
             }),
+            Some(parse_quote! { -1i128 * x + ((a + 1i128) * 3i128) }),
         );
 
         recognize_linscalar_tester(
@@ -1403,6 +1470,7 @@ mod test {
                 3*(x + a + 1) - x*4 + x
             },
             None,
+            None,
         );
     }