Browse Source

First cut at enforce_disjunction_invariant

Automatically transform a StatementTree into a semantically equivalent
one that satisfies the disjunction invariant
Ian Goldberg 3 months ago
parent
commit
6cda94defa

+ 7 - 1
sigma_compiler_core/src/lib.rs

@@ -15,7 +15,7 @@ mod pubscalareq;
 mod rangeproof;
 mod substitution;
 mod syntax;
-mod transform;
+pub mod transform;
 
 pub use syntax::{SigmaCompSpec, TaggedIdent, TaggedPoint, TaggedScalar, TaggedVarDict};
 
@@ -38,6 +38,12 @@ pub fn sigma_compiler_core(
 ) -> TokenStream {
     let mut codegen = codegen::CodeGen::new(spec);
 
+    // Enforce the disjunction invariant (do this before any other
+    // transformations, since they assume the invariant holds, and will
+    // maintain it)
+    transform::enforce_disjunction_invariant(&mut codegen, &mut spec.statements, &mut spec.vars)
+        .unwrap();
+
     // Apply any substitution transformations
     substitution::transform(&mut codegen, &mut spec.statements, &mut spec.vars).unwrap();
 

+ 33 - 7
sigma_compiler_core/src/sigma/combiners.rs

@@ -3,10 +3,10 @@
 
 use super::types::*;
 use quote::quote;
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use syn::parse::Result;
 use syn::visit::Visit;
-use syn::{parse_quote, Expr};
+use syn::{parse_quote, Expr, Ident};
 
 /// For each [`Ident`](struct@syn::Ident) representing a private
 /// `Scalar` (as listed in a [`VarDict`]) that appears in an [`Expr`],
@@ -273,11 +273,13 @@ impl StatementTree {
     ///
     /// is exactly that the invariant must be satisfied.
     ///
-    /// (In the future, it is possible we may provide a transformer that
-    /// will automatically convert [`StatementTree`]s to ones that
-    /// satisfy the invariant, but for now, the user of the macro must
-    /// manually write the statements in a form that satisfies the
-    /// disjunction invariant.
+    /// If you don't know that your [`StatementTree`] already satisfies
+    /// the invariant, call
+    /// [`enforce_disjunction_invariant`](super::super::transform::enforce_disjunction_invariant),
+    /// which will transform the [`StatementTree`] so that it does (and
+    /// also call this
+    /// [`check_disjunction_invariant`](StatementTree::check_disjunction_invariant)
+    /// function as a sanity check).
     pub fn check_disjunction_invariant(&self, vars: &VarDict) -> Result<()> {
         let mut disjunct_map: HashMap<String, usize> = HashMap::new();
 
@@ -461,6 +463,30 @@ impl StatementTree {
         Ok(())
     }
 
+    /// Produce a [`HashSet`] of the private Scalars that appear in any
+    /// leaf of the given [disjunction branch].
+    ///
+    /// [disjunction branch]: StatementTree::check_disjunction_invariant
+    pub fn disjunction_branch_priv_scalars(&mut self, vars: &VarDict) -> HashSet<Ident> {
+        let mut priv_scalars: HashSet<Ident> = HashSet::new();
+        self.for_each_disjunction_branch_leaf(&mut |leaf| {
+            if let StatementTree::Leaf(leafexpr) = leaf {
+                let mut psmap = PrivScalarMap {
+                    vars,
+                    closure: &mut |ident| {
+                        priv_scalars.insert(ident.clone());
+                        Ok(())
+                    },
+                    result: Ok(()),
+                };
+                psmap.visit_expr(leafexpr);
+            }
+            Ok(())
+        })
+        .unwrap();
+        priv_scalars
+    }
+
     #[cfg(not(doctest))]
     /// Flatten nested `And` nodes in a [`StatementTree`].
     ///

+ 0 - 2
sigma_compiler_core/src/syntax.rs

@@ -264,8 +264,6 @@ impl Parse for SigmaCompSpec {
             input.parse_terminated(Expr::parse, Token![,])?;
         let statementlist: Vec<Expr> = statementpunc.into_iter().collect();
         let statements = StatementTree::parse_andlist(&statementlist)?;
-        let vardict = taggedvardict_to_vardict(&vars);
-        statements.check_disjunction_invariant(&vardict)?;
 
         Ok(SigmaCompSpec {
             proto_name,

+ 498 - 1
sigma_compiler_core/src/transform.rs

@@ -3,8 +3,18 @@
 //!
 //! [disjunction invariant]: StatementTree::check_disjunction_invariant
 
+use super::codegen::CodeGen;
+use super::pedersen::{
+    convert_commitment, convert_randomness, recognize_pedersen_assignment, unique_random_scalars,
+    LinScalar, PedersenAssignment,
+};
 use super::sigma::combiners::*;
-use syn::{parse_quote, Expr};
+use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
+use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
+use quote::{format_ident, quote};
+use std::collections::{HashMap, HashSet};
+use syn::visit_mut::{self, VisitMut};
+use syn::{parse_quote, Error, Expr, Ident, Result};
 
 /// Simplify a [`StatementTree`] by pruning leaves that are the constant
 /// `true`, and simplifying `And`, `Or`, and `Thresh` combiners that
@@ -109,8 +119,362 @@ pub fn paren_if_needed(expr: Expr) -> Expr {
     }
 }
 
+/// Transform the [`StatementTree`] so that it satisfies the
+/// [disjunction invariant].
+///
+/// [disjunction invariant]: StatementTree::check_disjunction_invariant
+#[allow(non_snake_case)] // so that Points can be capital letters
+pub fn enforce_disjunction_invariant(
+    codegen: &mut CodeGen,
+    st: &mut StatementTree,
+    vars: &mut TaggedVarDict,
+) -> Result<()> {
+    // Make the VarDict version of the variable dictionary
+    let mut vardict = taggedvardict_to_vardict(vars);
+
+    // A HashSet of the unique random Scalars in the macro input
+    let mut randoms = unique_random_scalars(vars, st);
+
+    // A list of the computationally independent (non-vector) Points in
+    // the macro input.  If we need to do any transformations, there
+    // must be at least two of them in order to create Pedersen
+    // commitments.
+
+    // If we're testing, sort cind_points so that we get a deterministic
+    // choice of cind_A and cind_B
+    #[cfg(not(test))]
+    let cind_points = collect_cind_points(vars);
+    #[cfg(test)]
+    let mut cind_points = collect_cind_points(vars);
+    #[cfg(test)]
+    cind_points.sort_unstable();
+
+    // Extra statements to be added to the root disjunction branch
+    let mut root_extra_statements: Vec<StatementTree> = Vec::new();
+
+    // The generated variable name for the rng
+    let rng_var = codegen.gen_ident(&format_ident!("rng"));
+
+    // Find any statements that look like Pedersen commitments in the
+    // root disjunction branch of the StatementTree, and make a HashMap
+    // mapping the committed private variable to the parsed commitment.
+    let mut root_pedersens: HashMap<Ident, PedersenAssignment> = HashMap::new();
+    st.for_each_disjunction_branch_leaf(&mut |leaf| {
+        // See if we recognize this leaf expression as a
+        // PedersenAssignment, and if so, map its variable to the
+        // PedersenAssignment.
+        if let StatementTree::Leaf(leafexpr) = leaf {
+            if let Some(ped_assign) =
+                recognize_pedersen_assignment(vars, &randoms, &vardict, leafexpr)
+            {
+                root_pedersens.insert(ped_assign.var(), ped_assign);
+            }
+        }
+        Ok(())
+    })?;
+
+    // Count how many disjunction branches contain each private Scalar
+    let mut branch_count: HashMap<Ident, usize> = HashMap::new();
+    st.for_each_disjunction_branch(&mut |branch, _path| {
+        branch
+            .disjunction_branch_priv_scalars(&vardict)
+            .drain()
+            .for_each(|id| {
+                if let Some(n) = branch_count.get(&id) {
+                    branch_count.insert(id, n + 1);
+                } else {
+                    branch_count.insert(id, 1);
+                }
+            });
+        Ok(())
+    })?;
+
+    // Make a HashSet of any of those private Scalars whose count is
+    // strictly larger than 1.  (Those private Scalars are the ones
+    // that are in violation of the disjunction invariant.)
+    let mut invariant_violators: HashSet<Ident> = branch_count
+        .drain()
+        .filter_map(|(id, n)| if n > 1 { Some(id) } else { None })
+        .collect();
+
+    // If there are no invariant violators, we're done.
+    if invariant_violators.is_empty() {
+        return Ok(());
+    }
+
+    // Otherwise, ensure there are at least two computationally
+    // independent points, since we'll need to construct Pedersen
+    // commitments.
+    if cind_points.len() < 2 {
+        return Err(Error::new(
+            proc_macro2::Span::call_site(),
+            "At least two cind Points must be declared to support Pedersen commitments",
+        ));
+    }
+    let cind_A = &cind_points[0];
+    let cind_B = &cind_points[1];
+
+    // For each invariant violator, find (or create) a Pedersen
+    // commitment in the root disjunction branch for it.
+    let invariant_violator_pedersens: HashMap<Ident, PedersenAssignment> = invariant_violators
+        .drain()
+        .map(|id| {
+            // Check if the private Scalar is a vector variable or
+            // not
+            let is_vec = if let Some(TaggedIdent::Scalar(TaggedScalar { is_vec, .. })) =
+                vars.get(&id.to_string())
+            {
+                *is_vec
+            } else {
+                false
+            };
+
+            // See if we already have a PedersenAssignment in the
+            // root disjunction branch for this private Scalar
+            let ped_assign = if let Some(ped_assign) = root_pedersens.get(&id) {
+                ped_assign.clone()
+            } else {
+                // Create new variables for the Pedersen commitment and its
+                // random Scalar.
+                let commitment_var = codegen.gen_point(
+                    vars,
+                    &format_ident!("disj_{}_genC", id),
+                    is_vec, // is_vec
+                    true,   // send_to_verifier
+                );
+                let rand_var = codegen.gen_scalar(
+                    vars,
+                    &format_ident!("disj_{}_genr", id),
+                    true,   // is_rand
+                    is_vec, // is_vec
+                );
+
+                // Update vardict and randoms with the new vars
+                vardict = taggedvardict_to_vardict(vars);
+                randoms.insert(rand_var.to_string());
+
+                let ped_assign_expr: Expr = parse_quote! {
+                    #commitment_var = #id * #cind_A + #rand_var * #cind_B
+                };
+                let ped_assign =
+                    recognize_pedersen_assignment(vars, &randoms, &vardict, &ped_assign_expr)
+                        .unwrap();
+
+                if is_vec {
+                    codegen.prove_append(quote! {
+                        let #rand_var: Vec<Scalar> = #id
+                            .map(|_| Scalar::random(#rng_var))
+                            .collect();
+                        let #commitment_var = (0..#id.len())
+                            .map(|i| {
+                                #id[i] * #cind_A + #rand_var[i] * #cind_B
+                            })
+                            .collect();
+                    });
+                } else {
+                    codegen.prove_append(quote! {
+                        let #rand_var = Scalar::random(#rng_var);
+                        let #ped_assign_expr;
+                    });
+                }
+
+                root_extra_statements.push(StatementTree::Leaf(ped_assign_expr));
+
+                ped_assign
+            };
+
+            // At this point, we have a Pedersen commitment for some linear
+            // function of id (given by
+            // ped_assign.pedersen.var_term.coeff), using some linear
+            // function of rand_var (given by
+            // ped_assign.pedersen.rand_term.coeff) as the randomness.  But
+            // what we need is a Pedersen commitment for id itself.
+            // So we output runtime code for both the prover and the
+            // verifier that converts the commitment, and code for just
+            // the prover that converts the randomness.
+
+            // Make new runtime variables to hold the converted
+            // commitment and randomness
+            let commitment_var = codegen.gen_point(
+                vars,
+                &format_ident!("disj_{}_C", id),
+                is_vec, // is_vec
+                false,  // send_to_verifier
+            );
+            let rand_var = codegen.gen_ident(&format_ident!("disj_{}_r", id));
+
+            // Update vardict and randoms with the new vars
+            vardict = taggedvardict_to_vardict(vars);
+            randoms.insert(rand_var.to_string());
+
+            // The identity LinScalar for this id
+            let id_linscalar = LinScalar {
+                coeff: 1i128,
+                pub_scalar_expr: None,
+                id: id.clone(),
+                is_vec,
+            };
+
+            codegen.prove_verify_append(
+                convert_commitment(&commitment_var, &ped_assign, &id_linscalar, &vardict).unwrap(),
+            );
+            codegen.prove_append(
+                convert_randomness(&rand_var, &ped_assign, &id_linscalar, &vardict).unwrap(),
+            );
+
+            (id, ped_assign)
+        })
+        .collect();
+
+    // Do another pass over each disjunction branch (other than the
+    // root).  In each non-root branch, if there are any instances of an
+    // invariant violator, then change all instances of that violating
+    // identifier to a fresh identifier, and insert a Pedersen
+    // commitment (to the same commitment variable that exists in the
+    // root disjunction branch) to bind the new identifier to the
+    // original.
+    let mut disjunction_branch_num = 0usize;
+    st.for_each_disjunction_branch(&mut |branch, path| {
+        // Skip the root disjunction branch, which is represented by an
+        // empty path
+        if path.is_empty() {
+            return Ok(());
+        }
+
+        disjunction_branch_num += 1;
+
+        // Keep track of the ids in invariant_violator_pedersens
+        // that we encounter and rename in this disjunction branch
+        let mut ids_renamed: HashSet<Ident> = HashSet::new();
+
+        // Extra statements to be added to this disjunction branch
+        let mut branch_extra_statements: Vec<StatementTree> = Vec::new();
+
+        struct Renamer<'a> {
+            codegen: &'a CodeGen,
+            disjunction_branch_num: usize,
+            invariant_violators: &'a HashMap<Ident, PedersenAssignment>,
+            ids_renamed: &'a mut HashSet<Ident>,
+        }
+
+        impl<'a> VisitMut for Renamer<'a> {
+            fn visit_expr_mut(&mut self, node: &mut Expr) {
+                if let Expr::Path(expath) = node {
+                    if let Some(id) = expath.path.get_ident() {
+                        if self.invariant_violators.contains_key(id) {
+                            let replacement_ident = self.codegen.gen_ident(&format_ident!(
+                                "disj{}_{}",
+                                self.disjunction_branch_num,
+                                id
+                            ));
+                            self.ids_renamed.insert(id.clone());
+                            *node = parse_quote! { #replacement_ident };
+                            return;
+                        }
+                    }
+                }
+                // Unless we bailed out above, continue with the default
+                // traversal
+                visit_mut::visit_expr_mut(self, node);
+            }
+        }
+        let mut renamer = Renamer {
+            codegen,
+            disjunction_branch_num,
+            invariant_violators: &invariant_violator_pedersens,
+            ids_renamed: &mut ids_renamed,
+        };
+
+        branch.for_each_disjunction_branch_leaf(&mut |leaf| {
+            let StatementTree::Leaf(ref mut leafexpr) = leaf else {
+                panic!(
+                    "Should not happen: leaf {:?} is not a StatementTree::Leaf",
+                    leaf
+                );
+            };
+            renamer.visit_expr_mut(leafexpr);
+            Ok(())
+        })?;
+
+        // For each id we renamed, insert a Pedersen commitment to the
+        // new name (using the _same_ commitment value we computed in
+        // the root Pedersen commitment) into this disjunction branch.
+        // This binds the new name to the old name.
+        for id in ids_renamed {
+            // Is it a vector variable?
+            let is_vec = if let Some(TaggedIdent::Scalar(TaggedScalar { is_vec, .. })) =
+                vars.get(&id.to_string())
+            {
+                *is_vec
+            } else {
+                false
+            };
+
+            // Variables for the renamed private Scalar and the randomness
+            let id_var = codegen.gen_scalar(
+                vars,
+                &format_ident!("disj{}_{}", disjunction_branch_num, id,),
+                false,  // is_rand
+                is_vec, // is_vec
+            );
+            let rand_var = codegen.gen_scalar(
+                vars,
+                &format_ident!("disj{}_{}_r", disjunction_branch_num, id,),
+                true,   // is_rand
+                is_vec, // is_vec
+            );
+            let root_commitment_var = codegen.gen_ident(&format_ident!("disj_{}_C", id));
+            let root_rand_var = codegen.gen_ident(&format_ident!("disj_{}_r", id));
+            if is_vec {
+                codegen.prove_append(quote! {
+                    let #id_var = #id.clone();
+                    let #rand_var = #root_rand_var.clone();
+                });
+            } else {
+                codegen.prove_append(quote! {
+                    let #id_var = #id;
+                    let #rand_var = #root_rand_var;
+                });
+            }
+            branch_extra_statements.push(StatementTree::Leaf(parse_quote! {
+                #root_commitment_var = #id_var * #cind_A + #rand_var * #cind_B
+            }));
+        }
+
+        // Now add the branch_extra_statements to the top node of this
+        // disjunction branch.  If it's already an And node, just add
+        // them to the vector.  Otherwise, make a new And node
+        // containing the old node and the branch_extra_statements.
+        if let StatementTree::And(ref mut stvec) = branch {
+            stvec.append(&mut root_extra_statements);
+        } else {
+            let old_branch = std::mem::replace(branch, StatementTree::leaf_true());
+            branch_extra_statements.push(old_branch);
+            *branch = StatementTree::And(branch_extra_statements);
+        }
+
+        Ok(())
+    })?;
+
+    // Add the root_extra_statements to the root of the StatementTree.
+    // If it's already an And node, just add them to the vector.
+    // Otherwise, make a new And node containing the old root and the
+    // root_extra_statements
+    if let StatementTree::And(ref mut stvec) = st {
+        stvec.append(&mut root_extra_statements);
+    } else {
+        let old_st = std::mem::replace(st, StatementTree::leaf_true());
+        root_extra_statements.push(old_st);
+        *st = StatementTree::And(root_extra_statements);
+    }
+
+    // Sanity check
+    st.check_disjunction_invariant(&vardict)
+}
+
 #[cfg(test)]
 mod tests {
+    use super::super::syntax::taggedvardict_from_strs;
     use super::*;
 
     fn prune_tester(e: Expr, pruned_e: Expr) {
@@ -224,4 +588,137 @@ mod tests {
             },
         );
     }
+
+    fn enforce_disjunction_invariant_tester(vars: (&[&str], &[&str]), e: Expr, expect: Expr) {
+        let mut codegen = CodeGen::new_empty();
+        let mut st = StatementTree::parse(&e).unwrap();
+        let mut vars = taggedvardict_from_strs(vars);
+        enforce_disjunction_invariant(&mut codegen, &mut st, &mut vars).unwrap();
+        assert_eq!(st, StatementTree::parse(&expect).unwrap());
+    }
+
+    #[test]
+    fn enforce_disjunction_invariant_test() {
+        let vars = (
+            [
+                "x", "y", "z", "pub a", "pub b", "pub c", "rand r", "rand s", "rand t",
+            ]
+            .as_slice(),
+            ["C", "D", "cind A", "cind B"].as_slice(),
+        );
+
+        enforce_disjunction_invariant_tester(
+            vars,
+            parse_quote! {
+                C = x*A
+            },
+            parse_quote! {
+                C = x*A
+            },
+        );
+
+        enforce_disjunction_invariant_tester(
+            vars,
+            parse_quote! {
+                AND (
+                    C = x*A + r*B,
+                    OR (
+                        y=1,
+                        z=2,
+                    )
+                )
+            },
+            parse_quote! {
+                AND (
+                    C = x*A + r*B,
+                    OR (
+                        y=1,
+                        z=2,
+                    )
+                )
+            },
+        );
+
+        enforce_disjunction_invariant_tester(
+            vars,
+            parse_quote! {
+                AND (
+                    C = x*A + r*B,
+                    OR (
+                        x=1,
+                        x=2,
+                    )
+                )
+            },
+            parse_quote! {
+                AND (
+                    C = x*A + r*B,
+                    OR (
+                        AND (
+                            gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
+                            gen__disj1_x=1,
+                        ),
+                        AND (
+                            gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
+                            gen__disj2_x=2,
+                        ),
+                    )
+                )
+            },
+        );
+
+        enforce_disjunction_invariant_tester(
+            vars,
+            parse_quote! {
+                AND (
+                    C = x*A,
+                    OR (
+                        x=1,
+                        x=2,
+                    )
+                )
+            },
+            parse_quote! {
+                AND (
+                    C = x*A,
+                    OR (
+                        AND (
+                            gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
+                            gen__disj1_x=1,
+                        ),
+                        AND (
+                            gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
+                            gen__disj2_x=2,
+                        ),
+                    ),
+                    gen__disj_x_genC = x*A + gen__disj_x_genr*B,
+                )
+            },
+        );
+
+        enforce_disjunction_invariant_tester(
+            vars,
+            parse_quote! {
+                OR (
+                    x=1,
+                    x=2,
+                )
+            },
+            parse_quote! {
+                AND (
+                    gen__disj_x_genC = x*A + gen__disj_x_genr*B,
+                    OR (
+                        AND (
+                            gen__disj_x_C = gen__disj1_x * A + gen__disj1_x_r * B,
+                            gen__disj1_x=1,
+                        ),
+                        AND (
+                            gen__disj_x_C = gen__disj2_x * A + gen__disj2_x_r * B,
+                            gen__disj2_x=2,
+                        ),
+                    ),
+                )
+            },
+        );
+    }
 }

+ 33 - 0
tests/disj.rs

@@ -0,0 +1,33 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+#[test]
+fn range_test() -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (x, rand r),
+        (C, const cind A, const cind B),
+        C = (3*x+1)*A + (2*r+3)*B,
+        OR(
+            x=1,
+            x=2,
+        )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r = Scalar::random(&mut rng);
+    let x = Scalar::from_u128(1);
+    let C = (x + x + x + Scalar::ONE) * A + (r + r + Scalar::from_u128(3)) * B;
+
+    let instance = proof::Instance { C, A, B };
+    let witness = proof::Witness { x, r };
+
+    let proof = proof::prove(&instance, &witness, b"disj_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"disj_test")
+}