Procházet zdrojové kódy

In a disjunction branch, convert statements about equality of public Scalars to statements about equality of public Points by multiplying each side by a generator

When we're not in a disjunction, we just emit code to directly check the
equality of the Scalars, as before.  But inside, the equality may not be
true.

So inside a disjunction, change statements like `b = 2*a - 3` into
`b*A = (2*a - 3)*A`, where `A` is one of the declared cind Points.

This commit is not formatted so as to minimize the diff.  The next
commit will run "cargo fmt" with no code changes.
Ian Goldberg před 6 měsíci
rodič
revize
c4c86e28d1
3 změnil soubory, kde provedl 101 přidání a 16 odebrání
  1. 40 13
      sigma_compiler_core/src/pubscalareq.rs
  2. 12 3
      tests/pubscalars.rs
  3. 49 0
      tests/pubscalars_or.rs

+ 40 - 13
sigma_compiler_core/src/pubscalareq.rs

@@ -18,14 +18,15 @@
 use super::codegen::CodeGen;
 use super::sigma::combiners::*;
 use super::sigma::types::{expr_type_tokens, AExprType};
-use super::syntax::taggedvardict_to_vardict;
+use super::syntax::{collect_cind_points, taggedvardict_to_vardict};
 use super::transform::prune_statement_tree;
 use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
 use quote::quote;
-use syn::{parse_quote, Expr, Result};
+use syn::{parse_quote, Error, Expr, Result};
 
 /// Look for, and apply, all of the public scalar equality statements
 /// specified in leaves of the [`StatementTree`].
+#[allow(non_snake_case)] // so that Points can be capital letters
 pub fn transform(
     codegen: &mut CodeGen,
     st: &mut StatementTree,
@@ -34,15 +35,19 @@ pub fn transform(
     // Construct the VarDict corresponding to vars
     let vardict = taggedvardict_to_vardict(vars);
 
-    // Gather mutable references to all Exprs in the leaves of the
-    // StatementTree.  Note that this ignores the combiner structure in
-    // the StatementTree, but that's fine.
-    let mut leaves = st.leaves_mut();
+    // A list of the computationally independent (non-vector) Points in
+    // the macro input.  There must be at least one of them in order to
+    // handle public scalar equality statements inside disjunctions.
+    let cind_points = collect_cind_points(vars);
+
+    st.for_each_disjunction_branch(&mut |branch, path| {
+        // Are we in the root disjunction branch?  (path is empty)
+        let in_root_disjunction_branch = path.is_empty();
 
     // For each leaf expression, see if it looks like a public Scalar
     // equality statement
-    for leafexpr in leaves.iter_mut() {
-        if let Expr::Assign(syn::ExprAssign { left, right, .. }) = *leafexpr {
+        branch.for_each_disjunction_branch_leaf(&mut |leaf| {
+        if let StatementTree::Leaf(Expr::Assign(syn::ExprAssign { left, right, .. })) = leaf {
             if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
                 if let Some(id) = path.get_ident() {
                     let idstr = id.to_string();
@@ -62,8 +67,11 @@ pub fn transform(
                         ) = expr_type_tokens(&vardict, right)?
                         {
                             // We found a public Scalar equality
-                            // statement.  Add code to both the prover
-                            // and the verifier to check the statement.
+                            // statement.
+                            if in_root_disjunction_branch {
+                            // If we're in the root disjunction branch,
+                            // add code to both the prover and the
+                            // verifier to directly check the statement.
                             codegen.prove_verify_append(quote! {
                                 if #id != #right_tokens {
                                     return Err(SigmaError::VerificationFailure);
@@ -73,14 +81,33 @@ pub fn transform(
                             // Remove the statement from the
                             // [`StatementTree`] by replacing it with
                             // leaf_true (which will be pruned below).
-                            let mut expr: Expr = parse_quote! { true };
-                            std::mem::swap(&mut expr, *leafexpr);
+                            *leaf = StatementTree::leaf_true();
+                            } else {
+                                // If we're not in the root disjunction
+                                // branch, replace the statement
+                                // `left_id = right_side` with the
+                                // statement `left_id*A =
+                                // (right_side)*A` for a cind Point A.
+                                if cind_points.is_empty() {
+                                    return Err(Error::new(
+                                        proc_macro2::Span::call_site(),
+                                        "At least one cind Point must be declared to support public Scalar equality statements inside disjunctions",
+                                    ));
+                                }
+                                let cind_A = &cind_points[0];
+
+                                *leaf = StatementTree::Leaf(parse_quote! {
+                                    #id * #cind_A = (#right) * #cind_A
+                                });
+                            }
                         }
                     }
                 }
             }
         }
-    }
+        Ok(())
+    })
+    })?;
 
     // Now prune the StatementTree
     prune_statement_tree(st);

+ 12 - 3
tests/pubscalars.rs

@@ -5,8 +5,7 @@ use group::Group;
 use sha2::Sha512;
 use sigma_compiler::*;
 
-#[test]
-fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
+fn pubscalars_test_val(b_val: u128) -> Result<(), sigma_rs::errors::Error> {
     sigma_compiler! { proof,
         (x, z, rand r, rand s, pub a, pub b),
         (C, D, const cind A, const cind B),
@@ -25,7 +24,7 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     let x = Scalar::from_u128(5);
     let z = Scalar::from_u128(17);
     let a = Scalar::from_u128(7);
-    let b = Scalar::from_u128(11);
+    let b = Scalar::from_u128(b_val);
     let C = x * A + r * B;
     let D = z * A + s * B;
 
@@ -35,3 +34,13 @@ fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     let proof = proof::prove(&instance, &witness, b"pubscalars_test", &mut rng)?;
     proof::verify(&instance, &proof, b"pubscalars_test")
 }
+
+#[test]
+fn pubscalars_test() {
+    pubscalars_test_val(10u128).unwrap_err();
+    pubscalars_test_val(11u128).unwrap();
+    pubscalars_test_val(12u128).unwrap_err();
+    pubscalars_test_val(13u128).unwrap_err();
+    pubscalars_test_val(14u128).unwrap_err();
+    pubscalars_test_val(15u128).unwrap_err();
+}

+ 49 - 0
tests/pubscalars_or.rs

@@ -0,0 +1,49 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+fn pubscalars_or_test_val(b_val: u128) -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (x, z, rand r, rand s, pub a, pub b),
+        (C, D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = z*A + s*B,
+        z = 2*x + a,
+        OR (
+            b = 2*a,
+            b = 2*a - 3,
+        )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r = Scalar::random(&mut rng);
+    let s = Scalar::random(&mut rng);
+    let x = Scalar::from_u128(5);
+    let z = Scalar::from_u128(17);
+    let a = Scalar::from_u128(7);
+    let b = Scalar::from_u128(b_val);
+    let C = x * A + r * B;
+    let D = z * A + s * B;
+
+    let instance = proof::Instance { C, D, A, B, a, b };
+    let witness = proof::Witness { x, z, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"pubscalars_or_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"pubscalars_or_test")
+}
+
+#[test]
+fn pubscalars_or_test() {
+    pubscalars_or_test_val(10u128).unwrap_err();
+    pubscalars_or_test_val(11u128).unwrap();
+    pubscalars_or_test_val(12u128).unwrap_err();
+    pubscalars_or_test_val(13u128).unwrap_err();
+    pubscalars_or_test_val(14u128).unwrap();
+    pubscalars_or_test_val(15u128).unwrap_err();
+}