codegen.rs 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. //! A module for generating the code that uses the `sigma-rs` crate API.
  2. //!
  3. //! If that crate gets its own macro interface, it can use this module
  4. //! directly.
  5. use super::combiners::StatementTree;
  6. use super::types::{AExprType, VarDict};
  7. use proc_macro2::TokenStream;
  8. use quote::{format_ident, quote, ToTokens};
  9. use syn::Ident;
  10. /// Names and types of fields that might end up in a generated struct
  11. pub enum StructField {
  12. Scalar(Ident),
  13. VecScalar(Ident),
  14. Point(Ident),
  15. VecPoint(Ident),
  16. }
  17. /// A list of StructField items
  18. #[derive(Default)]
  19. pub struct StructFieldList {
  20. pub fields: Vec<StructField>,
  21. }
  22. impl StructFieldList {
  23. pub fn push_scalar(&mut self, s: &Ident) {
  24. self.fields.push(StructField::Scalar(s.clone()));
  25. }
  26. pub fn push_vecscalar(&mut self, s: &Ident) {
  27. self.fields.push(StructField::VecScalar(s.clone()));
  28. }
  29. pub fn push_point(&mut self, s: &Ident) {
  30. self.fields.push(StructField::Point(s.clone()));
  31. }
  32. pub fn push_vecpoint(&mut self, s: &Ident) {
  33. self.fields.push(StructField::VecPoint(s.clone()));
  34. }
  35. pub fn push_vars(&mut self, vars: &VarDict, for_params: bool) {
  36. for (id, ti) in vars.iter() {
  37. match ti {
  38. AExprType::Scalar { is_pub, is_vec, .. } => {
  39. if *is_pub == for_params {
  40. if *is_vec {
  41. self.push_vecscalar(&format_ident!("{}", id))
  42. } else {
  43. self.push_scalar(&format_ident!("{}", id))
  44. }
  45. }
  46. }
  47. AExprType::Point { is_vec, .. } => {
  48. if for_params {
  49. if *is_vec {
  50. self.push_vecpoint(&format_ident!("{}", id))
  51. } else {
  52. self.push_point(&format_ident!("{}", id))
  53. }
  54. }
  55. }
  56. }
  57. }
  58. }
  59. #[cfg(feature = "dump")]
  60. /// Output a ToTokens of the contents of the fields
  61. pub fn dump(&self) -> impl ToTokens {
  62. let dump_chunks = self.fields.iter().map(|f| match f {
  63. StructField::Scalar(id) => quote! {
  64. print!(" {}: ", stringify!(#id));
  65. Params::dump_scalar(&self.#id);
  66. println!("");
  67. },
  68. StructField::VecScalar(id) => quote! {
  69. print!(" {}: [", stringify!(#id));
  70. for s in self.#id.iter() {
  71. print!(" ");
  72. Params::dump_scalar(s);
  73. println!(",");
  74. }
  75. println!(" ]");
  76. },
  77. StructField::Point(id) => quote! {
  78. print!(" {}: ", stringify!(#id));
  79. Params::dump_point(&self.#id);
  80. println!("");
  81. },
  82. StructField::VecPoint(id) => quote! {
  83. print!(" {}: [", stringify!(#id));
  84. for p in self.#id.iter() {
  85. print!(" ");
  86. Params::dump_point(p);
  87. println!(",");
  88. }
  89. println!(" ]");
  90. },
  91. });
  92. quote! { #(#dump_chunks)* }
  93. }
  94. /// Output a ToTokens of the fields as they would appear in a struct
  95. /// definition
  96. pub fn field_decls(&self) -> impl ToTokens {
  97. let decls = self.fields.iter().map(|f| match f {
  98. StructField::Scalar(id) => quote! {
  99. pub #id: Scalar,
  100. },
  101. StructField::VecScalar(id) => quote! {
  102. pub #id: Vec<Scalar>,
  103. },
  104. StructField::Point(id) => quote! {
  105. pub #id: Point,
  106. },
  107. StructField::VecPoint(id) => quote! {
  108. pub #id: Vec<Point>,
  109. },
  110. });
  111. quote! { #(#decls)* }
  112. }
  113. /// Output a ToTokens of the list of fields
  114. pub fn field_list(&self) -> impl ToTokens {
  115. let field_ids = self.fields.iter().map(|f| match f {
  116. StructField::Scalar(id) => quote! {
  117. #id,
  118. },
  119. StructField::VecScalar(id) => quote! {
  120. #id,
  121. },
  122. StructField::Point(id) => quote! {
  123. #id,
  124. },
  125. StructField::VecPoint(id) => quote! {
  126. #id,
  127. },
  128. });
  129. quote! { #(#field_ids)* }
  130. }
  131. }
  132. /// The main struct to handle code generation using the `sigma-rs` API.
  133. pub struct CodeGen<'a> {
  134. proto_name: Ident,
  135. group_name: Ident,
  136. vars: &'a VarDict,
  137. statements: &'a mut StatementTree,
  138. }
  139. impl<'a> CodeGen<'a> {
  140. pub fn new(
  141. proto_name: Ident,
  142. group_name: Ident,
  143. vars: &'a VarDict,
  144. statements: &'a mut StatementTree,
  145. ) -> Self {
  146. Self {
  147. proto_name,
  148. group_name,
  149. vars,
  150. statements,
  151. }
  152. }
  153. /// Generate the code that uses the `sigma-rs` API to prove and
  154. /// verify the statements in the [`CodeGen`].
  155. ///
  156. /// `emit_prover` and `emit_verifier` are as in
  157. /// [`sigma_compiler_core`](super::super::sigma_compiler_core).
  158. pub fn generate(&mut self, emit_prover: bool, emit_verifier: bool) -> TokenStream {
  159. let proto_name = &self.proto_name;
  160. let group_name = &self.group_name;
  161. let group_types = quote! {
  162. use super::group;
  163. pub type Scalar = <super::#group_name as group::Group>::Scalar;
  164. pub type Point = super::#group_name;
  165. };
  166. // Flatten nested "And"s into single "And"s
  167. self.statements.flatten_ands();
  168. println!("Statements = {{");
  169. self.statements.dump();
  170. println!("}}");
  171. let mut pub_params_fields = StructFieldList::default();
  172. pub_params_fields.push_vars(self.vars, true);
  173. // Generate the public params struct definition
  174. let params_def = {
  175. let decls = pub_params_fields.field_decls();
  176. #[cfg(feature = "dump")]
  177. let dump_impl = {
  178. let dump_chunks = pub_params_fields.dump();
  179. quote! {
  180. impl Params {
  181. fn dump_scalar(s: &Scalar) {
  182. let bytes: &[u8] = &s.to_repr();
  183. print!("{:02x?}", bytes);
  184. }
  185. fn dump_point(p: &Point) {
  186. let bytes: &[u8] = &p.to_bytes();
  187. print!("{:02x?}", bytes);
  188. }
  189. pub fn dump(&self) {
  190. #dump_chunks
  191. }
  192. }
  193. }
  194. };
  195. #[cfg(not(feature = "dump"))]
  196. let dump_impl = {
  197. quote! {}
  198. };
  199. quote! {
  200. #[derive(Clone)]
  201. pub struct Params {
  202. #decls
  203. }
  204. #dump_impl
  205. }
  206. };
  207. let mut witness_fields = StructFieldList::default();
  208. witness_fields.push_vars(self.vars, false);
  209. // Generate the witness struct definition
  210. let witness_def = if emit_prover {
  211. let decls = witness_fields.field_decls();
  212. quote! {
  213. #[derive(Clone)]
  214. pub struct Witness {
  215. #decls
  216. }
  217. }
  218. } else {
  219. quote! {}
  220. };
  221. // Generate the (currently dummy) prove function
  222. let prove_func = if emit_prover {
  223. let dumper = if cfg!(feature = "dump") {
  224. quote! {
  225. println!("prover params = {{");
  226. params.dump();
  227. println!("}}");
  228. }
  229. } else {
  230. quote! {}
  231. };
  232. let params_ids = pub_params_fields.field_list();
  233. let witness_ids = witness_fields.field_list();
  234. quote! {
  235. pub fn prove(
  236. params: &Params,
  237. witness: &Witness,
  238. session_id: &[u8],
  239. rng: &mut (impl CryptoRng + RngCore),
  240. ) -> Result<Vec<u8>, SigmaError> {
  241. #dumper
  242. let Params { #params_ids } = params.clone();
  243. let Witness { #witness_ids } = witness.clone();
  244. Ok(Vec::<u8>::default())
  245. }
  246. }
  247. } else {
  248. quote! {}
  249. };
  250. // Generate the (currently dummy) verify function
  251. let verify_func = if emit_verifier {
  252. let dumper = if cfg!(feature = "dump") {
  253. quote! {
  254. println!("verifier params = {{");
  255. params.dump();
  256. println!("}}");
  257. }
  258. } else {
  259. quote! {}
  260. };
  261. let params_ids = pub_params_fields.field_list();
  262. quote! {
  263. pub fn verify(
  264. params: &Params,
  265. proof: &[u8],
  266. session_id: &[u8],
  267. ) -> Result<(), SigmaError> {
  268. #dumper
  269. let Params { #params_ids } = params.clone();
  270. Ok(())
  271. }
  272. }
  273. } else {
  274. quote! {}
  275. };
  276. // Output the generated module for this protocol
  277. let dump_use = if cfg!(feature = "dump") {
  278. quote! {
  279. use group::GroupEncoding;
  280. }
  281. } else {
  282. quote! {}
  283. };
  284. quote! {
  285. #[allow(non_snake_case)]
  286. pub mod #proto_name {
  287. use sigma_compiler::rand::{CryptoRng, RngCore};
  288. use sigma_compiler::group::ff::PrimeField;
  289. use sigma_compiler::sigma_rs::errors::Error as SigmaError;
  290. #dump_use
  291. #group_types
  292. #params_def
  293. #witness_def
  294. #prove_func
  295. #verify_func
  296. }
  297. }
  298. }
  299. }