codegen.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. //! A module for generating the code produced by this macro. This code
  2. //! will interact with the underlying `sigma` macro.
  3. use super::sigma::codegen::{StructField, StructFieldList};
  4. use super::syntax::*;
  5. use proc_macro2::TokenStream;
  6. use quote::{format_ident, quote};
  7. #[cfg(test)]
  8. use syn::parse_quote;
  9. use syn::Ident;
  10. /// The main struct to handle code generation for this macro.
  11. ///
  12. /// Initialize a [`CodeGen`] with the [`SigmaCompSpec`] you get by
  13. /// parsing the macro input. Pass it to the various transformations and
  14. /// statement handlers, which will both update the code it will
  15. /// generate, and modify the [`SigmaCompSpec`]. Then at the end, call
  16. /// [`CodeGen::generate`] with the modified [`SigmaCompSpec`] to generate the
  17. /// code output by this macro.
  18. pub struct CodeGen {
  19. /// The protocol name specified in the `sigma_compiler` macro
  20. /// invocation
  21. proto_name: Ident,
  22. /// The group name specified in the `sigma_compiler` macro
  23. /// invocation
  24. group_name: Ident,
  25. /// The variables that were explicitly listed in the
  26. /// `sigma_compiler` macro invocation
  27. vars: TaggedVarDict,
  28. /// A prefix that does not appear at the beginning of any variable
  29. /// name in `vars`
  30. unique_prefix: String,
  31. /// Variables (not necessarily appearing in `vars`, since they may
  32. /// be generated by the sigma_compiler itself) that the prover needs
  33. /// to send to the verifier along with the proof. These could
  34. /// include commitments to bits in range proofs, for example.
  35. sent_instance: StructFieldList,
  36. /// Extra code that will be emitted in the `prove` function
  37. prove_code: TokenStream,
  38. /// Extra code that will be emitted in the `verify` function
  39. verify_code: TokenStream,
  40. /// Extra code that will be emitted in the `verify` function before
  41. /// the `sent_instance` are deserialized. This is where the verifier
  42. /// sets the lengths of vector variables in the `sent_instance`.
  43. verify_pre_instance_code: TokenStream,
  44. }
  45. impl CodeGen {
  46. /// Find a prefix that does not appear at the beginning of any
  47. /// variable name in `vars`
  48. fn unique_prefix(vars: &TaggedVarDict) -> String {
  49. 'outer: for tag in 0usize.. {
  50. let try_prefix = if tag == 0 {
  51. "gen__".to_string()
  52. } else {
  53. format!("gen{}__", tag)
  54. };
  55. for v in vars.keys() {
  56. if v.starts_with(&try_prefix) {
  57. continue 'outer;
  58. }
  59. }
  60. return try_prefix;
  61. }
  62. // The compiler complains if this isn't here, but it will only
  63. // get hit if vars contains at least usize::MAX entries, which
  64. // isn't going to happen.
  65. String::new()
  66. }
  67. /// Create a new [`CodeGen`] given the [`SigmaCompSpec`] you get by
  68. /// parsing the macro input.
  69. pub fn new(spec: &SigmaCompSpec) -> Self {
  70. Self {
  71. proto_name: spec.proto_name.clone(),
  72. group_name: spec.group_name.clone(),
  73. vars: spec.vars.clone(),
  74. unique_prefix: Self::unique_prefix(&spec.vars),
  75. sent_instance: StructFieldList::default(),
  76. prove_code: quote! {},
  77. verify_code: quote! {},
  78. verify_pre_instance_code: quote! {},
  79. }
  80. }
  81. #[cfg(test)]
  82. /// Create an empty [`CodeGen`]. Primarily useful in testing.
  83. pub fn new_empty() -> Self {
  84. Self {
  85. proto_name: parse_quote! { proto },
  86. group_name: parse_quote! { G },
  87. vars: TaggedVarDict::default(),
  88. unique_prefix: "gen__".into(),
  89. sent_instance: StructFieldList::default(),
  90. prove_code: quote! {},
  91. verify_code: quote! {},
  92. verify_pre_instance_code: quote! {},
  93. }
  94. }
  95. /// Create a new generated private Scalar variable to put in the
  96. /// Witness.
  97. ///
  98. /// If you call this, you should also call
  99. /// [`prove_append`](Self::prove_append) with code like `quote!{ let
  100. /// #id = ... }` where `id` is the [`struct@Ident`] returned from
  101. /// this function.
  102. pub fn gen_scalar(
  103. &self,
  104. vars: &mut TaggedVarDict,
  105. base: &Ident,
  106. is_rand: bool,
  107. is_vec: bool,
  108. ) -> Ident {
  109. let id = format_ident!("{}{}", self.unique_prefix, base);
  110. vars.insert(
  111. id.to_string(),
  112. TaggedIdent::Scalar(TaggedScalar {
  113. id: id.clone(),
  114. is_pub: false,
  115. is_rand,
  116. is_vec,
  117. }),
  118. );
  119. id
  120. }
  121. /// Create a new public Point variable to put in the Instance,
  122. /// optionally marking it as needing to be sent from the prover to
  123. /// the verifier along with the proof.
  124. ///
  125. /// If you call this function, you should also call
  126. /// [`prove_append`](Self::prove_append) with code like `quote!{ let
  127. /// #id = ... }` where `id` is the [`struct@Ident`] returned from
  128. /// this function. If `is_vec` is `true`, then you should also call
  129. /// [`verify_pre_instance_append`](Self::verify_pre_instance_append)
  130. /// with code like `quote!{ let mut #id = Vec::<Point>::new();
  131. /// #id.resize(#len, Point::default()); }` where `len` is the number
  132. /// of elements you expect to have in the vector (computed at
  133. /// runtime, perhaps based on the values of public parameters).
  134. pub fn gen_point(
  135. &mut self,
  136. vars: &mut TaggedVarDict,
  137. base: &Ident,
  138. is_vec: bool,
  139. send_to_verifier: bool,
  140. ) -> Ident {
  141. let id = format_ident!("{}{}", self.unique_prefix, base);
  142. vars.insert(
  143. id.to_string(),
  144. TaggedIdent::Point(TaggedPoint {
  145. id: id.clone(),
  146. is_cind: false,
  147. is_const: false,
  148. is_vec,
  149. }),
  150. );
  151. if send_to_verifier {
  152. if is_vec {
  153. self.sent_instance.push_vecpoint(&id);
  154. } else {
  155. self.sent_instance.push_point(&id);
  156. }
  157. }
  158. id
  159. }
  160. /// Create a new identifier, using the unique prefix
  161. pub fn gen_ident(&self, base: &Ident) -> Ident {
  162. format_ident!("{}{}", self.unique_prefix, base)
  163. }
  164. /// Append some code to the generated `prove` function
  165. pub fn prove_append(&mut self, code: TokenStream) {
  166. let prove_code = &self.prove_code;
  167. self.prove_code = quote! {
  168. #prove_code
  169. #code
  170. };
  171. }
  172. /// Append some code to the generated `verify` function
  173. pub fn verify_append(&mut self, code: TokenStream) {
  174. let verify_code = &self.verify_code;
  175. self.verify_code = quote! {
  176. #verify_code
  177. #code
  178. };
  179. }
  180. /// Append some code to the generated `verify` function to be run
  181. /// before the `sent_instance` are deserialized
  182. pub fn verify_pre_instance_append(&mut self, code: TokenStream) {
  183. let verify_pre_instance_code = &self.verify_pre_instance_code;
  184. self.verify_pre_instance_code = quote! {
  185. #verify_pre_instance_code
  186. #code
  187. };
  188. }
  189. /// Append some code to both the generated `prove` and `verify`
  190. /// functions
  191. pub fn prove_verify_append(&mut self, code: TokenStream) {
  192. let prove_code = &self.prove_code;
  193. self.prove_code = quote! {
  194. #prove_code
  195. #code
  196. };
  197. let verify_code = &self.verify_code;
  198. self.verify_code = quote! {
  199. #verify_code
  200. #code
  201. };
  202. }
  203. /// Append some code to both the generated `prove` and `verify`
  204. /// functions, the latter to be run before the `sent_instance` are
  205. /// deserialized
  206. pub fn prove_verify_pre_instance_append(&mut self, code: TokenStream) {
  207. let prove_code = &self.prove_code;
  208. self.prove_code = quote! {
  209. #prove_code
  210. #code
  211. };
  212. let verify_pre_instance_code = &self.verify_pre_instance_code;
  213. self.verify_pre_instance_code = quote! {
  214. #verify_pre_instance_code
  215. #code
  216. };
  217. }
  218. /// Extract (as [`String`]s) the code inserted by
  219. /// [`prove_append`](Self::prove_append),
  220. /// [`verify_append`](Self::verify_append), and
  221. /// [`verify_pre_instance_append`](Self::verify_pre_instance_append).
  222. pub fn code_strings(&self) -> (String, String, String) {
  223. (
  224. self.prove_code.to_string(),
  225. self.verify_code.to_string(),
  226. self.verify_pre_instance_code.to_string(),
  227. )
  228. }
  229. /// Generate the code to be output by this macro.
  230. ///
  231. /// `emit_prover` and `emit_verifier` are as in
  232. /// [`sigma_compiler_core`](super::sigma_compiler_core).
  233. pub fn generate(
  234. &self,
  235. spec: &mut SigmaCompSpec,
  236. emit_prover: bool,
  237. emit_verifier: bool,
  238. ) -> TokenStream {
  239. let proto_name = &self.proto_name;
  240. let group_name = &self.group_name;
  241. let group_types = quote! {
  242. use super::group;
  243. pub type Scalar = <super::#group_name as group::Group>::Scalar;
  244. pub type Point = super::#group_name;
  245. };
  246. // vardict contains the variables that were defined in the macro
  247. // call to [`sigma_compiler`]
  248. let vardict = taggedvardict_to_vardict(&self.vars);
  249. // sigma_proofs_vardict contains the variables that we are passing
  250. // to sigma_proofs. We may have removed some via substitution, and
  251. // we may have added some when compiling statements like range
  252. // assertions into underlying linear combination assertions.
  253. let sigma_proofs_vardict = taggedvardict_to_vardict(&spec.vars);
  254. // Generate the code that uses the underlying sigma_proofs API
  255. let mut sigma_proofs_codegen = super::sigma::codegen::CodeGen::new(
  256. format_ident!("sigma"),
  257. format_ident!("Point"),
  258. &sigma_proofs_vardict,
  259. &mut spec.statements,
  260. );
  261. let sigma_proofs_code = sigma_proofs_codegen.generate(emit_prover, emit_verifier);
  262. let mut pub_instance_fields = StructFieldList::default();
  263. pub_instance_fields.push_vars(&vardict, true);
  264. let mut witness_fields = StructFieldList::default();
  265. witness_fields.push_vars(&vardict, false);
  266. let mut sigma_proofs_instance_fields = StructFieldList::default();
  267. sigma_proofs_instance_fields.push_vars(&sigma_proofs_vardict, true);
  268. let mut sigma_proofs_witness_fields = StructFieldList::default();
  269. sigma_proofs_witness_fields.push_vars(&sigma_proofs_vardict, false);
  270. // Generate the public instance struct definition
  271. let instance_def = {
  272. let decls = pub_instance_fields.field_decls();
  273. #[cfg(feature = "dump")]
  274. let dump_impl = {
  275. let dump_chunks = pub_instance_fields.dump();
  276. quote! {
  277. impl Instance {
  278. fn dump_scalar(s: &Scalar) {
  279. let bytes: &[u8] = &s.to_repr();
  280. print!("{:02x?}", bytes);
  281. }
  282. fn dump_point(p: &Point) {
  283. let bytes: &[u8] = &p.to_bytes();
  284. print!("{:02x?}", bytes);
  285. }
  286. pub fn dump(&self) {
  287. #dump_chunks
  288. }
  289. }
  290. }
  291. };
  292. #[cfg(not(feature = "dump"))]
  293. let dump_impl = {
  294. quote! {}
  295. };
  296. quote! {
  297. #[derive(Clone)]
  298. pub struct Instance {
  299. #decls
  300. }
  301. #dump_impl
  302. }
  303. };
  304. // Generate the witness struct definition
  305. let witness_def = if emit_prover {
  306. let decls = witness_fields.field_decls();
  307. quote! {
  308. #[derive(Clone)]
  309. pub struct Witness {
  310. #decls
  311. }
  312. }
  313. } else {
  314. quote! {}
  315. };
  316. // Generate the prove function
  317. let prove_func = if emit_prover {
  318. let instance_ids = pub_instance_fields.field_list();
  319. let witness_ids = witness_fields.field_list();
  320. let sigma_proofs_instance_ids = sigma_proofs_instance_fields.field_list();
  321. let sigma_proofs_witness_ids = sigma_proofs_witness_fields.field_list();
  322. let prove_code = &self.prove_code;
  323. let codegen_instance_var = format_ident!("{}sigma_instance", self.unique_prefix);
  324. let codegen_witness_var = format_ident!("{}sigma_witness", self.unique_prefix);
  325. let instance_var = format_ident!("{}instance", self.unique_prefix);
  326. let witness_var = format_ident!("{}witness", self.unique_prefix);
  327. let rng_var = format_ident!("{}rng", self.unique_prefix);
  328. let proof_var = format_ident!("{}proof", self.unique_prefix);
  329. let sid_var = format_ident!("{}session_id", self.unique_prefix);
  330. let sent_instance_code = {
  331. let chunks = self.sent_instance.fields.iter().map(|sf| match sf {
  332. StructField::Point(id) => quote! {
  333. #proof_var.extend(sigma_proofs::serialization::serialize_elements(
  334. std::slice::from_ref(&#codegen_instance_var.#id)
  335. ));
  336. },
  337. StructField::VecPoint(id) => quote! {
  338. #proof_var.extend(sigma_proofs::serialization::serialize_elements(
  339. &#codegen_instance_var.#id
  340. ));
  341. },
  342. _ => quote! {},
  343. });
  344. quote! { #(#chunks)* }
  345. };
  346. let dumper = if cfg!(feature = "dump") {
  347. quote! {
  348. println!("prover instance = {{");
  349. #instance_var.dump();
  350. println!("}}");
  351. }
  352. } else {
  353. quote! {}
  354. };
  355. quote! {
  356. pub fn prove(
  357. #instance_var: &Instance,
  358. #witness_var: &Witness,
  359. #sid_var: &[u8],
  360. #rng_var: &mut (impl CryptoRng + RngCore),
  361. ) -> Result<Vec<u8>, SigmaError> {
  362. #dumper
  363. let Instance { #instance_ids } = #instance_var.clone();
  364. let Witness { #witness_ids } = #witness_var.clone();
  365. #prove_code
  366. let mut #proof_var = Vec::<u8>::new();
  367. let #codegen_instance_var = sigma::Instance {
  368. #sigma_proofs_instance_ids
  369. };
  370. let #codegen_witness_var = sigma::Witness {
  371. #sigma_proofs_witness_ids
  372. };
  373. #sent_instance_code
  374. #proof_var.extend(
  375. sigma::prove(
  376. &#codegen_instance_var,
  377. &#codegen_witness_var,
  378. #sid_var,
  379. #rng_var,
  380. )?
  381. );
  382. Ok(#proof_var)
  383. }
  384. }
  385. } else {
  386. quote! {}
  387. };
  388. // Generate the verify function
  389. let verify_func = if emit_verifier {
  390. let instance_ids = pub_instance_fields.field_list();
  391. let sigma_proofs_instance_ids = sigma_proofs_instance_fields.field_list();
  392. let verify_pre_instance_code = &self.verify_pre_instance_code;
  393. let verify_code = &self.verify_code;
  394. let codegen_instance_var = format_ident!("{}sigma_instance", self.unique_prefix);
  395. let element_len_var = format_ident!("{}element_len", self.unique_prefix);
  396. let offset_var = format_ident!("{}proof_offset", self.unique_prefix);
  397. let instance_var = format_ident!("{}instance", self.unique_prefix);
  398. let proof_var = format_ident!("{}proof", self.unique_prefix);
  399. let sid_var = format_ident!("{}session_id", self.unique_prefix);
  400. let sent_instance_code = {
  401. let element_len_code = if self.sent_instance.fields.is_empty() {
  402. quote! {}
  403. } else {
  404. quote! {
  405. let #element_len_var =
  406. <Point as group::GroupEncoding>::Repr::default().as_ref().len();
  407. }
  408. };
  409. let chunks = self.sent_instance.fields.iter().map(|sf| match sf {
  410. StructField::Point(id) => quote! {
  411. let #id: Point = sigma_proofs::serialization::deserialize_elements(
  412. &#proof_var[#offset_var..],
  413. 1,
  414. ).ok_or(SigmaError::VerificationFailure)?[0];
  415. #offset_var += #element_len_var;
  416. },
  417. StructField::VecPoint(id) => quote! {
  418. #id = sigma_proofs::serialization::deserialize_elements(
  419. &#proof_var[#offset_var..],
  420. #id.len(),
  421. ).ok_or(SigmaError::VerificationFailure)?;
  422. #offset_var += #element_len_var * #id.len();
  423. },
  424. _ => quote! {},
  425. });
  426. quote! {
  427. let mut #offset_var = 0usize;
  428. #element_len_code
  429. #(#chunks)*
  430. }
  431. };
  432. let dumper = if cfg!(feature = "dump") {
  433. quote! {
  434. println!("verifier instance = {{");
  435. #instance_var.dump();
  436. println!("}}");
  437. }
  438. } else {
  439. quote! {}
  440. };
  441. quote! {
  442. pub fn verify(
  443. #instance_var: &Instance,
  444. #proof_var: &[u8],
  445. #sid_var: &[u8],
  446. ) -> Result<(), SigmaError> {
  447. #dumper
  448. let Instance { #instance_ids } = #instance_var.clone();
  449. #verify_pre_instance_code
  450. #sent_instance_code
  451. #verify_code
  452. let #codegen_instance_var = sigma::Instance {
  453. #sigma_proofs_instance_ids
  454. };
  455. sigma::verify(
  456. &#codegen_instance_var,
  457. &#proof_var[#offset_var..],
  458. #sid_var,
  459. )
  460. }
  461. }
  462. } else {
  463. quote! {}
  464. };
  465. // Output the generated module for this protocol
  466. let dump_use = if cfg!(feature = "dump") {
  467. quote! {
  468. use group::GroupEncoding;
  469. }
  470. } else {
  471. quote! {}
  472. };
  473. quote! {
  474. #[allow(non_snake_case)]
  475. pub mod #proto_name {
  476. use sigma_compiler::group::Group;
  477. use sigma_compiler::group::ff::{Field, PrimeField};
  478. use sigma_compiler::group::ff::derive::subtle::ConditionallySelectable;
  479. use sigma_compiler::rand::{CryptoRng, RngCore};
  480. use sigma_compiler::sigma_proofs;
  481. use sigma_compiler::sigma_proofs::errors::Error as SigmaError;
  482. use sigma_compiler::vecutils::*;
  483. use std::ops::Neg;
  484. #dump_use
  485. #group_types
  486. #sigma_proofs_code
  487. #instance_def
  488. #witness_def
  489. #prove_func
  490. #verify_func
  491. }
  492. }
  493. }
  494. }