frame_support_procedural/
match_and_insert.rs1use proc_macro2::{Group, Span, TokenStream, TokenTree};
21use std::iter::once;
22use syn::spanned::Spanned;
23
24mod keyword {
25 syn::custom_keyword!(target);
26 syn::custom_keyword!(pattern);
27 syn::custom_keyword!(tokens);
28}
29
30pub fn match_and_insert(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 let MatchAndInsertDef { pattern, tokens, target } =
32 syn::parse_macro_input!(input as MatchAndInsertDef);
33
34 match expand_in_stream(&pattern, &mut Some(tokens), target) {
35 Ok(stream) => stream.into(),
36 Err(err) => err.to_compile_error().into(),
37 }
38}
39
40struct MatchAndInsertDef {
41 target: TokenStream,
43 pattern: Vec<TokenTree>,
46 tokens: TokenStream,
48}
49
50impl syn::parse::Parse for MatchAndInsertDef {
51 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
52 let mut target;
53 input.parse::<keyword::target>()?;
54 input.parse::<syn::Token![=]>()?;
55 let _replace_with_bracket: syn::token::Bracket = syn::bracketed!(target in input);
56 let _replace_with_brace: syn::token::Brace = syn::braced!(target in target);
57 let target = target.parse()?;
58
59 let mut pattern;
60 input.parse::<keyword::pattern>()?;
61 input.parse::<syn::Token![=]>()?;
62 let _replace_with_bracket: syn::token::Bracket = syn::bracketed!(pattern in input);
63 let _replace_with_brace: syn::token::Brace = syn::braced!(pattern in pattern);
64 let pattern = pattern.parse::<TokenStream>()?.into_iter().collect::<Vec<TokenTree>>();
65
66 if let Some(t) = pattern.iter().find(|t| matches!(t, TokenTree::Group(_))) {
67 return Err(syn::Error::new(t.span(), "Unexpected group token tree"))
68 }
69 if let Some(t) = pattern.iter().find(|t| matches!(t, TokenTree::Literal(_))) {
70 return Err(syn::Error::new(t.span(), "Unexpected literal token tree"))
71 }
72
73 if pattern.is_empty() {
74 return Err(syn::Error::new(Span::call_site(), "empty match pattern is invalid"))
75 }
76
77 let mut tokens;
78 input.parse::<keyword::tokens>()?;
79 input.parse::<syn::Token![=]>()?;
80 let _replace_with_bracket: syn::token::Bracket = syn::bracketed!(tokens in input);
81 let _replace_with_brace: syn::token::Brace = syn::braced!(tokens in tokens);
82 let tokens = tokens.parse()?;
83
84 Ok(Self { tokens, pattern, target })
85 }
86}
87
88fn expand_in_stream(
92 pattern: &[TokenTree],
93 tokens: &mut Option<TokenStream>,
94 stream: TokenStream,
95) -> syn::Result<TokenStream> {
96 assert!(
97 tokens.is_some(),
98 "`tokens` must be some, Option is used because `tokens` is used only once"
99 );
100 assert!(
101 !pattern.is_empty(),
102 "`pattern` must not be empty, otherwise there is nothing to match against"
103 );
104
105 let stream_span = stream.span();
106 let mut stream = stream.into_iter();
107 let mut extended = TokenStream::new();
108 let mut match_cursor = 0;
109
110 while let Some(token) = stream.next() {
111 match token {
112 TokenTree::Group(group) => {
113 match_cursor = 0;
114 let group_stream = group.stream();
115 match expand_in_stream(pattern, tokens, group_stream) {
116 Ok(s) => {
117 extended.extend(once(TokenTree::Group(Group::new(group.delimiter(), s))));
118 extended.extend(stream);
119 return Ok(extended)
120 },
121 Err(_) => {
122 extended.extend(once(TokenTree::Group(group)));
123 },
124 }
125 },
126 other => {
127 advance_match_cursor(&other, pattern, &mut match_cursor);
128
129 extended.extend(once(other));
130
131 if match_cursor == pattern.len() {
132 extended
133 .extend(once(tokens.take().expect("tokens is used to replace only once")));
134 extended.extend(stream);
135 return Ok(extended)
136 }
137 },
138 }
139 }
140 let msg = format!("Cannot find pattern `{:?}` in given token stream", pattern);
142 Err(syn::Error::new(stream_span, msg))
143}
144
145fn advance_match_cursor(other: &TokenTree, pattern: &[TokenTree], match_cursor: &mut usize) {
146 use TokenTree::{Ident, Punct};
147
148 let does_match_other_pattern = match (other, &pattern[*match_cursor]) {
149 (Ident(i1), Ident(i2)) => i1 == i2,
150 (Punct(p1), Punct(p2)) => p1.as_char() == p2.as_char(),
151 _ => false,
152 };
153
154 if does_match_other_pattern {
155 *match_cursor += 1;
156 } else {
157 *match_cursor = 0;
158 }
159}