Browse Source

StatementTree::for_each_disjunction_branch method

Ian Goldberg 4 months ago
parent
commit
9985078308
1 changed files with 229 additions and 0 deletions
  1. 229 0
      sigma_compiler_core/src/sigma/combiners.rs

+ 229 - 0
sigma_compiler_core/src/sigma/combiners.rs

@@ -354,6 +354,46 @@ impl StatementTree {
         Ok(next_node)
     }
 
+    /// Call the supplied closure for each [disjunction branch] of the
+    /// given [`StatementTree`] (including the root).
+    ///
+    /// The calls are in preorder traversal (parents before children).
+    /// Abort and return `Err` if any call to the closure returns `Err`.
+    ///
+    /// [disjunction branch]: StatementTree::check_disjunction_invariant
+    pub fn for_each_disjunction_branch(
+        &mut self,
+        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
+    ) -> Result<()> {
+        self.for_each_disjunction_branch_rec(closure, true)
+    }
+
+    /// Internal recursive helper for
+    /// [`for_each_disjunction_branch`](StatementTree::for_each_disjunction_branch)
+    fn for_each_disjunction_branch_rec(
+        &mut self,
+        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
+        is_new_branch: bool,
+    ) -> Result<()> {
+        if is_new_branch {
+            (closure)(self)?;
+        }
+        match self {
+            StatementTree::Leaf(_) => {}
+            StatementTree::And(stvec) => {
+                stvec
+                    .iter_mut()
+                    .try_for_each(|st| st.for_each_disjunction_branch_rec(closure, false))?;
+            }
+            StatementTree::Or(stvec) | StatementTree::Thresh(_, stvec) => {
+                stvec
+                    .iter_mut()
+                    .try_for_each(|st| st.for_each_disjunction_branch_rec(closure, true))?;
+            }
+        }
+        Ok(())
+    }
+
     #[cfg(not(doctest))]
     /// Flatten nested `And` nodes in a [`StatementTree`].
     ///
@@ -709,6 +749,195 @@ mod test {
         st_nok3.check_disjunction_invariant(&vars).unwrap_err();
     }
 
+    fn disjunction_branch_tester(e: Expr, expected: Vec<Expr>) {
+        let mut output: Vec<StatementTree> = Vec::new();
+        let expected_st: Vec<StatementTree> = expected
+            .iter()
+            .map(|ex| StatementTree::parse(&ex).unwrap())
+            .collect();
+        let mut st = StatementTree::parse(&e).unwrap();
+        st.for_each_disjunction_branch(&mut |db| {
+            output.push(db.clone());
+            Ok(())
+        })
+        .unwrap();
+        assert_eq!(output, expected_st);
+    }
+
+    fn disjunction_branch_abort_tester(e: Expr, expected: Vec<Expr>) {
+        let mut output: Vec<StatementTree> = Vec::new();
+        let expected_st: Vec<StatementTree> = expected
+            .iter()
+            .map(|ex| StatementTree::parse(&ex).unwrap())
+            .collect();
+        let mut st = StatementTree::parse(&e).unwrap();
+        st.for_each_disjunction_branch(&mut |st| {
+            if st.is_leaf_true() {
+                return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
+            }
+            output.push(st.clone());
+            Ok(())
+        })
+        .unwrap_err();
+        assert_eq!(output, expected_st);
+    }
+
+    #[test]
+    fn disjunction_branch_test() {
+        disjunction_branch_tester(
+            parse_quote! {
+                C = c*B + r*A
+            },
+            vec![parse_quote! {
+                C = c*B + r*A
+            }],
+        );
+
+        disjunction_branch_tester(
+            parse_quote! {
+               AND (
+                   C = c*B + r*A,
+                   D = d*B + s*A,
+                   OR (
+                       c = d,
+                       c = d + 1,
+                   )
+               )
+            },
+            vec![
+                parse_quote! {
+                   AND (
+                       C = c*B + r*A,
+                       D = d*B + s*A,
+                       OR (
+                           c = d,
+                           c = d + 1,
+                       )
+                   )
+                },
+                parse_quote! {
+                    c = d
+                },
+                parse_quote! {
+                    c = d + 1
+                },
+            ],
+        );
+
+        disjunction_branch_tester(
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                d = 6,
+                            )
+                        ),
+                        c = d + 1,
+                    )
+                )
+            },
+            vec![
+                parse_quote! {
+                    AND (
+                        C = c*B + r*A,
+                        D = d*B + s*A,
+                        OR (
+                            AND (
+                                c = d,
+                                D = a*B + b*A,
+                                OR (
+                                    d = 5,
+                                    d = 6,
+                                )
+                            ),
+                            c = d + 1,
+                        )
+                    )
+                },
+                parse_quote! {
+                    AND (
+                        c = d,
+                        D = a*B + b*A,
+                        OR (
+                            d = 5,
+                            d = 6,
+                        )
+                    )
+                },
+                parse_quote! {
+                    d = 5
+                },
+                parse_quote! {
+                    d = 6
+                },
+                parse_quote! {
+                    c = d + 1
+                },
+            ],
+        );
+
+        disjunction_branch_abort_tester(
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                true,
+                                d = 6,
+                            )
+                        ),
+                        c = d + 1,
+                    )
+                )
+            },
+            vec![
+                parse_quote! {
+                    AND (
+                        C = c*B + r*A,
+                        D = d*B + s*A,
+                        OR (
+                            AND (
+                                c = d,
+                                D = a*B + b*A,
+                                OR (
+                                    d = 5,
+                                    true,
+                                    d = 6,
+                                )
+                            ),
+                            c = d + 1,
+                        )
+                    )
+                },
+                parse_quote! {
+                    AND (
+                        c = d,
+                        D = a*B + b*A,
+                        OR (
+                            d = 5,
+                            true,
+                            d = 6,
+                        )
+                    )
+                },
+                parse_quote! {
+                    d = 5
+                },
+            ],
+        );
+    }
+
     fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
         let mut st = StatementTree::parse(&e).unwrap();
         st.flatten_ands();