1use indexmap::IndexMap;
2use proc_macro2::{Ident, Span, TokenStream};
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_quote,
7 punctuated::Punctuated,
8 spanned::Spanned,
9 token::{Brace, Paren, PathSep},
10 FieldPat, Fields, ItemEnum, LitBool, Member, Pat, PatIdent, PatPath, PatRest, PatStruct,
11 PatTupleStruct, PatWild, Path, PathArguments, PathSegment, Token, Variant,
12};
13
14use proc_macro_crate::{crate_name, FoundCrate};
15
16pub(crate) mod kw {
17 syn::custom_keyword!(forward);
19 syn::custom_keyword!(transparent);
21 syn::custom_keyword!(splitable);
23 syn::custom_keyword!(expand);
25}
26
27#[derive(Clone)]
28pub(crate) enum ResolutionMode {
29 NoAnnotation,
31 Fatal,
33 WithExplicitBool(LitBool),
35 Forward(kw::forward, Option<Ident>),
37}
38
39impl std::fmt::Debug for ResolutionMode {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::NoAnnotation => writeln!(f, "None"),
43 Self::Fatal => writeln!(f, "Fatal"),
44 Self::WithExplicitBool(ref b) => writeln!(f, "Fatal({})", b.value()),
45 Self::Forward(_, ref ident) => writeln!(
46 f,
47 "Fatal(Forward, {})",
48 ident
49 .as_ref()
50 .map(|x| x.to_string())
51 .unwrap_or_else(|| "___".to_string())
52 ),
53 }
54 }
55}
56
57impl Default for ResolutionMode {
58 fn default() -> Self {
59 Self::Fatal
60 }
61}
62
63impl Parse for ResolutionMode {
64 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
65 let lookahead = input.lookahead1();
66
67 if lookahead.peek(kw::forward) {
68 Ok(Self::Forward(input.parse::<kw::forward>()?, None))
69 } else if lookahead.peek(LitBool) {
70 Ok(Self::WithExplicitBool(input.parse::<LitBool>()?))
71 } else {
72 Err(lookahead.error())
73 }
74 }
75}
76
77impl ToTokens for ResolutionMode {
78 fn to_tokens(&self, ts: &mut TokenStream) {
79 let trait_fatality = abs_helper_path(format_ident!("Fatality"), Span::call_site());
80 let tmp = match self {
81 Self::NoAnnotation => quote! { false },
82 Self::Fatal => quote! { true },
83 Self::WithExplicitBool(boolean) => {
84 let value = boolean.value;
85 quote! { #value }
86 }
87 Self::Forward(_, maybe_ident) => {
88 let ident = maybe_ident
89 .as_ref()
90 .expect("Forward must have ident set. qed");
91 quote! {
92 <_ as #trait_fatality >::is_fatal( #ident )
93 }
94 }
95 };
96 ts.extend(tmp)
97 }
98}
99
100fn abs_helper_path(what: impl Into<Path>, loco: Span) -> Path {
101 let what = what.into();
102 let found_crate = if cfg!(test) {
103 FoundCrate::Itself
104 } else {
105 crate_name("fatality").expect("`fatality` must be present in `Cargo.toml` for use. q.e.d")
106 };
107 let path: Path = match found_crate {
108 FoundCrate::Itself => parse_quote!( crate::#what ),
109 FoundCrate::Name(name) => {
110 let ident = Ident::new(&name, loco);
111 parse_quote! { :: #ident :: #what }
112 }
113 };
114 path
115}
116
117fn trait_fatality_impl_for_enum(
119 who: &Ident,
120 pattern_lut: &IndexMap<Variant, Pat>,
121 resolution_lut: &IndexMap<Variant, ResolutionMode>,
122) -> TokenStream {
123 let pat = pattern_lut.values();
124 let resolution = resolution_lut.values();
125
126 let fatality_trait = abs_helper_path(Ident::new("Fatality", who.span()), who.span());
127 quote! {
128 impl #fatality_trait for #who {
129 fn is_fatal(&self) -> bool {
130 match self {
131 #( #pat => #resolution, )*
132 }
133 }
134 }
135 }
136}
137
138fn trait_fatality_impl_for_struct(who: &Ident, resolution: &ResolutionMode) -> TokenStream {
140 let fatality_trait = abs_helper_path(Ident::new("Fatality", who.span()), who.span());
141 let resolution = match resolution {
142 ResolutionMode::Forward(_fwd, field) => {
143 let field = field
144 .as_ref()
145 .expect("Ident must be filled at this point. qed");
146 quote! {
147 #fatality_trait :: is_fatal( & self. #field )
148 }
149 }
150 rm => quote! {
151 #rm
152 },
153 };
154 quote! {
155 impl #fatality_trait for #who {
156 fn is_fatal(&self) -> bool {
157 #resolution
158 }
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
164#[allow(dead_code)]
165struct Transparent(kw::transparent);
166
167impl Parse for Transparent {
168 fn parse(input: ParseStream) -> syn::Result<Self> {
169 let lookahead = input.lookahead1();
170
171 if lookahead.peek(kw::transparent) {
172 Ok(Self(input.parse::<kw::transparent>()?))
173 } else {
174 Err(lookahead.error())
175 }
176 }
177}
178
179fn enum_variant_to_pattern(
186 variant: &Variant,
187 requested_resolution_mode: ResolutionMode,
188) -> Result<(Pat, ResolutionMode), syn::Error> {
189 to_pattern(
190 &variant.ident,
191 &variant.fields,
192 &variant.attrs,
193 requested_resolution_mode,
194 )
195}
196
197fn struct_to_pattern(
198 item: &syn::ItemStruct,
199 requested_resolution_mode: ResolutionMode,
200) -> Result<(Pat, ResolutionMode), syn::Error> {
201 to_pattern(
202 &item.ident,
203 &item.fields,
204 &item.attrs,
205 requested_resolution_mode,
206 )
207}
208
209fn to_pattern(
210 name: &Ident,
211 fields: &Fields,
212 attrs: &Vec<syn::Attribute>,
213 requested_resolution_mode: ResolutionMode,
214) -> Result<(Pat, ResolutionMode), syn::Error> {
215 let span = fields.span();
216 let me = PathSegment {
218 ident: Ident::new("Self", span),
219 arguments: PathArguments::None,
220 };
221 let path = Path {
222 leading_colon: None,
223 segments: Punctuated::<PathSegment, PathSep>::from_iter(vec![me, name.clone().into()]),
224 };
225 let is_transparent = attrs
226 .iter()
227 .find(|attr| {
228 if attr.path().is_ident("error") {
229 attr.parse_args::<Transparent>().is_ok()
230 } else {
231 false
232 }
233 })
234 .is_some();
235
236 let source = Ident::new("source", span);
237 let from = Ident::new("from", span);
238
239 let (pat, resolution) = match fields {
240 Fields::Named(ref fields) => {
241 let (fields, resolution) = match requested_resolution_mode {
242 ResolutionMode::Forward(fwd, _ident) => {
243 let fwd_field = if is_transparent {
244 fields.named.first().ok_or_else(|| syn::Error::new(fields.span(), "Missing inner field, must have exactly one inner field type, but requires one for `#[fatal(forward)]`."))?
245 } else {
246 fields.named.iter().find(|field| {
247 field
248 .attrs
249 .iter()
250 .find(|attr| attr.path().is_ident(&source) || attr.path().is_ident(&from))
251 .is_some()
252 }).ok_or_else(|| syn::Error::new(
253 fields.span(),
254 "No field annotated with `#[source]` or `#[from]`, but requires one for `#[fatal(forward)]`.")
255 )?
256 };
257
258 assert!(matches!(_ident, None));
259
260 let field_name = fwd_field
262 .ident
263 .clone()
264 .expect("Must have member/field name. qed");
265 let fp = FieldPat {
266 attrs: vec![],
267 member: Member::Named(field_name.clone()),
268 colon_token: None,
269 pat: Box::new(Pat::Ident(PatIdent {
270 attrs: vec![],
271 by_ref: Some(Token),
272 mutability: None,
273 ident: field_name.clone(),
274 subpat: None,
275 })),
276 };
277 (
278 Punctuated::<FieldPat, Token![,]>::from_iter([fp]),
279 ResolutionMode::Forward(fwd, fwd_field.ident.clone()),
280 )
281 }
282 rm => (Punctuated::<FieldPat, Token![,]>::new(), rm),
283 };
284
285 (
286 Pat::Struct(PatStruct {
287 attrs: vec![],
288 path,
289 brace_token: Brace(span),
290 fields,
291 qself: None,
292 rest: Some(PatRest {
293 attrs: vec![],
294 dot2_token: Token,
295 }),
296 }),
297 resolution,
298 )
299 }
300 Fields::Unnamed(ref fields) => {
301 let (mut field_pats, resolution) = if let ResolutionMode::Forward(keyword, _) =
302 requested_resolution_mode
303 {
304 let fwd_idx = if is_transparent {
306 if fields.unnamed.iter().count() != 1 {
308 return Err(
309 syn::Error::new(
310 fields.span(),
311 "Must have exactly one parameter when annotated with `#[transparent]` annotated field for `forward` with `fatality`",
312 )
313 );
314 }
315 0_usize
316 } else {
317 fields
318 .unnamed
319 .iter()
320 .enumerate()
321 .find_map(|(idx, field)| {
322 field
323 .attrs
324 .iter()
325 .find(|attr| {
326 attr.path().is_ident(&source) || attr.path().is_ident(&from)
327 })
328 .map(|_attr| idx)
329 })
330 .ok_or_else(|| {
331 syn::Error::new(
332 span,
333 "Must have a `#[source]` or `#[from]` annotated field for `#[fatal(forward)]`",
334 )
335 })?
336 };
337
338 let pat_capture_ident =
339 unnamed_fields_variant_pattern_constructor_binding_name(fwd_idx);
340 let mut field_pats = std::iter::repeat(Pat::Wild(PatWild {
342 attrs: vec![],
343 underscore_token: Token,
344 }))
345 .take(fwd_idx)
346 .collect::<Vec<_>>();
347
348 field_pats.push(Pat::Ident(PatIdent {
349 attrs: vec![],
350 by_ref: Some(Token),
351 mutability: None,
352 ident: pat_capture_ident.clone(),
353 subpat: None,
354 }));
355
356 (
357 field_pats,
358 ResolutionMode::Forward(keyword, Some(pat_capture_ident)),
359 )
360 } else {
361 (vec![], requested_resolution_mode)
362 };
363 field_pats.push(Pat::Rest(PatRest {
364 attrs: vec![],
365 dot2_token: Token,
366 }));
367 (
368 Pat::TupleStruct(PatTupleStruct {
369 attrs: vec![],
370 path,
371 qself: None,
372 paren_token: Paren(span),
373 elems: Punctuated::<Pat, Token![,]>::from_iter(field_pats),
374 }),
375 resolution,
376 )
377 }
378 Fields::Unit => {
379 if let ResolutionMode::Forward(..) = requested_resolution_mode {
380 return Err(syn::Error::new(
381 span,
382 "Cannot forward to a unit item variant",
383 ));
384 }
385 (
386 Pat::Path(PatPath {
387 attrs: vec![],
388 qself: None,
389 path,
390 }),
391 requested_resolution_mode,
392 )
393 }
394 };
395 assert!(
396 !matches!(resolution, ResolutionMode::Forward(_kw, None)),
397 "We always set the resolution identifier _right here_. qed"
398 );
399
400 Ok((pat, resolution))
401}
402
403fn unnamed_fields_variant_pattern_constructor_binding_name(ith: usize) -> Ident {
404 Ident::new(format!("arg_{}", ith).as_str(), Span::call_site())
405}
406
407#[derive(Hash, Debug)]
408struct VariantPattern(Variant);
409
410impl ToTokens for VariantPattern {
411 fn to_tokens(&self, ts: &mut TokenStream) {
412 let variant_name = &self.0.ident;
413 let variant_fields = &self.0.fields;
414
415 match variant_fields {
416 Fields::Unit => {
417 ts.extend(quote! { #variant_name });
418 }
419 Fields::Unnamed(unnamed) => {
420 let pattern = unnamed
421 .unnamed
422 .iter()
423 .enumerate()
424 .map(|(ith, _field)| {
425 Pat::Ident(PatIdent {
426 attrs: vec![],
427 by_ref: None,
428 mutability: None,
429 ident: unnamed_fields_variant_pattern_constructor_binding_name(ith),
430 subpat: None,
431 })
432 })
433 .collect::<Punctuated<Pat, Token![,]>>();
434 ts.extend(quote! { #variant_name(#pattern) });
435 }
436 Fields::Named(named) => {
437 let pattern = named
438 .named
439 .iter()
440 .map(|field| {
441 Pat::Ident(PatIdent {
442 attrs: vec![],
443 by_ref: None,
444 mutability: None,
445 ident: field.ident.clone().expect("Named field has a name. qed"),
446 subpat: None,
447 })
448 })
449 .collect::<Punctuated<Pat, Token![,]>>();
450 ts.extend(quote! { #variant_name{ #pattern } });
451 }
452 };
453 }
454}
455
456#[derive(Hash, Debug)]
458struct VariantConstructor(Variant);
459
460impl ToTokens for VariantConstructor {
461 fn to_tokens(&self, ts: &mut TokenStream) {
462 let variant_name = &self.0.ident;
463 let variant_fields = &self.0.fields;
464 ts.extend(match variant_fields {
465 Fields::Unit => quote! { #variant_name },
466 Fields::Unnamed(unnamed) => {
467 let constructor = unnamed
468 .unnamed
469 .iter()
470 .enumerate()
471 .map(|(ith, _field)| {
472 unnamed_fields_variant_pattern_constructor_binding_name(ith)
473 })
474 .collect::<Punctuated<Ident, Token![,]>>();
475 quote! { #variant_name (#constructor) }
476 }
477 Fields::Named(named) => {
478 let constructor = named
479 .named
480 .iter()
481 .map(|field| {
482 field
483 .ident
484 .clone()
485 .expect("Named must have named fields. qed")
486 })
487 .collect::<Punctuated<Ident, Token![,]>>();
488 quote! { #variant_name { #constructor } }
489 }
490 });
491 }
492}
493
494fn trait_split_impl(
499 attr: Attr,
500 original: ItemEnum,
501 resolution_lut: &IndexMap<Variant, ResolutionMode>,
502 jfyi_variants: &[Variant],
503 fatal_variants: &[Variant],
504) -> Result<TokenStream, syn::Error> {
505 if let Attr::Empty = attr {
506 return Ok(TokenStream::new());
507 }
508
509 let span = original.span();
510
511 let thiserror: Path = parse_quote!(thiserror::Error);
512 let thiserror = abs_helper_path(thiserror, span);
513
514 let split_trait = abs_helper_path(Ident::new("Split", span), span);
515
516 let original_ident = original.ident.clone();
517
518 let fatal_ident = Ident::new(format!("Fatal{}", original_ident).as_str(), span);
521 let mut fatal = original.clone();
522 fatal.variants = fatal_variants.iter().cloned().collect();
523 fatal.ident = fatal_ident.clone();
524
525 let jfyi_ident = Ident::new(format!("Jfyi{}", original_ident).as_str(), span);
527 let mut jfyi = original.clone();
528 jfyi.variants = jfyi_variants.iter().cloned().collect();
529 jfyi.ident = jfyi_ident.clone();
530
531 let fatal_patterns = fatal_variants
532 .iter()
533 .map(|variant| VariantPattern(variant.clone()))
534 .collect::<Vec<_>>();
535 let jfyi_patterns = jfyi_variants
536 .iter()
537 .map(|variant| VariantPattern(variant.clone()))
538 .collect::<Vec<_>>();
539
540 let fatal_constructors = fatal_variants
541 .iter()
542 .map(|variant| VariantConstructor(variant.clone()))
543 .collect::<Vec<_>>();
544 let jfyi_constructors = jfyi_variants
545 .iter()
546 .map(|variant| VariantConstructor(variant.clone()))
547 .collect::<Vec<_>>();
548
549 let mut ts = TokenStream::new();
550
551 ts.extend(quote! {
552 impl ::std::convert::From< #fatal_ident> for #original_ident {
553 fn from(fatal: #fatal_ident) -> Self {
554 match fatal {
555 #( #fatal_ident :: #fatal_patterns => Self:: #fatal_constructors, )*
557 }
558 }
559 }
560
561 impl ::std::convert::From< #jfyi_ident> for #original_ident {
562 fn from(jfyi: #jfyi_ident) -> Self {
563 match jfyi {
564 #( #jfyi_ident :: #jfyi_patterns => Self:: #jfyi_constructors, )*
566 }
567 }
568 }
569
570 #[derive(#thiserror, Debug)]
571 #fatal
572
573 #[derive(#thiserror, Debug)]
574 #jfyi
575 });
576
577 let trait_fatality = abs_helper_path(format_ident!("Fatality"), Span::call_site());
579
580 let fatal_patterns_w_if_maybe = fatal_variants
582 .iter()
583 .map(|variant| {
584 let pat = VariantPattern(variant.clone());
585 if let Some(ResolutionMode::Forward(_fwd_kw, ident)) = resolution_lut.get(variant) {
586 let ident = ident
587 .as_ref()
588 .expect("Forward mode must have an ident at this point. qed");
589 quote! { #pat if < _ as #trait_fatality >::is_fatal( & #ident ) }
590 } else {
591 pat.into_token_stream()
592 }
593 })
594 .collect::<Vec<_>>();
595
596 let jfyi_patterns_w_if_maybe = jfyi_variants
597 .iter()
598 .map(|variant| {
599 let pat = VariantPattern(variant.clone());
600 assert!(
601 !matches!(resolution_lut.get(variant), None),
602 "Cannot be annotated as fatal when in the JFYI slice. qed"
603 );
604 pat.into_token_stream()
605 })
606 .collect::<Vec<_>>();
607
608 let split_trait_impl = quote! {
609
610 impl #split_trait for #original_ident {
611 type Fatal = #fatal_ident;
612 type Jfyi = #jfyi_ident;
613
614 fn split(self) -> ::std::result::Result<Self::Jfyi, Self::Fatal> {
615 match self {
616 #( Self :: #fatal_patterns_w_if_maybe => Err(#fatal_ident :: #fatal_constructors), )*
618 #( Self :: #jfyi_patterns_w_if_maybe => Ok(#jfyi_ident :: #jfyi_constructors), )*
620 }
623 }
624 }
625 };
626 ts.extend(split_trait_impl);
627
628 Ok(ts)
629}
630
631pub(crate) fn fatality_struct_gen(
632 attr: Attr,
633 mut item: syn::ItemStruct,
634) -> syn::Result<proc_macro2::TokenStream> {
635 let name = item.ident.clone();
636 let mut resolution_mode = ResolutionMode::NoAnnotation;
637
638 while let Some(idx) = item.attrs.iter().enumerate().find_map(|(idx, attr)| {
640 if attr.path().is_ident("fatal") {
641 Some(idx)
642 } else {
643 None
644 }
645 }) {
646 let attr = item.attrs.remove(idx);
647 if let Ok(_) = attr.meta.require_path_only() {
648 resolution_mode = ResolutionMode::Fatal;
650 } else {
651 resolution_mode = attr.parse_args::<ResolutionMode>()?;
653 }
654 }
655
656 let (_pat, resolution_mode) = struct_to_pattern(&item, resolution_mode)?;
657
658 let thiserror: Path = parse_quote!(thiserror::Error);
660 let thiserror = abs_helper_path(thiserror, name.span());
661
662 let original_struct = quote! {
663 #[derive( #thiserror, Debug)]
664 #item
665 };
666
667 let mut ts = TokenStream::new();
668 ts.extend(original_struct);
669 ts.extend(trait_fatality_impl_for_struct(
670 &item.ident,
671 &resolution_mode,
672 ));
673
674 if let Attr::Splitable(kw) = attr {
675 return Err(syn::Error::new(
676 kw.span(),
677 "Cannot use `splitable` on a `struct`",
678 ));
679 }
680
681 Ok(ts)
682}
683
684pub(crate) fn fatality_enum_gen(attr: Attr, item: ItemEnum) -> syn::Result<TokenStream> {
685 let name = item.ident.clone();
686 let mut original = item.clone();
687
688 let mut resolution_lut = IndexMap::new();
689 let mut pattern_lut = IndexMap::new();
690
691 let mut jfyi_variants = Vec::new();
692 let mut fatal_variants = Vec::new();
693
694 for variant in original.variants.iter_mut() {
697 let mut resolution_mode = ResolutionMode::NoAnnotation;
698
699 while let Some(idx) = variant.attrs.iter().enumerate().find_map(|(idx, attr)| {
701 if attr.path().is_ident("fatal") {
702 Some(idx)
703 } else {
704 None
705 }
706 }) {
707 let attr = variant.attrs.remove(idx);
708 if let Ok(_) = attr.meta.require_path_only() {
709 resolution_mode = ResolutionMode::Fatal;
710 } else {
711 resolution_mode = attr.parse_args::<ResolutionMode>()?;
712 }
713 }
714
715 let (pattern, resolution_mode) = enum_variant_to_pattern(variant, resolution_mode)?;
719 match resolution_mode {
720 ResolutionMode::Forward(_, None) => unreachable!("Must have an ident. qed"),
721 ResolutionMode::Forward(_, ref _ident) => {
722 jfyi_variants.push(variant.clone());
723 fatal_variants.push(variant.clone());
724 }
725 ResolutionMode::WithExplicitBool(ref b) if b.value() => {
726 fatal_variants.push(variant.clone())
727 }
728 ResolutionMode::WithExplicitBool(_) => jfyi_variants.push(variant.clone()),
729 ResolutionMode::Fatal => fatal_variants.push(variant.clone()),
730 ResolutionMode::NoAnnotation => jfyi_variants.push(variant.clone()),
731 }
732 resolution_lut.insert(variant.clone(), resolution_mode);
733 pattern_lut.insert(variant.clone(), pattern);
734 }
735
736 let thiserror: Path = parse_quote!(thiserror::Error);
738 let thiserror = abs_helper_path(thiserror, name.span());
739
740 let original_enum = quote! {
741 #[derive( #thiserror, Debug)]
742 #original
743 };
744
745 let mut ts = TokenStream::new();
746 ts.extend(original_enum);
747 ts.extend(trait_fatality_impl_for_enum(
748 &original.ident,
749 &pattern_lut,
750 &resolution_lut,
751 ));
752
753 if let Attr::Splitable(_kw) = attr {
754 ts.extend(trait_split_impl(
755 attr,
756 original,
757 &resolution_lut,
758 &jfyi_variants,
759 &fatal_variants,
760 ));
761 }
762
763 Ok(ts)
764}
765
766#[derive(Clone, Copy, Debug)]
769pub(crate) enum Attr {
770 Splitable(kw::splitable),
771 Empty,
772}
773
774impl Parse for Attr {
775 fn parse(content: ParseStream) -> syn::Result<Self> {
776 let lookahead = content.lookahead1();
777
778 if lookahead.peek(kw::splitable) {
779 Ok(Self::Splitable(content.parse::<kw::splitable>()?))
780 } else if content.is_empty() {
781 Ok(Self::Empty)
782 } else {
783 Err(lookahead.error())
784 }
785 }
786}