Browse Source

In a disjunction branch, convert statements about equality of public Scalars to statements about equality of public Points by multiplying each side by a generator

When we're not in a disjunction, we just emit code to directly check the
equality of the Scalars, as before.  But inside, the equality may not be
true.

So inside a disjunction, change statements like `b = 2*a - 3` into
`b*A = (2*a - 3)*A`, where `A` is one of the declared cind Points.

This commit is not formatted so as to minimize the diff.  The next
commit will run "cargo fmt" with no code changes.
Ian Goldberg 3 months ago
parent
commit
c4c86e28d1
3 changed files with 101 additions and 16 deletions
  1. 40 13
      sigma_compiler_core/src/pubscalareq.rs
  2. 12 3
      tests/pubscalars.rs
  3. 49 0
      tests/pubscalars_or.rs

+ 40 - 13
sigma_compiler_core/src/pubscalareq.rs

@@ -18,14 +18,15 @@
 use super::codegen::CodeGen;
 use super::sigma::combiners::*;
 use super::sigma::types::{expr_type_tokens, AExprType};
-use super::syntax::taggedvardict_to_vardict;
+use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
 use super::transform::prune_statement_tree;
 use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
 use quote::quote;
-use syn::{parse_quote, Expr, Result};
+use syn::{parse_quote, Error, Expr, Result};
 
 /// Look for, and apply, all of the public scalar equality statements
 /// specified in leaves of the [`StatementTree`].
+#[allow(non_snake_case)] // so that Points can be capital letters
 pub fn transform(
     codegen: &mut CodeGen,
     st: &mut StatementTree,
@@ -34,15 +35,19 @@ pub fn transform(
     // Construct the VarDict corresponding to vars
     let vardict = taggedvardict_to_vardict(vars);
 
-    // Gather mutable references to all Exprs in the leaves of the
-    // StatementTree.  Note that this ignores the combiner structure in
-    // the StatementTree, but that's fine.
-    let mut leaves = st.leaves_mut();
+    // A list of the computationally independent (non-vector) Points in
+    // the macro input.  There must be at least one of them in order to
+    // handle public scalar equality statements inside disjunctions.
+    let cind_points = collect_cind_points(vars);
+
+    st.for_each_disjunction_branch(&mut |branch, path| {
+        // Are we in the root disjunction branch?  (path is empty)
+        let in_root_disjunction_branch = path.is_empty();
 
     // For each leaf expression, see if it looks like a public Scalar
     // equality statement
-    for leafexpr in leaves.iter_mut() {
-        if let Expr::Assign(syn::ExprAssign { left, right, .. }) = *leafexpr {
+        branch.for_each_disjunction_branch_leaf(&mut |leaf| {
+        if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, right, .. })) = leaf {
             if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
                 if let Some(id) = path.get_ident() {
                     let idstr = id.to_string();
@@ -62,8 +67,11 @@ pub fn transform(
                         ) = expr_type_tokens(&vardict, right)?
                         {
                             // We found a public Scalar equality
-                            // statement.  Add code to both the prover
-                            // and the verifier to check the statement.
+                            // statement.
+                            if in_root_disjunction_branch {
+                            // If we're in the root disjunction branch,
+                            // add code to both the prover and the
+                            // verifier to directly check the statement.
                             codegen.prove_verify_append(quote! {
                                 if #id != #right_tokens {
                                     return Err(SigmaError::VerificationFailure);
@@ -73,14 +81,33 @@ pub fn transform(
                             // Remove the statement from the
                             // [`StatementTree`] by replacing it with
                             // leaf_true (which will be pruned below).
-                            let mut expr: Expr = parse_quote! { true };
-                            std::mem::swap(&mut expr, *leafexpr);
+                            *leaf = StatementTree::leaf_true();
+                            } else {
+                                // If we're not in the root disjunction
+                                // branch, replace the statement
+                                // `left_id = right_side` with the
+                                // statement `left_id*A =
+                                // (right_side)*A` for a cind Point A.
+                                if cind_points.is_empty() {
+                                    return Err(Error::new(
+                                        proc_macro2::Span::call_site(),
+                                        "At least one cind Point must be declared to support public Scalar equality statements inside disjunctions",
+                                    ));
+                                }
+                                let cind_A = &cind_points[0];
+
+                                *leaf = StatementTree::Leaf(parse_quote! {
+                                    #id * #cind_A = (#right) * #cind_A
+                                });
+                            }
                         }
                     }
                 }
             }
         }
-    }
+        Ok(())
+    })
+    })?;
 
     // Now prune the StatementTree
     prune_statement_tree(st);

+ 12 - 3
tests/pubscalars.rs

@@ -5,8 +5,7 @@ use group::Group;
 use sha2::Sha512;
 use sigma_compiler::*;
 
-#[test]
-fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
+fn pubscalars_test_val(b_val: u128) -> Result<(), sigma_rs::errors::Error> {
     sigma_compiler! { proof,
         (x, z, rand r, rand s, pub a, pub b),
         (C, D, const cind A, const cind B),
@@ -25,7 +24,7 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     let x = Scalar::from_u128(5);
     let z = Scalar::from_u128(17);
     let a = Scalar::from_u128(7);
-    let b = Scalar::from_u128(11);
+    let b = Scalar::from_u128(b_val);
     let C = x * A + r * B;
     let D = z * A + s * B;
 
@@ -35,3 +34,13 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     let proof = proof::prove(&instance, &witness, b"pubscalars_test", &mut rng)?;
     proof::verify(&instance, &proof, b"pubscalars_test")
 }
+
+#[test]
+fn pubscalars_test() {
+    pubscalars_test_val(10u128).unwrap_err();
+    pubscalars_test_val(11u128).unwrap();
+    pubscalars_test_val(12u128).unwrap_err();
+    pubscalars_test_val(13u128).unwrap_err();
+    pubscalars_test_val(14u128).unwrap_err();
+    pubscalars_test_val(15u128).unwrap_err();
+}

+ 49 - 0
tests/pubscalars_or.rs

@@ -0,0 +1,49 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn pubscalars_or_test_val(b_val: u128) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (x, z, rand r, rand s, pub a, pub b),
+        (C, D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = z*A + s*B,
+        z = 2*x + a,
+        OR (
+            b = 2*a,
+            b = 2*a - 3,
+        )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r = Scalar::random(&mut rng);
+    let s = Scalar::random(&mut rng);
+    let x = Scalar::from_u128(5);
+    let z = Scalar::from_u128(17);
+    let a = Scalar::from_u128(7);
+    let b = Scalar::from_u128(b_val);
+    let C = x * A + r * B;
+    let D = z * A + s * B;
+
+    let instance = proof::Instance { C, D, A, B, a, b };
+    let witness = proof::Witness { x, z, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"pubscalars_or_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"pubscalars_or_test")
+}
+
+#[test]
+fn pubscalars_or_test() {
+    pubscalars_or_test_val(10u128).unwrap_err();
+    pubscalars_or_test_val(11u128).unwrap();
+    pubscalars_or_test_val(12u128).unwrap_err();
+    pubscalars_or_test_val(13u128).unwrap_err();
+    pubscalars_or_test_val(14u128).unwrap();
+    pubscalars_or_test_val(15u128).unwrap_err();
+}