Browse Source

Allow statements that are equality of a public Scalar variable to a public Scalar expression.

Vector variables are not currently supported in this kind of statement.
Ian Goldberg 4 months ago
parent
commit
4641a3cabc

+ 4 - 0
sigma_compiler_core/src/lib.rs

@@ -10,6 +10,7 @@ pub mod sigma {
 }
 mod codegen;
 mod pedersen;
+mod pubscalareq;
 mod rangeproof;
 mod substitution;
 mod syntax;
@@ -42,5 +43,8 @@ pub fn sigma_compiler_core(
     // Apply any range statement transformations
     rangeproof::transform(&mut codegen, &mut spec.statements, &mut spec.vars).unwrap();
 
+    // Apply any public scalar equality transformations
+    pubscalareq::transform(&mut codegen, &mut spec.statements, &mut spec.vars).unwrap();
+
     codegen.generate(spec, emit_prover, emit_verifier)
 }

+ 85 - 0
sigma_compiler_core/src/pubscalareq.rs

@@ -0,0 +1,85 @@
+//! A module to look for, and apply, any statement involving the
+//! equality of _public_ `Scalar`s.
+//!
+//! Such a statement is of the form `a = 2*(c+1)` where `a` and `c` are
+//! public `Scalar`s.  That is, it is a single variable name (which must
+//! be a public `Scalar`, as specified in the provided
+//! [`TaggedVarDict`]), an equal sign, and an [arithmetic expression]
+//! involving other public `Scalar` variables, constants, parens, and
+//! the operators `+`, `-`, and `*`.
+//!
+//! The statement is simply removed from the list of statements to be
+//! proven in the zero-knowledge sigma protocol, and code is emitted for
+//! the prover and verifier to each just check that the statement is
+//! satisfied.
+//!
+//! [arithmetic expression]: super::sigma::types::expr_type
+
+use super::codegen::CodeGen;
+use super::sigma::combiners::*;
+use super::sigma::types::{expr_type_tokens, AExprType};
+use super::syntax::taggedvardict_to_vardict;
+use super::transform::prune_statement_tree;
+use super::{TaggedIdent, TaggedScalar, TaggedVarDict};
+use quote::quote;
+use syn::{parse_quote, Expr, Result};
+
+/// Look for, and apply, all of the public scalar equality statements
+/// specified in leaves of the [`StatementTree`].
+pub fn transform(
+    codegen: &mut CodeGen,
+    st: &mut StatementTree,
+    vars: &mut TaggedVarDict,
+) -> Result<()> {
+    // Construct the VarDict corresponding to vars
+    let vardict = taggedvardict_to_vardict(vars);
+
+    // Gather mutable references to all Exprs in the leaves of the
+    // StatementTree.  Note that this ignores the combiner structure in
+    // the StatementTree, but that's fine.
+    let mut leaves = st.leaves_mut();
+
+    // For each leaf expression, see if it looks like a public Scalar
+    // equality statement
+    for leafexpr in leaves.iter_mut() {
+        if let Expr::Assign(syn::ExprAssign { left, right, .. }) = *leafexpr {
+            if let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() {
+                if let Some(id) = path.get_ident() {
+                    let idstr = id.to_string();
+                    if let Some(TaggedIdent::Scalar(TaggedScalar { is_pub: true, .. })) =
+                        vars.get(&idstr)
+                    {
+                        if let (AExprType::Scalar { is_pub: true, .. }, right_tokens) =
+                            expr_type_tokens(&vardict, right)?
+                        {
+                            // We found a public Scalar equality
+                            // statement.  Add code to both the prover
+                            // and the verifier to check the statement.
+                            codegen.prove_append(quote! {
+                                if #id != #right_tokens {
+                                    return Err(SigmaError::VerificationFailure);
+                                }
+                            });
+                            codegen.verify_append(quote! {
+                                if #id != #right_tokens {
+                                    return Err(SigmaError::VerificationFailure);
+                                }
+                            });
+
+                            // Remove the statement from the
+                            // [`StatementTree`] by replacing it with
+                            // leaf_true (which will be pruned below).
+                            let mut expr: Expr = parse_quote! { true };
+                            std::mem::swap(&mut expr, *leafexpr);
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    // Now prune the StatementTree
+    prune_statement_tree(st);
+
+    Ok(())
+}

+ 10 - 0
sigma_compiler_derive/src/lib.rs

@@ -112,6 +112,16 @@ use syn::parse_macro_input;
 ///        directly, or after other substitutions.  For example, the
 ///        statement `a = a + b` is not allowed, nor is the combination
 ///        of substitutions `a = b + 1, b = c + 2, c = 2*a`.
+///      - `a = arith_expr`, where `a` is a variable representing a
+///        public `Scalar`.  This is a _public Scalar equality
+///        statement_.  Its meaning is to say that the public `Scalar`
+///        `a` has the value given by the arithmetic expression, which
+///        must evaluate to a public `Scalar`.  The statement is simply
+///        removed from the list of statements to be proven in the
+///        zero-knowledge sigma protocol, and code is emitted for the
+///        prover and verifier to each just check that the statement is
+///        satisfied.  Currently, there can be no vector variables in
+///        this kind of statement.
 ///      - `(a..b).contains(x)`, where `a` and `b` are _public_
 ///        `Scalar`s (or arithmetic expressions evaluating to public
 ///        `Scalar`s), and `x` is a private `Scalar`, possibly

+ 7 - 7
tests/pubscalars.rs

@@ -8,13 +8,13 @@ use sigma_compiler::*;
 #[test]
 fn pubscalars_test() -> Result<(), sigma_rs::errors::Error> {
     sigma_compiler! { proof,
-            (x, z, rand r, rand s, pub a, pub b),
-            (C, D, const cind A, const cind B),
-            C = x*A + r*B,
-            D = z*A + s*B,
-            z = 2*x + a,
-    //        b = 2*a - 3,
-        }
+        (x, z, rand r, rand s, pub a, pub b),
+        (C, D, const cind A, const cind B),
+        C = x*A + r*B,
+        D = z*A + s*B,
+        z = 2*x + a,
+        b = 2*a - 3,
+    }
 
     type Scalar = <G as Group>::Scalar;
     let mut rng = rand::thread_rng();