derive_more_impl/
error.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6    self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7    State,
8};
9
10pub fn expand(
11    input: &syn::DeriveInput,
12    trait_name: &'static str,
13) -> Result<TokenStream> {
14    let syn::DeriveInput {
15        ident, generics, ..
16    } = input;
17
18    let state = State::with_attr_params(
19        input,
20        trait_name,
21        trait_name.to_lowercase(),
22        allowed_attr_params(),
23    )?;
24
25    let type_params: HashSet<_> = generics
26        .params
27        .iter()
28        .filter_map(|generic| match generic {
29            syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
30            _ => None,
31        })
32        .collect();
33
34    let (bounds, source, provide) = match state.derive_type {
35        DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
36        DeriveType::Enum => render_enum(&type_params, &state)?,
37    };
38
39    let source = source.map(|source| {
40        // Not using `#[inline]` here on purpose, since this is almost never part
41        // of a hot codepath.
42        quote! {
43            fn source(&self) -> Option<&(dyn derive_more::core::error::Error + 'static)> {
44                use derive_more::__private::AsDynError as _;
45                #source
46            }
47        }
48    });
49
50    let provide = provide.map(|provide| {
51        // Not using `#[inline]` here on purpose, since this is almost never part
52        // of a hot codepath.
53        quote! {
54            fn provide<'_request>(
55                &'_request self,
56                request: &mut derive_more::core::error::Request<'_request>,
57            ) {
58                #provide
59            }
60        }
61    });
62
63    let mut generics = generics.clone();
64
65    if !type_params.is_empty() {
66        let (_, ty_generics, _) = generics.split_for_impl();
67        generics = utils::add_extra_where_clauses(
68            &generics,
69            quote! {
70                where
71                    #ident #ty_generics: derive_more::core::fmt::Debug
72                                         + derive_more::core::fmt::Display
73            },
74        );
75    }
76
77    if !bounds.is_empty() {
78        let bounds = bounds.iter();
79        generics = utils::add_extra_where_clauses(
80            &generics,
81            quote! {
82                where #(
83                    #bounds: derive_more::core::fmt::Debug
84                             + derive_more::core::fmt::Display
85                             + derive_more::core::error::Error
86                             + 'static
87                ),*
88            },
89        );
90    }
91
92    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
93
94    let render = quote! {
95        #[automatically_derived]
96        impl #impl_generics derive_more::core::error::Error for #ident #ty_generics #where_clause {
97            #source
98            #provide
99        }
100    };
101
102    Ok(render)
103}
104
105fn render_struct(
106    type_params: &HashSet<syn::Ident>,
107    state: &State,
108) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
109    let parsed_fields = parse_fields(type_params, state)?;
110
111    let source = parsed_fields.render_source_as_struct();
112    let provide = cfg!(error_generic_member_access)
113        .then(|| parsed_fields.render_provide_as_struct())
114        .flatten();
115
116    Ok((parsed_fields.bounds, source, provide))
117}
118
119fn render_enum(
120    type_params: &HashSet<syn::Ident>,
121    state: &State,
122) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
123    let mut bounds = HashSet::default();
124    let mut source_match_arms = Vec::new();
125    let mut provide_match_arms = Vec::new();
126
127    for variant in state.enabled_variant_data().variants {
128        let default_info = FullMetaInfo {
129            enabled: true,
130            ..FullMetaInfo::default()
131        };
132
133        let state = State::from_variant(
134            state.input,
135            state.trait_name,
136            state.trait_attr.clone(),
137            allowed_attr_params(),
138            variant,
139            default_info,
140        )?;
141
142        let parsed_fields = parse_fields(type_params, &state)?;
143
144        if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
145            source_match_arms.push(expr);
146        }
147
148        if let Some(expr) = parsed_fields.render_provide_as_enum_variant_match_arm() {
149            provide_match_arms.push(expr);
150        }
151
152        bounds.extend(parsed_fields.bounds.into_iter());
153    }
154
155    let render = |match_arms: &mut Vec<TokenStream>, unmatched| {
156        if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
157            match_arms.push(quote! { _ => #unmatched });
158        }
159
160        (!match_arms.is_empty()).then(|| {
161            quote! {
162                match self {
163                    #(#match_arms),*
164                }
165            }
166        })
167    };
168
169    let source = render(&mut source_match_arms, quote! { None });
170    let provide = render(&mut provide_match_arms, quote! { () });
171
172    Ok((bounds, source, provide))
173}
174
175fn allowed_attr_params() -> AttrParams {
176    AttrParams {
177        enum_: vec!["ignore"],
178        struct_: vec!["ignore"],
179        variant: vec!["ignore"],
180        field: vec!["ignore", "source", "optional", "backtrace"],
181    }
182}
183
184struct ParsedFields<'input, 'state> {
185    data: MultiFieldData<'input, 'state>,
186    source: Option<usize>,
187    backtrace: Option<usize>,
188    bounds: HashSet<syn::Type>,
189}
190
191impl<'input, 'state> ParsedFields<'input, 'state> {
192    fn new(data: MultiFieldData<'input, 'state>) -> Self {
193        Self {
194            data,
195            source: None,
196            backtrace: None,
197            bounds: HashSet::default(),
198        }
199    }
200}
201
202impl ParsedFields<'_, '_> {
203    fn render_source_as_struct(&self) -> Option<TokenStream> {
204        let source = self.source?;
205        let ident = &self.data.members[source];
206        let is_optional = self.data.infos[source].info.source_optional == Some(true)
207            || self.data.field_types[source].is_option();
208
209        Some(render_some(quote! { (&#ident) }, is_optional))
210    }
211
212    fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
213        let source = self.source?;
214        let pattern = self.data.matcher(&[source], &[quote! { source }]);
215        let is_optional = self.data.infos[source].info.source_optional == Some(true)
216            || self.data.field_types[source].is_option();
217
218        let expr = render_some(quote! { source }, is_optional);
219        Some(quote! { #pattern => #expr })
220    }
221
222    fn render_provide_as_struct(&self) -> Option<TokenStream> {
223        let backtrace = self.backtrace?;
224
225        let source_provider = self.source.map(|source| {
226            let source_expr = &self.data.members[source];
227            quote! {
228                derive_more::core::error::Error::provide(&#source_expr, request);
229            }
230        });
231        let backtrace_provider = self
232            .source
233            .filter(|source| *source == backtrace)
234            .is_none()
235            .then(|| {
236                let backtrace_expr = &self.data.members[backtrace];
237                quote! {
238                    request.provide_ref::<::std::backtrace::Backtrace>(&#backtrace_expr);
239                }
240            });
241
242        (source_provider.is_some() || backtrace_provider.is_some()).then(|| {
243            quote! {
244                #backtrace_provider
245                #source_provider
246            }
247        })
248    }
249
250    fn render_provide_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
251        let backtrace = self.backtrace?;
252
253        match self.source {
254            Some(source) if source == backtrace => {
255                let pattern = self.data.matcher(&[source], &[quote! { source }]);
256                Some(quote! {
257                    #pattern => {
258                        derive_more::core::error::Error::provide(source, request);
259                    }
260                })
261            }
262            Some(source) => {
263                let pattern = self.data.matcher(
264                    &[source, backtrace],
265                    &[quote! { source }, quote! { backtrace }],
266                );
267                Some(quote! {
268                    #pattern => {
269                        request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
270                        derive_more::core::error::Error::provide(source, request);
271                    }
272                })
273            }
274            None => {
275                let pattern = self.data.matcher(&[backtrace], &[quote! { backtrace }]);
276                Some(quote! {
277                    #pattern => {
278                        request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
279                    }
280                })
281            }
282        }
283    }
284}
285
286fn render_some(mut expr: TokenStream, unpack: bool) -> TokenStream {
287    if unpack {
288        expr = quote! { derive_more::core::option::Option::as_ref(#expr)? }
289    }
290    quote! { Some(#expr.as_dyn_error()) }
291}
292
293fn parse_fields<'input, 'state>(
294    type_params: &HashSet<syn::Ident>,
295    state: &'state State<'input>,
296) -> Result<ParsedFields<'input, 'state>> {
297    let mut parsed_fields = match state.derive_type {
298        DeriveType::Named => {
299            parse_fields_impl(state, |attr, field, _| {
300                // Unwrapping is safe, cause fields in named struct
301                // always have an ident
302                let ident = field.ident.as_ref().unwrap();
303
304                match attr {
305                    "source" => ident == "source",
306                    "backtrace" => {
307                        ident == "backtrace"
308                            || is_type_path_ends_with_segment(&field.ty, "Backtrace")
309                    }
310                    _ => unreachable!(),
311                }
312            })
313        }
314
315        DeriveType::Unnamed => {
316            let mut parsed_fields =
317                parse_fields_impl(state, |attr, field, len| match attr {
318                    "source" => {
319                        len == 1
320                            && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
321                    }
322                    "backtrace" => {
323                        is_type_path_ends_with_segment(&field.ty, "Backtrace")
324                    }
325                    _ => unreachable!(),
326                })?;
327
328            parsed_fields.source = parsed_fields
329                .source
330                .or_else(|| infer_source_field(&state.fields, &parsed_fields));
331
332            Ok(parsed_fields)
333        }
334
335        _ => unreachable!(),
336    }?;
337
338    if let Some(source) = parsed_fields.source {
339        let is_optional = parsed_fields.data.infos[source].info.source_optional
340            == Some(true)
341            || state.fields[source].ty.is_option();
342
343        add_bound_if_type_parameter_used_in_type(
344            &mut parsed_fields.bounds,
345            type_params,
346            &state.fields[source].ty,
347            is_optional,
348        );
349    }
350
351    Ok(parsed_fields)
352}
353
354/// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
355/// and doesn't contain any generic parameters.
356fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
357    let syn::Type::Path(ty) = ty else {
358        return false;
359    };
360
361    // Unwrapping is safe, cause 'syn::TypePath.path.segments'
362    // have to have at least one segment
363    let segment = ty.path.segments.last().unwrap();
364
365    if !matches!(segment.arguments, syn::PathArguments::None) {
366        return false;
367    }
368
369    segment.ident == tail
370}
371
372fn infer_source_field(
373    fields: &[&syn::Field],
374    parsed_fields: &ParsedFields,
375) -> Option<usize> {
376    // if we have exactly two fields
377    if fields.len() != 2 {
378        return None;
379    }
380
381    // no source field was specified/inferred
382    if parsed_fields.source.is_some() {
383        return None;
384    }
385
386    // but one of the fields was specified/inferred as backtrace field
387    if let Some(backtrace) = parsed_fields.backtrace {
388        // then infer *other field* as source field
389        let source = (backtrace + 1) % 2;
390        // unless it was explicitly marked as non-source
391        if parsed_fields.data.infos[source].info.source != Some(false) {
392            return Some(source);
393        }
394    }
395
396    None
397}
398
399fn parse_fields_impl<'input, 'state, P>(
400    state: &'state State<'input>,
401    is_valid_default_field_for_attr: P,
402) -> Result<ParsedFields<'input, 'state>>
403where
404    P: Fn(&str, &syn::Field, usize) -> bool,
405{
406    let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
407
408    let iter = fields
409        .iter()
410        .zip(infos.iter().map(|info| &info.info))
411        .enumerate()
412        .map(|(index, (field, info))| (index, *field, info));
413
414    let source = parse_field_impl(
415        &is_valid_default_field_for_attr,
416        state.fields.len(),
417        iter.clone(),
418        "source",
419        |info| info.source,
420    )?;
421
422    let backtrace = parse_field_impl(
423        &is_valid_default_field_for_attr,
424        state.fields.len(),
425        iter.clone(),
426        "backtrace",
427        |info| info.backtrace,
428    )?;
429
430    let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
431
432    if let Some((index, _, _)) = source {
433        parsed_fields.source = Some(index);
434    }
435
436    if let Some((index, _, _)) = backtrace {
437        parsed_fields.backtrace = Some(index);
438    }
439
440    Ok(parsed_fields)
441}
442
443fn parse_field_impl<'a, P, V>(
444    is_valid_default_field_for_attr: &P,
445    len: usize,
446    iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
447    attr: &str,
448    value: V,
449) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
450where
451    P: Fn(&str, &syn::Field, usize) -> bool,
452    V: Fn(&MetaInfo) -> Option<bool>,
453{
454    let explicit_fields = iter
455        .clone()
456        .filter(|(_, _, info)| matches!(value(info), Some(true)));
457
458    let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
459        None => is_valid_default_field_for_attr(attr, field, len),
460        _ => false,
461    });
462
463    let field = assert_iter_contains_zero_or_one_item(
464        explicit_fields,
465        &format!(
466            "Multiple `{attr}` attributes specified. \
467             Single attribute per struct/enum variant allowed.",
468        ),
469    )?;
470
471    let field = match field {
472        field @ Some(_) => field,
473        None => assert_iter_contains_zero_or_one_item(
474            inferred_fields,
475            "Conflicting fields found. Consider specifying some \
476             `#[error(...)]` attributes to resolve conflict.",
477        )?,
478    };
479
480    Ok(field)
481}
482
483fn assert_iter_contains_zero_or_one_item<'a>(
484    mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
485    error_msg: &str,
486) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
487    let Some(item) = iter.next() else {
488        return Ok(None);
489    };
490
491    if let Some((_, field, _)) = iter.next() {
492        return Err(Error::new(field.span(), error_msg));
493    }
494
495    Ok(Some(item))
496}
497
498fn add_bound_if_type_parameter_used_in_type(
499    bounds: &mut HashSet<syn::Type>,
500    type_params: &HashSet<syn::Ident>,
501    ty: &syn::Type,
502    unpack: bool,
503) {
504    if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
505        bounds.insert(
506            unpack
507                .then(|| ty.get_inner())
508                .flatten()
509                .cloned()
510                .unwrap_or(ty),
511        );
512    }
513}
514
515/// Extension of a [`syn::Type`] used by this expansion.
516trait TypeExt {
517    /// Checks syntactically whether this [`syn::Type`] represents an [`Option`].
518    fn is_option(&self) -> bool;
519
520    /// Returns the inner [`syn::Type`] if this one represents a wrapper.
521    ///
522    /// `filter` filters out this [`syn::Type`] by its name.
523    fn get_inner_if(&self, filter: impl Fn(&syn::Ident) -> bool) -> Option<&Self>;
524
525    /// Returns the inner [`syn::Type`] if this one represents a wrapper.
526    fn get_inner(&self) -> Option<&Self> {
527        self.get_inner_if(|_| true)
528    }
529
530    /// Returns the inner [`syn::Type`] if this one represents an [`Option`].
531    fn get_option_inner(&self) -> Option<&Self> {
532        self.get_inner_if(|ident| ident == "Option")
533    }
534}
535
536impl TypeExt for syn::Type {
537    fn is_option(&self) -> bool {
538        self.get_option_inner().is_some()
539    }
540
541    fn get_inner_if(&self, filter: impl Fn(&syn::Ident) -> bool) -> Option<&Self> {
542        match self {
543            Self::Group(g) => g.elem.get_option_inner(),
544            Self::Paren(p) => p.elem.get_option_inner(),
545            Self::Path(p) => p
546                .path
547                .segments
548                .last()
549                .filter(|s| filter(&s.ident))
550                .and_then(|s| {
551                    if let syn::PathArguments::AngleBracketed(a) = &s.arguments {
552                        if let Some(syn::GenericArgument::Type(ty)) = a.args.first() {
553                            return Some(ty);
554                        }
555                    }
556                    None
557                }),
558            _ => None,
559        }
560    }
561}