diff --git a/core/codegen/src/derive/uri_display.rs b/core/codegen/src/derive/uri_display.rs index 87e5ca5a..eaa00792 100644 --- a/core/codegen/src/derive/uri_display.rs +++ b/core/codegen/src/derive/uri_display.rs @@ -3,6 +3,7 @@ use devise::{*, ext::SpanDiagnosticExt}; use crate::exports::*; use crate::derive::form_field::{FieldExt, VariantExt}; +use crate::syn_ext::{GenericsExt as _, TypeExt as _}; use crate::http::uri::fmt; const NO_EMPTY_FIELDS: &str = "fieldless structs are not supported"; @@ -11,11 +12,28 @@ const NO_EMPTY_ENUMS: &str = "empty enums are not supported"; const ONLY_ONE_UNNAMED: &str = "tuple structs or variants must have exactly one field"; const EXACTLY_ONE_FIELD: &str = "struct must have exactly one field"; -pub fn derive_uri_display_query(input: proc_macro::TokenStream) -> TokenStream { - const URI_DISPLAY: StaticTokens = quote_static!(#_fmt::UriDisplay<#_fmt::Query>); - const FORMATTER: StaticTokens = quote_static!(#_fmt::Formatter<#_fmt::Query>); +const Q_URI_DISPLAY: StaticTokens = quote_static!(#_fmt::UriDisplay<#_fmt::Query>); +const Q_FORMATTER: StaticTokens = quote_static!(#_fmt::Formatter<#_fmt::Query>); - let uri_display = DeriveGenerator::build_for(input.clone(), quote!(impl #URI_DISPLAY)) +const P_URI_DISPLAY: StaticTokens = quote_static!(#_fmt::UriDisplay<#_fmt::Path>); +const P_FORMATTER: StaticTokens = quote_static!(#_fmt::Formatter<#_fmt::Path>); + +fn generic_bounds_mapper(bound: StaticTokens) -> MapperBuild { + MapperBuild::new() + .try_enum_map(|m, e| mapper::enum_null(m, e)) + .try_fields_map(move |_, fields| { + let generic_idents = fields.parent.input().generics().type_idents(); + + let bounds = fields.iter() + .filter(|f| !f.ty.is_concrete(&generic_idents)) + .map(move |ty| quote_spanned!(ty.span() => #ty: #bound)); + + Ok(quote!(#(#bounds,)*)) + }) +} + +pub fn derive_uri_display_query(input: proc_macro::TokenStream) -> TokenStream { + let uri_display = DeriveGenerator::build_for(input.clone(), quote!(impl #Q_URI_DISPLAY)) .support(Support::Struct | Support::Enum | Support::Type | Support::Lifetime) .validator(ValidatorBuild::new() .enum_validate(|_, data| { @@ -43,10 +61,10 @@ pub fn derive_uri_display_query(input: proc_macro::TokenStream) -> TokenStream { } }) ) - .type_bound(URI_DISPLAY) + .type_bound_mapper(generic_bounds_mapper(Q_URI_DISPLAY)) .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { - fn fmt(&self, f: &mut #FORMATTER) -> ::std::fmt::Result { + fn fmt(&self, f: &mut #Q_FORMATTER) -> ::std::fmt::Result { #output Ok(()) } @@ -93,12 +111,9 @@ pub fn derive_uri_display_query(input: proc_macro::TokenStream) -> TokenStream { #[allow(non_snake_case)] pub fn derive_uri_display_path(input: proc_macro::TokenStream) -> TokenStream { - const URI_DISPLAY: StaticTokens = quote_static!(#_fmt::UriDisplay<#_fmt::Path>); - const FORMATTER: StaticTokens = quote_static!(#_fmt::Formatter<#_fmt::Path>); - - let uri_display = DeriveGenerator::build_for(input.clone(), quote!(impl #URI_DISPLAY)) + let uri_display = DeriveGenerator::build_for(input.clone(), quote!(impl #P_URI_DISPLAY)) .support(Support::TupleStruct | Support::Type | Support::Lifetime) - .type_bound(URI_DISPLAY) + .type_bound_mapper(generic_bounds_mapper(P_URI_DISPLAY)) .validator(ValidatorBuild::new() .fields_validate(|_, fields| match fields.count() { 1 => Ok(()), @@ -107,7 +122,7 @@ pub fn derive_uri_display_path(input: proc_macro::TokenStream) -> TokenStream { ) .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { - fn fmt(&self, f: &mut #FORMATTER) -> ::std::fmt::Result { + fn fmt(&self, f: &mut #P_FORMATTER) -> ::std::fmt::Result { #output Ok(()) } @@ -141,6 +156,11 @@ fn from_uri_param(input: proc_macro::TokenStream, ty: TokenStream) fmt::Kind::Query => quote!(#_fmt::Query), }; + let display_trait = match P::KIND { + fmt::Kind::Path => P_URI_DISPLAY, + fmt::Kind::Query => Q_URI_DISPLAY, + }; + let ty: syn::Type = syn::parse2(ty).expect("valid type"); let gen = match ty { syn::Type::Reference(ref r) => r.lifetime.as_ref().map(|l| quote!(<#l>)), @@ -150,7 +170,7 @@ fn from_uri_param(input: proc_macro::TokenStream, ty: TokenStream) let param_trait = quote!(impl #gen #_fmt::FromUriParam<#part, #ty>); DeriveGenerator::build_for(input, param_trait) .support(Support::All) - .type_bound(quote!(#_fmt::UriDisplay<#part>)) + .type_bound_mapper(generic_bounds_mapper(display_trait)) .inner_mapper(MapperBuild::new() .with_output(move |_, _| quote! { type Target = #ty;