Browse Source

Add support for THRESH

Ian Goldberg 1 month ago
parent
commit
adf7102ec6
4 changed files with 165 additions and 9 deletions
  1. 8 7
      README.md
  2. 17 2
      sigma-compiler-core/src/sigma/codegen.rs
  3. 71 0
      tests/threshold.rs
  4. 69 0
      tests/threshold_pubscalars.rs

+ 8 - 7
README.md

@@ -154,13 +154,14 @@ The pieces are as follows:
        typical example.  This is a _not-equals statement_, and it
        means that the value of the expression on the left is not
        equal to the value of the expression on the right.
-   - Statements can also be combined with `AND(st1,st2,...,stn)` and
-     `OR(st1,st2,...,stn)`.  The list of statements in the macro
-     invocation are implicitly put into a top-level `AND`.  `AND`s
-     and `OR`s can be arbitrarily nested.  As usual, an `AND`
-     statement is true when all of its component statements are
-     true; an `OR` statement is true when at least one of its
-     component statements is true.
+   - Statements can also be combined with `AND(st1,st2,...,stn)`,
+     `OR(st1,st2,...,stn)`, or `THRESH(t,st1,st2,...,stn)`.  The list of
+     statements in the macro invocation are implicitly put into a
+     top-level `AND`.  `AND`s, `OR`s, and `THRESH`s can be arbitrarily
+     nested.  As usual, an `AND` statement is true when all of its
+     component statements are true; an `OR` statement is true when at
+     least one of its component statements is true; a `THRESH` statement
+     is true when at least `t` of its component statements are true.
    
 
 The macro creates a submodule with the name specified by

+ 17 - 2
sigma-compiler-core/src/sigma/codegen.rs

@@ -548,8 +548,23 @@ impl<'a> CodeGen<'a> {
                     },
                 )
             }
-            StatementTree::Thresh(_thresh, _stvec) => {
-                todo! {"Thresh not yet implemented"};
+            StatementTree::Thresh(thresh, stvec) => {
+                let (proto, witness): (Vec<TokenStream>, Vec<TokenStream>) = stvec
+                    .iter()
+                    .map(|st| self.proto_witness_codegen(st))
+                    .unzip();
+                (
+                    quote! {
+                        SigmaOk(ComposedRelation::threshold(#thresh, [
+                            #(#proto?,)*
+                        ]))
+                    },
+                    quote! {
+                        SigmaOk(ComposedWitness::threshold([
+                            #(#witness?,)*
+                        ]))
+                    },
+                )
             }
         }
     }

+ 71 - 0
tests/threshold.rs

