combiners.rs 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. //! This module creates and manipulates trees of basic statements
  2. //! combined with `AND`, `OR`, and `THRESH`.
  3. use super::types::*;
  4. use quote::quote;
  5. use std::collections::HashMap;
  6. use syn::parse::Result;
  7. use syn::visit::Visit;
  8. use syn::{parse_quote, Expr};
  9. /// For each [`Ident`](struct@syn::Ident) representing a private
  10. /// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
  11. /// call a given closure.
  12. pub struct PrivScalarMap<'a> {
  13. /// The [`VarDict`] that maps variable names to their types
  14. pub vars: &'a VarDict,
  15. /// The closure that is called for each [`Ident`](struct@syn::Ident)
  16. /// found in the [`Expr`] (provided in the call to
  17. /// [`visit_expr`](PrivScalarMap::visit_expr)) that represents a
  18. /// private `Scalar`
  19. pub closure: &'a mut dyn FnMut(&syn::Ident) -> Result<()>,
  20. /// The accumulated result. This will be the first
  21. /// [`Err`](Result::Err) returned from the closure, or
  22. /// [`Ok(())`](Result::Ok) if all calls to the closure succeeded.
  23. pub result: Result<()>,
  24. }
  25. impl<'a> Visit<'a> for PrivScalarMap<'a> {
  26. fn visit_path(&mut self, path: &'a syn::Path) {
  27. // Whenever we see a `Path`, check first if it's just a bare
  28. // `Ident`
  29. let Some(id) = path.get_ident() else {
  30. return;
  31. };
  32. // Then check if that `Ident` appears in the `VarDict`
  33. let Some(vartype) = self.vars.get(&id.to_string()) else {
  34. return;
  35. };
  36. // If so, and the `Ident` represents a private Scalar,
  37. // call the closure if we haven't seen an `Err` returned from
  38. // the closure yet.
  39. if let AExprType::Scalar { is_pub: false, .. } = vartype {
  40. if self.result.is_ok() {
  41. self.result = (self.closure)(id);
  42. }
  43. }
  44. }
  45. }
  46. /// The statements in the ZKP form a tree. The leaves are basic
  47. /// statements of various kinds; for example, equations or inequalities
  48. /// about Scalars and Points. The interior nodes are combiners: `And`,
  49. /// `Or`, or `Thresh` (with a given constant threshold). A leaf is true
  50. /// if the basic statement it contains is true. An `And` node is true
  51. /// if all of its children are true. An `Or` node is true if at least
  52. /// one of its children is true. A `Thresh` node (with threshold `k`) is
  53. /// true if at least `k` of its children are true.
  54. #[derive(Clone, Debug, Eq, PartialEq)]
  55. pub enum StatementTree {
  56. Leaf(Expr),
  57. And(Vec<StatementTree>),
  58. Or(Vec<StatementTree>),
  59. Thresh(usize, Vec<StatementTree>),
  60. }
  61. impl StatementTree {
  62. #[cfg(not(doctest))]
  63. /// Parse an [`Expr`] (which may contain nested `AND`, `OR`, or
  64. /// `THRESH`) into a [`StatementTree`]. For example, the
  65. /// [`Expr`] obtained from:
  66. /// ```
  67. /// parse_quote! {
  68. /// AND (
  69. /// C = c*B + r*A,
  70. /// D = d*B + s*A,
  71. /// OR (
  72. /// AND (
  73. /// C = c0*B + r0*A,
  74. /// D = d0*B + s0*A,
  75. /// c0 = d0,
  76. /// ),
  77. /// AND (
  78. /// C = c1*B + r1*A,
  79. /// D = d1*B + s1*A,
  80. /// c1 = d1 + 1,
  81. /// ),
  82. /// )
  83. /// )
  84. /// }
  85. /// ```
  86. ///
  87. /// would yield a [`StatementTree::And`] containing a 3-element
  88. /// vector. The first two elements are [`StatementTree::Leaf`], and
  89. /// the third is [`StatementTree::Or`] containing a 2-element
  90. /// vector. Each element is an [`StatementTree::And`] with a vector
  91. /// containing 3 [`StatementTree::Leaf`]s.
  92. ///
  93. /// Note that `AND`, `OR`, and `THRESH` in the expression are
  94. /// case-insensitive.
  95. pub fn parse(expr: &Expr) -> Result<Self> {
  96. // See if the expression describes a combiner
  97. if let Expr::Call(syn::ExprCall { func, args, .. }) = expr {
  98. if let Expr::Path(syn::ExprPath { path, .. }) = func.as_ref() {
  99. if let Some(funcname) = path.get_ident() {
  100. match funcname.to_string().to_lowercase().as_str() {
  101. "and" => {
  102. let children: Result<Vec<StatementTree>> =
  103. args.iter().map(Self::parse).collect();
  104. return Ok(Self::And(children?));
  105. }
  106. "or" => {
  107. let children: Result<Vec<StatementTree>> =
  108. args.iter().map(Self::parse).collect();
  109. return Ok(Self::Or(children?));
  110. }
  111. "thresh" => {
  112. if let Some(Expr::Lit(syn::ExprLit {
  113. lit: syn::Lit::Int(litint),
  114. ..
  115. })) = args.first()
  116. {
  117. let thresh = litint.base10_parse::<usize>()?;
  118. // Remember that args.len() is one more
  119. // than the number of expressions,
  120. // because the first arg is the
  121. // threshold
  122. if thresh < 1 || thresh >= args.len() {
  123. return Err(syn::Error::new(
  124. litint.span(),
  125. "threshold out of range",
  126. ));
  127. }
  128. let children: Result<Vec<StatementTree>> =
  129. args.iter().skip(1).map(Self::parse).collect();
  130. return Ok(Self::Thresh(thresh, children?));
  131. }
  132. }
  133. _ => {}
  134. }
  135. }
  136. }
  137. }
  138. Ok(StatementTree::Leaf(expr.clone()))
  139. }
  140. /// A convenience function that takes a list of [`Expr`]s, and
  141. /// returns the [`StatementTree`] that implicitly puts `AND` around
  142. /// the [`Expr`]s. This is useful because a common thing to do is
  143. /// to just write a list of [`Expr`]s in the top-level macro
  144. /// invocation, having the semantics of "all of these must be true".
  145. pub fn parse_andlist(exprlist: &[Expr]) -> Result<Self> {
  146. let children: Result<Vec<StatementTree>> = exprlist.iter().map(Self::parse).collect();
  147. Ok(StatementTree::And(children?))
  148. }
  149. /// Return a vector of references to all of the leaf expressions in
  150. /// the [`StatementTree`]
  151. pub fn leaves(&self) -> Vec<&Expr> {
  152. match self {
  153. StatementTree::Leaf(ref e) => vec![e],
  154. StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
  155. v.iter().fold(Vec::<&Expr>::new(), |mut b, st| {
  156. b.extend(st.leaves());
  157. b
  158. })
  159. }
  160. }
  161. }
  162. /// Return a vector of mutable references to all of the leaf
  163. /// expressions in the [`StatementTree`]
  164. pub fn leaves_mut(&mut self) -> Vec<&mut Expr> {
  165. match self {
  166. StatementTree::Leaf(ref mut e) => vec![e],
  167. StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => {
  168. v.iter_mut().fold(Vec::<&mut Expr>::new(), |mut b, st| {
  169. b.extend(st.leaves_mut());
  170. b
  171. })
  172. }
  173. }
  174. }
  175. /// Return a vector of mutable references to all of the leaves in
  176. /// the [`StatementTree`]
  177. pub fn leaves_st_mut(&mut self) -> Vec<&mut StatementTree> {
  178. match self {
  179. StatementTree::Leaf(_) => vec![self],
  180. StatementTree::And(v) | StatementTree::Or(v) | StatementTree::Thresh(_, v) => v
  181. .iter_mut()
  182. .fold(Vec::<&mut StatementTree>::new(), |mut b, st| {
  183. b.extend(st.leaves_st_mut());
  184. b
  185. }),
  186. }
  187. }
  188. #[cfg(not(doctest))]
  189. /// Verify whether the [`StatementTree`] satisfies the disjunction
  190. /// invariant.
  191. ///
  192. /// A _disjunction node_ is an [`Or`](StatementTree::Or) or
  193. /// [`Thresh`](StatementTree::Thresh) node in the [`StatementTree`].
  194. ///
  195. /// A _disjunction branch_ is a subtree rooted at the child of a
  196. /// disjunction node, or at the root of the [`StatementTree`].
  197. ///
  198. /// The _disjunction invariant_ is that a private variable (which is
  199. /// necessarily a `Scalar` since there are no private `Point`
  200. /// variables) that appears in a disjunction branch cannot also
  201. /// appear outside of that disjunction branch.
  202. ///
  203. /// For example, if all of the lowercase variables are private
  204. /// `Scalar`s, the [`StatementTree`] created from:
  205. ///
  206. /// ```
  207. /// AND (
  208. /// C = c*B + r*A,
  209. /// D = d*B + s*A,
  210. /// OR (
  211. /// AND (
  212. /// C = c0*B + r0*A,
  213. /// D = d0*B + s0*A,
  214. /// c0 = d0,
  215. /// ),
  216. /// AND (
  217. /// C = c1*B + r1*A,
  218. /// D = d1*B + s1*A,
  219. /// c1 = d1 + 1,
  220. /// ),
  221. /// )
  222. /// )
  223. /// ```
  224. ///
  225. /// satisfies the disjunction invariant, but
  226. ///
  227. /// ```
  228. /// AND (
  229. /// C = c*B + r*A,
  230. /// D = d*B + s*A,
  231. /// OR (
  232. /// AND (
  233. /// D = d0*B + s0*A,
  234. /// c = d0,
  235. /// ),
  236. /// AND (
  237. /// C = c1*B + r1*A,
  238. /// D = d1*B + s1*A,
  239. /// c1 = d1 + 1,
  240. /// ),
  241. /// )
  242. /// )
  243. /// ```
  244. ///
  245. /// does not, because `c` appears in the first child of the `OR` and
  246. /// also outside of the `OR` entirely. Indeed, the reason to write
  247. /// the first expression above rather than the more natural
  248. ///
  249. /// ```
  250. /// AND (
  251. /// C = c*B + r*A,
  252. /// D = d*B + s*A,
  253. /// OR (
  254. /// c = d,
  255. /// c = d + 1,
  256. /// )
  257. /// )
  258. /// ```
  259. ///
  260. /// is exactly that the invariant must be satisfied.
  261. ///
  262. /// (In the future, it is possible we may provide a transformer that
  263. /// will automatically convert [`StatementTree`]s to ones that
  264. /// satisfy the invariant, but for now, the user of the macro must
  265. /// manually write the statements in a form that satisfies the
  266. /// disjunction invariant.
  267. pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
  268. let mut disjunct_map: HashMap<String, usize> = HashMap::new();
  269. // If the recursive call returns Err, return that Err.
  270. // Otherwise, we don't care about the Ok(usize) returned, so
  271. // just return Ok(())
  272. self.check_disjunction_invariant_rec(vars, &mut disjunct_map, 0, 0)?;
  273. Ok(())
  274. }
  275. /// Internal recursive helper for
  276. /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant).
  277. ///
  278. /// The `disjunct_map` is a [`HashMap`] that maps the names of
  279. /// variables to an identifier of which child of a disjunction node
  280. /// the variable appears in (or the root if none). In the case of
  281. /// nested disjunction node, the closest one to the leaf is what
  282. /// matters. Nodes are numbered in pre-order fashion, starting at 0
  283. /// for the root, 1 for the first child of the root, 2 for the first
  284. /// child of node 1, etc. `cur_node` is the node id of `self`, and
  285. /// `cur_disjunct_child` is the node id of the closest child of a
  286. /// disjunction node (or 0 for the root if none). Returns the next
  287. /// node id to use in the preorder traversal.
  288. fn check_disjunction_invariant_rec(
  289. &self,
  290. vars: &VarDict,
  291. disjunct_map: &mut HashMap<String, usize>,
  292. cur_node: usize,
  293. cur_disjunct_child: usize,
  294. ) -> Result<usize> {
  295. let mut next_node = cur_node;
  296. match self {
  297. Self::And(v) => {
  298. for st in v {
  299. next_node = st.check_disjunction_invariant_rec(
  300. vars,
  301. disjunct_map,
  302. next_node + 1,
  303. cur_disjunct_child,
  304. )?;
  305. }
  306. }
  307. Self::Or(v) | Self::Thresh(_, v) => {
  308. for st in v {
  309. next_node = st.check_disjunction_invariant_rec(
  310. vars,
  311. disjunct_map,
  312. next_node + 1,
  313. next_node + 1,
  314. )?;
  315. }
  316. }
  317. Self::Leaf(e) => {
  318. let mut psmap = PrivScalarMap {
  319. vars,
  320. closure: &mut |ident| {
  321. let varname = ident.to_string();
  322. if let Some(dis_id) = disjunct_map.get(&varname) {
  323. if *dis_id != cur_disjunct_child {
  324. return Err(syn::Error::new(
  325. ident.span(),
  326. "Disjunction invariant violation: a private variable cannot appear both inside and outside a single term of an OR or THRESH"));
  327. }
  328. } else {
  329. disjunct_map.insert(varname, cur_disjunct_child);
  330. }
  331. Ok(())
  332. },
  333. result: Ok(()),
  334. };
  335. psmap.visit_expr(e);
  336. psmap.result?;
  337. }
  338. }
  339. Ok(next_node)
  340. }
  341. #[cfg(not(doctest))]
  342. /// Flatten nested `And` nodes in a [`StatementTree`].
  343. ///
  344. /// The underlying `sigma-rs` crate can share `Scalars` across
  345. /// statements that are direct children of the same `And` node, but
  346. /// not in nested `And` nodes.
  347. ///
  348. /// So a [`StatementTree`] like this:
  349. ///
  350. /// ```
  351. /// AND (
  352. /// C = x*B + r*A,
  353. /// AND (
  354. /// D = x*B + s*A,
  355. /// E = x*B + t*A,
  356. /// ),
  357. /// )
  358. /// ```
  359. ///
  360. /// Needs to be flattened to:
  361. ///
  362. /// ```
  363. /// AND (
  364. /// C = x*B + r*A,
  365. /// D = x*B + s*A,
  366. /// E = x*B + t*A,
  367. /// )
  368. /// ```
  369. pub fn flatten_ands(&mut self) {
  370. match self {
  371. StatementTree::Leaf(_) => {}
  372. StatementTree::Or(svec) | StatementTree::Thresh(_, svec) => {
  373. // Flatten each child
  374. svec.iter_mut().for_each(|st| st.flatten_ands());
  375. }
  376. StatementTree::And(svec) => {
  377. // Flatten each child, and if any of the children are
  378. // `And`s, replace that child with the list of its
  379. // children
  380. let old_svec = std::mem::take(svec);
  381. let mut new_svec: Vec<StatementTree> = Vec::new();
  382. for mut st in old_svec {
  383. st.flatten_ands();
  384. match st {
  385. StatementTree::And(mut child_svec) => {
  386. new_svec.append(&mut child_svec);
  387. }
  388. _ => {
  389. new_svec.push(st);
  390. }
  391. }
  392. }
  393. *self = StatementTree::And(new_svec);
  394. }
  395. }
  396. }
  397. /// Produce a [`StatementTree`] that represents the constant `true`
  398. pub fn leaf_true() -> StatementTree {
  399. StatementTree::Leaf(parse_quote! { true })
  400. }
  401. /// Test if the given [`StatementTree`] represents the constant `true`
  402. pub fn is_leaf_true(&self) -> bool {
  403. if let StatementTree::Leaf(Expr::Lit(exprlit)) = self {
  404. if let syn::Lit::Bool(syn::LitBool { value: true, .. }) = exprlit.lit {
  405. return true;
  406. }
  407. }
  408. false
  409. }
  410. fn dump_int(&self, depth: usize) {
  411. match self {
  412. StatementTree::Leaf(e) => {
  413. println!(
  414. "{:1$}{2},",
  415. "",
  416. depth * 2,
  417. quote! { #e }.to_string().replace('\n', " ")
  418. )
  419. }
  420. StatementTree::And(v) => {
  421. println!("{:1$}And (", "", depth * 2);
  422. v.iter().for_each(|n| n.dump_int(depth + 1));
  423. println!("{:1$})", "", depth * 2);
  424. }
  425. StatementTree::Or(v) => {
  426. println!("{:1$}Or (", "", depth * 2);
  427. v.iter().for_each(|n| n.dump_int(depth + 1));
  428. println!("{:1$})", "", depth * 2);
  429. }
  430. StatementTree::Thresh(thresh, v) => {
  431. println!("{:1$}Thresh ({2}", "", depth * 2, thresh);
  432. v.iter().for_each(|n| n.dump_int(depth + 1));
  433. println!("{:1$})", "", depth * 2);
  434. }
  435. }
  436. }
  437. pub fn dump(&self) {
  438. self.dump_int(0);
  439. }
  440. }
  441. #[cfg(test)]
  442. mod test {
  443. use super::StatementTree::*;
  444. use super::*;
  445. use quote::quote;
  446. #[test]
  447. fn leaf_true_test() {
  448. assert!(StatementTree::leaf_true().is_leaf_true());
  449. assert!(!StatementTree::Leaf(parse_quote! { false }).is_leaf_true());
  450. assert!(!StatementTree::Leaf(parse_quote! { 1 }).is_leaf_true());
  451. assert!(!StatementTree::parse(&parse_quote! {
  452. OR(1=1, a=b)
  453. })
  454. .unwrap()
  455. .is_leaf_true());
  456. }
  457. #[test]
  458. fn combiners_simple_test() {
  459. let exprlist: Vec<Expr> = vec![
  460. parse_quote! { C = c*B + r*A },
  461. parse_quote! { D = d*B + s*A },
  462. parse_quote! { c = d },
  463. ];
  464. let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
  465. let And(v) = statementtree else {
  466. panic!("Incorrect result");
  467. };
  468. let [Leaf(l0), Leaf(l1), Leaf(l2)] = v.as_slice() else {
  469. panic!("Incorrect result");
  470. };
  471. assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
  472. assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
  473. assert_eq!(quote! {#l2}.to_string(), "c = d");
  474. }
  475. #[test]
  476. fn combiners_nested_test() {
  477. let exprlist: Vec<Expr> = vec![
  478. parse_quote! { C = c*B + r*A },
  479. parse_quote! { D = d*B + s*A },
  480. parse_quote! {
  481. OR (
  482. AND (
  483. C = c0*B + r0*A,
  484. D = d0*B + s0*A,
  485. c0 = d0,
  486. ),
  487. AND (
  488. C = c1*B + r1*A,
  489. D = d1*B + s1*A,
  490. c1 = d1 + 1,
  491. ),
  492. ) },
  493. ];
  494. let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
  495. let And(v0) = statementtree else {
  496. panic!("Incorrect result");
  497. };
  498. let [Leaf(l0), Leaf(l1), Or(v1)] = v0.as_slice() else {
  499. panic!("Incorrect result");
  500. };
  501. assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
  502. assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
  503. let [And(v2), And(v3)] = v1.as_slice() else {
  504. panic!("Incorrect result");
  505. };
  506. let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
  507. panic!("Incorrect result");
  508. };
  509. assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
  510. assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
  511. assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
  512. let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
  513. panic!("Incorrect result");
  514. };
  515. assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
  516. assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
  517. assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
  518. }
  519. #[test]
  520. fn combiners_thresh_test() {
  521. let exprlist: Vec<Expr> = vec![
  522. parse_quote! { C = c*B + r*A },
  523. parse_quote! { D = d*B + s*A },
  524. parse_quote! {
  525. THRESH (1,
  526. AND (
  527. C = c0*B + r0*A,
  528. D = d0*B + s0*A,
  529. c0 = d0,
  530. ),
  531. AND (
  532. C = c1*B + r1*A,
  533. D = d1*B + s1*A,
  534. c1 = d1 + 1,
  535. ),
  536. ) },
  537. ];
  538. let statementtree = StatementTree::parse_andlist(&exprlist).unwrap();
  539. let And(v0) = statementtree else {
  540. panic!("Incorrect result");
  541. };
  542. let [Leaf(l0), Leaf(l1), Thresh(thresh, v1)] = v0.as_slice() else {
  543. panic!("Incorrect result");
  544. };
  545. assert_eq!(*thresh, 1);
  546. assert_eq!(quote! {#l0}.to_string(), "C = c * B + r * A");
  547. assert_eq!(quote! {#l1}.to_string(), "D = d * B + s * A");
  548. let [And(v2), And(v3)] = v1.as_slice() else {
  549. panic!("Incorrect result");
  550. };
  551. let [Leaf(l20), Leaf(l21), Leaf(l22)] = v2.as_slice() else {
  552. panic!("Incorrect result");
  553. };
  554. assert_eq!(quote! {#l20}.to_string(), "C = c0 * B + r0 * A");
  555. assert_eq!(quote! {#l21}.to_string(), "D = d0 * B + s0 * A");
  556. assert_eq!(quote! {#l22}.to_string(), "c0 = d0");
  557. let [Leaf(l30), Leaf(l31), Leaf(l32)] = v3.as_slice() else {
  558. panic!("Incorrect result");
  559. };
  560. assert_eq!(quote! {#l30}.to_string(), "C = c1 * B + r1 * A");
  561. assert_eq!(quote! {#l31}.to_string(), "D = d1 * B + s1 * A");
  562. assert_eq!(quote! {#l32}.to_string(), "c1 = d1 + 1");
  563. }
  564. #[test]
  565. #[should_panic]
  566. fn combiners_bad_thresh_test() {
  567. // The threshold is out of range
  568. let exprlist: Vec<Expr> = vec![
  569. parse_quote! { C = c*B + r*A },
  570. parse_quote! { D = d*B + s*A },
  571. parse_quote! {
  572. THRESH (3,
  573. AND (
  574. C = c0*B + r0*A,
  575. D = d0*B + s0*A,
  576. c0 = d0,
  577. ),
  578. AND (
  579. C = c1*B + r1*A,
  580. D = d1*B + s1*A,
  581. c1 = d1 + 1,
  582. ),
  583. ) },
  584. ];
  585. StatementTree::parse_andlist(&exprlist).unwrap();
  586. }
  587. #[test]
  588. // Test the disjunction invariant checker
  589. fn disjunction_invariant_test() {
  590. let vars: VarDict = vardict_from_strs(&[
  591. ("c", "S"),
  592. ("d", "S"),
  593. ("c0", "S"),
  594. ("c1", "S"),
  595. ("d0", "S"),
  596. ("d1", "S"),
  597. ("A", "pP"),
  598. ("B", "pP"),
  599. ("C", "pP"),
  600. ("D", "pP"),
  601. ]);
  602. // This one is OK
  603. let st_ok = StatementTree::parse(&parse_quote! {
  604. AND (
  605. C = c*B + r*A,
  606. D = d*B + s*A,
  607. OR (
  608. AND (
  609. C = c0*B + r0*A,
  610. D = d0*B + s0*A,
  611. c0 = d0,
  612. ),
  613. AND (
  614. C = c1*B + r1*A,
  615. D = d1*B + s1*A,
  616. c1 = d1 + 1,
  617. ),
  618. )
  619. )
  620. })
  621. .unwrap();
  622. // not OK: c0 appears in two branches of the OR
  623. let st_nok1 = StatementTree::parse(&parse_quote! {
  624. AND (
  625. C = c*B + r*A,
  626. D = d*B + s*A,
  627. OR (
  628. AND (
  629. C = c0*B + r0*A,
  630. D = d0*B + s0*A,
  631. c0 = d0,
  632. ),
  633. AND (
  634. C = c0*B + r0*A,
  635. D = d1*B + s1*A,
  636. c0 = d1 + 1,
  637. ),
  638. )
  639. )
  640. })
  641. .unwrap();
  642. // not OK: c appears in one branch of the OR and also outside
  643. // the OR
  644. let st_nok2 = StatementTree::parse(&parse_quote! {
  645. AND (
  646. C = c*B + r*A,
  647. D = d*B + s*A,
  648. OR (
  649. AND (
  650. D = d0*B + s0*A,
  651. c = d0,
  652. ),
  653. AND (
  654. C = c1*B + r1*A,
  655. D = d1*B + s1*A,
  656. c1 = d1 + 1,
  657. ),
  658. )
  659. )
  660. })
  661. .unwrap();
  662. // not OK: c and d appear in both branches of the OR, and also
  663. // outside it
  664. let st_nok3 = StatementTree::parse(&parse_quote! {
  665. AND (
  666. C = c*B + r*A,
  667. D = d*B + s*A,
  668. OR (
  669. c = d,
  670. c = d + 1,
  671. )
  672. )
  673. })
  674. .unwrap();
  675. st_ok.check_disjunction_invariant(&vars).unwrap();
  676. st_nok1.check_disjunction_invariant(&vars).unwrap_err();
  677. st_nok2.check_disjunction_invariant(&vars).unwrap_err();
  678. st_nok3.check_disjunction_invariant(&vars).unwrap_err();
  679. }
  680. fn flatten_ands_tester(e: Expr, flattened_e: Expr) {
  681. let mut st = StatementTree::parse(&e).unwrap();
  682. st.flatten_ands();
  683. assert_eq!(st, StatementTree::parse(&flattened_e).unwrap());
  684. }
  685. #[test]
  686. // Test flatten_ands
  687. fn flatten_ands_test() {
  688. flatten_ands_tester(
  689. parse_quote! {
  690. C = x*B + r*A
  691. },
  692. parse_quote! {
  693. C = x*B + r*A
  694. },
  695. );
  696. flatten_ands_tester(
  697. parse_quote! {
  698. AND (
  699. C = x*B + r*A,
  700. AND (
  701. D = x*B + s*A,
  702. E = x*B + t*A,
  703. ),
  704. )
  705. },
  706. parse_quote! {
  707. AND (
  708. C = x*B + r*A,
  709. D = x*B + s*A,
  710. E = x*B + t*A,
  711. )
  712. },
  713. );
  714. flatten_ands_tester(
  715. parse_quote! {
  716. AND (
  717. AND (
  718. OR (
  719. D = B + s*A,
  720. D = s*A,
  721. ),
  722. D = x*B + t*A,
  723. ),
  724. C = x*B + r*A,
  725. )
  726. },
  727. parse_quote! {
  728. AND (
  729. OR (
  730. D = B + s*A,
  731. D = s*A,
  732. ),
  733. D = x*B + t*A,
  734. C = x*B + r*A,
  735. )
  736. },
  737. );
  738. flatten_ands_tester(
  739. parse_quote! {
  740. AND (
  741. AND (
  742. OR (
  743. D = B + s*A,
  744. AND (
  745. D = s*A,
  746. AND (
  747. E = s*B,
  748. F = s*C,
  749. ),
  750. ),
  751. ),
  752. D = x*B + t*A,
  753. ),
  754. C = x*B + r*A,
  755. )
  756. },
  757. parse_quote! {
  758. AND (
  759. OR (
  760. D = B + s*A,
  761. AND (
  762. D = s*A,
  763. E = s*B,
  764. F = s*C,
  765. )
  766. ),
  767. D = x*B + t*A,
  768. C = x*B + r*A,
  769. )
  770. },
  771. );
  772. }
  773. }