Bladeren bron

Have for_each_disjunction_branch pass the path to each disjunction branch to the closure

The path is specified by a &[usize]
Ian Goldberg 4 maanden geleden
bovenliggende
commit
646f409856
1 gewijzigde bestanden met toevoegingen van 371 en 151 verwijderingen
  1. 371 151
      sigma_compiler_core/src/sigma/combiners.rs

+ 371 - 151
sigma_compiler_core/src/sigma/combiners.rs

@@ -206,7 +206,7 @@ impl StatementTree {
     /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
     ///
     /// A _disjunction branch_ is a subtree rooted at a non-disjuction
-    /// node that is the child of a disjunction node, or at the root of
+    /// node that is the child of a disjunction node or at the root of
     /// the [`StatementTree`].
     ///
     /// The _disjunction invariant_ is that a private variable (which is
@@ -356,33 +356,55 @@ impl StatementTree {
     }
 
     /// Call the supplied closure for each [disjunction branch] of the
-    /// given [`StatementTree`] (including the root).
+    /// given [`StatementTree`] (including the root, if the root is a
+    /// non-disjunction node).
     ///
     /// The calls are in preorder traversal (parents before children).
+    /// The given `closure` will be called with the root of each
+    /// [disjunction branch] as well as a slice of [`usize`] indicating
+    /// the path through the [`StatementTree`] to that disjunction
+    /// branch.  The disjunction branch at the root has path `[]`.
+    /// The disjunction branch rooted at, say, the 2nd child of an `Or`
+    /// node in the root disjunction branch will have path `[2]`.  The
+    /// disjunction branch rooted at the 1st child of an `Or` node in
+    /// that disjunction branch will have path `[2,1]`, and so on.
+    ///
     /// Abort and return `Err` if any call to the closure returns `Err`.
     ///
     /// [disjunction branch]: StatementTree::check_disjunction_invariant
     pub fn for_each_disjunction_branch(
         &mut self,
-        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
+        closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
     ) -> Result<()> {
-        self.for_each_disjunction_branch_rec(closure, true)
+        let mut path: Vec<usize> = Vec::new();
+        self.for_each_disjunction_branch_rec(closure, &mut path, 0, true)?;
+        Ok(())
     }
 
     /// Internal recursive helper for
-    /// [`for_each_disjunction_branch`](StatementTree::for_each_disjunction_branch)
+    /// [`for_each_disjunction_branch`](StatementTree::for_each_disjunction_branch).
+    ///
+    ///   - `path` is the path to this disjunction branch
+    ///   - `last_index` is the last index used for a child of this
+    ///     disjunction branch
+    ///   - `is_new_branch` is `true` if this node is the start of a new
+    ///     disjunction branch
+    ///
+    /// The return value (if `Ok`) is the updated value of `last_index`.
     fn for_each_disjunction_branch_rec(
         &mut self,
-        closure: &mut dyn FnMut(&mut StatementTree) -> Result<()>,
+        closure: &mut dyn FnMut(&mut StatementTree, &[usize]) -> Result<()>,
+        path: &mut Vec<usize>,
+        mut last_index: usize,
         is_new_branch: bool,
-    ) -> Result<()> {
+    ) -> Result<usize> {
         // We're starting a new branch (and should call the closure) if
         // and only if both is_new_branch is true, and also we're at a
         // non-disjuction node
         match self {
             StatementTree::Leaf(_) | StatementTree::And(_) => {
                 if is_new_branch {
-                    (closure)(self)?;
+                    (closure)(self, path)?;
                 }
             }
             _ => {}
@@ -390,17 +412,25 @@ impl StatementTree {
         match self {
             StatementTree::Leaf(_) => {}
             StatementTree::And(stvec) => {
-                stvec
-                    .iter_mut()
-                    .try_for_each(|st| st.for_each_disjunction_branch_rec(closure, false))?;
+                stvec.iter_mut().try_for_each(|st| -> Result<()> {
+                    last_index =
+                        st.for_each_disjunction_branch_rec(closure, path, last_index, false)?;
+                    Ok(())
+                })?;
             }
             StatementTree::Or(stvec) | StatementTree::Thresh(_, stvec) => {
-                stvec
-                    .iter_mut()
-                    .try_for_each(|st| st.for_each_disjunction_branch_rec(closure, true))?;
+                path.push(last_index);
+                let pathlen = path.len();
+                stvec.iter_mut().try_for_each(|st| -> Result<()> {
+                    last_index += 1;
+                    path[pathlen - 1] = last_index;
+                    st.for_each_disjunction_branch_rec(closure, path, 0, true)?;
+                    Ok(())
+                })?;
+                path.pop();
             }
         }
-        Ok(())
+        Ok(last_index)
     }
 
     /// Call the supplied closure for each [`StatementTree::Leaf`] of
@@ -786,33 +816,33 @@ mod test {
         st_nok3.check_disjunction_invariant(&vars).unwrap_err();
     }
 
-    fn disjunction_branch_tester(e: Expr, expected: Vec<Expr>) {
-        let mut output: Vec<StatementTree> = Vec::new();
-        let expected_st: Vec<StatementTree> = expected
+    fn disjunction_branch_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
+        let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
+        let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
             .iter()
-            .map(|ex| StatementTree::parse(&ex).unwrap())
+            .map(|(path, ex)| (path.clone(), StatementTree::parse(&ex).unwrap()))
             .collect();
         let mut st = StatementTree::parse(&e).unwrap();
-        st.for_each_disjunction_branch(&mut |db| {
-            output.push(db.clone());
+        st.for_each_disjunction_branch(&mut |db, path| {
+            output.push((path.to_vec(), db.clone()));
             Ok(())
         })
         .unwrap();
         assert_eq!(output, expected_st);
     }
 
-    fn disjunction_branch_abort_tester(e: Expr, expected: Vec<Expr>) {
-        let mut output: Vec<StatementTree> = Vec::new();
-        let expected_st: Vec<StatementTree> = expected
+    fn disjunction_branch_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Expr)>) {
+        let mut output: Vec<(Vec<usize>, StatementTree)> = Vec::new();
+        let expected_st: Vec<(Vec<usize>, StatementTree)> = expected
             .iter()
-            .map(|ex| StatementTree::parse(&ex).unwrap())
+            .map(|(path, ex)| (path.clone(), StatementTree::parse(&ex).unwrap()))
             .collect();
         let mut st = StatementTree::parse(&e).unwrap();
-        st.for_each_disjunction_branch(&mut |st| {
+        st.for_each_disjunction_branch(&mut |st, path| {
             if st.is_leaf_true() {
                 return Err(syn::Error::new(proc_macro2::Span::call_site(), "true leaf"));
             }
-            output.push(st.clone());
+            output.push((path.to_vec(), st.clone()));
             Ok(())
         })
         .unwrap_err();
@@ -825,9 +855,12 @@ mod test {
             parse_quote! {
                 C = c*B + r*A
             },
-            vec![parse_quote! {
-                C = c*B + r*A
-            }],
+            vec![(
+                vec![],
+                parse_quote! {
+                    C = c*B + r*A
+                },
+            )],
         );
 
         disjunction_branch_tester(
@@ -842,22 +875,44 @@ mod test {
                )
             },
             vec![
-                parse_quote! {
-                   AND (
-                       C = c*B + r*A,
-                       D = d*B + s*A,
-                       OR (
-                           c = d,
-                           c = d + 1,
+                (
+                    vec![],
+                    parse_quote! {
+                       AND (
+                           C = c*B + r*A,
+                           D = d*B + s*A,
+                           OR (
+                               c = d,
+                               c = d + 1,
+                           )
                        )
-                   )
-                },
-                parse_quote! {
-                    c = d
-                },
-                parse_quote! {
-                    c = d + 1
-                },
+                    },
+                ),
+                (
+                    vec![1],
+                    parse_quote! {
+                        c = d
+                    },
+                ),
+                (
+                    vec![2],
+                    parse_quote! {
+                        c = d + 1
+                    },
+                ),
+            ],
+        );
+
+        disjunction_branch_tester(
+            parse_quote! {
+                OR (
+                    C = c*B + r*A,
+                    D = c*B + r*A,
+                )
+            },
+            vec![
+                (vec![1], parse_quote! { C = c*B + r*A }),
+                (vec![2], parse_quote! { D = c*B + r*A }),
             ],
         );
 
@@ -880,42 +935,166 @@ mod test {
                 )
             },
             vec![
-                parse_quote! {
+                (
+                    vec![],
+                    parse_quote! {
+                        AND (
+                            C = c*B + r*A,
+                            D = d*B + s*A,
+                            OR (
+                                AND (
+                                    c = d,
+                                    D = a*B + b*A,
+                                    OR (
+                                        d = 5,
+                                        d = 6,
+                                    )
+                                ),
+                                c = d + 1,
+                            )
+                        )
+                    },
+                ),
+                (
+                    vec![1],
+                    parse_quote! {
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                d = 6,
+                            )
+                        )
+                    },
+                ),
+                (
+                    vec![1, 1],
+                    parse_quote! {
+                        d = 5
+                    },
+                ),
+                (
+                    vec![1, 2],
+                    parse_quote! {
+                        d = 6
+                    },
+                ),
+                (
+                    vec![2],
+                    parse_quote! {
+                        c = d + 1
+                    },
+                ),
+            ],
+        );
+
+        disjunction_branch_tester(
+            parse_quote! {
+                AND (
+                    C = c*B + r*A,
+                    D = d*B + s*A,
                     AND (
-                        C = c*B + r*A,
-                        D = d*B + s*A,
-                        OR (
+                        c = d + 1,
+                        AND (
+                            s = r,
+                            OR (
+                                d = 1,
+                                AND (
+                                    d = 2,
+                                    s = 1,
+                                )
+                            )
+                        )
+                    ),
+                    OR (
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                d = 6,
+                            )
+                        ),
+                        c = d + 1,
+                    )
+                )
+            },
+            vec![
+                (
+                    vec![],
+                    parse_quote! {
+                        AND (
+                            C = c*B + r*A,
+                            D = d*B + s*A,
                             AND (
-                                c = d,
-                                D = a*B + b*A,
-                                OR (
-                                    d = 5,
-                                    d = 6,
+                                c = d + 1,
+                                AND (
+                                    s = r,
+                                    OR (
+                                        d = 1,
+                                        AND (
+                                            d = 2,
+                                            s = 1,
+                                        )
+                                    )
                                 )
                             ),
-                            c = d + 1,
+                            OR (
+                                AND (
+                                    c = d,
+                                    D = a*B + b*A,
+                                    OR (
+                                        d = 5,
+                                        d = 6,
+                                    )
+                                ),
+                                c = d + 1,
+                            )
                         )
-                    )
-                },
-                parse_quote! {
-                    AND (
-                        c = d,
-                        D = a*B + b*A,
-                        OR (
-                            d = 5,
-                            d = 6,
+                    },
+                ),
+                (vec![1], parse_quote! { d = 1 }),
+                (
+                    vec![2],
+                    parse_quote! {
+                        AND (
+                            d = 2,
+                            s = 1,
                         )
-                    )
-                },
-                parse_quote! {
-                    d = 5
-                },
-                parse_quote! {
-                    d = 6
-                },
-                parse_quote! {
-                    c = d + 1
-                },
+                    },
+                ),
+                (
+                    vec![3],
+                    parse_quote! {
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                d = 6,
+                            )
+                        )
+                    },
+                ),
+                (
+                    vec![3, 1],
+                    parse_quote! {
+                        d = 5
+                    },
+                ),
+                (
+                    vec![3, 2],
+                    parse_quote! {
+                        d = 6
+                    },
+                ),
+                (
+                    vec![4],
+                    parse_quote! {
+                        c = d + 1
+                    },
+                ),
             ],
         );
 
@@ -939,79 +1118,94 @@ mod test {
                 )
             },
             vec![
-                parse_quote! {
-                    AND (
-                        C = c*B + r*A,
-                        D = d*B + s*A,
-                        OR (
-                            AND (
-                                c = d,
-                                D = a*B + b*A,
-                                OR (
-                                    d = 5,
-                                    true,
-                                    d = 6,
-                                )
-                            ),
-                            c = d + 1,
+                (
+                    vec![],
+                    parse_quote! {
+                        AND (
+                            C = c*B + r*A,
+                            D = d*B + s*A,
+                            OR (
+                                AND (
+                                    c = d,
+                                    D = a*B + b*A,
+                                    OR (
+                                        d = 5,
+                                        true,
+                                        d = 6,
+                                    )
+                                ),
+                                c = d + 1,
+                            )
                         )
-                    )
-                },
-                parse_quote! {
-                    AND (
-                        c = d,
-                        D = a*B + b*A,
-                        OR (
-                            d = 5,
-                            true,
-                            d = 6,
+                    },
+                ),
+                (
+                    vec![1],
+                    parse_quote! {
+                        AND (
+                            c = d,
+                            D = a*B + b*A,
+                            OR (
+                                d = 5,
+                                true,
+                                d = 6,
+                            )
                         )
-                    )
-                },
-                parse_quote! {
-                    d = 5
-                },
+                    },
+                ),
+                (
+                    vec![1, 1],
+                    parse_quote! {
+                        d = 5
+                    },
+                ),
             ],
         );
     }
 
