|
|
@@ -4,10 +4,11 @@
|
|
|
//! directly.
|
|
|
|
|
|
use super::combiners::StatementTree;
|
|
|
-use super::types::{AExprType, VarDict};
|
|
|
+use super::types::{expr_type_tokens_id_closure, AExprType, VarDict};
|
|
|
use proc_macro2::TokenStream;
|
|
|
use quote::{format_ident, quote, ToTokens};
|
|
|
-use syn::Ident;
|
|
|
+use std::collections::HashSet;
|
|
|
+use syn::{Expr, Ident};
|
|
|
|
|
|
/// Names and types of fields that might end up in a generated struct
|
|
|
pub enum StructField {
|
|
|
@@ -183,23 +184,345 @@ impl<'a> CodeGen<'a> {
|
|
|
|
|
|
/// Generate the code for the `protocol` and `protocol_witness`
|
|
|
/// functions that create the `Protocol` and `ProtocolWitness`
|
|
|
- /// structs, respectively, given a [`VarDict`] and a
|
|
|
- /// [`StatementTree`] describing the statements to be proven. The
|
|
|
- /// output components are the code for the `protocol` and
|
|
|
- /// `protocol_witness` functions, respectively. The `protocol` code
|
|
|
+ /// structs, respectively, given a slice of [`Expr`]s that will be
|
|
|
+ /// bundled into a single `LinearRelation`. The `protocol` code
|
|
|
/// must evaluate to a `Result<Protocol>` and the `protocol_witness`
|
|
|
/// code must evaluate to a `Result<ProtocolWitness>`.
|
|
|
- fn proto_witness_codegen(&self, statement: &StatementTree) -> (TokenStream, TokenStream) {
|
|
|
+ fn linear_relation_codegen(&self, exprs: &[&Expr]) -> (TokenStream, TokenStream) {
|
|
|
+ let params_var = format_ident!("{}params", self.unique_prefix);
|
|
|
+ let lr_var = format_ident!("{}lr", self.unique_prefix);
|
|
|
+ let mut allocated_vars: HashSet<Ident> = HashSet::new();
|
|
|
+ let mut param_vec_code = quote! {};
|
|
|
+ let mut witness_vec_code = quote! {};
|
|
|
+ let mut witness_code = quote! {};
|
|
|
+ let mut scalar_allocs = quote! {};
|
|
|
+ let mut element_allocs = quote! {};
|
|
|
+ let mut eq_code = quote! {};
|
|
|
+ let mut element_assigns = quote! {};
|
|
|
+
|
|
|
+ for (i, expr) in exprs.iter().enumerate() {
|
|
|
+ let eq_id = format_ident!("{}eq{}", self.unique_prefix, i + 1);
|
|
|
+ let vec_index_var = format_ident!("{}i", self.unique_prefix);
|
|
|
+ let vec_len_var = format_ident!("{}veclen{}", self.unique_prefix, i + 1);
|
|
|
+ // Ensure the `Expr` is of a type we recognize. In
|
|
|
+ // particular, it must be an assignment (C = something)
|
|
|
+ // where the variable on the left is a public Point, and the
|
|
|
+ // something on the right is an arithmetic expression that
|
|
|
+ // evaluates to a private Point. It is allowed for neither
|
|
|
+ // or both Points to be vector variables.
|
|
|
+ let Expr::Assign(syn::ExprAssign { left, right, .. }) = expr else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Unrecognized expression: {expr_str}");
|
|
|
+ };
|
|
|
+ let Expr::Path(syn::ExprPath { path, .. }) = left.as_ref() else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Left side of = is not a variable: {expr_str}");
|
|
|
+ };
|
|
|
+ let Some(left_id) = path.get_ident() else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Left side of = is not a variable: {expr_str}");
|
|
|
+ };
|
|
|
+ let Some(AExprType::Point {
|
|
|
+ is_vec: left_is_vec,
|
|
|
+ is_pub: true,
|
|
|
+ }) = self.vars.get(&left_id.to_string())
|
|
|
+ else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Left side of = is not a public point: {expr_str}");
|
|
|
+ };
|
|
|
+ // Record any vector variables we encountered in this
|
|
|
+ // expression
|
|
|
+ let mut vec_param_vars: HashSet<Ident> = HashSet::new();
|
|
|
+ let mut vec_witness_vars: HashSet<Ident> = HashSet::new();
|
|
|
+ if *left_is_vec {
|
|
|
+ vec_param_vars.insert(left_id.clone());
|
|
|
+ }
|
|
|
+ let Ok((right_type, right_tokens)) =
|
|
|
+ expr_type_tokens_id_closure(self.vars, right, &mut |id, id_type| match id_type {
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_vec: false,
|
|
|
+ is_pub: false,
|
|
|
+ ..
|
|
|
+ } => {
|
|
|
+ if allocated_vars.insert(id.clone()) {
|
|
|
+ scalar_allocs = quote! {
|
|
|
+ #scalar_allocs
|
|
|
+ let #id = #lr_var.allocate_scalar();
|
|
|
+ };
|
|
|
+ witness_code = quote! {
|
|
|
+ #witness_code
|
|
|
+ witnessvec.push(witness.#id);
|
|
|
+ };
|
|
|
+ }
|
|
|
+ Ok(quote! {#id})
|
|
|
+ }
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_vec: false,
|
|
|
+ is_pub: true,
|
|
|
+ ..
|
|
|
+ } => Ok(quote! {#params_var.#id}),
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_vec: true,
|
|
|
+ is_pub: false,
|
|
|
+ ..
|
|
|
+ } => {
|
|
|
+ vec_witness_vars.insert(id.clone());
|
|
|
+ if allocated_vars.insert(id.clone()) {
|
|
|
+ scalar_allocs = quote! {
|
|
|
+ #scalar_allocs
|
|
|
+ let #id = (0..#vec_len_var)
|
|
|
+ .map(|i| #lr_var.allocate_scalar())
|
|
|
+ .collect::<Vec<_>>();
|
|
|
+ };
|
|
|
+ witness_code = quote! {
|
|
|
+ #witness_code
|
|
|
+ witnessvec.extend(witness.#id.clone());
|
|
|
+ };
|
|
|
+ }
|
|
|
+ Ok(quote! {#id[#vec_index_var]})
|
|
|
+ }
|
|
|
+ AExprType::Scalar {
|
|
|
+ is_vec: true,
|
|
|
+ is_pub: true,
|
|
|
+ ..
|
|
|
+ } => {
|
|
|
+ vec_param_vars.insert(id.clone());
|
|
|
+ Ok(quote! {#params_var.#id[#vec_index_var]})
|
|
|
+ }
|
|
|
+ AExprType::Point { is_vec: false, .. } => {
|
|
|
+ if allocated_vars.insert(id.clone()) {
|
|
|
+ element_allocs = quote! {
|
|
|
+ #element_allocs
|
|
|
+ let #id = #lr_var.allocate_element();
|
|
|
+ };
|
|
|
+ element_assigns = quote! {
|
|
|
+ #element_assigns
|
|
|
+ #lr_var.set_element(#id, #params_var.#id);
|
|
|
+ };
|
|
|
+ }
|
|
|
+ Ok(quote! {#id})
|
|
|
+ }
|
|
|
+ AExprType::Point { is_vec: true, .. } => {
|
|
|
+ vec_param_vars.insert(id.clone());
|
|
|
+ if allocated_vars.insert(id.clone()) {
|
|
|
+ element_allocs = quote! {
|
|
|
+ #element_allocs
|
|
|
+ let #id = (0..#vec_len_var)
|
|
|
+ .map(|#vec_index_var| #lr_var.allocate_element())
|
|
|
+ .collect::<Vec<_>>();
|
|
|
+ };
|
|
|
+ element_assigns = quote! {
|
|
|
+ #element_assigns
|
|
|
+ for #vec_index_var in 0..#vec_len_var {
|
|
|
+ #lr_var.set_element(
|
|
|
+ #id[#vec_index_var],
|
|
|
+ #params_var.#id[#vec_index_var],
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+ Ok(quote! {#id[#vec_index_var]})
|
|
|
+ }
|
|
|
+ })
|
|
|
+ else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Right side of = is not a valid arithmetic expression: {expr_str}");
|
|
|
+ };
|
|
|
+ let AExprType::Point {
|
|
|
+ is_vec: right_is_vec,
|
|
|
+ is_pub: false,
|
|
|
+ } = right_type
|
|
|
+ else {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Right side of = does not evaluate to a private Point: {expr_str}");
|
|
|
+ };
|
|
|
+ if *left_is_vec != right_is_vec {
|
|
|
+ let expr_str = quote! { #expr }.to_string();
|
|
|
+ panic!("Only one side of = is a vector expression: {expr_str}");
|
|
|
+ }
|
|
|
+ let vec_param_varvec = Vec::from_iter(vec_param_vars);
|
|
|
+ let vec_witness_varvec = Vec::from_iter(vec_witness_vars);
|
|
|
+
|
|
|
+ if !vec_param_varvec.is_empty() {
|
|
|
+ let firstvar = &vec_param_varvec[0];
|
|
|
+ param_vec_code = quote! {
|
|
|
+ #param_vec_code
|
|
|
+ let #vec_len_var = #params_var.#firstvar.len();
|
|
|
+ };
|
|
|
+ for thisvar in vec_param_varvec.iter().skip(1) {
|
|
|
+ param_vec_code = quote! {
|
|
|
+ #param_vec_code
|
|
|
+ if #vec_len_var != #params_var.#thisvar.len() {
|
|
|
+ eprintln!(
|
|
|
+ "Params {} and {} must have the same length",
|
|
|
+ stringify!(#firstvar),
|
|
|
+ stringify!(#thisvar),
|
|
|
+ );
|
|
|
+ return Err(SigmaError::VerificationFailure);
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+ if !vec_witness_varvec.is_empty() {
|
|
|
+ witness_vec_code = quote! {
|
|
|
+ #witness_vec_code
|
|
|
+ let #vec_len_var = params.#firstvar.len();
|
|
|
+ };
|
|
|
+ }
|
|
|
+ for witvar in vec_witness_varvec {
|
|
|
+ witness_vec_code = quote! {
|
|
|
+ #witness_vec_code
|
|
|
+ if #vec_len_var != witness.#witvar.len() {
|
|
|
+ eprintln!(
|
|
|
+ "Params {} and {} must have the same length",
|
|
|
+ stringify!(#firstvar),
|
|
|
+ stringify!(#witvar),
|
|
|
+ );
|
|
|
+ return Err(SigmaError::VerificationFailure);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ if right_is_vec {
|
|
|
+ eq_code = quote! {
|
|
|
+ #eq_code
|
|
|
+ let #eq_id = (0..#vec_len_var)
|
|
|
+ .map(|#vec_index_var| #lr_var.allocate_eq(#right_tokens))
|
|
|
+ .collect::<Vec<_>>();
|
|
|
+ };
|
|
|
+ element_assigns = quote! {
|
|
|
+ #element_assigns
|
|
|
+ for #vec_index_var in 0..#vec_len_var {
|
|
|
+ #lr_var.set_element(
|
|
|
+ #eq_id[#vec_index_var],
|
|
|
+ #params_var.#left_id[#vec_index_var],
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+ } else {
|
|
|
+ eq_code = quote! {
|
|
|
+ #eq_code
|
|
|
+ let #eq_id = #lr_var.allocate_eq(#right_tokens);
|
|
|
+ };
|
|
|
+ element_assigns = quote! {
|
|
|
+ #element_assigns
|
|
|
+ #lr_var.set_element(#eq_id, #params_var.#left_id);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
(
|
|
|
quote! {
|
|
|
- Ok(Protocol::from(LinearRelation::<Point>::new()))
|
|
|
+ {
|
|
|
+ let mut #lr_var = LinearRelation::<Point>::new();
|
|
|
+ #param_vec_code
|
|
|
+ #scalar_allocs
|
|
|
+ #element_allocs
|
|
|
+ #eq_code
|
|
|
+ #element_assigns
|
|
|
+
|
|
|
+ Ok(Protocol::from(#lr_var))
|
|
|
+ }
|
|
|
},
|
|
|
quote! {
|
|
|
- Ok(ProtocolWitness::Simple(vec![]))
|
|
|
+ {
|
|
|
+ #witness_vec_code
|
|
|
+ let mut witnessvec = Vec::new();
|
|
|
+ #witness_code
|
|
|
+ Ok(ProtocolWitness::Simple(witnessvec))
|
|
|
+ }
|
|
|
},
|
|
|
)
|
|
|
}
|
|
|
|
|
|
+ /// Generate the code for the `protocol` and `protocol_witness`
|
|
|
+ /// functions that create the `Protocol` and `ProtocolWitness`
|
|
|
+ /// structs, respectively, given a [`StatementTree`] describing the
|
|
|
+ /// statements to be proven. The output components are the code for
|
|
|
+ /// the `protocol` and `protocol_witness` functions, respectively.
|
|
|
+ /// The `protocol` code must evaluate to a `Result<Protocol>` and
|
|
|
+ /// the `protocol_witness` code must evaluate to a
|
|
|
+ /// `Result<ProtocolWitness>`.
|
|
|
+ fn proto_witness_codegen(&self, statement: &StatementTree) -> (TokenStream, TokenStream) {
|
|
|
+ match statement {
|
|
|
+ // The StatementTree has no statements (it's just the single
|
|
|
+ // leaf "true")
|
|
|
+ StatementTree::Leaf(_) if statement.is_leaf_true() => (
|
|
|
+ quote! {
|
|
|
+ Ok(Protocol::from(LinearRelation::<Point>::new()))
|
|
|
+ },
|
|
|
+ quote! {
|
|
|
+ Ok(ProtocolWitness::Simple(vec![]))
|
|
|
+ },
|
|
|
+ ),
|
|
|
+ // The StatementTree is a single statement. Generate a
|
|
|
+ // single LinearRelation from it.
|
|
|
+ StatementTree::Leaf(leafexpr) => {
|
|
|
+ self.linear_relation_codegen(std::slice::from_ref(&leafexpr))
|
|
|
+ }
|
|
|
+ // The StatementTree is an And. Separate out the leaf
|
|
|
+ // statements, and generate a single LinearRelation from
|
|
|
+ // them. Then if there are non-leaf nodes as well, And them
|
|
|
+ // together.
|
|
|
+ StatementTree::And(stvec) => {
|
|
|
+ let mut leaves: Vec<&Expr> = Vec::new();
|
|
|
+ let mut others: Vec<&StatementTree> = Vec::new();
|
|
|
+ for st in stvec {
|
|
|
+ match st {
|
|
|
+ StatementTree::Leaf(le) => leaves.push(le),
|
|
|
+ _ => others.push(st),
|
|
|
+ }
|
|
|
+ }
|
|
|
+ let (proto_code, witness_code) = self.linear_relation_codegen(&leaves);
|
|
|
+ if others.is_empty() {
|
|
|
+ (proto_code, witness_code)
|
|
|
+ } else {
|
|
|
+ let (others_proto, others_witness): (Vec<TokenStream>, Vec<TokenStream>) =
|
|
|
+ others
|
|
|
+ .iter()
|
|
|
+ .map(|st| self.proto_witness_codegen(st))
|
|
|
+ .unzip();
|
|
|
+ (
|
|
|
+ quote! {
|
|
|
+ Ok(Protocol::And(vec![
|
|
|
+ #proto_code?,
|
|
|
+ #(#others_proto?,)*
|
|
|
+ ]))
|
|
|
+ },
|
|
|
+ quote! {
|
|
|
+ Ok(ProtocolWitness::And(vec![
|
|
|
+ #witness_code?,
|
|
|
+ #(#others_witness?,)*
|
|
|
+ ]))
|
|
|
+ },
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+ StatementTree::Or(stvec) => {
|
|
|
+ let (proto, witness): (Vec<TokenStream>, Vec<TokenStream>) = stvec
|
|
|
+ .iter()
|
|
|
+ .map(|st| self.proto_witness_codegen(st))
|
|
|
+ .unzip();
|
|
|
+ (
|
|
|
+ quote! {
|
|
|
+ Ok(Protocol::Or(vec![
|
|
|
+ #(#proto?,)*
|
|
|
+ ]))
|
|
|
+ },
|
|
|
+ // TODO: Choose the correct branch for the witness
|
|
|
+ // (currently hardcoded at 0)
|
|
|
+ quote! {
|
|
|
+ Ok(ProtocolWitness::Or(0, vec![
|
|
|
+ #(#witness?,)*
|
|
|
+ ]))
|
|
|
+ },
|
|
|
+ )
|
|
|
+ }
|
|
|
+ StatementTree::Thresh(_thresh, _stvec) => {
|
|
|
+ todo! {"Thresh not yet implemented"};
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/// Generate the code that uses the `sigma-rs` API to prove and
|
|
|
/// verify the statements in the [`CodeGen`].
|
|
|
///
|
|
|
@@ -283,14 +606,12 @@ impl<'a> CodeGen<'a> {
|
|
|
|
|
|
// Generate the function that creates the sigma-rs Protocol
|
|
|
let protocol_func = {
|
|
|
- let params_ids = pub_params_fields.field_list();
|
|
|
let params_var = format_ident!("{}params", self.unique_prefix);
|
|
|
|
|
|
quote! {
|
|
|
fn protocol(
|
|
|
#params_var: &Params,
|
|
|
) -> Result<Protocol<Point>, SigmaError> {
|
|
|
- let Params { #params_ids } = #params_var.clone();
|
|
|
#protocol_code
|
|
|
}
|
|
|
}
|
|
|
@@ -298,18 +619,11 @@ impl<'a> CodeGen<'a> {
|
|
|
|
|
|
// Generate the function that creates the sigma-rs ProtocolWitness
|
|
|
let witness_func = {
|
|
|
- let params_ids = pub_params_fields.field_list();
|
|
|
- let witness_ids = witness_fields.field_list();
|
|
|
- let params_var = format_ident!("{}params", self.unique_prefix);
|
|
|
- let witness_var = format_ident!("{}witness", self.unique_prefix);
|
|
|
-
|
|
|
quote! {
|
|
|
fn protocol_witness(
|
|
|
- #params_var: &Params,
|
|
|
- #witness_var: &Witness,
|
|
|
+ params: &Params,
|
|
|
+ witness: &Witness,
|
|
|
) -> Result<ProtocolWitness<Point>, SigmaError> {
|
|
|
- let Params { #params_ids } = #params_var.clone();
|
|
|
- let Witness { #witness_ids } = #witness_var.clone();
|
|
|
#witness_code
|
|
|
}
|
|
|
}
|