Browse Source

StatementTree::for_each_disjunction_branch_leaf method

Ian Goldberg 4 months ago
parent
commit
081b1c8369
1 changed files with 170 additions and 0 deletions
  1. 170 0
      sigma_compiler_core/src/sigma/combiners.rs

+ 170 - 0
sigma_compiler_core/src/sigma/combiners.rs

@@ -394,6 +394,34 @@ impl StatementTree {
         Ok(())
     }
 
+    /// Call the supplied closure for each [`StatementTree::Leaf`] of
+    /// the given [disjunction branch].
+    ///
+    /// Abort and return `Err` if any call to the closure returns `Err`.
+    ///
+    /// [disjunction branch]: StatementTree::check_disjunction_invariant
+    pub fn for_each_disjunction_branch_leaf(
+        &mut self,
+        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
+    ) -> Result<()> {
+        match self {
+            StatementTree::Leaf(_) => {
+                (closure)(self)?;
+            }
+            StatementTree::And(stvec) => {
+                stvec
+                    .iter_mut()
+                    .try_for_each(|st| st.for_each_disjunction_branch_leaf(closure))?;
+            }
+            StatementTree::Or(_) | StatementTree::Thresh(_, _) => {
+                // Don't recurse into Or or Thresh nodes, since the
+                // children of those nodes are in different disjunction
+                // branches.
+            }
+        }
+        Ok(())
+    }
+
     #[cfg(not(doctest))]
     /// Flatten nested `And` nodes in a [`StatementTree`].
     ///
@@ -938,6 +966,148 @@ mod test {
         );
     }
 
+    fn disjunction_branch_leaf_tester(e: Expr, expected: Vec<Vec<Expr>>) {
+        let mut output: Vec<Vec<StatementTree>> = Vec::new();
+        let expected_st: Vec<Vec<StatementTree>> = expected
+            .iter()
+            .map(|vex| {
+                vex.iter()
+                    .map(|ex| StatementTree::parse(&ex).unwrap())
+                    .collect()
+            })
+            .collect();
+        let mut st = StatementTree::parse(&e).unwrap();
+        st.for_each_disjunction_branch(&mut |db| {
+            let mut dis_branch_output: Vec<StatementTree> = Vec::new();
+            db.for_each_disjunction_branch_leaf(&mut |leaf| {
+                dis_branch_output.push(leaf.clone());
+                Ok(())
+            })
+            .unwrap();
+            output.push(dis_branch_output);
+            Ok(())
+        })
+        .unwrap();
+        assert_eq!(output, expected_st);
+    }
+
+    fn disjunction_branch_leaf_abort_tester(e: Expr, expected: Vec<Vec<Expr>>) {
+        let mut output: Vec<Vec<StatementTree>> = Vec::new();
+        let expected_st: Vec<Vec<StatementTree>> = expected
+            .iter()
+            .map(|vex| {
+                vex.iter()
+                    .map(|ex| StatementTree::parse(&ex).unwrap())
+                    .collect()
+            })
+            .collect();
+        let mut st = StatementTree::parse(&e).unwrap();
+        st.for_each_disjunction_branch(&mut |db| {
+            let mut dis_branch_output: Vec<StatementTree> = Vec::new();
+            db.for_each_disjunction_branch_leaf(&mut |leaf| {
+                if leaf.is_leaf_true() {
+                    return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
+                }
+                dis_branch_output.push(leaf.clone());
+                Ok(())
+            })?;
+            output.push(dis_branch_output);
+            Ok(())
+        })
+        .unwrap_err();
+        assert_eq!(output, expected_st);
+    }
+
+    #[test]
+    fn disjunction_branch_leaf_test() {
+        disjunction_branch_leaf_tester(
+            parse_quote! {
+                C = c*B + r*A
+            },
+            vec![vec![parse_quote! { C = c*B + r*A }]],
+        );
+
+        disjunction_branch_leaf_tester(
+            parse_quote! {
+               AND (
+                   C = c*B + r*A,
+                   D = d*B + s*A,
+                   OR (
+                       c = d,
+                       c = d + 1,
+                   )
+               )
+            },
+            vec![
+                vec![
+                    parse_quote! { C = c*B + r*A },
+                    parse_quote! { D = d*B + s*A },
+                ],
+                vec![parse_quote! { c = d }],
+                vec![parse_quote! { c = d + 1 }],
+            ],
+        );
+
+        disjunction_branch_leaf_tester(
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                d = 6,
+                            )
+                        ),
+                        c = d + 1,
+                    )
+                )
+            },
+            vec![
+                vec![
+                    parse_quote! { C = c*B + r*A },
+                    parse_quote! { D = d*B + s*A },
+                ],
+                vec![parse_quote! { c = d }, parse_quote! { D = a*B + b*A }],
+                vec![parse_quote! { d = 5 }],
+                vec![parse_quote! { d = 6 }],
+                vec![parse_quote! { c = d + 1 }],
+            ],
+        );
+
+        disjunction_branch_leaf_abort_tester(
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
+                    OR (
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                true,
+                                d = 6,
+                            )
+                        ),
+                        c = d + 1,
+                    )
+                )
+            },
+            vec![
+                vec![
+                    parse_quote! { C = c*B + r*A },
+                    parse_quote! { D = d*B + s*A },
+                ],
+                vec![parse_quote! { c = d }, parse_quote! { D = a*B + b*A }],
+                vec![parse_quote! { d = 5 }],
+            ],
+        );
+    }
+
     fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
         let mut st = StatementTree::parse(&e).unwrap();
         st.flatten_ands();