derive_more_impl/
deref.rs

1use crate::utils::{
2    add_extra_where_clauses, numbered_vars, panic_one_field, SingleFieldData, State,
3};
4use proc_macro2::TokenStream;
5use quote::quote;
6use syn::{parse::Result, Data, DeriveInput};
7
8/// Provides the hook to expand `#[derive(Deref)]` into an implementation of `Deref`
9pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
10    match input.data {
11        Data::Struct(_) => expand_struct(input, trait_name),
12        Data::Enum(_) => expand_enum(input, trait_name),
13        _ => panic!("only structs and enums can use `derive({trait_name})`"),
14    }
15}
16
17fn expand_enum(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
18    let state = State::with_field_ignore_and_forward(
19        input,
20        trait_name,
21        trait_name.to_lowercase(),
22    )?;
23
24    let trait_path = &state.trait_path;
25    let enum_name = &input.ident;
26
27    let mut target = None;
28    let mut match_arms = vec![];
29    let mut predicates = vec![];
30
31    for variant_state in state.enabled_variant_data().variant_states.into_iter() {
32        let data = variant_state.enabled_fields_data();
33        if data.fields.len() != 1 {
34            panic_one_field(variant_state.trait_name, &variant_state.trait_attr);
35        };
36
37        let vars = numbered_vars(variant_state.fields.len(), "");
38        let matcher = data.matcher(&data.field_indexes, &vars);
39
40        let info = data.infos[0].clone();
41        let field_type = data.field_types[0];
42
43        let (target_, var) = if info.forward {
44            let casted_trait = data.casted_traits[0].clone();
45            predicates.push(quote! { #field_type: #trait_path });
46            (
47                quote! { #casted_trait::Target },
48                quote! { #casted_trait::deref(__0) },
49            )
50        } else {
51            (quote! { #field_type }, quote! { __0 })
52        };
53
54        if target.is_none() {
55            target = Some(target_);
56        }
57
58        match_arms.push(quote! {
59            #matcher => #var
60        });
61    }
62
63    let target = target.unwrap();
64
65    let generics = if predicates.is_empty() {
66        &input.generics
67    } else {
68        &add_extra_where_clauses(&input.generics, quote! { where #(#predicates),* })
69    };
70
71    let (imp_generics, type_generics, where_clause) = generics.split_for_impl();
72
73    Ok(quote! {
74        #[allow(deprecated)] // omit warnings on deprecated fields/variants
75        #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
76        #[automatically_derived]
77        impl #imp_generics #trait_path for #enum_name #type_generics #where_clause {
78            type Target = #target;
79
80            #[inline]
81            fn deref(&self) -> &Self::Target {
82                match self {
83                    #(#match_arms),*
84                }
85            }
86        }
87    })
88}
89
90fn expand_struct(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
91    let state = State::with_field_ignore_and_forward(
92        input,
93        trait_name,
94        trait_name.to_lowercase(),
95    )?;
96    let SingleFieldData {
97        input_type,
98        field_type,
99        trait_path,
100        casted_trait,
101        ty_generics,
102        member,
103        info,
104        ..
105    } = state.assert_single_enabled_field();
106
107    let (target, body, generics) = if info.forward {
108        (
109            quote! { #casted_trait::Target },
110            quote! { #casted_trait::deref(&#member) },
111            add_extra_where_clauses(
112                &input.generics,
113                quote! {
114                    where #field_type: #trait_path
115                },
116            ),
117        )
118    } else {
119        (
120            quote! { #field_type },
121            quote! { &#member },
122            input.generics.clone(),
123        )
124    };
125    let (impl_generics, _, where_clause) = generics.split_for_impl();
126
127    Ok(quote! {
128        #[allow(deprecated)] // omit warnings on deprecated fields/variants
129        #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
130        #[automatically_derived]
131        impl #impl_generics #trait_path for #input_type #ty_generics #where_clause {
132            type Target = #target;
133
134            #[inline]
135            fn deref(&self) -> &Self::Target {
136                #body
137            }
138        }
139    })
140}