| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722 |
- //! A module for operations that transform a [`StatementTree`].
- //! Every transformation must maintain the [disjunction invariant].
- //!
- //! [disjunction invariant]: StatementTree::check_disjunction_invariant
- use super::codegen::CodeGen;
- use super::pedersen::{
- convert_commitment, convert_randomness, recognize_pedersen_assignment, unique_random_scalars,
- LinScalar, PedersenAssignment,
- };
- use super::sigma::combiners::*;
- use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
- use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
- use quote::{format_ident, quote};
- use std::collections::{HashMap, HashSet};
- use syn::visit_mut::{self, VisitMut};
- use syn::{parse_quote, Error, Expr, Ident, Result};
- /// Simplify a [`StatementTree`] by pruning leaves that are the constant
- /// `true`, and simplifying `And`, `Or`, and `Thresh` combiners that
- /// have fewer than two children.
- pub fn prune_statement_tree(st: &mut StatementTree) {
- match st {
- // If the StatementTree is just a Leaf, just keep it unmodified,
- // even if it is leaf_true.
- StatementTree::Leaf(_) => {}
- // For the And combiner, recursively simplify each child, and then
- // prune the child if it is leaf_true. If we end up with 1
- // child replace ourselves with that child. If we end up with 0
- // children, replace ourselves with leaf_true.
- StatementTree::And(v) => {
- let mut i: usize = 0;
- // Note that v.len _can change_ during this loop
- while i < v.len() {
- prune_statement_tree(&mut v[i]);
- if v[i].is_leaf_true() {
- // Remove this child, and _do not_ increment i
- v.remove(i);
- } else {
- i += 1;
- }
- }
- if v.is_empty() {
- *st = StatementTree::leaf_true();
- } else if v.len() == 1 {
- let child = v.remove(0);
- *st = child;
- }
- }
- // For the Or combiner, recursively simplify each child, and if
- // it ends up leaf_true, replace ourselves with leaf_true.
- // If we end up with 1 child, we must have started wth 1 child.
- // Replace ourselves with that child anyway.
- StatementTree::Or(v) => {
- let mut i: usize = 0;
- // Note that v.len _can change_ during this loop
- while i < v.len() {
- prune_statement_tree(&mut v[i]);
- if v[i].is_leaf_true() {
- *st = StatementTree::leaf_true();
- return;
- } else {
- i += 1;
- }
- }
- if v.len() == 1 {
- let child = v.remove(0);
- *st = child;
- }
- }
- // For the Thresh combiner, recursively simplify each child, and
- // if it ends up leaf_true, prune it, and subtract 1 from the
- // thresh. If the thresh hits 0, replace ourselves with
- // leaf_true. If we end up with 1 child and thresh is 1,
- // replace ourselves with that child.
- StatementTree::Thresh(thresh, v) => {
- let mut i: usize = 0;
- // Note that v.len _can change_ during this loop
- while i < v.len() {
- prune_statement_tree(&mut v[i]);
- if v[i].is_leaf_true() {
- // Remove this child, and _do not_ increment i
- v.remove(i);
- // But decrement thresh
- *thresh -= 1;
- if *thresh == 0 {
- *st = StatementTree::leaf_true();
- return;
- }
- } else {
- i += 1;
- }
- }
- if v.len() == 1 {
- // If thresh == 0, we would have exited above
- assert!(*thresh == 1);
- let child = v.remove(0);
- *st = child;
- }
- }
- }
- }
- /// Add parentheses around an [`Expr`] (which represents an [arithmetic
- /// expression]) if needed.
- ///
- /// The parentheses are needed if the [`Expr`] would parse as multiple
- /// tokens. For example, `a+b` turns into `(a+b)`, but `c`
- /// remains `c` and `(a+b)` remains `(a+b)`.
- ///
- /// [arithmetic expression]: super::sigma::types::expr_type
- pub fn paren_if_needed(expr: Expr) -> Expr {
- match expr {
- Expr::Unary(_) | Expr::Binary(_) => parse_quote! { (#expr) },
- _ => expr,
- }
- }
- /// Transform the [`StatementTree`] so that it satisfies the
- /// [disjunction invariant].
- ///
- /// [disjunction invariant]: StatementTree::check_disjunction_invariant
- #[allow(non_snake_case)] // so that Points can be capital letters
- pub fn enforce_disjunction_invariant(
- codegen: &mut CodeGen,
- st: &mut StatementTree,
- vars: &mut TaggedVarDict,
- ) -> Result<()> {
- // Make the VarDict version of the variable dictionary
- let mut vardict = taggedvardict_to_vardict(vars);
- // A HashSet of the unique random Scalars in the macro input
- let mut randoms = unique_random_scalars(vars, st);
- // A list of the computationally independent (non-vector) Points in
- // the macro input. If we need to do any transformations, there
- // must be at least two of them in order to create Pedersen
- // commitments.
- let cind_points = collect_cind_points(vars);
- // Extra statements to be added to the root disjunction branch
- let mut root_extra_statements: Vec<StatementTree> = Vec::new();
- // The generated variable name for the rng
- let rng_var = codegen.gen_ident(&format_ident!("rng"));
- // Find any statements that look like Pedersen commitments in the
- // root disjunction branch of the StatementTree, and make a HashMap
- // mapping the committed private variable to the parsed commitment.
- let mut root_pedersens: HashMap<Ident, PedersenAssignment> = HashMap::new();
- st.for_each_disjunction_branch_leaf(&mut |leaf| {
- // See if we recognize this leaf expression as a
- // PedersenAssignment, and if so, map its variable to the
- // PedersenAssignment.
- if let StatementTree::Leaf(leafexpr) = leaf {
- if let Some(ped_assign) =
- recognize_pedersen_assignment(vars, &randoms, &vardict, leafexpr)
- {
- root_pedersens.insert(ped_assign.var(), ped_assign);
- }
- }
- Ok(())
- })?;
- // Count how many disjunction branches contain each private Scalar
- let mut branch_count: HashMap<Ident, usize> = HashMap::new();
- st.for_each_disjunction_branch(&mut |branch, _path| {
- branch
- .disjunction_branch_priv_scalars(&vardict)
- .drain()
- .for_each(|id| {
- if let Some(n) = branch_count.get(&id) {
- branch_count.insert(id, n + 1);
- } else {
- branch_count.insert(id, 1);
- }
- });
- Ok(())
- })?;
- // Make a HashSet of any of those private Scalars whose count is
- // strictly larger than 1. (Those private Scalars are the ones
- // that are in violation of the disjunction invariant.)
- let mut invariant_violators: HashSet<Ident> = branch_count
- .drain()
- .filter_map(|(id, n)| if n > 1 { Some(id) } else { None })
- .collect();
- // If there are no invariant violators, we're done.
- if invariant_violators.is_empty() {
- return Ok(());
- }
- // Otherwise, ensure there are at least two computationally
- // independent points, since we'll need to construct Pedersen
- // commitments.
- if cind_points.len() < 2 {
- return Err(Error::new(
- proc_macro2::Span::call_site(),
- "At least two cind Points must be declared to support Pedersen commitments",
- ));
- }
- let cind_A = &cind_points[0];
- let cind_B = &cind_points[1];
- // For each invariant violator, find (or create) a Pedersen
- // commitment in the root disjunction branch for it.
- let invariant_violator_pedersens: HashMap<Ident, PedersenAssignment> = invariant_violators
- .drain()
- .map(|id| {
- // Check if the private Scalar is a vector variable or
- // not
- let is_vec = if let Some(TaggedIdent::Scalar(TaggedScalar { is_vec, .. })) =
- vars.get(&id.to_string())
- {
- *is_vec
- } else {
- false
- };
- // See if we already have a PedersenAssignment in the
- // root disjunction branch for this private Scalar
- let ped_assign = if let Some(ped_assign) = root_pedersens.get(&id) {
- ped_assign.clone()
- } else {
- // Create new variables for the Pedersen commitment and its
- // random Scalar.
- let commitment_var = codegen.gen_point(
- vars,
- &format_ident!("disj_{}_genC", id),
- is_vec, // is_vec
- true, // send_to_verifier
- );
- let rand_var = codegen.gen_scalar(
- vars,
- &format_ident!("disj_{}_genr", id),
- true, // is_rand
- is_vec, // is_vec
- );
- // Update vardict and randoms with the new vars
- vardict = taggedvardict_to_vardict(vars);
- randoms.insert(rand_var.to_string());
- let ped_assign_expr: Expr = parse_quote! {
- #commitment_var = #id * #cind_A + #rand_var * #cind_B
- };
- let ped_assign =
- recognize_pedersen_assignment(vars, &randoms, &vardict, &ped_assign_expr)
- .unwrap();
- if is_vec {
- codegen.prove_append(quote! {
- let #rand_var: Vec<Scalar> = #id
- .map(|_| Scalar::random(#rng_var))
- .collect();
- let #commitment_var = (0..#id.len())
- .map(|i| {
- #id[i] * #cind_A + #rand_var[i] * #cind_B
- })
- .collect();
- });
- } else {
- codegen.prove_append(quote! {
- let #rand_var = Scalar::random(#rng_var);
- let #ped_assign_expr;
- });
- }
- root_extra_statements.push(StatementTree::Leaf(ped_assign_expr));
- ped_assign
- };
- // At this point, we have a Pedersen commitment for some linear
- // function of id (given by
- // ped_assign.pedersen.var_term.coeff), using some linear
- // function of rand_var (given by
- // ped_assign.pedersen.rand_term.coeff) as the randomness. But
- // what we need is a Pedersen commitment for id itself.
- // So we output runtime code for both the prover and the
- // verifier that converts the commitment, and code for just
- // the prover that converts the randomness.
- // Make new runtime variables to hold the converted
- // commitment and randomness
- let commitment_var = codegen.gen_point(
- vars,
- &format_ident!("disj_{}_C", id),
- is_vec, // is_vec
- false, // send_to_verifier
- );
- let rand_var = codegen.gen_ident(&format_ident!("disj_{}_r", id));
- // Update vardict and randoms with the new vars
- vardict = taggedvardict_to_vardict(vars);
- randoms.insert(rand_var.to_string());
- // The identity LinScalar for this id
- let id_linscalar = LinScalar {
- coeff: 1i128,
- pub_scalar_expr: None,
- id: id.clone(),
- is_vec,
- };
- codegen.prove_verify_append(
- convert_commitment(&commitment_var, &ped_assign, &id_linscalar, &vardict).unwrap(),
- );
- codegen.prove_append(
- convert_randomness(&rand_var, &ped_assign, &id_linscalar, &vardict).unwrap(),
- );
- (id, ped_assign)
- })
- .collect();
- // Do another pass over each disjunction branch (other than the
- // root). In each non-root branch, if there are any instances of an
- // invariant violator, then change all instances of that violating
- // identifier to a fresh identifier, and insert a Pedersen
- // commitment (to the same commitment variable that exists in the
- // root disjunction branch) to bind the new identifier to the
- // original.
- let mut disjunction_branch_num = 0usize;
- st.for_each_disjunction_branch(&mut |branch, path| {
- // Skip the root disjunction branch, which is represented by an
- // empty path
- if path.is_empty() {
- return Ok(());
- }
- disjunction_branch_num += 1;
- // Keep track of the ids in invariant_violator_pedersens
- // that we encounter and rename in this disjunction branch
- let mut ids_renamed: HashSet<Ident> = HashSet::new();
- // Extra statements to be added to this disjunction branch
- let mut branch_extra_statements: Vec<StatementTree> = Vec::new();
- struct Renamer<'a> {
- codegen: &'a CodeGen,
- disjunction_branch_num: usize,
- invariant_violators: &'a HashMap<Ident, PedersenAssignment>,
- ids_renamed: &'a mut HashSet<Ident>,
- }
- impl<'a> VisitMut for Renamer<'a> {
- fn visit_expr_mut(&mut self, node: &mut Expr) {
- if let Expr::Path(expath) = node {
- if let Some(id) = expath.path.get_ident() {
- if self.invariant_violators.contains_key(id) {
- let replacement_ident = self.codegen.gen_ident(&format_ident!(
- "disj{}_{}",
- self.disjunction_branch_num,
- id
- ));
- self.ids_renamed.insert(id.clone());
- *node = parse_quote! { #replacement_ident };
- return;
- }
- }
- }
- // Unless we bailed out above, continue with the default
- // traversal
- visit_mut::visit_expr_mut(self, node);
- }
- }
- let mut renamer = Renamer {
- codegen,
- disjunction_branch_num,
- invariant_violators: &invariant_violator_pedersens,
- ids_renamed: &mut ids_renamed,
- };
- branch.for_each_disjunction_branch_leaf(&mut |leaf| {
- let StatementTree::Leaf(ref mut leafexpr) = leaf else {
- panic!(
- "Should not happen: leaf {:?} is not a StatementTree::Leaf",
- leaf
- );
- };
- renamer.visit_expr_mut(leafexpr);
- Ok(())
- })?;
- // For each id we renamed, insert a Pedersen commitment to the
- // new name (using the _same_ commitment value we computed in
- // the root Pedersen commitment) into this disjunction branch.
- // This binds the new name to the old name.
- for id in ids_renamed {
- // Is it a vector variable?
- let is_vec = if let Some(TaggedIdent::Scalar(TaggedScalar { is_vec, .. })) =
- vars.get(&id.to_string())
- {
- *is_vec
- } else {
- false
- };
- // Variables for the renamed private Scalar and the randomness
- let id_var = codegen.gen_scalar(
- vars,
- &format_ident!("disj{}_{}", disjunction_branch_num, id,),
- false, // is_rand
- is_vec, // is_vec
- );
- let rand_var = codegen.gen_scalar(
- vars,
- &format_ident!("disj{}_{}_r", disjunction_branch_num, id,),
- true, // is_rand
- is_vec, // is_vec
- );
- let root_commitment_var = codegen.gen_ident(&format_ident!("disj_{}_C", id));
- let root_rand_var = codegen.gen_ident(&format_ident!("disj_{}_r", id));
- if is_vec {
- codegen.prove_append(quote! {
- let #id_var = #id.clone();
- let #rand_var = #root_rand_var.clone();
- });
- } else {
- codegen.prove_append(quote! {
- let #id_var = #id;
- let #rand_var = #root_rand_var;
- });
- }
- // The generators for the Pedersen commitment for this id
- let ped_assign = invariant_violator_pedersens.get(&id).unwrap();
- let var_generator = &ped_assign.pedersen.var_term.id;
- let rand_generator = &ped_assign.pedersen.rand_term.id;
- branch_extra_statements.push(StatementTree::Leaf(parse_quote! {
- #root_commitment_var = #id_var * #var_generator + #rand_var * #rand_generator
- }));
- }
- // Now add the branch_extra_statements to the top node of this
- // disjunction branch. If it's already an And node, just add
- // them to the vector. Otherwise, make a new And node
- // containing the old node and the branch_extra_statements.
- if let StatementTree::And(ref mut stvec) = branch {
- stvec.append(&mut branch_extra_statements);
- } else {
- let old_branch = std::mem::replace(branch, StatementTree::leaf_true());
- branch_extra_statements.push(old_branch);
- *branch = StatementTree::And(branch_extra_statements);
- }
- Ok(())
- })?;
- // Add the root_extra_statements to the root of the StatementTree.
- // If it's already an And node, just add them to the vector.
- // Otherwise, make a new And node containing the old root and the
- // root_extra_statements
- if let StatementTree::And(ref mut stvec) = st {
- stvec.append(&mut root_extra_statements);
- } else {
- let old_st = std::mem::replace(st, StatementTree::leaf_true());
- root_extra_statements.push(old_st);
- *st = StatementTree::And(root_extra_statements);
- }
- // Sanity check
- st.check_disjunction_invariant(&vardict)
- }
- #[cfg(test)]
- mod tests {
- use super::super::syntax::taggedvardict_from_strs;
- use super::*;
- fn prune_tester(e: Expr, pruned_e: Expr) {
- let mut st = StatementTree::parse(&e).unwrap();
- prune_statement_tree(&mut st);
- assert_eq!(st, StatementTree::parse(&pruned_e).unwrap());
- }
- #[test]
- fn prune_statement_tree_test() {
- prune_tester(
- parse_quote! {
- AND (
- true,
- e = f,
- )
- },
- parse_quote! {
- e = f
- },
- );
- prune_tester(
- parse_quote! {
- AND (
- e = f,
- true,
- )
- },
- parse_quote! {
- e = f
- },
- );
- prune_tester(
- parse_quote! {
- AND (
- e = f,
- true,
- b = c,
- )
- },
- parse_quote! {
- AND (
- e = f,
- b = c,
- )
- },
- );
- prune_tester(
- parse_quote! {
- OR (
- true,
- e = f,
- )
- },
- parse_quote! {
- true
- },
- );
- prune_tester(
- parse_quote! {
- AND (
- a = b,
- true,
- OR (
- c = d,
- true,
- e = f
- )
- )
- },
- parse_quote! {
- a = b
- },
- );
- prune_tester(
- parse_quote! {
- THRESH (3,
- a = b,
- true,
- THRESH (1,
- c = d,
- true,
- e = f
- )
- )
- },
- parse_quote! {
- a = b
- },
- );
- prune_tester(
- parse_quote! {
- THRESH (3,
- a = b,
- true,
- THRESH (2,
- c = d,
- true,
- e = f
- )
- )
- },
- parse_quote! {
- THRESH (2,
- a = b,
- THRESH (1,
- c = d,
- e = f
- )
- )
- },
- );
- }
- fn enforce_disjunction_invariant_tester(vars: (&[&str], &[&str]), e: Expr, expect: Expr) {
- let mut codegen = CodeGen::new_empty();
- let mut st = StatementTree::parse(&e).unwrap();
- let mut vars = taggedvardict_from_strs(vars);
- enforce_disjunction_invariant(&mut codegen, &mut st, &mut vars).unwrap();
- assert_eq!(st, StatementTree::parse(&expect).unwrap());
- }
- #[test]
- fn enforce_disjunction_invariant_test() {
- let vars = (
- [
- "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
- ]
- .as_slice(),
- ["C", "D", "cind A", "cind B"].as_slice(),
- );
- enforce_disjunction_invariant_tester(
- vars,
- parse_quote! {
- C = x*A
- },
- parse_quote! {
- C = x*A
- },
- );
- enforce_disjunction_invariant_tester(
- vars,
- parse_quote! {
- AND (
- C = x*A + r*B,
- OR (
- y=1,
- z=2,
- )
- )
- },
- parse_quote! {
- AND (
- C = x*A + r*B,
- OR (
- y=1,
- z=2,
- )
- )
- },
- );
- enforce_disjunction_invariant_tester(
- vars,
- parse_quote! {
- AND (
- C = x*A + r*B,
- OR (
- x=1,
- x=2,
- )
- )
- },
- parse_quote! {
- AND (
- C = x*A + r*B,
- OR (
- AND (
- gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
- gen__disj1_x=1,
- ),
- AND (
- gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
- gen__disj2_x=2,
- ),
- )
- )
- },
- );
- enforce_disjunction_invariant_tester(
- vars,
- parse_quote! {
- AND (
- C = x*A,
- OR (
- x=1,
- x=2,
- )
- )
- },
- parse_quote! {
- AND (
- C = x*A,
- OR (
- AND (
- gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
- gen__disj1_x=1,
- ),
- AND (
- gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
- gen__disj2_x=2,
- ),
- ),
- gen__disj_x_genC = x*A + gen__disj_x_genr*B,
- )
- },
- );
- enforce_disjunction_invariant_tester(
- vars,
- parse_quote! {
- OR (
- x=1,
- x=2,
- )
- },
- parse_quote! {
- AND (
- gen__disj_x_genC = x*A + gen__disj_x_genr*B,
- OR (
- AND (
- gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
- gen__disj1_x=1,
- ),
- AND (
- gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
- gen__disj2_x=2,
- ),
- ),
- )
- },
- );
- }
- }
|