-    fn disjunction_branch_leaf_tester(e: Expr, expected: Vec<Vec<Expr>>) {
-        let mut output: Vec<Vec<StatementTree>> = Vec::new();
-        let expected_st: Vec<Vec<StatementTree>> = expected
+    fn disjunction_branch_leaf_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
+        let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
+        let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
             .iter()
-            .map(|vex| {
-                vex.iter()
-                    .map(|ex| StatementTree::parse(&ex).unwrap())
-                    .collect()
+            .map(|(path, vex)| {
+                (
+                    path.clone(),
+                    vex.iter()
+                        .map(|ex| StatementTree::parse(&ex).unwrap())
+                        .collect(),
+                )
             })
             .collect();
         let mut st = StatementTree::parse(&e).unwrap();
-        st.for_each_disjunction_branch(&mut |db| {
+        st.for_each_disjunction_branch(&mut |db, path| {
             let mut dis_branch_output: Vec<StatementTree> = Vec::new();
             db.for_each_disjunction_branch_leaf(&mut |leaf| {
                 dis_branch_output.push(leaf.clone());
                 Ok(())
             })
             .unwrap();
-            output.push(dis_branch_output);
+            output.push((path.to_vec(), dis_branch_output));
             Ok(())
         })
         .unwrap();
         assert_eq!(output, expected_st);
     }
 
-    fn disjunction_branch_leaf_abort_tester(e: Expr, expected: Vec<Vec<Expr>>) {
-        let mut output: Vec<Vec<StatementTree>> = Vec::new();
-        let expected_st: Vec<Vec<StatementTree>> = expected
+    fn disjunction_branch_leaf_abort_tester(e: Expr, expected: Vec<(Vec<usize>, Vec<Expr>)>) {
+        let mut output: Vec<(Vec<usize>, Vec<StatementTree>)> = Vec::new();
+        let expected_st: Vec<(Vec<usize>, Vec<StatementTree>)> = expected
             .iter()
-            .map(|vex| {
-                vex.iter()
-                    .map(|ex| StatementTree::parse(&ex).unwrap())
-                    .collect()
+            .map(|(path, vex)| {
+                (
+                    path.clone(),
+                    vex.iter()
+                        .map(|ex| StatementTree::parse(&ex).unwrap())
+                        .collect(),
+                )
             })
             .collect();
         let mut st = StatementTree::parse(&e).unwrap();
-        st.for_each_disjunction_branch(&mut |db| {
+        st.for_each_disjunction_branch(&mut |db, path| {
             let mut dis_branch_output: Vec<StatementTree> = Vec::new();
             db.for_each_disjunction_branch_leaf(&mut |leaf| {
                 if leaf.is_leaf_true() {
@@ -1020,7 +1214,7 @@ mod test {
                 dis_branch_output.push(leaf.clone());
                 Ok(())
             })?;
-            output.push(dis_branch_output);
+            output.push((path.to_vec(), dis_branch_output));
             Ok(())
         })
         .unwrap_err();
@@ -1033,7 +1227,7 @@ mod test {
             parse_quote! {
                 C = c*B + r*A
             },
-            vec![vec![parse_quote! { C = c*B + r*A }]],
+            vec![(vec![], vec![parse_quote! { C = c*B + r*A }])],
         );
 
         disjunction_branch_leaf_tester(
@@ -1048,12 +1242,15 @@ mod test {
                )
             },
             vec![
-                vec![
-                    parse_quote! { C = c*B + r*A },
-                    parse_quote! { D = d*B + s*A },
-                ],
-                vec![parse_quote! { c = d }],
-                vec![parse_quote! { c = d + 1 }],
+                (
+                    vec![],
+                    vec![
+                        parse_quote! { C = c*B + r*A },
+                        parse_quote! { D = d*B + s*A },
+                    ],
+                ),
+                (vec![1], vec![parse_quote! { c = d }]),
+                (vec![2], vec![parse_quote! { c = d + 1 }]),
             ],
         );
 
@@ -1072,13 +1269,16 @@ mod test {
                )
             },
             vec![
-                vec![
-                    parse_quote! { C = c*B + r*A },
-                    parse_quote! { D = d*B + s*A },
-                ],
-                vec![parse_quote! { c = d }],
-                vec![parse_quote! { c = d + 1 }],
-                vec![parse_quote! { c = d + 2 }],
+                (
+                    vec![],
+                    vec![
+                        parse_quote! { C = c*B + r*A },
+                        parse_quote! { D = d*B + s*A },
+                    ],
+                ),
+                (vec![1], vec![parse_quote! { c = d }]),
+                (vec![2, 1], vec![parse_quote! { c = d + 1 }]),
+                (vec![2, 2], vec![parse_quote! { c = d + 2 }]),
             ],
         );
 
@@ -1101,14 +1301,24 @@ mod test {
                 )
             },
             vec![
-                vec![
-                    parse_quote! { C = c*B + r*A },
-                    parse_quote! { D = d*B + s*A },
-                ],
-                vec![parse_quote! { c = d }, parse_quote! { D = a*B + b*A }],
-                vec![parse_quote! { d = 5 }],
-                vec![parse_quote! { d = 6 }],
-                vec![parse_quote! { c = d + 1 }],
+                (
+                    vec![],
+                    vec![
+                        parse_quote! { C = c*B + r*A },
+                        parse_quote! { D = d*B + s*A },
+                    ],
+                ),
+                (
+                    vec![1],
+                    vec![
+                        parse_quote! { c = d },
+                        parse_quote! { D
+                        = a*B + b*A },
+                    ],
+                ),
+                (vec![1, 1], vec![parse_quote! { d = 5 }]),
+                (vec![1, 2], vec![parse_quote! { d = 6 }]),
+                (vec![2], vec![parse_quote! { c = d + 1 }]),
             ],
         );
 
@@ -1132,12 +1342,22 @@ mod test {
                 )
             },
             vec![
-                vec![
-                    parse_quote! { C = c*B + r*A },
-                    parse_quote! { D = d*B + s*A },
-                ],
-                vec![parse_quote! { c = d }, parse_quote! { D = a*B + b*A }],
-                vec![parse_quote! { d = 5 }],
+                (
+                    vec![],
+                    vec![
+                        parse_quote! { C = c*B + r*A },
+                        parse_quote! { D = d*B + s*A },
+                    ],
+                ),
+                (
+                    vec![1],
+                    vec![
+                        parse_quote! { c = d },
+                        parse_quote! { D
+                        = a*B + b*A },
+                    ],
+                ),
+                (vec![1, 1], vec![parse_quote! { d = 5 }]),
             ],
         );
     }