tracing_attributes/
expand.rs

1use std::iter;
2
3use proc_macro2::TokenStream;
4use quote::{quote, quote_spanned, ToTokens};
5use syn::visit_mut::VisitMut;
6use syn::{
7    punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg,
8    Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType,
9    Path, ReturnType, Signature, Stmt, Token, Type, TypePath,
10};
11
12use crate::{
13    attr::{Field, Fields, FormatMode, InstrumentArgs, Level},
14    MaybeItemFn, MaybeItemFnRef,
15};
16
17/// Given an existing function, generate an instrumented version of that function
18pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
19    input: MaybeItemFnRef<'a, B>,
20    args: InstrumentArgs,
21    instrumented_function_name: &str,
22    self_type: Option<&TypePath>,
23) -> proc_macro2::TokenStream {
24    // these are needed ahead of time, as ItemFn contains the function body _and_
25    // isn't representable inside a quote!/quote_spanned! macro
26    // (Syn's ToTokens isn't implemented for ItemFn)
27    let MaybeItemFnRef {
28        outer_attrs,
29        inner_attrs,
30        vis,
31        sig,
32        block,
33    } = input;
34
35    let Signature {
36        output,
37        inputs: params,
38        unsafety,
39        asyncness,
40        constness,
41        abi,
42        ident,
43        generics:
44            syn::Generics {
45                params: gen_params,
46                where_clause,
47                ..
48            },
49        ..
50    } = sig;
51
52    let warnings = args.warnings();
53
54    let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output {
55        (erase_impl_trait(return_type), return_type.span())
56    } else {
57        // Point at function name if we don't have an explicit return type
58        (syn::parse_quote! { () }, ident.span())
59    };
60    // Install a fake return statement as the first thing in the function
61    // body, so that we eagerly infer that the return type is what we
62    // declared in the async fn signature.
63    // The `#[allow(..)]` is given because the return statement is
64    // unreachable, but does affect inference, so it needs to be written
65    // exactly that way for it to do its magic.
66    let fake_return_edge = quote_spanned! {return_span=>
67        #[allow(
68            unknown_lints, unreachable_code, clippy::diverging_sub_expression,
69            clippy::let_unit_value, clippy::unreachable, clippy::let_with_type_underscore,
70            clippy::empty_loop
71        )]
72        if false {
73            let __tracing_attr_fake_return: #return_type = loop {};
74            return __tracing_attr_fake_return;
75        }
76    };
77    let block = quote! {
78        {
79            #fake_return_edge
80            #block
81        }
82    };
83
84    let body = gen_block(
85        &block,
86        params,
87        asyncness.is_some(),
88        args,
89        instrumented_function_name,
90        self_type,
91    );
92
93    quote!(
94        #(#outer_attrs) *
95        #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output
96        #where_clause
97        {
98            #(#inner_attrs) *
99            #warnings
100            #body
101        }
102    )
103}
104
105/// Instrument a block
106fn gen_block<B: ToTokens>(
107    block: &B,
108    params: &Punctuated<FnArg, Token![,]>,
109    async_context: bool,
110    mut args: InstrumentArgs,
111    instrumented_function_name: &str,
112    self_type: Option<&TypePath>,
113) -> proc_macro2::TokenStream {
114    // generate the span's name
115    let span_name = args
116        // did the user override the span's name?
117        .name
118        .as_ref()
119        .map(|name| quote!(#name))
120        .unwrap_or_else(|| quote!(#instrumented_function_name));
121
122    let args_level = args.level();
123    let level = args_level.clone();
124
125    let follows_from = args.follows_from.iter();
126    let follows_from = quote! {
127        #(for cause in #follows_from {
128            __tracing_attr_span.follows_from(cause);
129        })*
130    };
131
132    // generate this inside a closure, so we can return early on errors.
133    let span = (|| {
134        // Pull out the arguments-to-be-skipped first, so we can filter results
135        // below.
136        let param_names: Vec<(Ident, (Ident, RecordType))> = params
137            .clone()
138            .into_iter()
139            .flat_map(|param| match param {
140                FnArg::Typed(PatType { pat, ty, .. }) => {
141                    param_names(*pat, RecordType::parse_from_ty(&ty))
142                }
143                FnArg::Receiver(_) => Box::new(iter::once((
144                    Ident::new("self", param.span()),
145                    RecordType::Debug,
146                ))),
147            })
148            // Little dance with new (user-exposed) names and old (internal)
149            // names of identifiers. That way, we could do the following
150            // even though async_trait (<=0.1.43) rewrites "self" as "_self":
151            // ```
152            // #[async_trait]
153            // impl Foo for FooImpl {
154            //     #[instrument(skip(self))]
155            //     async fn foo(&self, v: usize) {}
156            // }
157            // ```
158            .map(|(x, record_type)| {
159                // if we are inside a function generated by async-trait <=0.1.43, we need to
160                // take care to rewrite "_self" as "self" for 'user convenience'
161                if self_type.is_some() && x == "_self" {
162                    (Ident::new("self", x.span()), (x, record_type))
163                } else {
164                    (x.clone(), (x, record_type))
165                }
166            })
167            .collect();
168
169        for skip in &args.skips {
170            if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
171                return quote_spanned! {skip.span()=>
172                    compile_error!("attempting to skip non-existent parameter")
173                };
174            }
175        }
176
177        let target = args.target();
178
179        let parent = args.parent.iter();
180
181        // filter out skipped fields
182        let quoted_fields: Vec<_> = param_names
183            .iter()
184            .filter(|(param, _)| {
185                if args.skip_all || args.skips.contains(param) {
186                    return false;
187                }
188
189                // If any parameters have the same name as a custom field, skip
190                // and allow them to be formatted by the custom field.
191                if let Some(ref fields) = args.fields {
192                    fields.0.iter().all(|Field { ref name, .. }| {
193                        let first = name.first();
194                        first != name.last() || !first.iter().any(|name| name == &param)
195                    })
196                } else {
197                    true
198                }
199            })
200            .map(|(user_name, (real_name, record_type))| match record_type {
201                RecordType::Value => quote!(#user_name = #real_name),
202                RecordType::Debug => quote!(#user_name = tracing::field::debug(&#real_name)),
203            })
204            .collect();
205
206        // replace every use of a variable with its original name
207        if let Some(Fields(ref mut fields)) = args.fields {
208            let mut replacer = IdentAndTypesRenamer {
209                idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(),
210                types: Vec::new(),
211            };
212
213            // when async-trait <=0.1.43 is in use, replace instances
214            // of the "Self" type inside the fields values
215            if let Some(self_type) = self_type {
216                replacer.types.push(("Self", self_type.clone()));
217            }
218
219            for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
220                syn::visit_mut::visit_expr_mut(&mut replacer, e);
221            }
222        }
223
224        let custom_fields = &args.fields;
225
226        quote!(tracing::span!(
227            target: #target,
228            #(parent: #parent,)*
229            #level,
230            #span_name,
231            #(#quoted_fields,)*
232            #custom_fields
233
234        ))
235    })();
236
237    let target = args.target();
238
239    let err_event = match args.err_args {
240        Some(event_args) => {
241            let level_tokens = event_args.level(Level::Error);
242            match event_args.mode {
243                FormatMode::Default | FormatMode::Display => Some(quote!(
244                    tracing::event!(target: #target, #level_tokens, error = %e)
245                )),
246                FormatMode::Debug => Some(quote!(
247                    tracing::event!(target: #target, #level_tokens, error = ?e)
248                )),
249            }
250        }
251        _ => None,
252    };
253
254    let ret_event = match args.ret_args {
255        Some(event_args) => {
256            let level_tokens = event_args.level(args_level);
257            match event_args.mode {
258                FormatMode::Display => Some(quote!(
259                    tracing::event!(target: #target, #level_tokens, return = %x)
260                )),
261                FormatMode::Default | FormatMode::Debug => Some(quote!(
262                    tracing::event!(target: #target, #level_tokens, return = ?x)
263                )),
264            }
265        }
266        _ => None,
267    };
268
269    // Generate the instrumented function body.
270    // If the function is an `async fn`, this will wrap it in an async block,
271    // which is `instrument`ed using `tracing-futures`. Otherwise, this will
272    // enter the span and then perform the rest of the body.
273    // If `err` is in args, instrument any resulting `Err`s.
274    // If `ret` is in args, instrument any resulting `Ok`s when the function
275    // returns `Result`s, otherwise instrument any resulting values.
276    if async_context {
277        let mk_fut = match (err_event, ret_event) {
278            (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=>
279                async move {
280                    match async move #block.await {
281                        #[allow(clippy::unit_arg)]
282                        Ok(x) => {
283                            #ret_event;
284                            Ok(x)
285                        },
286                        Err(e) => {
287                            #err_event;
288                            Err(e)
289                        }
290                    }
291                }
292            ),
293            (Some(err_event), None) => quote_spanned!(block.span()=>
294                async move {
295                    match async move #block.await {
296                        #[allow(clippy::unit_arg)]
297                        Ok(x) => Ok(x),
298                        Err(e) => {
299                            #err_event;
300                            Err(e)
301                        }
302                    }
303                }
304            ),
305            (None, Some(ret_event)) => quote_spanned!(block.span()=>
306                async move {
307                    let x = async move #block.await;
308                    #ret_event;
309                    x
310                }
311            ),
312            (None, None) => quote_spanned!(block.span()=>
313                async move #block
314            ),
315        };
316
317        return quote!(
318            let __tracing_attr_span = #span;
319            let __tracing_instrument_future = #mk_fut;
320            if !__tracing_attr_span.is_disabled() {
321                #follows_from
322                tracing::Instrument::instrument(
323                    __tracing_instrument_future,
324                    __tracing_attr_span
325                )
326                .await
327            } else {
328                __tracing_instrument_future.await
329            }
330        );
331    }
332
333    let span = quote!(
334        // These variables are left uninitialized and initialized only
335        // if the tracing level is statically enabled at this point.
336        // While the tracing level is also checked at span creation
337        // time, that will still create a dummy span, and a dummy guard
338        // and drop the dummy guard later. By lazily initializing these
339        // variables, Rust will generate a drop flag for them and thus
340        // only drop the guard if it was created. This creates code that
341        // is very straightforward for LLVM to optimize out if the tracing
342        // level is statically disabled, while not causing any performance
343        // regression in case the level is enabled.
344        let __tracing_attr_span;
345        let __tracing_attr_guard;
346        if tracing::level_enabled!(#level) || tracing::if_log_enabled!(#level, {true} else {false}) {
347            __tracing_attr_span = #span;
348            #follows_from
349            __tracing_attr_guard = __tracing_attr_span.enter();
350        }
351    );
352
353    match (err_event, ret_event) {
354        (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=>
355            #span
356            #[allow(clippy::redundant_closure_call)]
357            match (move || #block)() {
358                #[allow(clippy::unit_arg)]
359                Ok(x) => {
360                    #ret_event;
361                    Ok(x)
362                },
363                Err(e) => {
364                    #err_event;
365                    Err(e)
366                }
367            }
368        },
369        (Some(err_event), None) => quote_spanned!(block.span()=>
370            #span
371            #[allow(clippy::redundant_closure_call)]
372            match (move || #block)() {
373                #[allow(clippy::unit_arg)]
374                Ok(x) => Ok(x),
375                Err(e) => {
376                    #err_event;
377                    Err(e)
378                }
379            }
380        ),
381        (None, Some(ret_event)) => quote_spanned!(block.span()=>
382            #span
383            #[allow(clippy::redundant_closure_call)]
384            let x = (move || #block)();
385            #ret_event;
386            x
387        ),
388        (None, None) => quote_spanned!(block.span() =>
389            // Because `quote` produces a stream of tokens _without_ whitespace, the
390            // `if` and the block will appear directly next to each other. This
391            // generates a clippy lint about suspicious `if/else` formatting.
392            // Therefore, suppress the lint inside the generated code...
393            #[allow(clippy::suspicious_else_formatting)]
394            {
395                #span
396                // ...but turn the lint back on inside the function body.
397                #[warn(clippy::suspicious_else_formatting)]
398                #block
399            }
400        ),
401    }
402}
403
404/// Indicates whether a field should be recorded as `Value` or `Debug`.
405enum RecordType {
406    /// The field should be recorded using its `Value` implementation.
407    Value,
408    /// The field should be recorded using `tracing::field::debug()`.
409    Debug,
410}
411
412impl RecordType {
413    /// Array of primitive types which should be recorded as [RecordType::Value].
414    const TYPES_FOR_VALUE: &'static [&'static str] = &[
415        "bool",
416        "str",
417        "u8",
418        "i8",
419        "u16",
420        "i16",
421        "u32",
422        "i32",
423        "u64",
424        "i64",
425        "f32",
426        "f64",
427        "usize",
428        "isize",
429        "NonZeroU8",
430        "NonZeroI8",
431        "NonZeroU16",
432        "NonZeroI16",
433        "NonZeroU32",
434        "NonZeroI32",
435        "NonZeroU64",
436        "NonZeroI64",
437        "NonZeroUsize",
438        "NonZeroIsize",
439        "Wrapping",
440    ];
441
442    /// Parse `RecordType` from [Type] by looking up
443    /// the [RecordType::TYPES_FOR_VALUE] array.
444    fn parse_from_ty(ty: &Type) -> Self {
445        match ty {
446            Type::Path(TypePath { path, .. })
447                if path
448                    .segments
449                    .iter()
450                    .last()
451                    .map(|path_segment| {
452                        let ident = path_segment.ident.to_string();
453                        Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident)
454                    })
455                    .unwrap_or(false) =>
456            {
457                RecordType::Value
458            }
459            Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem),
460            _ => RecordType::Debug,
461        }
462    }
463}
464
465fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> {
466    match pat {
467        Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))),
468        Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type),
469        // We can't get the concrete type of fields in the struct/tuple
470        // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`.
471        // Therefore, the struct/tuple patterns in the arguments will just
472        // always be recorded as `RecordType::Debug`.
473        Pat::Struct(PatStruct { fields, .. }) => Box::new(
474            fields
475                .into_iter()
476                .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)),
477        ),
478        Pat::Tuple(PatTuple { elems, .. }) => Box::new(
479            elems
480                .into_iter()
481                .flat_map(|p| param_names(p, RecordType::Debug)),
482        ),
483        Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new(
484            elems
485                .into_iter()
486                .flat_map(|p| param_names(p, RecordType::Debug)),
487        ),
488
489        // The above *should* cover all cases of irrefutable patterns,
490        // but we purposefully don't do any funny business here
491        // (such as panicking) because that would obscure rustc's
492        // much more informative error message.
493        _ => Box::new(iter::empty()),
494    }
495}
496
497/// The specific async code pattern that was detected
498enum AsyncKind<'a> {
499    /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
500    /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
501    Function(&'a ItemFn),
502    /// A function returning an async (move) block, optionally `Box::pin`-ed,
503    /// as generated by `async-trait >= 0.1.44`:
504    /// `Box::pin(async move { ... })`
505    Async {
506        async_expr: &'a ExprAsync,
507        pinned_box: bool,
508    },
509}
510
511pub(crate) struct AsyncInfo<'block> {
512    // statement that must be patched
513    source_stmt: &'block Stmt,
514    kind: AsyncKind<'block>,
515    self_type: Option<TypePath>,
516    input: &'block ItemFn,
517}
518
519impl<'block> AsyncInfo<'block> {
520    /// Get the AST of the inner function we need to hook, if it looks like a
521    /// manual future implementation.
522    ///
523    /// When we are given a function that returns a (pinned) future containing the
524    /// user logic, it is that (pinned) future that needs to be instrumented.
525    /// Were we to instrument its parent, we would only collect information
526    /// regarding the allocation of that future, and not its own span of execution.
527    ///
528    /// We inspect the block of the function to find if it matches any of the
529    /// following patterns:
530    ///
531    /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
532    ///   `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
533    ///
534    /// - A function returning an async (move) block, optionally `Box::pin`-ed,
535    ///   as generated by `async-trait >= 0.1.44`:
536    ///   `Box::pin(async move { ... })`
537    ///
538    /// We the return the statement that must be instrumented, along with some
539    /// other information.
540    /// 'gen_body' will then be able to use that information to instrument the
541    /// proper function/future.
542    ///
543    /// (this follows the approach suggested in
544    /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
545    pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> {
546        // are we in an async context? If yes, this isn't a manual async-like pattern
547        if input.sig.asyncness.is_some() {
548            return None;
549        }
550
551        let block = &input.block;
552
553        // list of async functions declared inside the block
554        let inside_funs = block.stmts.iter().filter_map(|stmt| {
555            if let Stmt::Item(Item::Fn(fun)) = &stmt {
556                // If the function is async, this is a candidate
557                if fun.sig.asyncness.is_some() {
558                    return Some((stmt, fun));
559                }
560            }
561            None
562        });
563
564        // last expression of the block: it determines the return value of the
565        // block, this is quite likely a `Box::pin` statement or an async block
566        let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
567            if let Stmt::Expr(expr, _semi) = stmt {
568                Some((stmt, expr))
569            } else {
570                None
571            }
572        })?;
573
574        // is the last expression an async block?
575        if let Expr::Async(async_expr) = last_expr {
576            return Some(AsyncInfo {
577                source_stmt: last_expr_stmt,
578                kind: AsyncKind::Async {
579                    async_expr,
580                    pinned_box: false,
581                },
582                self_type: None,
583                input,
584            });
585        }
586
587        // is the last expression a function call?
588        let (outside_func, outside_args) = match last_expr {
589            Expr::Call(ExprCall { func, args, .. }) => (func, args),
590            _ => return None,
591        };
592
593        // is it a call to `Box::pin()`?
594        let path = match outside_func.as_ref() {
595            Expr::Path(path) => &path.path,
596            _ => return None,
597        };
598        if !path_to_string(path).ends_with("Box::pin") {
599            return None;
600        }
601
602        // Does the call take an argument? If it doesn't,
603        // it's not gonna compile anyway, but that's no reason
604        // to (try to) perform an out of bounds access
605        if outside_args.is_empty() {
606            return None;
607        }
608
609        // Is the argument to Box::pin an async block that
610        // captures its arguments?
611        if let Expr::Async(async_expr) = &outside_args[0] {
612            return Some(AsyncInfo {
613                source_stmt: last_expr_stmt,
614                kind: AsyncKind::Async {
615                    async_expr,
616                    pinned_box: true,
617                },
618                self_type: None,
619                input,
620            });
621        }
622
623        // Is the argument to Box::pin a function call itself?
624        let func = match &outside_args[0] {
625            Expr::Call(ExprCall { func, .. }) => func,
626            _ => return None,
627        };
628
629        // "stringify" the path of the function called
630        let func_name = match **func {
631            Expr::Path(ref func_path) => path_to_string(&func_path.path),
632            _ => return None,
633        };
634
635        // Was that function defined inside of the current block?
636        // If so, retrieve the statement where it was declared and the function itself
637        let (stmt_func_declaration, func) = inside_funs
638            .into_iter()
639            .find(|(_, fun)| fun.sig.ident == func_name)?;
640
641        // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
642        // parameter type) with the type of "_self"
643        let mut self_type = None;
644        for arg in &func.sig.inputs {
645            if let FnArg::Typed(ty) = arg {
646                if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
647                    if ident == "_self" {
648                        let mut ty = *ty.ty.clone();
649                        // extract the inner type if the argument is "&self" or "&mut self"
650                        if let Type::Reference(syn::TypeReference { elem, .. }) = ty {
651                            ty = *elem;
652                        }
653
654                        if let Type::Path(tp) = ty {
655                            self_type = Some(tp);
656                            break;
657                        }
658                    }
659                }
660            }
661        }
662
663        Some(AsyncInfo {
664            source_stmt: stmt_func_declaration,
665            kind: AsyncKind::Function(func),
666            self_type,
667            input,
668        })
669    }
670
671    pub(crate) fn gen_async(
672        self,
673        args: InstrumentArgs,
674        instrumented_function_name: &str,
675    ) -> Result<proc_macro::TokenStream, syn::Error> {
676        // let's rewrite some statements!
677        let mut out_stmts: Vec<TokenStream> = self
678            .input
679            .block
680            .stmts
681            .iter()
682            .map(|stmt| stmt.to_token_stream())
683            .collect();
684
685        if let Some((iter, _stmt)) = self
686            .input
687            .block
688            .stmts
689            .iter()
690            .enumerate()
691            .find(|(_iter, stmt)| *stmt == self.source_stmt)
692        {
693            // instrument the future by rewriting the corresponding statement
694            out_stmts[iter] = match self.kind {
695                // `Box::pin(immediately_invoked_async_fn())`
696                AsyncKind::Function(fun) => {
697                    let fun = MaybeItemFn::from(fun.clone());
698                    gen_function(
699                        fun.as_ref(),
700                        args,
701                        instrumented_function_name,
702                        self.self_type.as_ref(),
703                    )
704                }
705                // `async move { ... }`, optionally pinned
706                AsyncKind::Async {
707                    async_expr,
708                    pinned_box,
709                } => {
710                    let instrumented_block = gen_block(
711                        &async_expr.block,
712                        &self.input.sig.inputs,
713                        true,
714                        args,
715                        instrumented_function_name,
716                        None,
717                    );
718                    let async_attrs = &async_expr.attrs;
719                    if pinned_box {
720                        quote! {
721                            Box::pin(#(#async_attrs) * async move { #instrumented_block })
722                        }
723                    } else {
724                        quote! {
725                            #(#async_attrs) * async move { #instrumented_block }
726                        }
727                    }
728                }
729            };
730        }
731
732        let vis = &self.input.vis;
733        let sig = &self.input.sig;
734        let attrs = &self.input.attrs;
735        Ok(quote!(
736            #(#attrs) *
737            #vis #sig {
738                #(#out_stmts) *
739            }
740        )
741        .into())
742    }
743}
744
745// Return a path as a String
746fn path_to_string(path: &Path) -> String {
747    use std::fmt::Write;
748    // some heuristic to prevent too many allocations
749    let mut res = String::with_capacity(path.segments.len() * 5);
750    for i in 0..path.segments.len() {
751        write!(&mut res, "{}", path.segments[i].ident)
752            .expect("writing to a String should never fail");
753        if i < path.segments.len() - 1 {
754            res.push_str("::");
755        }
756    }
757    res
758}
759
760/// A visitor struct to replace idents and types in some piece
761/// of code (e.g. the "self" and "Self" tokens in user-supplied
762/// fields expressions when the function is generated by an old
763/// version of async-trait).
764struct IdentAndTypesRenamer<'a> {
765    types: Vec<(&'a str, TypePath)>,
766    idents: Vec<(Ident, Ident)>,
767}
768
769impl<'a> VisitMut for IdentAndTypesRenamer<'a> {
770    // we deliberately compare strings because we want to ignore the spans
771    // If we apply clippy's lint, the behavior changes
772    #[allow(clippy::cmp_owned)]
773    fn visit_ident_mut(&mut self, id: &mut Ident) {
774        for (old_ident, new_ident) in &self.idents {
775            if id.to_string() == old_ident.to_string() {
776                *id = new_ident.clone();
777            }
778        }
779    }
780
781    fn visit_type_mut(&mut self, ty: &mut Type) {
782        for (type_name, new_type) in &self.types {
783            if let Type::Path(TypePath { path, .. }) = ty {
784                if path_to_string(path) == *type_name {
785                    *ty = Type::Path(new_type.clone());
786                }
787            }
788        }
789    }
790}
791
792// A visitor struct that replace an async block by its patched version
793struct AsyncTraitBlockReplacer<'a> {
794    block: &'a Block,
795    patched_block: Block,
796}
797
798impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> {
799    fn visit_block_mut(&mut self, i: &mut Block) {
800        if i == self.block {
801            *i = self.patched_block.clone();
802        }
803    }
804}
805
806// Replaces any `impl Trait` with `_` so it can be used as the type in
807// a `let` statement's LHS.
808struct ImplTraitEraser;
809
810impl VisitMut for ImplTraitEraser {
811    fn visit_type_mut(&mut self, t: &mut Type) {
812        if let Type::ImplTrait(..) = t {
813            *t = syn::TypeInfer {
814                underscore_token: Token![_](t.span()),
815            }
816            .into();
817        } else {
818            syn::visit_mut::visit_type_mut(self, t);
819        }
820    }
821}
822
823fn erase_impl_trait(ty: &Type) -> Type {
824    let mut ty = ty.clone();
825    ImplTraitEraser.visit_type_mut(&mut ty);
826    ty
827}