Rework catch attribute

See catch attribute docs for the new syntax.
This commit is contained in:
Matthew Pomes 2024-07-02 00:25:33 -05:00
parent 1308c1903d
commit f8c8bb87e6
No known key found for this signature in database
GPG Key ID: B8C0D93B8D8FBDB7
6 changed files with 260 additions and 94 deletions

View File

@ -1,28 +1,64 @@
mod parse;
use devise::ext::SpanDiagnosticExt;
use devise::{Diagnostic, Level, Result, Spanned};
use devise::{Result, Spanned};
use proc_macro2::{TokenStream, Span};
use crate::http_codegen::Optional;
use crate::syn_ext::ReturnTypeExt;
use crate::syn_ext::{IdentExt, ReturnTypeExt};
use crate::exports::*;
fn error_arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
match arg {
syn::FnArg::Receiver(_) => Err(Diagnostic::spanned(
arg.span(),
Level::Error,
"Catcher cannot have self as a parameter",
)),
syn::FnArg::Typed(syn::PatType { ty, .. }) => match ty.as_ref() {
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(elem.as_ref()),
_ => Err(Diagnostic::spanned(
ty.span(),
Level::Error,
"Error type must be a reference",
)),
},
use self::parse::ErrorGuard;
use super::param::Guard;
fn error_type(guard: &ErrorGuard) -> TokenStream {
let ty = &guard.ty;
quote! {
(#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>())
}
}
fn error_guard_decl(guard: &ErrorGuard) -> TokenStream {
let (ident, ty) = (guard.ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: &#ty = match #_catcher::downcast(__error_init.as_ref()) {
Some(v) => v,
None => return #_Result::Err((#__status, __error_init)),
};
}
}
fn request_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromRequest>::from_request(#__req).await {
#Outcome::Success(__v) => __v,
#Outcome::Forward(__e) => {
::rocket::trace::info!(
name: "forward",
target: concat!("rocket::codegen::catch::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
status = __e.code,
"request guard forwarding; trying next catcher"
);
return #_Err((#__status, __error_init));
},
#[allow(unreachable_code)]
#Outcome::Error((__c, __e)) => {
::rocket::trace::info!(
name: "failure",
target: concat!("rocket::codegen::catch::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
reason = %#display_hack!(&__e),
"request guard failed; forwarding to 500 handler"
);
return #_Err((#Status::InternalServerError, __error_init));
}
};
}
}
@ -31,7 +67,7 @@ pub fn _catch(
input: proc_macro::TokenStream
) -> Result<TokenStream> {
// Parse and validate all of the user's input.
let catch = parse::Attribute::parse(args.into(), input)?;
let catch = parse::Attribute::parse(args.into(), input.into())?;
// Gather everything we'll need to generate the catcher.
let user_catcher_fn = &catch.function;
@ -40,48 +76,27 @@ pub fn _catch(
let status_code = Optional(catch.status.map(|s| s.code));
let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated"));
// Determine the number of parameters that will be passed in.
if catch.function.sig.inputs.len() > 3 {
return Err(catch.function.sig.paren_token.span.join()
.error("invalid number of arguments: must be zero, one, or two")
.help("catchers optionally take `&Request` or `Status, &Request`"));
}
// This ensures that "Responder not implemented" points to the return type.
let return_type_span = catch.function.sig.output.ty()
.map(|ty| ty.span())
.unwrap_or_else(Span::call_site);
// TODO: how to handle request?
// - Right now: (), (&Req), (Status, &Req) allowed
// Set the `req` and `status` spans to that of their respective function
// arguments for a more correct `wrong type` error span. `rev` to be cute.
let codegen_args = &[__req, __status, __error];
let inputs = catch.function.sig.inputs.iter().rev()
.zip(codegen_args.iter())
.map(|(fn_arg, codegen_arg)| match fn_arg {
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
}).rev();
let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 {
let arg = catch.function.sig.inputs.first().unwrap();
let ty = error_arg_ty(arg)?;
(quote_spanned!(arg.span() =>
let #__error: &#ty = match ::rocket::catcher::downcast(__error_init.as_ref()) {
Some(v) => v,
None => return #_Result::Err((#__status, __error_init)),
};
), quote! {Some((#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()))})
} else {
(quote! {}, quote! {None})
};
let status_guard = catch.status_guard.as_ref().map(|(_, s)| {
let ident = s.rocketized();
quote! { let #ident = #__status; }
});
let error_guard = catch.error_guard.as_ref().map(error_guard_decl);
let error_type = Optional(catch.error_guard.as_ref().map(error_type));
let request_guards = catch.request_guards.iter().map(request_guard_decl);
let parameter_names = catch.arguments.map.values()
.map(|(ident, _)| ident.rocketized());
// We append `.await` to the function call if this is `async`.
let dot_await = catch.function.sig.asyncness
.map(|a| quote_spanned!(a.span() => .await));
let catcher_response = quote_spanned!(return_type_span => {
let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await;
let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await;
#_response::Responder::respond_to(___responder, #__req).map_err(|s| (s, __error_init))?
});
@ -104,7 +119,9 @@ pub fn _catch(
__error_init: #ErasedError<'__r>,
) -> #_catcher::BoxFuture<'__r> {
#_Box::pin(async move {
#make_error
#error_guard
#status_guard
#(#request_guards)*
let __response = #catcher_response;
#_Result::Ok(
#Response::build()

View File

@ -1,8 +1,12 @@
use devise::ext::SpanDiagnosticExt;
use devise::ext::{SpanDiagnosticExt, TypeExt};
use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned};
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream, Ident};
use quote::ToTokens;
use crate::attribute::param::Dynamic;
use crate::attribute::param::{Dynamic, Guard};
use crate::name::{ArgumentMap, Arguments, Name};
use crate::proc_macro_ext::Diagnostics;
use crate::syn_ext::FnArgExt;
use crate::{http, http_codegen};
/// This structure represents the parsed `catch` attribute and associated items.
@ -11,7 +15,57 @@ pub struct Attribute {
pub status: Option<http::Status>,
/// The function that was decorated with the `catch` attribute.
pub function: syn::ItemFn,
pub error: Option<SpanWrapped<Dynamic>>,
pub arguments: Arguments,
pub error_guard: Option<ErrorGuard>,
pub status_guard: Option<(Name, syn::Ident)>,
pub request_guards: Vec<Guard>,
}
pub struct ErrorGuard {
pub span: Span,
pub name: Name,
pub ident: syn::Ident,
pub ty: syn::Type,
}
impl ErrorGuard {
fn new(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<Self> {
if let Some((ident, ty)) = args.map.get(&param.name) {
match ty {
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(Self {
span: param.span(),
name: param.name.clone(),
ident: ident.clone(),
ty: elem.as_ref().clone(),
}),
ty => {
let msg = format!(
"Error argument must be a reference, found `{}`",
ty.to_token_stream()
);
let diag = param.span()
.error("invalid type")
.span_note(ty.span(), msg)
.help(format!("Perhaps use `&{}` instead", ty.to_token_stream()));
Err(diag)
}
}
} else {
let msg = format!("expected argument named `{}` here", param.name);
let diag = param.span().error("unused parameter").span_note(args.span, msg);
Err(diag)
}
}
}
fn status_guard(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<(Name, Ident)> {
if let Some((ident, _)) = args.map.get(&param.name) {
Ok((param.name.clone(), ident.clone()))
} else {
let msg = format!("expected argument named `{}` here", param.name);
let diag = param.span().error("unused parameter").span_note(args.span, msg);
Err(diag)
}
}
/// We generate a full parser for the meta-item for great error messages.
@ -19,7 +73,8 @@ pub struct Attribute {
struct Meta {
#[meta(naked)]
code: Code,
// error: Option<SpanWrapped<Dynamic>>,
error: Option<SpanWrapped<Dynamic>>,
status: Option<SpanWrapped<Dynamic>>,
}
/// `Some` if there's a code, `None` if it's `default`.
@ -46,16 +101,66 @@ impl FromMeta for Code {
impl Attribute {
pub fn parse(args: TokenStream, input: proc_macro::TokenStream) -> Result<Self> {
let mut diags = Diagnostics::new();
let function: syn::ItemFn = syn::parse(input)
.map_err(Diagnostic::from)
.map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?;
let attr: MetaItem = syn::parse2(quote!(catch(#args)))?;
let status = Meta::from_meta(&attr)
let attr = Meta::from_meta(&attr)
.map(|meta| meta)
.map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \
`#[catch(404)]` or `#[catch(default)]`"))?;
Ok(Attribute { status: status.code.0, function, error: None })
let span = function.sig.paren_token.span.join();
let mut arguments = Arguments { map: ArgumentMap::new(), span };
for arg in function.sig.inputs.iter() {
if let Some((ident, ty)) = arg.typed() {
let value = (ident.clone(), ty.with_stripped_lifetimes());
arguments.map.insert(Name::from(ident), value);
} else {
let span = arg.span();
let diag = if arg.wild().is_some() {
span.error("handler arguments must be named")
.help("to name an ignored handler argument, use `_name`")
} else {
span.error("handler arguments must be of the form `ident: Type`")
};
diags.push(diag);
}
}
// let mut error_guard = None;
let error_guard = attr.error.clone()
.map(|p| ErrorGuard::new(p, &arguments))
.and_then(|p| p.map_err(|e| diags.push(e)).ok());
let status_guard = attr.status.clone()
.map(|n| status_guard(n, &arguments))
.and_then(|p| p.map_err(|e| diags.push(e)).ok());
let request_guards = arguments.map.iter()
.filter(|(name, _)| {
let mut all_other_guards = error_guard.iter()
.map(|g| &g.name)
.chain(status_guard.iter().map(|(n, _)| n));
all_other_guards.all(|n| n != *name)
})
.enumerate()
.map(|(index, (name, (ident, ty)))| Guard {
source: Dynamic { index, name: name.clone(), trailing: false },
fn_ident: ident.clone(),
ty: ty.clone(),
})
.collect();
diags.head_err_or(Attribute {
status: attr.code.0,
function,
arguments,
error_guard,
status_guard,
request_guards,
})
}
}

View File

@ -8,7 +8,7 @@ use crate::proc_macro_ext::Diagnostics;
use crate::http_codegen::{Method, MediaType};
use crate::attribute::param::{Parameter, Dynamic, Guard};
use crate::syn_ext::FnArgExt;
use crate::name::Name;
use crate::name::{ArgumentMap, Arguments, Name};
use crate::http::ext::IntoOwned;
use crate::http::uri::{Origin, fmt};
@ -31,14 +31,6 @@ pub struct Route {
pub arguments: Arguments,
}
type ArgumentMap = IndexMap<Name, (syn::Ident, syn::Type)>;
#[derive(Debug)]
pub struct Arguments {
pub span: Span,
pub map: ArgumentMap
}
/// The parsed `#[route(..)]` attribute.
#[derive(Debug, FromMeta)]
pub struct Attribute {

View File

@ -294,17 +294,16 @@ route_attribute!(options => Method::Options);
/// ```rust
/// # #[macro_use] extern crate rocket;
/// #
/// use rocket::Request;
/// use rocket::http::Status;
/// use rocket::http::{Status, uri::Origin};
///
/// #[catch(404)]
/// fn not_found(req: &Request) -> String {
/// format!("Sorry, {} does not exist.", req.uri())
/// fn not_found(uri: &Origin) -> String {
/// format!("Sorry, {} does not exist.", uri)
/// }
///
/// #[catch(default)]
/// fn default(status: Status, req: &Request) -> String {
/// format!("{} ({})", status, req.uri())
/// #[catch(default, status = "<status>")]
/// fn default(status: Status, uri: &Origin) -> String {
/// format!("{} ({})", status, uri)
/// }
/// ```
///
@ -313,19 +312,59 @@ route_attribute!(options => Method::Options);
/// The grammar for the `#[catch]` attributes is defined as:
///
/// ```text
/// catch := STATUS | 'default'
/// catch := STATUS | 'default' (',' parameter)*
///
/// STATUS := valid HTTP status code (integer in [200, 599])
/// parameter := 'rank' '=' INTEGER
/// | 'status' '=' '"' SINGLE_PARAM '"'
/// | 'error' '=' '"' SINGLE_PARAM '"'
/// SINGLE_PARAM := '<' IDENT '>'
/// ```
///
/// # Typing Requirements
///
/// The decorated function may take zero, one, or two arguments. It's type
/// signature must be one of the following, where `R:`[`Responder`]:
/// Every identifier, except for `_`, that appears in a dynamic parameter, must appear
/// as an argument to the function.
///
/// * `fn() -> R`
/// * `fn(`[`&Request`]`) -> R`
/// * `fn(`[`Status`]`, `[`&Request`]`) -> R`
/// The type of each function argument corresponding to a dynamic parameter is required to
/// meet specific requirements.
///
/// - `status`: Must be [`Status`].
/// - `error`: Must be a reference to a type that implements `Transient`. See
/// [Typed catchers](Self#Typed-catchers) for more info.
///
/// All other arguments must implement [`FromRequest`].
///
/// A route argument declared a `_` must not appear in the function argument list and has no typing requirements.
///
/// The return type of the decorated function must implement the [`Responder`] trait.
///
/// # Typed catchers
///
/// To make catchers more expressive and powerful, they can catch specific
/// error types. This is accomplished using the [`transient`] crate as a
/// replacement for [`std::any::Any`]. When a [`FromRequest`], [`FromParam`],
/// [`FromSegments`], [`FromForm`], or [`FromData`] implementation fails or
/// forwards, Rocket will convert to the error type to `dyn Any<Co<'r>>`, if the
/// error type implements `Transient`.
///
/// Only a single error type can be carried by a request - if a route forwards,
/// and another route is attempted, any error produced by the second route
/// overwrites the first.
///
/// ## Custom error types
///
/// All[^transient-impls] error types that Rocket itself produces implement
/// `Transient`, and can therefore be caught by a typed catcher. If you have
/// a custom guard of any type, you can implement `Transient` using the derive
/// macro provided by the `transient` crate. If the error type has lifetimes,
/// please read the documentation for the `Transient` derive macro - although it
/// prevents any unsafe implementation, it's not the easiest to use. Note that
/// Rocket upcasts the type to `dyn Any<Co<'r>>`, where `'r` is the lifetime of
/// the `Request`, so any `Transient` impl must be able to trancend to `Co<'r>`,
/// and desend from `Co<'r>` at the catcher.
///
/// [^transient-impls]: As of writing, this is a WIP.
///
/// # Semantics
///
@ -333,10 +372,12 @@ route_attribute!(options => Method::Options);
///
/// 1. An error [`Handler`].
///
/// The generated handler calls the decorated function, passing in the
/// [`Status`] and [`&Request`] values if requested. The returned value is
/// used to generate a [`Response`] via the type's [`Responder`]
/// implementation.
/// The generated handler validates and generates all arguments for the generated function according
/// to their specific requirements. The order in which arguments are processed is:
///
/// 1. The `error` type. This means no other guards will be evaluated if the error type does not match.
/// 2. Request guards, from left to right. If a Request guards forwards, the next catcher will be tried.
/// If the Request guard fails, the error is instead routed to the `500` catcher.
///
/// 2. A static structure used by [`catchers!`] to generate a [`Catcher`].
///
@ -351,6 +392,7 @@ route_attribute!(options => Method::Options);
/// [`Catcher`]: ../rocket/struct.Catcher.html
/// [`Response`]: ../rocket/struct.Response.html
/// [`Responder`]: ../rocket/response/trait.Responder.html
/// [`FromRequest`]: ../rocket/request/trait.FromRequest.html
#[proc_macro_attribute]
pub fn catch(args: TokenStream, input: TokenStream) -> TokenStream {
emit!(attribute::catch::catch_attribute(args, input))

View File

@ -1,8 +1,18 @@
use crate::http::uncased::UncasedStr;
use indexmap::IndexMap;
use syn::{Ident, ext::IdentExt};
use proc_macro2::{Span, TokenStream};
pub type ArgumentMap = IndexMap<Name, (syn::Ident, syn::Type)>;
#[derive(Debug)]
pub struct Arguments {
pub span: Span,
pub map: ArgumentMap
}
/// A "name" read by codegen, which may or may not be an identifier. A `Name` is
/// typically constructed indirectly via FromMeta, or From<Ident> or directly
/// from a string via `Name::new()`. A name is tokenized as a string.

View File

@ -2,9 +2,9 @@
#[cfg(test)] mod tests;
use rocket::{Rocket, Request, Build};
use rocket::{Rocket, Build};
use rocket::response::{content, status};
use rocket::http::Status;
use rocket::http::{Status, uri::Origin};
// Custom impl so I can implement Static (or Transient) ---
// We should upstream implementations for most common error types
@ -46,11 +46,11 @@ fn general_not_found() -> content::RawHtml<&'static str> {
}
#[catch(404)]
fn hello_not_found(req: &Request<'_>) -> content::RawHtml<String> {
fn hello_not_found(uri: &Origin<'_>) -> content::RawHtml<String> {
content::RawHtml(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
req.uri()))
uri))
}
// Demonstrates a downcast error from `hello`
@ -58,14 +58,14 @@ fn hello_not_found(req: &Request<'_>) -> content::RawHtml<String> {
// be present. I'm thinking about adding a param to the macro to indicate which (and whether)
// param is a downcast error.
// `error` and `status` type. All other params must be `FromRequest`?
// `error` and `status` type. All other params must be `FromOrigin`?
#[catch(422, error = "<e>" /*, status = "<_s>"*/)]
fn param_error(e: &IntErr, _s: Status, req: &Request<'_>) -> content::RawHtml<String> {
fn param_error(e: &IntErr, uri: &Origin<'_>) -> content::RawHtml<String> {
content::RawHtml(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>\
<p>Error: {e:?}</p>",
req.uri()))
uri))
}
#[catch(default)]
@ -73,9 +73,9 @@ fn sergio_error() -> &'static str {
"I...don't know what to say."
}
#[catch(default)]
fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom<String> {
let msg = format!("{} ({})", status, req.uri());
#[catch(default, status = "<status>")]
fn default_catcher(status: Status, uri: &Origin<'_>) -> status::Custom<String> {
let msg = format!("{} ({})", status, uri);
status::Custom(status, msg)
}