Browse Source

Code generation to use the sigma-rs API

Ian Goldberg 4 months ago
parent
commit
dcf9c81a49
1 changed files with 334 additions and 20 deletions
  1. 334 20
      sigma_compiler_core/src/sigma/codegen.rs

+ 334 - 20
sigma_compiler_core/src/sigma/codegen.rs

@@ -4,10 +4,11 @@
 //! directly.
 
 use super::combiners::StatementTree;
-use super::types::{AExprType, VarDict};
+use super::types::{expr_type_tokens_id_closure, AExprType, VarDict};
 use proc_macro2::TokenStream;
 use quote::{format_ident, quote, ToTokens};
-use syn::Ident;
+use std::collections::HashSet;
+use syn::{Expr, Ident};
 
 /// Names and types of fields that might end up in a generated struct
 pub enum StructField {
@@ -183,23 +184,345 @@ impl<'a> CodeGen<'a> {
 
     /// Generate the code for the `protocol` and `protocol_witness`
     /// functions that create the `Protocol` and `ProtocolWitness`
-    /// structs, respectively, given a [`VarDict`] and a
-    /// [`StatementTree`] describing the statements to be proven.  The
-    /// output components are the code for the `protocol` and
-    /// `protocol_witness` functions, respectively.  The `protocol` code
+    /// structs, respectively, given a slice of [`Expr`]s that will be
+    /// bundled into a single `LinearRelation`.  The `protocol` code
     /// must evaluate to a `Result<Protocol>` and the `protocol_witness`
     /// code must evaluate to a `Result<ProtocolWitness>`.
-    fn proto_witness_codegen(&self, statement: &StatementTree) -> (TokenStream, TokenStream) {
+    fn linear_relation_codegen(&self, exprs: &[&Expr]) -> (TokenStream, TokenStream) {
+        let params_var = format_ident!("{}params", self.unique_prefix);
+        let lr_var = format_ident!("{}lr", self.unique_prefix);
+        let mut allocated_vars: HashSet<Ident> = HashSet::new();
+        let mut param_vec_code = quote! {};
+        let mut witness_vec_code = quote! {};
+        let mut witness_code = quote! {};
+        let mut scalar_allocs = quote! {};
+        let mut element_allocs = quote! {};
+        let mut eq_code = quote! {};
+        let mut element_assigns = quote! {};
+
+        for (i, expr) in exprs.iter().enumerate() {
+            let eq_id = format_ident!("{}eq{}", self.unique_prefix, i + 1);
+            let vec_index_var = format_ident!("{}i", self.unique_prefix);
+            let vec_len_var = format_ident!("{}veclen{}", self.unique_prefix, i + 1);
+            // Ensure the `Expr` is of a type we recognize.  In
+            // particular, it must be an assignment (C = something)
+            // where the variable on the left is a public Point, and the
+            // something on the right is an arithmetic expression that
+            // evaluates to a private Point.  It is allowed for neither
+            // or both Points to be vector variables.
+            let Expr::Assign(syn::ExprAssign { left, right, .. }) = expr else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Unrecognized expression: {expr_str}");
+            };
+            let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Left side of = is not a variable: {expr_str}");
+            };
+            let Some(left_id) = path.get_ident() else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Left side of = is not a variable: {expr_str}");
+            };
+            let Some(AExprType::Point {
+                is_vec: left_is_vec,
+                is_pub: true,
+            }) = self.vars.get(&left_id.to_string())
+            else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Left side of = is not a public point: {expr_str}");
+            };
+            // Record any vector variables we encountered in this
+            // expression
+            let mut vec_param_vars: HashSet<Ident> = HashSet::new();
+            let mut vec_witness_vars: HashSet<Ident> = HashSet::new();
+            if *left_is_vec {
+                vec_param_vars.insert(left_id.clone());
+            }
+            let Ok((right_type, right_tokens)) =
+                expr_type_tokens_id_closure(self.vars, right, &mut |id, id_type| match id_type {
+                    AExprType::Scalar {
+                        is_vec: false,
+                        is_pub: false,
+                        ..
+                    } => {
+                        if allocated_vars.insert(id.clone()) {
+                            scalar_allocs = quote! {
+                                #scalar_allocs
+                                let #id = #lr_var.allocate_scalar();
+                            };
+                            witness_code = quote! {
+                                #witness_code
+                                witnessvec.push(witness.#id);
+                            };
+                        }
+                        Ok(quote! {#id})
+                    }
+                    AExprType::Scalar {
+                        is_vec: false,
+                        is_pub: true,
+                        ..
+                    } => Ok(quote! {#params_var.#id}),
+                    AExprType::Scalar {
+                        is_vec: true,
+                        is_pub: false,
+                        ..
+                    } => {
+                        vec_witness_vars.insert(id.clone());
+                        if allocated_vars.insert(id.clone()) {
+                            scalar_allocs = quote! {
+                                #scalar_allocs
+                                let #id = (0..#vec_len_var)
+                                    .map(|i| #lr_var.allocate_scalar())
+                                    .collect::<Vec<_>>();
+                            };
+                            witness_code = quote! {
+                                #witness_code
+                                witnessvec.extend(witness.#id.clone());
+                            };
+                        }
+                        Ok(quote! {#id[#vec_index_var]})
+                    }
+                    AExprType::Scalar {
+                        is_vec: true,
+                        is_pub: true,
+                        ..
+                    } => {
+                        vec_param_vars.insert(id.clone());
+                        Ok(quote! {#params_var.#id[#vec_index_var]})
+                    }
+                    AExprType::Point { is_vec: false, .. } => {
+                        if allocated_vars.insert(id.clone()) {
+                            element_allocs = quote! {
+                                #element_allocs
+                                let #id = #lr_var.allocate_element();
+                            };
+                            element_assigns = quote! {
+                                #element_assigns
+                                #lr_var.set_element(#id, #params_var.#id);
+                            };
+                        }
+                        Ok(quote! {#id})
+                    }
+                    AExprType::Point { is_vec: true, .. } => {
+                        vec_param_vars.insert(id.clone());
+                        if allocated_vars.insert(id.clone()) {
+                            element_allocs = quote! {
+                                #element_allocs
+                                let #id = (0..#vec_len_var)
+                                    .map(|#vec_index_var| #lr_var.allocate_element())
+                                    .collect::<Vec<_>>();
+                            };
+                            element_assigns = quote! {
+                                #element_assigns
+                                for #vec_index_var in 0..#vec_len_var {
+                                    #lr_var.set_element(
+                                        #id[#vec_index_var],
+                                        #params_var.#id[#vec_index_var],
+                                    );
+                                }
+                            };
+                        }
+                        Ok(quote! {#id[#vec_index_var]})
+                    }
+                })
+            else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Right side of = is not a valid arithmetic expression: {expr_str}");
+            };
+            let AExprType::Point {
+                is_vec: right_is_vec,
+                is_pub: false,
+            } = right_type
+            else {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Right side of = does not evaluate to a private Point: {expr_str}");
+            };
+            if *left_is_vec != right_is_vec {
+                let expr_str = quote! { #expr }.to_string();
+                panic!("Only one side of = is a vector expression: {expr_str}");
+            }
+            let vec_param_varvec = Vec::from_iter(vec_param_vars);
+            let vec_witness_varvec = Vec::from_iter(vec_witness_vars);
+
+            if !vec_param_varvec.is_empty() {
+                let firstvar = &vec_param_varvec[0];
+                param_vec_code = quote! {
+                    #param_vec_code
+                    let #vec_len_var = #params_var.#firstvar.len();
+                };
+                for thisvar in vec_param_varvec.iter().skip(1) {
+                    param_vec_code = quote! {
+                        #param_vec_code
+                        if #vec_len_var != #params_var.#thisvar.len() {
+                            eprintln!(
+                                "Params {} and {} must have the same length",
+                                stringify!(#firstvar),
+                                stringify!(#thisvar),
+                            );
+                            return Err(SigmaError::VerificationFailure);
+                        }
+                    };
+                }
+                if !vec_witness_varvec.is_empty() {
+                    witness_vec_code = quote! {
+                        #witness_vec_code
+                        let #vec_len_var = params.#firstvar.len();
+                    };
+                }
+                for witvar in vec_witness_varvec {
+                    witness_vec_code = quote! {
+                        #witness_vec_code
+                        if #vec_len_var != witness.#witvar.len() {
+                            eprintln!(
+                                "Params {} and {} must have the same length",
+                                stringify!(#firstvar),
+                                stringify!(#witvar),
+                            );
+                            return Err(SigmaError::VerificationFailure);
+                        }
+                    }
+                }
+            };
+            if right_is_vec {
+                eq_code = quote! {
+                    #eq_code
+                    let #eq_id = (0..#vec_len_var)
+                        .map(|#vec_index_var| #lr_var.allocate_eq(#right_tokens))
+                        .collect::<Vec<_>>();
+                };
+                element_assigns = quote! {
+                    #element_assigns
+                    for #vec_index_var in 0..#vec_len_var {
+                        #lr_var.set_element(
+                            #eq_id[#vec_index_var],
+                            #params_var.#left_id[#vec_index_var],
+                        );
+                    }
+                };
+            } else {
+                eq_code = quote! {
+                    #eq_code
+                    let #eq_id = #lr_var.allocate_eq(#right_tokens);
+                };
+                element_assigns = quote! {
+                    #element_assigns
+                    #lr_var.set_element(#eq_id, #params_var.#left_id);
+                }
+            }
+        }
+
         (
             quote! {
-                Ok(Protocol::from(LinearRelation::<Point>::new()))
+                {
+                    let mut #lr_var = LinearRelation::<Point>::new();
+                    #param_vec_code
+                    #scalar_allocs
+                    #element_allocs
+                    #eq_code
+                    #element_assigns
+
+                    Ok(Protocol::from(#lr_var))
+                }
             },
             quote! {
-                Ok(ProtocolWitness::Simple(vec![]))
+                {
+                    #witness_vec_code
+                    let mut witnessvec = Vec::new();
+                    #witness_code
+                    Ok(ProtocolWitness::Simple(witnessvec))
+                }
             },
         )
     }
 
+    /// Generate the code for the `protocol` and `protocol_witness`
+    /// functions that create the `Protocol` and `ProtocolWitness`
+    /// structs, respectively, given a [`StatementTree`] describing the
+    /// statements to be proven.  The output components are the code for
+    /// the `protocol` and `protocol_witness` functions, respectively.
+    /// The `protocol` code must evaluate to a `Result<Protocol>` and
+    /// the `protocol_witness` code must evaluate to a
+    /// `Result<ProtocolWitness>`.
+    fn proto_witness_codegen(&self, statement: &StatementTree) -> (TokenStream, TokenStream) {
+        match statement {
+            // The StatementTree has no statements (it's just the single
+            // leaf "true")
+            StatementTree::Leaf(_) if statement.is_leaf_true() => (
+                quote! {
+                    Ok(Protocol::from(LinearRelation::<Point>::new()))
+                },
+                quote! {
+                    Ok(ProtocolWitness::Simple(vec![]))
+                },
+            ),
+            // The StatementTree is a single statement.  Generate a
+            // single LinearRelation from it.
+            StatementTree::Leaf(leafexpr) => {
+                self.linear_relation_codegen(std::slice::from_ref(&leafexpr))
+            }
+            // The StatementTree is an And.  Separate out the leaf
+            // statements, and generate a single LinearRelation from
+            // them.  Then if there are non-leaf nodes as well, And them
+            // together.
+            StatementTree::And(stvec) => {
+                let mut leaves: Vec<&Expr> = Vec::new();
+                let mut others: Vec<&StatementTree> = Vec::new();
+                for st in stvec {
+                    match st {
+                        StatementTree::Leaf(le) => leaves.push(le),
+                        _ => others.push(st),
+                    }
+                }
+                let (proto_code, witness_code) = self.linear_relation_codegen(&leaves);
+                if others.is_empty() {
+                    (proto_code, witness_code)
+                } else {
+                    let (others_proto, others_witness): (Vec<TokenStream>, Vec<TokenStream>) =
+                        others
+                            .iter()
+                            .map(|st| self.proto_witness_codegen(st))
+                            .unzip();
+                    (
+                        quote! {
+                            Ok(Protocol::And(vec![
+                                #proto_code?,
+                                #(#others_proto?,)*
+                            ]))
+                        },
+                        quote! {
+                            Ok(ProtocolWitness::And(vec![
+                                #witness_code?,
+                                #(#others_witness?,)*
+                            ]))
+                        },
+                    )
+                }
+            }
+            StatementTree::Or(stvec) => {
+                let (proto, witness): (Vec<TokenStream>, Vec<TokenStream>) = stvec
+                    .iter()
+                    .map(|st| self.proto_witness_codegen(st))
+                    .unzip();
+                (
+                    quote! {
+                        Ok(Protocol::Or(vec![
+                            #(#proto?,)*
+                        ]))
+                    },
+                    // TODO: Choose the correct branch for the witness
+                    // (currently hardcoded at 0)
+                    quote! {
+                        Ok(ProtocolWitness::Or(0, vec![
+                            #(#witness?,)*
+                        ]))
+                    },
+                )
+            }
+            StatementTree::Thresh(_thresh, _stvec) => {
+                todo! {"Thresh not yet implemented"};
+            }
+        }
+    }
+
     /// Generate the code that uses the `sigma-rs` API to prove and
     /// verify the statements in the [`CodeGen`].
     ///
@@ -283,14 +606,12 @@ impl<'a> CodeGen<'a> {
 
         // Generate the function that creates the sigma-rs Protocol
         let protocol_func = {
-            let params_ids = pub_params_fields.field_list();
             let params_var = format_ident!("{}params", self.unique_prefix);
 
             quote! {
                 fn protocol(
                     #params_var: &Params,
                 ) -> Result<Protocol<Point>, SigmaError> {
-                    let Params { #params_ids } = #params_var.clone();
                     #protocol_code
                 }
             }
@@ -298,18 +619,11 @@ impl<'a> CodeGen<'a> {
 
         // Generate the function that creates the sigma-rs ProtocolWitness
         let witness_func = {
-            let params_ids = pub_params_fields.field_list();
-            let witness_ids = witness_fields.field_list();
-            let params_var = format_ident!("{}params", self.unique_prefix);
-            let witness_var = format_ident!("{}witness", self.unique_prefix);
-
             quote! {
                 fn protocol_witness(
-                    #params_var: &Params,
-                    #witness_var: &Witness,
+                    params: &Params,
+                    witness: &Witness,
                 ) -> Result<ProtocolWitness<Point>, SigmaError> {
-                    let Params { #params_ids } = #params_var.clone();
-                    let Witness { #witness_ids } = #witness_var.clone();
                     #witness_code
                 }
             }