1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6 self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7 State,
8};
9
10pub fn expand(
11 input: &syn::DeriveInput,
12 trait_name: &'static str,
13) -> Result<TokenStream> {
14 let syn::DeriveInput {
15 ident, generics, ..
16 } = input;
17
18 let state = State::with_attr_params(
19 input,
20 trait_name,
21 trait_name.to_lowercase(),
22 allowed_attr_params(),
23 )?;
24
25 let type_params: HashSet<_> = generics
26 .params
27 .iter()
28 .filter_map(|generic| match generic {
29 syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
30 _ => None,
31 })
32 .collect();
33
34 let (bounds, source, provide) = match state.derive_type {
35 DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
36 DeriveType::Enum => render_enum(&type_params, &state)?,
37 };
38
39 let source = source.map(|source| {
40 quote! {
43 fn source(&self) -> Option<&(dyn derive_more::core::error::Error + 'static)> {
44 use derive_more::__private::AsDynError as _;
45 #source
46 }
47 }
48 });
49
50 let provide = provide.map(|provide| {
51 quote! {
54 fn provide<'_request>(
55 &'_request self,
56 request: &mut derive_more::core::error::Request<'_request>,
57 ) {
58 #provide
59 }
60 }
61 });
62
63 let mut generics = generics.clone();
64
65 if !type_params.is_empty() {
66 let (_, ty_generics, _) = generics.split_for_impl();
67 generics = utils::add_extra_where_clauses(
68 &generics,
69 quote! {
70 where
71 #ident #ty_generics: derive_more::core::fmt::Debug
72 + derive_more::core::fmt::Display
73 },
74 );
75 }
76
77 if !bounds.is_empty() {
78 let bounds = bounds.iter();
79 generics = utils::add_extra_where_clauses(
80 &generics,
81 quote! {
82 where #(
83 #bounds: derive_more::core::fmt::Debug
84 + derive_more::core::fmt::Display
85 + derive_more::core::error::Error
86 + 'static
87 ),*
88 },
89 );
90 }
91
92 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
93
94 let render = quote! {
95 #[automatically_derived]
96 impl #impl_generics derive_more::core::error::Error for #ident #ty_generics #where_clause {
97 #source
98 #provide
99 }
100 };
101
102 Ok(render)
103}
104
105fn render_struct(
106 type_params: &HashSet<syn::Ident>,
107 state: &State,
108) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
109 let parsed_fields = parse_fields(type_params, state)?;
110
111 let source = parsed_fields.render_source_as_struct();
112 let provide = cfg!(error_generic_member_access)
113 .then(|| parsed_fields.render_provide_as_struct())
114 .flatten();
115
116 Ok((parsed_fields.bounds, source, provide))
117}
118
119fn render_enum(
120 type_params: &HashSet<syn::Ident>,
121 state: &State,
122) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
123 let mut bounds = HashSet::default();
124 let mut source_match_arms = Vec::new();
125 let mut provide_match_arms = Vec::new();
126
127 for variant in state.enabled_variant_data().variants {
128 let default_info = FullMetaInfo {
129 enabled: true,
130 ..FullMetaInfo::default()
131 };
132
133 let state = State::from_variant(
134 state.input,
135 state.trait_name,
136 state.trait_attr.clone(),
137 allowed_attr_params(),
138 variant,
139 default_info,
140 )?;
141
142 let parsed_fields = parse_fields(type_params, &state)?;
143
144 if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
145 source_match_arms.push(expr);
146 }
147
148 if let Some(expr) = parsed_fields.render_provide_as_enum_variant_match_arm() {
149 provide_match_arms.push(expr);
150 }
151
152 bounds.extend(parsed_fields.bounds.into_iter());
153 }
154
155 let render = |match_arms: &mut Vec<TokenStream>, unmatched| {
156 if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
157 match_arms.push(quote! { _ => #unmatched });
158 }
159
160 (!match_arms.is_empty()).then(|| {
161 quote! {
162 match self {
163 #(#match_arms),*
164 }
165 }
166 })
167 };
168
169 let source = render(&mut source_match_arms, quote! { None });
170 let provide = render(&mut provide_match_arms, quote! { () });
171
172 Ok((bounds, source, provide))
173}
174
175fn allowed_attr_params() -> AttrParams {
176 AttrParams {
177 enum_: vec!["ignore"],
178 struct_: vec!["ignore"],
179 variant: vec!["ignore"],
180 field: vec!["ignore", "source", "optional", "backtrace"],
181 }
182}
183
184struct ParsedFields<'input, 'state> {
185 data: MultiFieldData<'input, 'state>,
186 source: Option<usize>,
187 backtrace: Option<usize>,
188 bounds: HashSet<syn::Type>,
189}
190
191impl<'input, 'state> ParsedFields<'input, 'state> {
192 fn new(data: MultiFieldData<'input, 'state>) -> Self {
193 Self {
194 data,
195 source: None,
196 backtrace: None,
197 bounds: HashSet::default(),
198 }
199 }
200}
201
202impl ParsedFields<'_, '_> {
203 fn render_source_as_struct(&self) -> Option<TokenStream> {
204 let source = self.source?;
205 let ident = &self.data.members[source];
206 let is_optional = self.data.infos[source].info.source_optional == Some(true)
207 || self.data.field_types[source].is_option();
208
209 Some(render_some(quote! { (&#ident) }, is_optional))
210 }
211
212 fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
213 let source = self.source?;
214 let pattern = self.data.matcher(&[source], &[quote! { source }]);
215 let is_optional = self.data.infos[source].info.source_optional == Some(true)
216 || self.data.field_types[source].is_option();
217
218 let expr = render_some(quote! { source }, is_optional);
219 Some(quote! { #pattern => #expr })
220 }
221
222 fn render_provide_as_struct(&self) -> Option<TokenStream> {
223 let backtrace = self.backtrace?;
224
225 let source_provider = self.source.map(|source| {
226 let source_expr = &self.data.members[source];
227 quote! {
228 derive_more::core::error::Error::provide(&#source_expr, request);
229 }
230 });
231 let backtrace_provider = self
232 .source
233 .filter(|source| *source == backtrace)
234 .is_none()
235 .then(|| {
236 let backtrace_expr = &self.data.members[backtrace];
237 quote! {
238 request.provide_ref::<::std::backtrace::Backtrace>(&#backtrace_expr);
239 }
240 });
241
242 (source_provider.is_some() || backtrace_provider.is_some()).then(|| {
243 quote! {
244 #backtrace_provider
245 #source_provider
246 }
247 })
248 }
249
250 fn render_provide_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
251 let backtrace = self.backtrace?;
252
253 match self.source {
254 Some(source) if source == backtrace => {
255 let pattern = self.data.matcher(&[source], &[quote! { source }]);
256 Some(quote! {
257 #pattern => {
258 derive_more::core::error::Error::provide(source, request);
259 }
260 })
261 }
262 Some(source) => {
263 let pattern = self.data.matcher(
264 &[source, backtrace],
265 &[quote! { source }, quote! { backtrace }],
266 );
267 Some(quote! {
268 #pattern => {
269 request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
270 derive_more::core::error::Error::provide(source, request);
271 }
272 })
273 }
274 None => {
275 let pattern = self.data.matcher(&[backtrace], &[quote! { backtrace }]);
276 Some(quote! {
277 #pattern => {
278 request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
279 }
280 })
281 }
282 }
283 }
284}
285
286fn render_some(mut expr: TokenStream, unpack: bool) -> TokenStream {
287 if unpack {
288 expr = quote! { derive_more::core::option::Option::as_ref(#expr)? }
289 }
290 quote! { Some(#expr.as_dyn_error()) }
291}
292
293fn parse_fields<'input, 'state>(
294 type_params: &HashSet<syn::Ident>,
295 state: &'state State<'input>,
296) -> Result<ParsedFields<'input, 'state>> {
297 let mut parsed_fields = match state.derive_type {
298 DeriveType::Named => {
299 parse_fields_impl(state, |attr, field, _| {
300 let ident = field.ident.as_ref().unwrap();
303
304 match attr {
305 "source" => ident == "source",
306 "backtrace" => {
307 ident == "backtrace"
308 || is_type_path_ends_with_segment(&field.ty, "Backtrace")
309 }
310 _ => unreachable!(),
311 }
312 })
313 }
314
315 DeriveType::Unnamed => {
316 let mut parsed_fields =
317 parse_fields_impl(state, |attr, field, len| match attr {
318 "source" => {
319 len == 1
320 && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
321 }
322 "backtrace" => {
323 is_type_path_ends_with_segment(&field.ty, "Backtrace")
324 }
325 _ => unreachable!(),
326 })?;
327
328 parsed_fields.source = parsed_fields
329 .source
330 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
331
332 Ok(parsed_fields)
333 }
334
335 _ => unreachable!(),
336 }?;
337
338 if let Some(source) = parsed_fields.source {
339 let is_optional = parsed_fields.data.infos[source].info.source_optional
340 == Some(true)
341 || state.fields[source].ty.is_option();
342
343 add_bound_if_type_parameter_used_in_type(
344 &mut parsed_fields.bounds,
345 type_params,
346 &state.fields[source].ty,
347 is_optional,
348 );
349 }
350
351 Ok(parsed_fields)
352}
353
354fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
357 let syn::Type::Path(ty) = ty else {
358 return false;
359 };
360
361 let segment = ty.path.segments.last().unwrap();
364
365 if !matches!(segment.arguments, syn::PathArguments::None) {
366 return false;
367 }
368
369 segment.ident == tail
370}
371
372fn infer_source_field(
373 fields: &[&syn::Field],
374 parsed_fields: &ParsedFields,
375) -> Option<usize> {
376 if fields.len() != 2 {
378 return None;
379 }
380
381 if parsed_fields.source.is_some() {
383 return None;
384 }
385
386 if let Some(backtrace) = parsed_fields.backtrace {
388 let source = (backtrace + 1) % 2;
390 if parsed_fields.data.infos[source].info.source != Some(false) {
392 return Some(source);
393 }
394 }
395
396 None
397}
398
399fn parse_fields_impl<'input, 'state, P>(
400 state: &'state State<'input>,
401 is_valid_default_field_for_attr: P,
402) -> Result<ParsedFields<'input, 'state>>
403where
404 P: Fn(&str, &syn::Field, usize) -> bool,
405{
406 let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
407
408 let iter = fields
409 .iter()
410 .zip(infos.iter().map(|info| &info.info))
411 .enumerate()
412 .map(|(index, (field, info))| (index, *field, info));
413
414 let source = parse_field_impl(
415 &is_valid_default_field_for_attr,
416 state.fields.len(),
417 iter.clone(),
418 "source",
419 |info| info.source,
420 )?;
421
422 let backtrace = parse_field_impl(
423 &is_valid_default_field_for_attr,
424 state.fields.len(),
425 iter.clone(),
426 "backtrace",
427 |info| info.backtrace,
428 )?;
429
430 let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
431
432 if let Some((index, _, _)) = source {
433 parsed_fields.source = Some(index);
434 }
435
436 if let Some((index, _, _)) = backtrace {
437 parsed_fields.backtrace = Some(index);
438 }
439
440 Ok(parsed_fields)
441}
442
443fn parse_field_impl<'a, P, V>(
444 is_valid_default_field_for_attr: &P,
445 len: usize,
446 iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
447 attr: &str,
448 value: V,
449) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
450where
451 P: Fn(&str, &syn::Field, usize) -> bool,
452 V: Fn(&MetaInfo) -> Option<bool>,
453{
454 let explicit_fields = iter
455 .clone()
456 .filter(|(_, _, info)| matches!(value(info), Some(true)));
457
458 let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
459 None => is_valid_default_field_for_attr(attr, field, len),
460 _ => false,
461 });
462
463 let field = assert_iter_contains_zero_or_one_item(
464 explicit_fields,
465 &format!(
466 "Multiple `{attr}` attributes specified. \
467 Single attribute per struct/enum variant allowed.",
468 ),
469 )?;
470
471 let field = match field {
472 field @ Some(_) => field,
473 None => assert_iter_contains_zero_or_one_item(
474 inferred_fields,
475 "Conflicting fields found. Consider specifying some \
476 `#[error(...)]` attributes to resolve conflict.",
477 )?,
478 };
479
480 Ok(field)
481}
482
483fn assert_iter_contains_zero_or_one_item<'a>(
484 mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
485 error_msg: &str,
486) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
487 let Some(item) = iter.next() else {
488 return Ok(None);
489 };
490
491 if let Some((_, field, _)) = iter.next() {
492 return Err(Error::new(field.span(), error_msg));
493 }
494
495 Ok(Some(item))
496}
497
498fn add_bound_if_type_parameter_used_in_type(
499 bounds: &mut HashSet<syn::Type>,
500 type_params: &HashSet<syn::Ident>,
501 ty: &syn::Type,
502 unpack: bool,
503) {
504 if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
505 bounds.insert(
506 unpack
507 .then(|| ty.get_inner())
508 .flatten()
509 .cloned()
510 .unwrap_or(ty),
511 );
512 }
513}
514
515trait TypeExt {
517 fn is_option(&self) -> bool;
519
520 fn get_inner_if(&self, filter: impl Fn(&syn::Ident) -> bool) -> Option<&Self>;
524
525 fn get_inner(&self) -> Option<&Self> {
527 self.get_inner_if(|_| true)
528 }
529
530 fn get_option_inner(&self) -> Option<&Self> {
532 self.get_inner_if(|ident| ident == "Option")
533 }
534}
535
536impl TypeExt for syn::Type {
537 fn is_option(&self) -> bool {
538 self.get_option_inner().is_some()
539 }
540
541 fn get_inner_if(&self, filter: impl Fn(&syn::Ident) -> bool) -> Option<&Self> {
542 match self {
543 Self::Group(g) => g.elem.get_option_inner(),
544 Self::Paren(p) => p.elem.get_option_inner(),
545 Self::Path(p) => p
546 .path
547 .segments
548 .last()
549 .filter(|s| filter(&s.ident))
550 .and_then(|s| {
551 if let syn::PathArguments::AngleBracketed(a) = &s.arguments {
552 if let Some(syn::GenericArgument::Type(ty)) = a.args.first() {
553 return Some(ty);
554 }
555 }
556 None
557 }),
558 _ => None,
559 }
560 }
561}