@@ -0,0 +1,71 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+#[test]
+fn threshold_test() -> sigma_proofs::errors::Result<()> {
+    sigma_compiler! { thresh3,
+        (x1, x2, x3, x4, x5, rand r),
+        (C, const cind G0, const cind G1, const cind G2, const cind G3,
+            const cind G4, const cind G5),
+        C = r*G0 + x1*G1 + x2*G2 + x3*G3 + x4*G4 + x5*G5,
+        THRESH ( 3, x1 = 1, x2 = 2, x3 = 3, x4 = 4, x5 = 5 )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let G0 = G::generator();
+    let G1 = G::hash_from_bytes::<Sha512>(b"Generator G1");
+    let G2 = G::hash_from_bytes::<Sha512>(b"Generator G2");
+    let G3 = G::hash_from_bytes::<Sha512>(b"Generator G3");
+    let G4 = G::hash_from_bytes::<Sha512>(b"Generator G4");
+    let G5 = G::hash_from_bytes::<Sha512>(b"Generator G5");
+    let r = Scalar::random(&mut rng);
+    let y = Scalar::random(&mut rng);
+
+    // Iterate over all combinations of 5 bits
+    for true_pattern in 0u32..32 {
+        let x1 = Scalar::from(if true_pattern & 1 == 0 { 2u32 } else { 1u32 });
+        let x2 = Scalar::from(if true_pattern & 2 == 0 { 3u32 } else { 2u32 });
+        let x3 = Scalar::from(if true_pattern & 4 == 0 { 4u32 } else { 3u32 });
+        let x4 = Scalar::from(if true_pattern & 8 == 0 { 5u32 } else { 4u32 });
+        let x5 = Scalar::from(if true_pattern & 16 == 0 { 6u32 } else { 5u32 });
+        let C = r * G0 + x1 * G1 + x2 * G2 + x3 * G3 + x4 * G4 + x5 * G5;
+
+        let num_true = true_pattern.count_ones();
+
+        let instance = thresh3::Instance {
+            C,
+            G0,
+            G1,
+            G2,
+            G3,
+            G4,
+            G5,
+        };
+        let witness = thresh3::Witness {
+            x1,
+            x2,
+            x3,
+            x4,
+            x5,
+            r,
+        };
+
+        match thresh3::prove(&instance, &witness, b"thresh_test", &mut rng) {
+            Ok(_) if num_true < 3 => {
+                panic!("THRESH passed when it should have failed (true_pattern = {true_pattern})")
+            }
+            Err(_) if num_true >= 3 => {
+                panic!("THRESH failed when it should have passed (true_pattern = {true_pattern})")
+            }
+            Ok(proof) => {
+                thresh3::verify(&instance, &proof, b"thresh_test")?;
+            }
+            Err(_) => {}
+        }
+    }
+    Ok(())
+}

+ 69 - 0
tests/threshold_pubscalars.rs

@@ -0,0 +1,69 @@
+#![allow(non_snake_case)]
+use curve25519_dalek::ristretto::RistrettoPoint as G;
+use group::Group;
+use sha2::Sha512;
+use sigma_compiler::*;
+
+#[test]
+fn threshold_pubscalars_test() -> sigma_proofs::errors::Result<()> {
+    sigma_compiler! { thresh3,
+        (pub x1, pub x2, pub x3, pub x4, pub x5, rand r),
+        (C, const cind G0, const cind G1, const cind G2, const cind G3,
+            const cind G4, const cind G5),
+        C = r*G0 + x1*G1 + x2*G2 + x3*G3 + x4*G4 + x5*G5,
+        THRESH ( 3, x1 = 1, x2 = 2, x3 = 3, x4 = 4, x5 = 5 )
+    }
+
+    type Scalar = <G as Group>::Scalar;
+    let mut rng = rand::thread_rng();
+    let G0 = G::generator();
+    let G1 = G::hash_from_bytes::<Sha512>(b"Generator G1");
+    let G2 = G::hash_from_bytes::<Sha512>(b"Generator G2");
+    let G3 = G::hash_from_bytes::<Sha512>(b"Generator G3");
+    let G4 = G::hash_from_bytes::<Sha512>(b"Generator G4");
+    let G5 = G::hash_from_bytes::<Sha512>(b"Generator G5");
+    let r = Scalar::random(&mut rng);
+    let y = Scalar::random(&mut rng);
+
+    // Iterate over all combinations of 5 bits
+    for true_pattern in 0u32..32 {
+        let x1 = Scalar::from(if true_pattern & 1 == 0 { 2u32 } else { 1u32 });
+        let x2 = Scalar::from(if true_pattern & 2 == 0 { 3u32 } else { 2u32 });
+        let x3 = Scalar::from(if true_pattern & 4 == 0 { 4u32 } else { 3u32 });
+        let x4 = Scalar::from(if true_pattern & 8 == 0 { 5u32 } else { 4u32 });
+        let x5 = Scalar::from(if true_pattern & 16 == 0 { 6u32 } else { 5u32 });
+        let C = r * G0 + x1 * G1 + x2 * G2 + x3 * G3 + x4 * G4 + x5 * G5;
+
+        let num_true = true_pattern.count_ones();
+
+        let instance = thresh3::Instance {
+            C,
+            G0,
+            G1,
+            G2,
+            G3,
+            G4,
+            G5,
+            x1,
+            x2,
+            x3,
+            x4,
+            x5,
+        };
+        let witness = thresh3::Witness { r };
+
+        match thresh3::prove(&instance, &witness, b"thresh_pubscalars_test", &mut rng) {
+            Ok(_) if num_true < 3 => {
+                panic!("THRESH passed when it should have failed (true_pattern = {true_pattern})")
+            }
+            Err(_) if num_true >= 3 => {
+                panic!("THRESH failed when it should have passed (true_pattern = {true_pattern})")
+            }
+            Ok(proof) => {
+                thresh3::verify(&instance, &proof, b"thresh_pubscalars_test")?;
+            }
+            Err(_) => {}
+        }
+    }
+    Ok(())
+}