derive_more_impl/
from.rs

1//! Implementation of a [`From`] derive macro.
2
3use std::{
4    any::{Any, TypeId},
5    iter,
6};
7
8use proc_macro2::{Span, TokenStream};
9use quote::{format_ident, quote, ToTokens as _, TokenStreamExt as _};
10use syn::{
11    parse::{Parse, ParseStream},
12    parse_quote,
13    spanned::Spanned as _,
14    token,
15};
16
17use crate::utils::{
18    attr::{self, ParseMultiple as _},
19    polyfill, Either, Spanning,
20};
21
22/// Expands a [`From`] derive macro.
23pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStream> {
24    let attr_name = format_ident!("from");
25
26    match &input.data {
27        syn::Data::Struct(data) => Expansion {
28            attrs: StructAttribute::parse_attrs_with(
29                &input.attrs,
30                &attr_name,
31                &ConsiderLegacySyntax {
32                    fields: &data.fields,
33                },
34            )?
35            .map(|attr| attr.into_inner().into())
36            .as_ref(),
37            ident: &input.ident,
38            variant: None,
39            fields: &data.fields,
40            generics: &input.generics,
41            has_explicit_from: false,
42        }
43        .expand(),
44        syn::Data::Enum(data) => {
45            let mut has_explicit_from = false;
46            let attrs = data
47                .variants
48                .iter()
49                .map(|variant| {
50                    let attr = VariantAttribute::parse_attrs_with(
51                        &variant.attrs,
52                        &attr_name,
53                        &ConsiderLegacySyntax {
54                            fields: &variant.fields,
55                        },
56                    )?
57                    .map(Spanning::into_inner);
58                    if matches!(
59                        attr,
60                        Some(
61                            VariantAttribute::Empty(_)
62                                | VariantAttribute::Types(_)
63                                | VariantAttribute::Forward(_)
64                        ),
65                    ) {
66                        has_explicit_from = true;
67                    }
68                    Ok(attr)
69                })
70                .collect::<syn::Result<Vec<_>>>()?;
71
72            data.variants
73                .iter()
74                .zip(&attrs)
75                .map(|(variant, attrs)| {
76                    Expansion {
77                        attrs: attrs.as_ref(),
78                        ident: &input.ident,
79                        variant: Some(&variant.ident),
80                        fields: &variant.fields,
81                        generics: &input.generics,
82                        has_explicit_from,
83                    }
84                    .expand()
85                })
86                .collect()
87        }
88        syn::Data::Union(data) => Err(syn::Error::new(
89            data.union_token.span(),
90            "`From` cannot be derived for unions",
91        )),
92    }
93}
94
95/// Representation of a [`From`] derive macro struct container attribute.
96///
97/// ```rust,ignore
98/// #[from(forward)]
99/// #[from(<types>)]
100/// ```
101type StructAttribute = attr::Conversion;
102
103/// Representation of a [`From`] derive macro enum variant attribute.
104///
105/// ```rust,ignore
106/// #[from]
107/// #[from(skip)] #[from(ignore)]
108/// #[from(forward)]
109/// #[from(<types>)]
110/// ```
111type VariantAttribute = attr::FieldConversion;
112
113/// Expansion of a macro for generating [`From`] implementation of a struct or
114/// enum.
115struct Expansion<'a> {
116    /// [`From`] attributes.
117    ///
118    /// As a [`VariantAttribute`] is superset of a [`StructAttribute`], we use
119    /// it for both derives.
120    attrs: Option<&'a VariantAttribute>,
121
122    /// Struct or enum [`syn::Ident`].
123    ///
124    /// [`syn::Ident`]: struct@syn::Ident
125    ident: &'a syn::Ident,
126
127    /// Variant [`syn::Ident`] in case of enum expansion.
128    ///
129    /// [`syn::Ident`]: struct@syn::Ident
130    variant: Option<&'a syn::Ident>,
131
132    /// Struct or variant [`syn::Fields`].
133    fields: &'a syn::Fields,
134
135    /// Struct or enum [`syn::Generics`].
136    generics: &'a syn::Generics,
137
138    /// Indicator whether one of the enum variants has
139    /// [`VariantAttribute::Empty`], [`VariantAttribute::Types`] or
140    /// [`VariantAttribute::Forward`].
141    ///
142    /// Always [`false`] for structs.
143    has_explicit_from: bool,
144}
145
146impl Expansion<'_> {
147    /// Expands [`From`] implementations for a struct or an enum variant.
148    fn expand(&self) -> syn::Result<TokenStream> {
149        use crate::utils::FieldsExt as _;
150
151        let ident = self.ident;
152        let field_tys = self.fields.iter().map(|f| &f.ty).collect::<Vec<_>>();
153        let (impl_gens, ty_gens, where_clause) = self.generics.split_for_impl();
154
155        let skip_variant = self.has_explicit_from
156            || (self.variant.is_some() && self.fields.is_empty());
157        match (self.attrs, skip_variant) {
158            (Some(VariantAttribute::Types(tys)), _) => {
159                tys.0.iter().map(|ty| {
160                    let variant = self.variant.iter();
161
162                    let mut from_tys = self.fields.validate_type(ty)?;
163                    let init = self.expand_fields(|ident, ty, index| {
164                        let ident = ident.into_iter();
165                        let index = index.into_iter();
166                        let from_ty = from_tys.next().unwrap_or_else(|| unreachable!());
167                        quote! {
168                            #( #ident: )* <#ty as derive_more::core::convert::From<#from_ty>>::from(
169                                value #( .#index )*
170                            ),
171                        }
172                    });
173
174                    Ok(quote! {
175                        #[allow(deprecated)] // omit warnings on deprecated fields/variants
176                        #[allow(unreachable_code)] // omit warnings for `!` and unreachable types
177                        #[automatically_derived]
178                        impl #impl_gens derive_more::core::convert::From<#ty>
179                         for #ident #ty_gens #where_clause {
180                            #[inline]
181                            fn from(value: #ty) -> Self {
182                                #ident #( :: #variant )* #init
183                            }
184                        }
185                    })
186                })
187                .collect()
188            }
189            (Some(VariantAttribute::Empty(_)), _) | (None, false) => {
190                let variant = self.variant.iter();
191                let init = self.expand_fields(|ident, _, index| {
192                    let ident = ident.into_iter();
193                    let index = index.into_iter();
194                    quote! { #( #ident: )* value #( . #index )*, }
195                });
196
197                Ok(quote! {
198                    #[allow(deprecated)] // omit warnings on deprecated fields/variants
199                    #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
200                    #[automatically_derived]
201                    impl #impl_gens derive_more::core::convert::From<(#( #field_tys ),*)>
202                     for #ident #ty_gens #where_clause {
203                        #[inline]
204                        fn from(value: (#( #field_tys ),*)) -> Self {
205                            #ident #( :: #variant )* #init
206                        }
207                    }
208                })
209            }
210            (Some(VariantAttribute::Forward(_)), _) => {
211                let mut i = 0;
212                let mut gen_idents = Vec::with_capacity(self.fields.len());
213                let init = self.expand_fields(|ident, ty, index| {
214                    let ident = ident.into_iter();
215                    let index = index.into_iter();
216                    let gen_ident = format_ident!("__FromT{i}");
217                    let out = quote! {
218                        #( #ident: )* <#ty as derive_more::core::convert::From<#gen_ident>>::from(
219                            value #( .#index )*
220                        ),
221                    };
222                    gen_idents.push(gen_ident);
223                    i += 1;
224                    out
225                });
226
227                let variant = self.variant.iter();
228                let generics = {
229                    let mut generics = self.generics.clone();
230                    for (ty, ident) in field_tys.iter().zip(&gen_idents) {
231                        generics
232                            .make_where_clause()
233                            .predicates
234                            .push(parse_quote! { #ty: derive_more::core::convert::From<#ident> });
235                        generics
236                            .params
237                            .push(syn::TypeParam::from(ident.clone()).into());
238                    }
239                    generics
240                };
241                let (impl_gens, _, where_clause) = generics.split_for_impl();
242
243                Ok(quote! {
244                    #[allow(deprecated)] // omit warnings on deprecated fields/variants
245                    #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
246                    #[automatically_derived]
247                    impl #impl_gens derive_more::core::convert::From<(#( #gen_idents ),*)>
248                     for #ident #ty_gens #where_clause {
249                        #[inline]
250                        fn from(value: (#( #gen_idents ),*)) -> Self {
251                            #ident #(:: #variant)* #init
252                        }
253                    }
254                })
255            }
256            (Some(VariantAttribute::Skip(_)), _) | (None, true) => {
257                Ok(TokenStream::new())
258            }
259        }
260    }
261
262    /// Expands fields initialization wrapped into [`token::Brace`]s in case of
263    /// [`syn::FieldsNamed`], or [`token::Paren`] in case of
264    /// [`syn::FieldsUnnamed`].
265    ///
266    /// [`token::Brace`]: struct@token::Brace
267    /// [`token::Paren`]: struct@token::Paren
268    fn expand_fields(
269        &self,
270        mut wrap: impl FnMut(
271            Option<&syn::Ident>,
272            &syn::Type,
273            Option<syn::Index>,
274        ) -> TokenStream,
275    ) -> TokenStream {
276        let surround = match self.fields {
277            syn::Fields::Named(_) | syn::Fields::Unnamed(_) => {
278                Some(|tokens| match self.fields {
279                    syn::Fields::Named(named) => {
280                        let mut out = TokenStream::new();
281                        named
282                            .brace_token
283                            .surround(&mut out, |out| out.append_all(tokens));
284                        out
285                    }
286                    syn::Fields::Unnamed(unnamed) => {
287                        let mut out = TokenStream::new();
288                        unnamed
289                            .paren_token
290                            .surround(&mut out, |out| out.append_all(tokens));
291                        out
292                    }
293                    syn::Fields::Unit => unreachable!(),
294                })
295            }
296            syn::Fields::Unit => None,
297        };
298
299        surround
300            .map(|surround| {
301                surround(if self.fields.len() == 1 {
302                    let field = self
303                        .fields
304                        .iter()
305                        .next()
306                        .unwrap_or_else(|| unreachable!("self.fields.len() == 1"));
307                    wrap(field.ident.as_ref(), &field.ty, None)
308                } else {
309                    self.fields
310                        .iter()
311                        .enumerate()
312                        .map(|(i, field)| {
313                            wrap(field.ident.as_ref(), &field.ty, Some(i.into()))
314                        })
315                        .collect()
316                })
317            })
318            .unwrap_or_default()
319    }
320}
321
322/// [`attr::Parser`] considering legacy syntax for [`attr::Types`] and emitting [`legacy_error`], if
323/// any occurs.
324struct ConsiderLegacySyntax<'a> {
325    /// [`syn::Fields`] of a struct or enum variant, the attribute is parsed for.
326    fields: &'a syn::Fields,
327}
328
329impl attr::Parser for ConsiderLegacySyntax<'_> {
330    fn parse<T: Parse + Any>(&self, input: ParseStream<'_>) -> syn::Result<T> {
331        if TypeId::of::<T>() == TypeId::of::<attr::Types>() {
332            let ahead = input.fork();
333            if let Ok(p) = ahead.parse::<syn::Path>() {
334                if p.is_ident("types") {
335                    return legacy_error(&ahead, input.span(), self.fields);
336                }
337            }
338        }
339        T::parse(input)
340    }
341}
342
343/// Constructs a [`syn::Error`] for legacy syntax: `#[from(types(i32, "&str"))]`.
344fn legacy_error<T>(
345    tokens: ParseStream<'_>,
346    span: Span,
347    fields: &syn::Fields,
348) -> syn::Result<T> {
349    let content;
350    syn::parenthesized!(content in tokens);
351
352    let types = content
353        .parse_terminated(polyfill::NestedMeta::parse, token::Comma)?
354        .into_iter()
355        .map(|meta| {
356            let value = match meta {
357                polyfill::NestedMeta::Meta(meta) => {
358                    meta.into_token_stream().to_string()
359                }
360                polyfill::NestedMeta::Lit(syn::Lit::Str(str)) => str.value(),
361                polyfill::NestedMeta::Lit(_) => unreachable!(),
362            };
363            if fields.len() > 1 {
364                format!(
365                    "({})",
366                    fields
367                        .iter()
368                        .map(|_| value.clone())
369                        .collect::<Vec<_>>()
370                        .join(", "),
371                )
372            } else {
373                value
374            }
375        })
376        .chain(match fields.len() {
377            0 => Either::Left(iter::empty()),
378            1 => Either::Right(iter::once(
379                fields
380                    .iter()
381                    .next()
382                    .unwrap_or_else(|| unreachable!("fields.len() == 1"))
383                    .ty
384                    .to_token_stream()
385                    .to_string(),
386            )),
387            _ => Either::Right(iter::once(format!(
388                "({})",
389                fields
390                    .iter()
391                    .map(|f| f.ty.to_token_stream().to_string())
392                    .collect::<Vec<_>>()
393                    .join(", ")
394            ))),
395        })
396        .collect::<Vec<_>>()
397        .join(", ");
398
399    Err(syn::Error::new(
400        span,
401        format!("legacy syntax, remove `types` and use `{types}` instead"),
402    ))
403}