Kaynağa Gözat

Flatten nested `And` nodes in a [`StatementTree`]

Ian Goldberg 5 ay önce
ebeveyn
işleme
1fc9daf16b
2 değiştirilmiş dosya ile 157 ekleme ve 0 silme
  1. 2 0
      Cargo.toml
  2. 155 0
      sigma_compiler_core/src/sigma/combiners.rs

+ 2 - 0
Cargo.toml

@@ -6,6 +6,8 @@ edition = "2021"
 [dependencies]
 sigma_compiler_derive = { path = "sigma_compiler_derive" }
 group = "0.13"
+sigma-rs = { path = "../sigma" }
 
 [dev-dependencies]
 curve25519-dalek = { version = "4", features = [ "group", "rand_core", "digest" ] }
+rand = "0.8.5"

+ 155 - 0
sigma_compiler_core/src/sigma/combiners.rs

@@ -320,6 +320,63 @@ impl StatementTree {
         }
         Ok(next_node)
     }
+
+    #[cfg(not(doctest))]
+    /// Flatten nested `And` nodes in a [`StatementTree`].
+    ///
+    /// The underlying `sigma-rs` crate can share `Scalars` across
+    /// statements that are direct children of the same `And` node, but
+    /// not in nested `And` nodes.
+    ///
+    /// So a [`StatementTree`] like this:
+    ///
+    /// ```
+    ///    AND (
+    ///        C = x*B + r*A,
+    ///        AND (
+    ///            D = x*B + s*A,
+    ///            E = x*B + t*A,
+    ///        ),
+    ///    )
+    /// ```
+    ///
+    /// Needs to be flattened to:
+    ///
+    /// ```
+    ///    AND (
+    ///        C = x*B + r*A,
+    ///        D = x*B + s*A,
+    ///        E = x*B + t*A,
+    ///    )
+    /// ```
+    pub fn flatten_ands(&mut self) {
+        match self {
+            StatementTree::Leaf(_) => {}
+            StatementTree::Or(svec) | StatementTree::Thresh(_, svec) => {
+                // Flatten each child
+                svec.iter_mut().for_each(|st| st.flatten_ands());
+            }
+            StatementTree::And(svec) => {
+                // Flatten each child, and if any of the children are
+                // `And`s, replace that child with the list of its
+                // children
+                let old_svec = std::mem::take(svec);
+                let mut new_svec: Vec<StatementTree> = Vec::new();
+                for mut st in old_svec {
+                    st.flatten_ands();
+                    match st {
+                        StatementTree::And(mut child_svec) => {
+                            new_svec.append(&mut child_svec);
+                        }
+                        _ => {
+                            new_svec.push(st);
+                        }
+                    }
+                }
+                *self = StatementTree::And(new_svec);
+            }
+        }
+    }
 }
 
 #[cfg(test)]
@@ -560,4 +617,102 @@ mod test {
         st_nok2.check_disjunction_invariant(&vars).unwrap_err();
         st_nok3.check_disjunction_invariant(&vars).unwrap_err();
     }
+
+    fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
+        let mut st = StatementTree::parse(&e).unwrap();
+        st.flatten_ands();
+        assert_eq!(st, StatementTree::parse(&flattened_e).unwrap());
+    }
+
+    #[test]
+    // Test flatten_ands
+    fn flatten_ands_test() {
+        flatten_ands_tester(
+            parse_quote! {
+                C = x*B + r*A
+            },
+            parse_quote! {
+                C = x*B + r*A
+            },
+        );
+
+        flatten_ands_tester(
+            parse_quote! {
+                AND (
+                    C = x*B + r*A,
+                    AND (
+                        D = x*B + s*A,
+                        E = x*B + t*A,
+                    ),
+                )
+            },
+            parse_quote! {
+                AND (
+                    C = x*B + r*A,
+                    D = x*B + s*A,
+                    E = x*B + t*A,
+                )
+            },
+        );
+
+        flatten_ands_tester(
+            parse_quote! {
+                AND (
+                    AND (
+                        OR (
+                            D = B + s*A,
+                            D = s*A,
+                        ),
+                        D = x*B + t*A,
+                    ),
+                    C = x*B + r*A,
+                )
+            },
+            parse_quote! {
+                AND (
+                    OR (
+                        D = B + s*A,
+                        D = s*A,
+                    ),
+                    D = x*B + t*A,
+                    C = x*B + r*A,
+                )
+            },
+        );
+
+        flatten_ands_tester(
+            parse_quote! {
+                AND (
+                    AND (
+                        OR (
+                            D = B + s*A,
+                            AND (
+                                D = s*A,
+                                AND (
+                                    E = s*B,
+                                    F = s*C,
+                                ),
+                            ),
+                        ),
+                        D = x*B + t*A,
+                    ),
+                    C = x*B + r*A,
+                )
+            },
+            parse_quote! {
+                AND (
+                    OR (
+                        D = B + s*A,
+                        AND (
+                            D = s*A,
+                            E = s*B,
+                            F = s*C,
+                        )
+                    ),
+                    D = x*B + t*A,
+                    C = x*B + r*A,
+                )
+            },
+        );
+    }
 }