Browse Source

Allow the left side of linear combination statements to be expressions evaluating to a public Point

Previously, they had to be exactly a single Point variable
Ian Goldberg 3 months ago
parent
commit
26db0ee8a2
3 changed files with 129 additions and 27 deletions
  1. 41 27
      sigma_compiler_core/src/sigma/codegen.rs
  2. 6 0
      sigma_compiler_derive/src/lib.rs
  3. 82 0
      tests/left_expr.rs

+ 41 - 27
sigma_compiler_core/src/sigma/codegen.rs

@@ -204,39 +204,53 @@ impl<'a> CodeGen<'a> {
             let eq_id = format_ident!("{}eq{}", self.unique_prefix, i + 1);
             let vec_index_var = format_ident!("{}i", self.unique_prefix);
             let vec_len_var = format_ident!("{}veclen{}", self.unique_prefix, i + 1);
+
+            // Record any vector variables we encountered in this
+            // expression
+            let mut vec_param_vars: HashSet<Ident> = HashSet::new();
+            let mut vec_witness_vars: HashSet<Ident> = HashSet::new();
+
             // Ensure the `Expr` is of a type we recognize.  In
-            // particular, it must be an assignment (C = something)
-            // where the variable on the left is a public Point, and the
-            // something on the right is an arithmetic expression that
-            // evaluates to a private Point.  It is allowed for neither
-            // or both Points to be vector variables.
+            // particular, it must be an assignment (left = right) where
+            // the expression on the left is an arithmetic expression
+            // that evaluates to a public Point, and the expression on
+            // the right is an arithmetic expression that evaluates to a
+            // Point.  It is allowed for neither or both Points to be
+            // vector variables.
             let Expr::Assign(syn::ExprAssign { left, right, .. }) = expr else {
                 let expr_str = quote! { #expr }.to_string();
                 panic!("Unrecognized expression: {expr_str}");
             };
-            let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() else {
-                let expr_str = quote! { #expr }.to_string();
-                panic!("Left side of = is not a variable: {expr_str}");
-            };
-            let Some(left_id) = path.get_ident() else {
-                let expr_str = quote! { #expr }.to_string();
-                panic!("Left side of = is not a variable: {expr_str}");
-            };
-            let Some(AExprType::Point {
-                is_vec: left_is_vec,
+            let (left_type, left_tokens) =
+                expr_type_tokens_id_closure(self.vars, left, &mut |id, id_type| match id_type {
+                    AExprType::Scalar { is_pub: false, .. } => {
+                        panic!("Left side of = contains a private Scalar");
+                    }
+                    AExprType::Scalar {
+                        is_vec: false,
+                        is_pub: true,
+                        ..
+                    }
+                    | AExprType::Point { is_vec: false, .. } => Ok(quote! {#instance_var.#id}),
+                    AExprType::Scalar {
+                        is_vec: true,
+                        is_pub: true,
+                        ..
+                    }
+                    | AExprType::Point { is_vec: true, .. } => {
+                        vec_param_vars.insert(id.clone());
+                        Ok(quote! {#instance_var.#id[#vec_index_var]})
+                    }
+                })
+                .unwrap();
+            let AExprType::Point {
                 is_pub: true,
-            }) = self.vars.get(&left_id.to_string())
+                is_vec: left_is_vec,
+            } = left_type
             else {
                 let expr_str = quote! { #expr }.to_string();
-                panic!("Left side of = is not a public point: {expr_str}");
+                panic!("Left side of = does not evaluate to a public point: {expr_str}");
             };
-            // Record any vector variables we encountered in this
-            // expression
-            let mut vec_param_vars: HashSet<Ident> = HashSet::new();
-            let mut vec_witness_vars: HashSet<Ident> = HashSet::new();
-            if *left_is_vec {
-                vec_param_vars.insert(left_id.clone());
-            }
             let Ok((right_type, right_tokens)) =
                 expr_type_tokens_id_closure(self.vars, right, &mut |id, id_type| match id_type {
                     AExprType::Scalar {
@@ -336,7 +350,7 @@ impl<'a> CodeGen<'a> {
                 let expr_str = quote! { #expr }.to_string();
                 panic!("Right side of = does not evaluate to a Point: {expr_str}");
             };
-            if *left_is_vec != right_is_vec {
+            if left_is_vec != right_is_vec {
                 let expr_str = quote! { #expr }.to_string();
                 panic!("Only one side of = is a vector expression: {expr_str}");
             }
@@ -394,7 +408,7 @@ impl<'a> CodeGen<'a> {
                     for #vec_index_var in 0..#vec_len_var {
                         #lr_var.set_element(
                             #eq_id[#vec_index_var],
-                            #instance_var.#left_id[#vec_index_var],
+                            #left_tokens,
                         );
                     }
                 };
@@ -405,7 +419,7 @@ impl<'a> CodeGen<'a> {
                 };
                 element_assigns = quote! {
                     #element_assigns
-                    #lr_var.set_element(#eq_id, #instance_var.#left_id);
+                    #lr_var.set_element(#eq_id, #left_tokens);
                 }
             }
         }

+ 6 - 0
sigma_compiler_derive/src/lib.rs

@@ -101,6 +101,12 @@ use syn::parse_macro_input;
 ///        hold component-wise.  Any non-vector variable in the
 ///        statement is considered equivalent to a vector variable, all
 ///        of whose entries have the same value.
+///
+///        As an extension, you can also use an arithmetic expression
+///        evaluating to a _public_ `Point` in place of `C` on the left
+///        side of the `=`.  For example, if `a` is a `Scalar` tagged
+///        `pub`, and `C` is a `Point`, then the expression `(2*a+1)*C =
+///        arith_expr` is a valid linear combination statement.
 ///      - `a = arith_expr`, where `a` is a variable representing a
 ///        private `Scalar`.  This is a _substitution statement_.  Its
 ///        meaning is to say that the private `Scalar` `a` has the value

+ 82 - 0
tests/left_expr.rs

@@ -0,0 +1,82 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::ff::PrimeField;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+#[test]
+fn left_expr_test() -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (x, y, pub a, rand r, rand s),
+        (C, D, const cind A, const cind B),
+        D = y*A + s*B,
+        (2*a-1)*C + D = x*A + r*B,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let r = Scalar::random(&mut rng);
+    let s = Scalar::random(&mut rng);
+    let a = Scalar::from_u128(9);
+    let x = Scalar::from_u128(5);
+    let y = Scalar::from_u128(12);
+    let D = y * A + s * B;
+    let C = (x * A + r * B - D) * (a + a - Scalar::ONE).invert();
+
+    let instance = proof::Instance { C, D, A, B, a };
+    let witness = proof::Witness { x, y, r, s };
+
+    let proof = proof::prove(&instance, &witness, b"left_expr_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"left_expr_test")
+}
+
+#[test]
+fn left_expr_vec_test() -> Result<(), sigma_rs::errors::Error> {
+    sigma_compiler! { proof,
+        (vec x, vec y, z, pub vec a, pub b, rand vec r, rand vec s, rand t),
+        (vec C, vec D, E, const cind A, const cind B),
+        E = z*A + t*B,
+        D = y*A + s*B,
+        (2*a-1)*C + b*D + E = x*A + r*B,
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let A = G::hash_from_bytes::<Sha512>(b"Generator A");
+    let B = G::generator();
+    let vlen = 5usize;
+    let r: Vec<Scalar> = (0..vlen).map(|_| Scalar::random(&mut rng)).collect();
+    let s: Vec<Scalar> = (0..vlen).map(|_| Scalar::random(&mut rng)).collect();
+    let t = Scalar::random(&mut rng);
+    let a: Vec<Scalar> = (0..vlen).map(|i| Scalar::from_u128(i as u128)).collect();
+    let b = Scalar::from_u128(17);
+    let x: Vec<Scalar> = (0..vlen)
+        .map(|i| Scalar::from_u128((2 * i) as u128))
+        .collect();
+    let y: Vec<Scalar> = (0..vlen)
+        .map(|i| Scalar::from_u128((3 * i) as u128))
+        .collect();
+    let z = Scalar::from_u128(12);
+    let E = z * A + t * B;
+    let D: Vec<G> = (0..vlen).map(|i| y[i] * A + s[i] * B).collect();
+    let C: Vec<G> = (0..vlen)
+        .map(|i| (x[i] * A + r[i] * B - b * D[i] - E) * (a[i] + a[i] - Scalar::ONE).invert())
+        .collect();
+
+    let instance = proof::Instance {
+        C,
+        D,
+        E,
+        A,
+        B,
+        a,
+        b,
+    };
+    let witness = proof::Witness { x, y, z, r, s, t };
+
+    let proof = proof::prove(&instance, &witness, b"left_expr_vec_test", &mut rng)?;
+    proof::verify(&instance, &proof, b"left_expr_vec_test")
+}