From f8c8bb87e64c1247f4d875cf955a98be37ce6933 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Tue, 2 Jul 2024 00:25:33 -0500 Subject: [PATCH] Rework catch attribute See catch attribute docs for the new syntax. --- core/codegen/src/attribute/catch/mod.rs | 119 ++++++++++++---------- core/codegen/src/attribute/catch/parse.rs | 119 ++++++++++++++++++++-- core/codegen/src/attribute/route/parse.rs | 10 +- core/codegen/src/lib.rs | 76 ++++++++++---- core/codegen/src/name.rs | 10 ++ examples/error-handling/src/main.rs | 20 ++-- 6 files changed, 260 insertions(+), 94 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 03d28686..e768e073 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -1,28 +1,64 @@ mod parse; -use devise::ext::SpanDiagnosticExt; -use devise::{Diagnostic, Level, Result, Spanned}; +use devise::{Result, Spanned}; use proc_macro2::{TokenStream, Span}; use crate::http_codegen::Optional; -use crate::syn_ext::ReturnTypeExt; +use crate::syn_ext::{IdentExt, ReturnTypeExt}; use crate::exports::*; -fn error_arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> { - match arg { - syn::FnArg::Receiver(_) => Err(Diagnostic::spanned( - arg.span(), - Level::Error, - "Catcher cannot have self as a parameter", - )), - syn::FnArg::Typed(syn::PatType { ty, .. }) => match ty.as_ref() { - syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(elem.as_ref()), - _ => Err(Diagnostic::spanned( - ty.span(), - Level::Error, - "Error type must be a reference", - )), - }, +use self::parse::ErrorGuard; + +use super::param::Guard; + +fn error_type(guard: &ErrorGuard) -> TokenStream { + let ty = &guard.ty; + quote! { + (#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()) + } +} + +fn error_guard_decl(guard: &ErrorGuard) -> TokenStream { + let (ident, ty) = (guard.ident.rocketized(), &guard.ty); + quote_spanned! { ty.span() => + let #ident: &#ty = match #_catcher::downcast(__error_init.as_ref()) { + Some(v) => v, + None => return #_Result::Err((#__status, __error_init)), + }; + } +} + +fn request_guard_decl(guard: &Guard) -> TokenStream { + let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); + quote_spanned! { ty.span() => + let #ident: #ty = match <#ty as #FromRequest>::from_request(#__req).await { + #Outcome::Success(__v) => __v, + #Outcome::Forward(__e) => { + ::rocket::trace::info!( + name: "forward", + target: concat!("rocket::codegen::catch::", module_path!()), + parameter = stringify!(#ident), + type_name = stringify!(#ty), + status = __e.code, + "request guard forwarding; trying next catcher" + ); + + return #_Err((#__status, __error_init)); + }, + #[allow(unreachable_code)] + #Outcome::Error((__c, __e)) => { + ::rocket::trace::info!( + name: "failure", + target: concat!("rocket::codegen::catch::", module_path!()), + parameter = stringify!(#ident), + type_name = stringify!(#ty), + reason = %#display_hack!(&__e), + "request guard failed; forwarding to 500 handler" + ); + + return #_Err((#Status::InternalServerError, __error_init)); + } + }; } } @@ -31,7 +67,7 @@ pub fn _catch( input: proc_macro::TokenStream ) -> Result { // Parse and validate all of the user's input. - let catch = parse::Attribute::parse(args.into(), input)?; + let catch = parse::Attribute::parse(args.into(), input.into())?; // Gather everything we'll need to generate the catcher. let user_catcher_fn = &catch.function; @@ -40,48 +76,27 @@ pub fn _catch( let status_code = Optional(catch.status.map(|s| s.code)); let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated")); - // Determine the number of parameters that will be passed in. - if catch.function.sig.inputs.len() > 3 { - return Err(catch.function.sig.paren_token.span.join() - .error("invalid number of arguments: must be zero, one, or two") - .help("catchers optionally take `&Request` or `Status, &Request`")); - } - // This ensures that "Responder not implemented" points to the return type. let return_type_span = catch.function.sig.output.ty() .map(|ty| ty.span()) .unwrap_or_else(Span::call_site); - // TODO: how to handle request? - // - Right now: (), (&Req), (Status, &Req) allowed - // Set the `req` and `status` spans to that of their respective function - // arguments for a more correct `wrong type` error span. `rev` to be cute. - let codegen_args = &[__req, __status, __error]; - let inputs = catch.function.sig.inputs.iter().rev() - .zip(codegen_args.iter()) - .map(|(fn_arg, codegen_arg)| match fn_arg { - syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()), - syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span()) - }).rev(); - let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 { - let arg = catch.function.sig.inputs.first().unwrap(); - let ty = error_arg_ty(arg)?; - (quote_spanned!(arg.span() => - let #__error: &#ty = match ::rocket::catcher::downcast(__error_init.as_ref()) { - Some(v) => v, - None => return #_Result::Err((#__status, __error_init)), - }; - ), quote! {Some((#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()))}) - } else { - (quote! {}, quote! {None}) - }; + let status_guard = catch.status_guard.as_ref().map(|(_, s)| { + let ident = s.rocketized(); + quote! { let #ident = #__status; } + }); + let error_guard = catch.error_guard.as_ref().map(error_guard_decl); + let error_type = Optional(catch.error_guard.as_ref().map(error_type)); + let request_guards = catch.request_guards.iter().map(request_guard_decl); + let parameter_names = catch.arguments.map.values() + .map(|(ident, _)| ident.rocketized()); // We append `.await` to the function call if this is `async`. let dot_await = catch.function.sig.asyncness .map(|a| quote_spanned!(a.span() => .await)); let catcher_response = quote_spanned!(return_type_span => { - let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await; + let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await; #_response::Responder::respond_to(___responder, #__req).map_err(|s| (s, __error_init))? }); @@ -104,7 +119,9 @@ pub fn _catch( __error_init: #ErasedError<'__r>, ) -> #_catcher::BoxFuture<'__r> { #_Box::pin(async move { - #make_error + #error_guard + #status_guard + #(#request_guards)* let __response = #catcher_response; #_Result::Ok( #Response::build() diff --git a/core/codegen/src/attribute/catch/parse.rs b/core/codegen/src/attribute/catch/parse.rs index e7a3842a..ddc296c3 100644 --- a/core/codegen/src/attribute/catch/parse.rs +++ b/core/codegen/src/attribute/catch/parse.rs @@ -1,8 +1,12 @@ -use devise::ext::SpanDiagnosticExt; +use devise::ext::{SpanDiagnosticExt, TypeExt}; use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned}; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream, Ident}; +use quote::ToTokens; -use crate::attribute::param::Dynamic; +use crate::attribute::param::{Dynamic, Guard}; +use crate::name::{ArgumentMap, Arguments, Name}; +use crate::proc_macro_ext::Diagnostics; +use crate::syn_ext::FnArgExt; use crate::{http, http_codegen}; /// This structure represents the parsed `catch` attribute and associated items. @@ -11,7 +15,57 @@ pub struct Attribute { pub status: Option, /// The function that was decorated with the `catch` attribute. pub function: syn::ItemFn, - pub error: Option>, + pub arguments: Arguments, + pub error_guard: Option, + pub status_guard: Option<(Name, syn::Ident)>, + pub request_guards: Vec, +} + +pub struct ErrorGuard { + pub span: Span, + pub name: Name, + pub ident: syn::Ident, + pub ty: syn::Type, +} + +impl ErrorGuard { + fn new(param: SpanWrapped, args: &Arguments) -> Result { + if let Some((ident, ty)) = args.map.get(¶m.name) { + match ty { + syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(Self { + span: param.span(), + name: param.name.clone(), + ident: ident.clone(), + ty: elem.as_ref().clone(), + }), + ty => { + let msg = format!( + "Error argument must be a reference, found `{}`", + ty.to_token_stream() + ); + let diag = param.span() + .error("invalid type") + .span_note(ty.span(), msg) + .help(format!("Perhaps use `&{}` instead", ty.to_token_stream())); + Err(diag) + } + } + } else { + let msg = format!("expected argument named `{}` here", param.name); + let diag = param.span().error("unused parameter").span_note(args.span, msg); + Err(diag) + } + } +} + +fn status_guard(param: SpanWrapped, args: &Arguments) -> Result<(Name, Ident)> { + if let Some((ident, _)) = args.map.get(¶m.name) { + Ok((param.name.clone(), ident.clone())) + } else { + let msg = format!("expected argument named `{}` here", param.name); + let diag = param.span().error("unused parameter").span_note(args.span, msg); + Err(diag) + } } /// We generate a full parser for the meta-item for great error messages. @@ -19,7 +73,8 @@ pub struct Attribute { struct Meta { #[meta(naked)] code: Code, - // error: Option>, + error: Option>, + status: Option>, } /// `Some` if there's a code, `None` if it's `default`. @@ -46,16 +101,66 @@ impl FromMeta for Code { impl Attribute { pub fn parse(args: TokenStream, input: proc_macro::TokenStream) -> Result { + let mut diags = Diagnostics::new(); + let function: syn::ItemFn = syn::parse(input) .map_err(Diagnostic::from) .map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?; let attr: MetaItem = syn::parse2(quote!(catch(#args)))?; - let status = Meta::from_meta(&attr) + let attr = Meta::from_meta(&attr) .map(|meta| meta) .map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \ `#[catch(404)]` or `#[catch(default)]`"))?; - Ok(Attribute { status: status.code.0, function, error: None }) + let span = function.sig.paren_token.span.join(); + let mut arguments = Arguments { map: ArgumentMap::new(), span }; + for arg in function.sig.inputs.iter() { + if let Some((ident, ty)) = arg.typed() { + let value = (ident.clone(), ty.with_stripped_lifetimes()); + arguments.map.insert(Name::from(ident), value); + } else { + let span = arg.span(); + let diag = if arg.wild().is_some() { + span.error("handler arguments must be named") + .help("to name an ignored handler argument, use `_name`") + } else { + span.error("handler arguments must be of the form `ident: Type`") + }; + + diags.push(diag); + } + } + // let mut error_guard = None; + let error_guard = attr.error.clone() + .map(|p| ErrorGuard::new(p, &arguments)) + .and_then(|p| p.map_err(|e| diags.push(e)).ok()); + let status_guard = attr.status.clone() + .map(|n| status_guard(n, &arguments)) + .and_then(|p| p.map_err(|e| diags.push(e)).ok()); + let request_guards = arguments.map.iter() + .filter(|(name, _)| { + let mut all_other_guards = error_guard.iter() + .map(|g| &g.name) + .chain(status_guard.iter().map(|(n, _)| n)); + + all_other_guards.all(|n| n != *name) + }) + .enumerate() + .map(|(index, (name, (ident, ty)))| Guard { + source: Dynamic { index, name: name.clone(), trailing: false }, + fn_ident: ident.clone(), + ty: ty.clone(), + }) + .collect(); + + diags.head_err_or(Attribute { + status: attr.code.0, + function, + arguments, + error_guard, + status_guard, + request_guards, + }) } } diff --git a/core/codegen/src/attribute/route/parse.rs b/core/codegen/src/attribute/route/parse.rs index 13f3b93d..4df278a5 100644 --- a/core/codegen/src/attribute/route/parse.rs +++ b/core/codegen/src/attribute/route/parse.rs @@ -8,7 +8,7 @@ use crate::proc_macro_ext::Diagnostics; use crate::http_codegen::{Method, MediaType}; use crate::attribute::param::{Parameter, Dynamic, Guard}; use crate::syn_ext::FnArgExt; -use crate::name::Name; +use crate::name::{ArgumentMap, Arguments, Name}; use crate::http::ext::IntoOwned; use crate::http::uri::{Origin, fmt}; @@ -31,14 +31,6 @@ pub struct Route { pub arguments: Arguments, } -type ArgumentMap = IndexMap; - -#[derive(Debug)] -pub struct Arguments { - pub span: Span, - pub map: ArgumentMap -} - /// The parsed `#[route(..)]` attribute. #[derive(Debug, FromMeta)] pub struct Attribute { diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 39401f1c..8bd0f699 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -294,17 +294,16 @@ route_attribute!(options => Method::Options); /// ```rust /// # #[macro_use] extern crate rocket; /// # -/// use rocket::Request; -/// use rocket::http::Status; +/// use rocket::http::{Status, uri::Origin}; /// /// #[catch(404)] -/// fn not_found(req: &Request) -> String { -/// format!("Sorry, {} does not exist.", req.uri()) +/// fn not_found(uri: &Origin) -> String { +/// format!("Sorry, {} does not exist.", uri) /// } /// -/// #[catch(default)] -/// fn default(status: Status, req: &Request) -> String { -/// format!("{} ({})", status, req.uri()) +/// #[catch(default, status = "")] +/// fn default(status: Status, uri: &Origin) -> String { +/// format!("{} ({})", status, uri) /// } /// ``` /// @@ -313,19 +312,59 @@ route_attribute!(options => Method::Options); /// The grammar for the `#[catch]` attributes is defined as: /// /// ```text -/// catch := STATUS | 'default' +/// catch := STATUS | 'default' (',' parameter)* /// /// STATUS := valid HTTP status code (integer in [200, 599]) +/// parameter := 'rank' '=' INTEGER +/// | 'status' '=' '"' SINGLE_PARAM '"' +/// | 'error' '=' '"' SINGLE_PARAM '"' +/// SINGLE_PARAM := '<' IDENT '>' /// ``` /// /// # Typing Requirements /// -/// The decorated function may take zero, one, or two arguments. It's type -/// signature must be one of the following, where `R:`[`Responder`]: +/// Every identifier, except for `_`, that appears in a dynamic parameter, must appear +/// as an argument to the function. /// -/// * `fn() -> R` -/// * `fn(`[`&Request`]`) -> R` -/// * `fn(`[`Status`]`, `[`&Request`]`) -> R` +/// The type of each function argument corresponding to a dynamic parameter is required to +/// meet specific requirements. +/// +/// - `status`: Must be [`Status`]. +/// - `error`: Must be a reference to a type that implements `Transient`. See +/// [Typed catchers](Self#Typed-catchers) for more info. +/// +/// All other arguments must implement [`FromRequest`]. +/// +/// A route argument declared a `_` must not appear in the function argument list and has no typing requirements. +/// +/// The return type of the decorated function must implement the [`Responder`] trait. +/// +/// # Typed catchers +/// +/// To make catchers more expressive and powerful, they can catch specific +/// error types. This is accomplished using the [`transient`] crate as a +/// replacement for [`std::any::Any`]. When a [`FromRequest`], [`FromParam`], +/// [`FromSegments`], [`FromForm`], or [`FromData`] implementation fails or +/// forwards, Rocket will convert to the error type to `dyn Any>`, if the +/// error type implements `Transient`. +/// +/// Only a single error type can be carried by a request - if a route forwards, +/// and another route is attempted, any error produced by the second route +/// overwrites the first. +/// +/// ## Custom error types +/// +/// All[^transient-impls] error types that Rocket itself produces implement +/// `Transient`, and can therefore be caught by a typed catcher. If you have +/// a custom guard of any type, you can implement `Transient` using the derive +/// macro provided by the `transient` crate. If the error type has lifetimes, +/// please read the documentation for the `Transient` derive macro - although it +/// prevents any unsafe implementation, it's not the easiest to use. Note that +/// Rocket upcasts the type to `dyn Any>`, where `'r` is the lifetime of +/// the `Request`, so any `Transient` impl must be able to trancend to `Co<'r>`, +/// and desend from `Co<'r>` at the catcher. +/// +/// [^transient-impls]: As of writing, this is a WIP. /// /// # Semantics /// @@ -333,10 +372,12 @@ route_attribute!(options => Method::Options); /// /// 1. An error [`Handler`]. /// -/// The generated handler calls the decorated function, passing in the -/// [`Status`] and [`&Request`] values if requested. The returned value is -/// used to generate a [`Response`] via the type's [`Responder`] -/// implementation. +/// The generated handler validates and generates all arguments for the generated function according +/// to their specific requirements. The order in which arguments are processed is: +/// +/// 1. The `error` type. This means no other guards will be evaluated if the error type does not match. +/// 2. Request guards, from left to right. If a Request guards forwards, the next catcher will be tried. +/// If the Request guard fails, the error is instead routed to the `500` catcher. /// /// 2. A static structure used by [`catchers!`] to generate a [`Catcher`]. /// @@ -351,6 +392,7 @@ route_attribute!(options => Method::Options); /// [`Catcher`]: ../rocket/struct.Catcher.html /// [`Response`]: ../rocket/struct.Response.html /// [`Responder`]: ../rocket/response/trait.Responder.html +/// [`FromRequest`]: ../rocket/request/trait.FromRequest.html #[proc_macro_attribute] pub fn catch(args: TokenStream, input: TokenStream) -> TokenStream { emit!(attribute::catch::catch_attribute(args, input)) diff --git a/core/codegen/src/name.rs b/core/codegen/src/name.rs index c5b8e2b1..bed617aa 100644 --- a/core/codegen/src/name.rs +++ b/core/codegen/src/name.rs @@ -1,8 +1,18 @@ use crate::http::uncased::UncasedStr; +use indexmap::IndexMap; use syn::{Ident, ext::IdentExt}; use proc_macro2::{Span, TokenStream}; +pub type ArgumentMap = IndexMap; + +#[derive(Debug)] +pub struct Arguments { + pub span: Span, + pub map: ArgumentMap +} + + /// A "name" read by codegen, which may or may not be an identifier. A `Name` is /// typically constructed indirectly via FromMeta, or From or directly /// from a string via `Name::new()`. A name is tokenized as a string. diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index e4e03421..cdde474b 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -2,9 +2,9 @@ #[cfg(test)] mod tests; -use rocket::{Rocket, Request, Build}; +use rocket::{Rocket, Build}; use rocket::response::{content, status}; -use rocket::http::Status; +use rocket::http::{Status, uri::Origin}; // Custom impl so I can implement Static (or Transient) --- // We should upstream implementations for most common error types @@ -46,11 +46,11 @@ fn general_not_found() -> content::RawHtml<&'static str> { } #[catch(404)] -fn hello_not_found(req: &Request<'_>) -> content::RawHtml { +fn hello_not_found(uri: &Origin<'_>) -> content::RawHtml { content::RawHtml(format!("\

Sorry, but '{}' is not a valid path!

\

Try visiting /hello/<name>/<age> instead.

", - req.uri())) + uri)) } // Demonstrates a downcast error from `hello` @@ -58,14 +58,14 @@ fn hello_not_found(req: &Request<'_>) -> content::RawHtml { // be present. I'm thinking about adding a param to the macro to indicate which (and whether) // param is a downcast error. -// `error` and `status` type. All other params must be `FromRequest`? +// `error` and `status` type. All other params must be `FromOrigin`? #[catch(422, error = "" /*, status = "<_s>"*/)] -fn param_error(e: &IntErr, _s: Status, req: &Request<'_>) -> content::RawHtml { +fn param_error(e: &IntErr, uri: &Origin<'_>) -> content::RawHtml { content::RawHtml(format!("\

Sorry, but '{}' is not a valid path!

\

Try visiting /hello/<name>/<age> instead.

\

Error: {e:?}

", - req.uri())) + uri)) } #[catch(default)] @@ -73,9 +73,9 @@ fn sergio_error() -> &'static str { "I...don't know what to say." } -#[catch(default)] -fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom { - let msg = format!("{} ({})", status, req.uri()); +#[catch(default, status = "")] +fn default_catcher(status: Status, uri: &Origin<'_>) -> status::Custom { + let msg = format!("{} ({})", status, uri); status::Custom(status, msg) }