tokio_macros/
entry.rs

1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::parse::{Parse, ParseStream, Parser};
4use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
5
6// syn::AttributeArgs does not implement syn::Parse
7type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
8
9#[derive(Clone, Copy, PartialEq)]
10enum RuntimeFlavor {
11    CurrentThread,
12    Threaded,
13    Local,
14}
15
16impl RuntimeFlavor {
17    fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
18        match s {
19            "current_thread" => Ok(RuntimeFlavor::CurrentThread),
20            "multi_thread" => Ok(RuntimeFlavor::Threaded),
21            "local" => Ok(RuntimeFlavor::Local),
22            "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
23            "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
24            "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
25            _ => Err(format!("No such runtime flavor `{s}`. The runtime flavors are `current_thread`, `local`, and `multi_thread`.")),
26        }
27    }
28}
29
30#[derive(Clone, Copy, PartialEq)]
31enum UnhandledPanic {
32    Ignore,
33    ShutdownRuntime,
34}
35
36impl UnhandledPanic {
37    fn from_str(s: &str) -> Result<UnhandledPanic, String> {
38        match s {
39            "ignore" => Ok(UnhandledPanic::Ignore),
40            "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime),
41            _ => Err(format!("No such unhandled panic behavior `{s}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`.")),
42        }
43    }
44
45    fn into_tokens(self, crate_path: &TokenStream) -> TokenStream {
46        match self {
47            UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore },
48            UnhandledPanic::ShutdownRuntime => {
49                quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime }
50            }
51        }
52    }
53}
54
55struct FinalConfig {
56    name: Option<String>,
57    flavor: RuntimeFlavor,
58    worker_threads: Option<usize>,
59    start_paused: Option<bool>,
60    crate_name: Option<Path>,
61    unhandled_panic: Option<UnhandledPanic>,
62}
63
64/// Config used in case of the attribute not being able to build a valid config
65const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
66    name: None,
67    flavor: RuntimeFlavor::CurrentThread,
68    worker_threads: None,
69    start_paused: None,
70    crate_name: None,
71    unhandled_panic: None,
72};
73
74struct Configuration {
75    name: Option<String>,
76    rt_multi_thread_available: bool,
77    default_flavor: RuntimeFlavor,
78    flavor: Option<RuntimeFlavor>,
79    worker_threads: Option<(usize, Span)>,
80    start_paused: Option<(bool, Span)>,
81    is_test: bool,
82    crate_name: Option<Path>,
83    unhandled_panic: Option<(UnhandledPanic, Span)>,
84}
85
86impl Configuration {
87    fn new(is_test: bool, rt_multi_thread: bool) -> Self {
88        Configuration {
89            name: None,
90            rt_multi_thread_available: rt_multi_thread,
91            default_flavor: match is_test {
92                true => RuntimeFlavor::CurrentThread,
93                false => RuntimeFlavor::Threaded,
94            },
95            flavor: None,
96            worker_threads: None,
97            start_paused: None,
98            is_test,
99            crate_name: None,
100            unhandled_panic: None,
101        }
102    }
103
104    fn set_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
105        if self.name.is_some() {
106            return Err(syn::Error::new(span, "`name` set multiple times."));
107        }
108
109        let runtime_name = parse_string(name, span, "name")?;
110        self.name = Some(runtime_name);
111        Ok(())
112    }
113
114    fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
115        if self.flavor.is_some() {
116            return Err(syn::Error::new(span, "`flavor` set multiple times."));
117        }
118
119        let runtime_str = parse_string(runtime, span, "flavor")?;
120        let runtime =
121            RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
122        self.flavor = Some(runtime);
123        Ok(())
124    }
125
126    fn set_worker_threads(
127        &mut self,
128        worker_threads: syn::Lit,
129        span: Span,
130    ) -> Result<(), syn::Error> {
131        if self.worker_threads.is_some() {
132            return Err(syn::Error::new(
133                span,
134                "`worker_threads` set multiple times.",
135            ));
136        }
137
138        let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
139        if worker_threads == 0 {
140            return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
141        }
142        self.worker_threads = Some((worker_threads, span));
143        Ok(())
144    }
145
146    fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
147        if self.start_paused.is_some() {
148            return Err(syn::Error::new(span, "`start_paused` set multiple times."));
149        }
150
151        let start_paused = parse_bool(start_paused, span, "start_paused")?;
152        self.start_paused = Some((start_paused, span));
153        Ok(())
154    }
155
156    fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
157        if self.crate_name.is_some() {
158            return Err(syn::Error::new(span, "`crate` set multiple times."));
159        }
160        let name_path = parse_path(name, span, "crate")?;
161        self.crate_name = Some(name_path);
162        Ok(())
163    }
164
165    fn set_unhandled_panic(
166        &mut self,
167        unhandled_panic: syn::Lit,
168        span: Span,
169    ) -> Result<(), syn::Error> {
170        if self.unhandled_panic.is_some() {
171            return Err(syn::Error::new(
172                span,
173                "`unhandled_panic` set multiple times.",
174            ));
175        }
176
177        let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic")?;
178        let unhandled_panic =
179            UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?;
180        self.unhandled_panic = Some((unhandled_panic, span));
181        Ok(())
182    }
183
184    fn macro_name(&self) -> &'static str {
185        if self.is_test {
186            "tokio::test"
187        } else {
188            "tokio::main"
189        }
190    }
191
192    fn build(&self) -> Result<FinalConfig, syn::Error> {
193        use RuntimeFlavor as F;
194
195        let flavor = self.flavor.unwrap_or(self.default_flavor);
196
197        let worker_threads = match (flavor, self.worker_threads) {
198            (F::CurrentThread | F::Local, Some((_, worker_threads_span))) => {
199                let msg = format!(
200                    "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
201                    self.macro_name(),
202                );
203                return Err(syn::Error::new(worker_threads_span, msg));
204            }
205            (F::CurrentThread | F::Local, None) => None,
206            (F::Threaded, worker_threads) if self.rt_multi_thread_available => {
207                worker_threads.map(|(val, _span)| val)
208            }
209            (F::Threaded, _) => {
210                let msg = if self.flavor.is_none() {
211                    "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
212                } else {
213                    "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
214                };
215                return Err(syn::Error::new(Span::call_site(), msg));
216            }
217        };
218
219        let start_paused = match (flavor, self.start_paused) {
220            (F::Threaded, Some((_, start_paused_span))) => {
221                let msg = format!(
222                    "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
223                    self.macro_name(),
224                );
225                return Err(syn::Error::new(start_paused_span, msg));
226            }
227            (F::CurrentThread | F::Local, Some((start_paused, _))) => Some(start_paused),
228            (_, None) => None,
229        };
230
231        let unhandled_panic = match (flavor, self.unhandled_panic) {
232            (F::Threaded, Some((_, unhandled_panic_span))) => {
233                let msg = format!(
234                    "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
235                    self.macro_name(),
236                );
237                return Err(syn::Error::new(unhandled_panic_span, msg));
238            }
239            (F::CurrentThread | F::Local, Some((unhandled_panic, _))) => Some(unhandled_panic),
240            (_, None) => None,
241        };
242
243        Ok(FinalConfig {
244            name: self.name.clone(),
245            crate_name: self.crate_name.clone(),
246            flavor,
247            worker_threads,
248            start_paused,
249            unhandled_panic,
250        })
251    }
252}
253
254fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
255    match int {
256        syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
257            Ok(value) => Ok(value),
258            Err(e) => Err(syn::Error::new(
259                span,
260                format!("Failed to parse value of `{field}` as integer: {e}"),
261            )),
262        },
263        _ => Err(syn::Error::new(
264            span,
265            format!("Failed to parse value of `{field}` as integer."),
266        )),
267    }
268}
269
270fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
271    match int {
272        syn::Lit::Str(s) => Ok(s.value()),
273        syn::Lit::Verbatim(s) => Ok(s.to_string()),
274        _ => Err(syn::Error::new(
275            span,
276            format!("Failed to parse value of `{field}` as string."),
277        )),
278    }
279}
280
281fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
282    match lit {
283        syn::Lit::Str(s) => {
284            let err = syn::Error::new(
285                span,
286                format!(
287                    "Failed to parse value of `{}` as path: \"{}\"",
288                    field,
289                    s.value()
290                ),
291            );
292            s.parse::<syn::Path>().map_err(|_| err.clone())
293        }
294        _ => Err(syn::Error::new(
295            span,
296            format!("Failed to parse value of `{field}` as path."),
297        )),
298    }
299}
300
301fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
302    match bool {
303        syn::Lit::Bool(b) => Ok(b.value),
304        _ => Err(syn::Error::new(
305            span,
306            format!("Failed to parse value of `{field}` as bool."),
307        )),
308    }
309}
310
311fn contains_impl_trait(ty: &syn::Type) -> bool {
312    match ty {
313        syn::Type::ImplTrait(_) => true,
314        syn::Type::Array(t) => contains_impl_trait(&t.elem),
315        syn::Type::Ptr(t) => contains_impl_trait(&t.elem),
316        syn::Type::Reference(t) => contains_impl_trait(&t.elem),
317        syn::Type::Slice(t) => contains_impl_trait(&t.elem),
318        syn::Type::Tuple(t) => t.elems.iter().any(contains_impl_trait),
319        syn::Type::Paren(t) => contains_impl_trait(&t.elem),
320        syn::Type::Group(t) => contains_impl_trait(&t.elem),
321        syn::Type::Path(t) => match t.path.segments.last() {
322            Some(segment) => match &segment.arguments {
323                syn::PathArguments::AngleBracketed(args) => args.args.iter().any(|arg| match arg {
324                    syn::GenericArgument::Type(t) => contains_impl_trait(t),
325                    syn::GenericArgument::AssocType(t) => contains_impl_trait(&t.ty),
326                    _ => false,
327                }),
328                syn::PathArguments::Parenthesized(args) => {
329                    args.inputs.iter().any(contains_impl_trait)
330                        || matches!(&args.output, syn::ReturnType::Type(_, t) if contains_impl_trait(t))
331                }
332                syn::PathArguments::None => false,
333            },
334            None => false,
335        },
336        _ => false,
337    }
338}
339
340fn build_config(
341    input: &ItemFn,
342    args: AttributeArgs,
343    is_test: bool,
344    rt_multi_thread: bool,
345) -> Result<FinalConfig, syn::Error> {
346    if input.sig.asyncness.is_none() {
347        let msg = "the `async` keyword is missing from the function declaration";
348        return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
349    }
350
351    let mut config = Configuration::new(is_test, rt_multi_thread);
352    let macro_name = config.macro_name();
353
354    for arg in args {
355        match arg {
356            syn::Meta::NameValue(namevalue) => {
357                let ident = namevalue
358                    .path
359                    .get_ident()
360                    .ok_or_else(|| {
361                        syn::Error::new_spanned(&namevalue, "Must have specified ident")
362                    })?
363                    .to_string()
364                    .to_lowercase();
365                let lit = match &namevalue.value {
366                    syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
367                    expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
368                };
369                match ident.as_str() {
370                    "worker_threads" => {
371                        config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?;
372                    }
373                    "flavor" => {
374                        config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?;
375                    }
376                    "start_paused" => {
377                        config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?;
378                    }
379                    "core_threads" => {
380                        let msg = "Attribute `core_threads` is renamed to `worker_threads`";
381                        return Err(syn::Error::new_spanned(namevalue, msg));
382                    }
383                    "crate" => {
384                        config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
385                    }
386                    "unhandled_panic" => {
387                        config
388                            .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?;
389                    }
390                    "name" => {
391                        config.set_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
392                    }
393                    name => {
394                        let msg = format!(
395                            "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`, `name`.",
396                        );
397                        return Err(syn::Error::new_spanned(namevalue, msg));
398                    }
399                }
400            }
401            syn::Meta::Path(path) => {
402                let name = path
403                    .get_ident()
404                    .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
405                    .to_string()
406                    .to_lowercase();
407                let msg = match name.as_str() {
408                    "threaded_scheduler" | "multi_thread" => {
409                        format!(
410                            "Set the runtime flavor with #[{macro_name}(flavor = \"multi_thread\")]."
411                        )
412                    }
413                    "basic_scheduler" | "current_thread" | "single_threaded" => {
414                        format!(
415                            "Set the runtime flavor with #[{macro_name}(flavor = \"current_thread\")]."
416                        )
417                    }
418                    "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic"
419                    | "name" => {
420                        format!("The `{name}` attribute requires an argument.")
421                    }
422                    name => {
423                        format!("Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`, `name`.")
424                    }
425                };
426                return Err(syn::Error::new_spanned(path, msg));
427            }
428            other => {
429                return Err(syn::Error::new_spanned(
430                    other,
431                    "Unknown attribute inside the macro",
432                ));
433            }
434        }
435    }
436
437    config.build()
438}
439
440fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
441    input.sig.asyncness = None;
442
443    // If type mismatch occurs, the current rustc points to the last statement.
444    let (last_stmt_start_span, last_stmt_end_span) = {
445        let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
446
447        // `Span` on stable Rust has a limitation that only points to the first
448        // token, not the whole tokens. We can work around this limitation by
449        // using the first/last span of the tokens like
450        // `syn::Error::new_spanned` does.
451        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
452        let end = last_stmt.last().map_or(start, |t| t.span());
453        (start, end)
454    };
455
456    let crate_path = config
457        .crate_name
458        .map(ToTokens::into_token_stream)
459        .unwrap_or_else(|| {
460            Ident::new("tokio", Span::call_site().located_at(last_stmt_start_span))
461                .into_token_stream()
462        });
463
464    let use_builder = quote_spanned! {Span::call_site().located_at(last_stmt_start_span)=>
465        use #crate_path::runtime::Builder;
466    };
467
468    let mut rt = match config.flavor {
469        RuntimeFlavor::CurrentThread | RuntimeFlavor::Local => {
470            quote_spanned! {last_stmt_start_span=>
471                Builder::new_current_thread()
472            }
473        }
474        RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
475            Builder::new_multi_thread()
476        },
477    };
478
479    let build = if let RuntimeFlavor::Local = config.flavor {
480        quote_spanned! {last_stmt_start_span=> build_local(Default::default())}
481    } else {
482        quote_spanned! {last_stmt_start_span=> build()}
483    };
484
485    if let Some(v) = config.worker_threads {
486        rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) };
487    }
488    if let Some(v) = config.start_paused {
489        rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) };
490    }
491    if let Some(v) = config.unhandled_panic {
492        let unhandled_panic = v.into_tokens(&crate_path);
493        rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) };
494    }
495    if let Some(v) = config.name {
496        rt = quote_spanned! {last_stmt_start_span=> #rt.name(#v) };
497    }
498
499    let generated_attrs = if is_test {
500        quote! {
501            #[::core::prelude::v1::test]
502        }
503    } else {
504        quote! {}
505    };
506
507    let body_ident = quote! { body };
508    // This explicit `return` is intentional. See tokio-rs/tokio#4636
509    let last_block = quote_spanned! {last_stmt_end_span=>
510
511        #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return, clippy::unwrap_in_result)]
512        {
513            #use_builder
514
515            return #rt
516                .enable_all()
517                .#build
518                .expect("Failed building the Runtime")
519                .block_on(#body_ident);
520        }
521
522    };
523
524    let body = input.body();
525
526    // For test functions pin the body to the stack and use `Pin<&mut dyn
527    // Future>` to reduce the amount of `Runtime::block_on` (and related
528    // functions) copies we generate during compilation due to the generic
529    // parameter `F` (the future to block on). This could have an impact on
530    // performance, but because it's only for testing it's unlikely to be very
531    // large.
532    //
533    // We don't do this for the main function as it should only be used once so
534    // there will be no benefit.
535    let output_type = match &input.sig.output {
536        // For functions with no return value syn doesn't print anything,
537        // but that doesn't work as `Output` for our boxed `Future`, so
538        // default to `()` (the same type as the function output).
539        syn::ReturnType::Default => quote! { () },
540        syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
541    };
542
543    let body = if is_test {
544        quote! {
545            let body = async #body;
546            #crate_path::pin!(body);
547            let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body;
548        }
549    } else {
550        // force typecheck without runtime overhead
551        let check_block = match &input.sig.output {
552            syn::ReturnType::Type(_, t)
553                if matches!(**t, syn::Type::Never(_)) || contains_impl_trait(t) =>
554            {
555                quote! {}
556            }
557            _ => quote! {
558                if false {
559                    let _: &dyn ::core::future::Future<Output = #output_type> = &body;
560                }
561            },
562        };
563
564        quote! {
565            let body = async #body;
566            // Compile-time assertion that the future's output matches the return type.
567            let body = {
568                #check_block
569                body
570            };
571        }
572    };
573
574    input.into_tokens(generated_attrs, body, last_block)
575}
576
577fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
578    tokens.extend(error.into_compile_error());
579    tokens
580}
581
582pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
583    // If any of the steps for this macro fail, we still want to expand to an item that is as close
584    // to the expected output as possible. This helps out IDEs such that completions and other
585    // related features keep working.
586    let input: ItemFn = match syn::parse2(item.clone()) {
587        Ok(it) => it,
588        Err(e) => return token_stream_with_error(item, e),
589    };
590
591    let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
592        let msg = "the main function cannot accept arguments";
593        Err(syn::Error::new_spanned(&input.sig.ident, msg))
594    } else {
595        AttributeArgs::parse_terminated
596            .parse2(args)
597            .and_then(|args| build_config(&input, args, false, rt_multi_thread))
598    };
599
600    match config {
601        Ok(config) => parse_knobs(input, false, config),
602        Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
603    }
604}
605
606// Check whether given attribute is a test attribute of forms:
607// * `#[test]`
608// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]`
609// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]`
610fn is_test_attribute(attr: &Attribute) -> bool {
611    let path = match &attr.meta {
612        syn::Meta::Path(path) => path,
613        _ => return false,
614    };
615    let candidates = [
616        ["core", "prelude", "*", "test"],
617        ["std", "prelude", "*", "test"],
618    ];
619    if path.leading_colon.is_none()
620        && path.segments.len() == 1
621        && path.segments[0].arguments.is_none()
622        && path.segments[0].ident == "test"
623    {
624        return true;
625    } else if path.segments.len() != candidates[0].len() {
626        return false;
627    }
628    candidates.into_iter().any(|segments| {
629        path.segments.iter().zip(segments).all(|(segment, path)| {
630            segment.arguments.is_none() && (path == "*" || segment.ident == path)
631        })
632    })
633}
634
635pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
636    // If any of the steps for this macro fail, we still want to expand to an item that is as close
637    // to the expected output as possible. This helps out IDEs such that completions and other
638    // related features keep working.
639    let input: ItemFn = match syn::parse2(item.clone()) {
640        Ok(it) => it,
641        Err(e) => return token_stream_with_error(item, e),
642    };
643    let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) {
644        let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes";
645        Err(syn::Error::new_spanned(attr, msg))
646    } else {
647        AttributeArgs::parse_terminated
648            .parse2(args)
649            .and_then(|args| build_config(&input, args, true, rt_multi_thread))
650    };
651
652    match config {
653        Ok(config) => parse_knobs(input, true, config),
654        Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
655    }
656}
657
658struct ItemFn {
659    outer_attrs: Vec<Attribute>,
660    vis: Visibility,
661    sig: Signature,
662    brace_token: syn::token::Brace,
663    inner_attrs: Vec<Attribute>,
664    stmts: Vec<proc_macro2::TokenStream>,
665}
666
667impl ItemFn {
668    /// Access all attributes of the function item.
669    fn attrs(&self) -> impl Iterator<Item = &Attribute> {
670        self.outer_attrs.iter().chain(self.inner_attrs.iter())
671    }
672
673    /// Get the body of the function item in a manner so that it can be
674    /// conveniently used with the `quote!` macro.
675    fn body(&self) -> Body<'_> {
676        Body {
677            brace_token: self.brace_token,
678            stmts: &self.stmts,
679        }
680    }
681
682    /// Convert our local function item into a token stream.
683    fn into_tokens(
684        self,
685        generated_attrs: proc_macro2::TokenStream,
686        body: proc_macro2::TokenStream,
687        last_block: proc_macro2::TokenStream,
688    ) -> TokenStream {
689        let mut tokens = proc_macro2::TokenStream::new();
690        // Outer attributes are simply streamed as-is.
691        for attr in self.outer_attrs {
692            attr.to_tokens(&mut tokens);
693        }
694
695        // Inner attributes require extra care, since they're not supported on
696        // blocks (which is what we're expanded into) we instead lift them
697        // outside of the function. This matches the behavior of `syn`.
698        for mut attr in self.inner_attrs {
699            attr.style = syn::AttrStyle::Outer;
700            attr.to_tokens(&mut tokens);
701        }
702
703        // Add generated macros at the end, so macros processed later are aware of them.
704        generated_attrs.to_tokens(&mut tokens);
705
706        self.vis.to_tokens(&mut tokens);
707        self.sig.to_tokens(&mut tokens);
708
709        self.brace_token.surround(&mut tokens, |tokens| {
710            body.to_tokens(tokens);
711            last_block.to_tokens(tokens);
712        });
713
714        tokens
715    }
716}
717
718impl Parse for ItemFn {
719    #[inline]
720    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
721        // This parse implementation has been largely lifted from `syn`, with
722        // the exception of:
723        // * We don't have access to the plumbing necessary to parse inner
724        //   attributes in-place.
725        // * We do our own statements parsing to avoid recursively parsing
726        //   entire statements and only look for the parts we're interested in.
727
728        let outer_attrs = input.call(Attribute::parse_outer)?;
729        let vis: Visibility = input.parse()?;
730        let sig: Signature = input.parse()?;
731
732        let content;
733        let brace_token = braced!(content in input);
734        let inner_attrs = Attribute::parse_inner(&content)?;
735
736        let mut buf = proc_macro2::TokenStream::new();
737        let mut stmts = Vec::new();
738
739        while !content.is_empty() {
740            if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
741                semi.to_tokens(&mut buf);
742                stmts.push(buf);
743                buf = proc_macro2::TokenStream::new();
744                continue;
745            }
746
747            // Parse a single token tree and extend our current buffer with it.
748            // This avoids parsing the entire content of the sub-tree.
749            buf.extend([content.parse::<TokenTree>()?]);
750        }
751
752        if !buf.is_empty() {
753            stmts.push(buf);
754        }
755
756        Ok(Self {
757            outer_attrs,
758            vis,
759            sig,
760            brace_token,
761            inner_attrs,
762            stmts,
763        })
764    }
765}
766
767struct Body<'a> {
768    brace_token: syn::token::Brace,
769    // Statements, with terminating `;`.
770    stmts: &'a [TokenStream],
771}
772
773impl ToTokens for Body<'_> {
774    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
775        self.brace_token.surround(tokens, |tokens| {
776            for stmt in self.stmts {
777                stmt.to_tokens(tokens);
778            }
779        });
780    }
781}