Major improvements

- Catchers now carry `TypeId` and type name for collision detection
- Transient updated to 0.3, with new derive macro
- Added `Transient` or `Static` implementations for error types
- CI should now pass
This commit is contained in:
Matthew Pomes 2024-06-29 18:41:44 -05:00
parent 99e210928d
commit 09c56c79c7
No known key found for this signature in database
GPG Key ID: B8C0D93B8D8FBDB7
17 changed files with 157 additions and 51 deletions

View File

@ -1,13 +1,24 @@
mod parse;
use devise::ext::SpanDiagnosticExt;
use devise::{Spanned, Result};
use devise::{Diagnostic, Level, Result, Spanned};
use proc_macro2::{TokenStream, Span};
use crate::http_codegen::Optional;
use crate::syn_ext::ReturnTypeExt;
use crate::exports::*;
fn arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
match arg {
syn::FnArg::Receiver(_) => Err(Diagnostic::spanned(
arg.span(),
Level::Error,
"Catcher cannot have self as a parameter"
)),
syn::FnArg::Typed(syn::PatType {ty, ..})=> Ok(ty.as_ref()),
}
}
pub fn _catch(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
@ -45,16 +56,17 @@ pub fn _catch(
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
}).rev();
let make_error = if catch.function.sig.inputs.len() >= 3 {
let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 {
let arg = catch.function.sig.inputs.first().unwrap();
quote_spanned!(arg.span() =>
let ty = arg_ty(arg)?;
(quote_spanned!(arg.span() =>
let #__error = match ::rocket::catcher::downcast(__error_init.as_ref()) {
Some(v) => v,
None => return #_Result::Err((#__status, __error_init)),
};
)
), quote! {Some((#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()))})
} else {
quote! {}
(quote! {}, quote! {None})
};
// We append `.await` to the function call if this is `async`.
@ -99,6 +111,7 @@ pub fn _catch(
#_catcher::StaticInfo {
name: ::core::stringify!(#user_catcher_fn_name),
code: #status_code,
error_type: #error_type,
handler: monomorphized_function,
location: (::core::file!(), ::core::line!(), ::core::column!()),
}

View File

@ -56,6 +56,6 @@ impl Attribute {
.map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \
`#[catch(404)]` or `#[catch(default)]`"))?;
Ok(Attribute { status: status.code.0, function, error: status.error })
Ok(Attribute { status: status.code.0, function, error: None })
}
}

View File

@ -115,11 +115,15 @@ fn query_decls(route: &Route) -> Option<TokenStream> {
);
::rocket::trace::info!(
target: concat!("rocket::codegen::route::", module_path!()),
error_type = ::std::any::type_name_of_val(&__error),
error_type = ::std::any::type_name_of_val(&__e),
"Forwarding error"
);
return #Outcome::Forward((#__data, #Status::UnprocessableEntity, #resolve_error!(__e)));
return #Outcome::Forward((
#__data,
#Status::UnprocessableEntity,
#resolve_error!(__e)
));
}
(#(#ident.unwrap()),*)
@ -207,7 +211,11 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {
#i
);
return #Outcome::Forward((#__data, #Status::InternalServerError, #resolve_error!()));
return #Outcome::Forward((
#__data,
#Status::InternalServerError,
#resolve_error!()
));
}
}
},
@ -226,7 +234,8 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {
fn data_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome, resolve_error);
define_spanned_export!(ty.span() =>
__req, __data, display_hack, FromData, Outcome, resolve_error);
quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await {
@ -251,7 +260,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream {
parameter = stringify!(#ident),
type_name = stringify!(#ty),
reason = %#display_hack!(&__e),
error_type = ::std::any::type_name_of_val(&__error),
error_type = ::std::any::type_name_of_val(&__e),
"data guard failed"
);

View File

@ -58,3 +58,23 @@ fn test_status_param() {
assert_eq!(response.into_string().unwrap(), code.to_string());
}
}
#[catch(404)]
fn bad_req_untyped(_: Status, _: &Request<'_>) -> &'static str { "404" }
#[catch(404)]
fn bad_req_string(_: &String, _: Status, _: &Request<'_>) -> &'static str { "404 String" }
#[catch(404)]
fn bad_req_tuple(_: &(), _: Status, _: &Request<'_>) -> &'static str { "404 ()" }
#[test]
fn test_typed_catchers() {
fn rocket() -> Rocket<Build> {
rocket::build()
.register("/", catchers![bad_req_untyped, bad_req_string, bad_req_tuple])
}
// Assert the catchers do not collide. They are only differentiated by their error type.
let client = Client::debug(rocket()).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.status(), Status::NotFound);
}

View File

@ -36,7 +36,7 @@ memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.18", features = ["percent-encode"] }
state = "0.6"
transient = { version = "0.2.1" }
transient = { version = "0.3" }
[dependencies.serde]
version = "1.0"

View File

@ -74,7 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
cookie = { version = "0.18", features = ["percent-encode"] }
futures = { version = "0.3.30", default-features = false, features = ["std"] }
state = "0.6"
transient = { version = "0.2.1" }
transient = { version = "0.3" }
# tracing
tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] }

View File

@ -1,6 +1,8 @@
use std::fmt;
use std::io::Cursor;
use transient::TypeId;
use crate::http::uri::Path;
use crate::http::ext::IntoOwned;
use crate::response::Response;
@ -122,6 +124,9 @@ pub struct Catcher {
/// The catcher's associated error handler.
pub handler: Box<dyn Handler>,
/// Catcher error type
pub(crate) error_type: Option<(TypeId, &'static str)>,
/// The mount point.
pub(crate) base: uri::Origin<'static>,
@ -134,10 +139,11 @@ pub struct Catcher {
pub(crate) location: Option<(&'static str, u32, u32)>,
}
// The rank is computed as -(number of nonempty segments in base) => catchers
// The rank is computed as -(number of nonempty segments in base) *2 => catchers
// with more nonempty segments have lower ranks => higher precedence.
// Doubled to provide space between for typed catchers.
fn rank(base: Path<'_>) -> isize {
-(base.segments().filter(|s| !s.is_empty()).count() as isize)
-(base.segments().filter(|s| !s.is_empty()).count() as isize) * 2
}
impl Catcher {
@ -149,22 +155,26 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> BoxFuture<'r>
/// {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) })
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: ErasedError<'r>) -> BoxFuture<'r> {
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> BoxFuture<'r>
/// {
/// let res = (status, format!("{}: {}", status, req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// let not_found_catcher = Catcher::new(404, handle_404);
@ -189,6 +199,7 @@ impl Catcher {
name: None,
base: uri::Origin::root().clone(),
handler: Box::new(handler),
error_type: None,
rank: rank(uri::Origin::root().path()),
code,
location: None,
@ -201,13 +212,15 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> BoxFuture<'r>
/// {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// let catcher = Catcher::new(404, handle_404);
@ -227,14 +240,16 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
/// use rocket::response::Responder;
/// use rocket::http::Status;
/// # use rocket::uri;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> BoxFuture<'r>
/// {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// let catcher = Catcher::new(404, handle_404);
@ -281,13 +296,15 @@ impl Catcher {
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> BoxFuture<'r>
/// {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
/// }
///
/// let catcher = Catcher::new(404, handle_404);
@ -332,6 +349,8 @@ pub struct StaticInfo {
pub name: &'static str,
/// The catcher's status code.
pub code: Option<u16>,
/// The catcher's error type.
pub error_type: Option<(TypeId, &'static str)>,
/// The catcher's handler, i.e, the annotated function.
pub handler: for<'r> fn(Status, &'r Request<'_>, ErasedError<'r>) -> BoxFuture<'r>,
/// The file, line, and column where the catcher was defined.
@ -343,7 +362,13 @@ impl From<StaticInfo> for Catcher {
#[inline]
fn from(info: StaticInfo) -> Catcher {
let mut catcher = Catcher::new(info.code, info.handler);
if info.error_type.is_some() {
// Lower rank if the error_type is defined, to ensure typed catchers
// are always tried first
catcher.rank -= 1;
}
catcher.name = Some(info.name.into());
catcher.error_type = info.error_type;
catcher.location = Some(info.location);
catcher
}
@ -354,6 +379,7 @@ impl fmt::Debug for Catcher {
f.debug_struct("Catcher")
.field("name", &self.name)
.field("base", &self.base)
.field("error_type", &self.error_type.as_ref().map(|(_, n)| n))
.field("code", &self.code)
.field("rank", &self.rank)
.finish()

View File

@ -31,7 +31,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
/// and used as follows:
///
/// ```rust,no_run
/// use rocket::{Request, Catcher, catcher::{self, ErasedErrorRef}};
/// use rocket::{Request, Catcher, catcher::{self, ErasedError}};
/// use rocket::response::{Response, Responder};
/// use rocket::http::Status;
///
@ -47,11 +47,13 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
///
/// #[rocket::async_trait]
/// impl catcher::Handler for CustomHandler {
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> catcher::Result<'r> {
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
/// -> catcher::Result<'r>
/// {
/// let inner = match self.0 {
/// Kind::Simple => "simple".respond_to(req)?,
/// Kind::Intermediate => "intermediate".respond_to(req)?,
/// Kind::Complex => "complex".respond_to(req)?,
/// Kind::Simple => "simple".respond_to(req).map_err(|e| (e, _e))?,
/// Kind::Intermediate => "intermediate".respond_to(req).map_err(|e| (e, _e))?,
/// Kind::Complex => "complex".respond_to(req).map_err(|e| (e, _e))?,
/// };
///
/// Response::build_from(inner).status(status).ok()
@ -99,7 +101,8 @@ pub trait Handler: Cloneable + Send + Sync + 'static {
/// Nevertheless, failure is allowed, both for convenience and necessity. If
/// an error handler fails, Rocket's default `500` catcher is invoked. If it
/// succeeds, the returned `Response` is used to respond to the client.
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: ErasedError<'r>) -> Result<'r>;
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: ErasedError<'r>)
-> Result<'r>;
}
// We write this manually to avoid double-boxing.

View File

@ -1,6 +1,6 @@
use transient::{Any, CanRecoverFrom, Co, Downcast};
#[doc(inline)]
pub use transient::{Static, Transient};
pub use transient::{Static, Transient, TypeId};
pub type ErasedError<'r> = Box<dyn Any<Co<'r>> + Send + Sync + 'r>;
pub type ErasedErrorRef<'r> = dyn Any<Co<'r>> + Send + Sync + 'r;

View File

@ -5,6 +5,7 @@ use std::error::Error as StdError;
use std::sync::Arc;
use figment::Profile;
use transient::Static;
use crate::listener::Endpoint;
use crate::{Catcher, Ignite, Orbit, Phase, Rocket, Route};
@ -85,10 +86,14 @@ pub enum ErrorKind {
Shutdown(Arc<Rocket<Orbit>>),
}
impl Static for ErrorKind {}
/// An error that occurs when a value was unexpectedly empty.
#[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Empty;
impl Static for Empty {}
impl Error {
#[inline(always)]
pub(crate) fn new(kind: ErrorKind) -> Error {

View File

@ -8,6 +8,7 @@ use std::net::AddrParseError;
use std::borrow::Cow;
use serde::{Serialize, ser::{Serializer, SerializeStruct}};
use transient::Transient;
use crate::http::Status;
use crate::form::name::{NameBuf, Name};
@ -54,7 +55,8 @@ use crate::data::ByteUnit;
/// Ok(i)
/// }
/// ```
#[derive(Default, Debug, PartialEq, Serialize)]
#[derive(Default, Debug, PartialEq, Serialize, Transient)]
#[variance('v = co)] // TODO: update when Transient v0.4
#[serde(transparent)]
pub struct Errors<'v>(Vec<Error<'v>>);

View File

@ -109,12 +109,16 @@ impl Rocket<Orbit> {
request._set_method(Method::Get);
match self.route(request, data).await {
Outcome::Success(response) => response,
Outcome::Error((status, error)) => self.dispatch_error(status, request, error).await,
Outcome::Forward((_, status, error)) => self.dispatch_error(status, request, error).await,
Outcome::Error((status, error))
=> self.dispatch_error(status, request, error).await,
Outcome::Forward((_, status, error))
=> self.dispatch_error(status, request, error).await,
}
}
Outcome::Forward((_, status, error)) => self.dispatch_error(status, request, error).await,
Outcome::Error((status, error)) => self.dispatch_error(status, request, error).await,
Outcome::Forward((_, status, error))
=> self.dispatch_error(status, request, error).await,
Outcome::Error((status, error))
=> self.dispatch_error(status, request, error).await,
};
// Set the cookies. Note that error responses will only include cookies
@ -274,7 +278,10 @@ impl Rocket<Orbit> {
) -> Result<Response<'r>, (Option<Status>, ErasedError<'r>)> {
if let Some(catcher) = self.router.catch(status, req) {
catcher.trace_info();
catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, req, error)).await
catch_handle(
catcher.name.as_deref(),
|| catcher.handler.handle(status, req, error)
).await
.map(|result| result.map_err(|(s, e)| (Some(s), e)))
.unwrap_or_else(|| Err((None, default_error_type())))
} else {

View File

@ -2,6 +2,7 @@ use std::fmt;
use std::num::NonZeroUsize;
use crate::mtls::x509::{self, nom};
use transient::Static;
/// An error returned by the [`Certificate`](crate::mtls::Certificate) guard.
///
@ -41,6 +42,8 @@ pub enum Error {
Trailing(usize),
}
impl Static for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {

View File

@ -7,7 +7,11 @@ use crate::http::Status;
/// Type alias for the return type of a [`Route`](crate::Route)'s
/// [`Handler::handle()`].
pub type Outcome<'r> = crate::outcome::Outcome<Response<'r>, (Status, ErasedError<'r>), (Data<'r>, Status, ErasedError<'r>)>;
pub type Outcome<'r> = crate::outcome::Outcome<
Response<'r>,
(Status, ErasedError<'r>),
(Data<'r>, Status, ErasedError<'r>)
>;
/// Type alias for the return type of a _raw_ [`Route`](crate::Route)'s
/// [`Handler`].
@ -240,7 +244,7 @@ impl<'r, 'o: 'r> Outcome<'o> {
Outcome::Error((code, Box::new(())))
}
/// Return an `Outcome` of `Error` with the status code `code`. This adds
/// the
/// the value for typed catchers.
///
/// This method exists to be used during manual routing.
///
@ -295,7 +299,9 @@ impl<'r, 'o: 'r> Outcome<'o> {
/// }
/// ```
#[inline(always)]
pub fn forward_val<T: Any<Co<'r>> + Send + Sync + 'r>(data: Data<'r>, status: Status, val: T) -> Outcome<'r> {
pub fn forward_val<T: Any<Co<'r>> + Send + Sync + 'r>(data: Data<'r>, status: Status, val: T)
-> Outcome<'r>
{
Outcome::Forward((data, status, Box::new(val)))
}
}

View File

@ -141,7 +141,9 @@ impl Catcher {
/// assert!(!a.collides_with(&b));
/// ```
pub fn collides_with(&self, other: &Self) -> bool {
self.code == other.code && self.base().segments().eq(other.base().segments())
self.code == other.code &&
types_collide(self, other) &&
self.base().segments().eq(other.base().segments())
}
}
@ -207,6 +209,10 @@ fn formats_collide(route: &Route, other: &Route) -> bool {
}
}
fn types_collide(catcher: &Catcher, other: &Catcher) -> bool {
catcher.error_type.as_ref().map(|(i, _)| i) == other.error_type.as_ref().map(|(i, _)| i)
}
#[cfg(test)]
mod tests {
use std::str::FromStr;

View File

@ -46,7 +46,11 @@ impl Rocket<Orbit> {
|rocket, request, data| Box::pin(rocket.preprocess(request, data)),
|token, rocket, request, data| Box::pin(async move {
if !request.errors.is_empty() {
return rocket.dispatch_error(Status::BadRequest, request, default_error_type()).await;
return rocket.dispatch_error(
Status::BadRequest,
request,
default_error_type()
).await;
}
rocket.dispatch(token, request, data).await

View File

@ -1,6 +1,6 @@
#[macro_use] extern crate rocket;
use rocket::catcher::ErasedErrorRef;
use rocket::catcher::ErasedError;
use rocket::{Request, Rocket, Route, Catcher, Build, route, catcher};
use rocket::data::Data;
use rocket::http::{Method, Status};
@ -74,7 +74,9 @@ fn catches_early_route_panic() {
#[test]
fn catches_early_catcher_panic() {
fn pre_future_catcher<'r>(_: Status, _: &'r Request<'_>, _: &ErasedErrorRef<'r>) -> catcher::BoxFuture<'r> {
fn pre_future_catcher<'r>(_: Status, _: &'r Request<'_>, _: ErasedError<'r>)
-> catcher::BoxFuture<'r>
{
panic!("a panicking pre-future catcher")
}