diff --git a/core/lib/src/catcher.rs b/core/lib/src/catcher.rs index e2f50279..8a27813e 100644 --- a/core/lib/src/catcher.rs +++ b/core/lib/src/catcher.rs @@ -2,10 +2,12 @@ //! values. use std::fmt; +use std::io::Cursor; use crate::response::Response; use crate::codegen::StaticCatcherInfo; use crate::request::Request; +use crate::http::ContentType; use futures::future::BoxFuture; use yansi::Paint; @@ -148,7 +150,7 @@ impl Catcher { impl Default for Catcher { fn default() -> Self { fn async_default<'r>(status: Status, request: &'r Request<'_>) -> ErrorHandlerFuture<'r> { - Box::pin(async move { default(status, request) }) + Box::pin(async move { Ok(default(status, request)) }) } Catcher { code: None, handler: Box::new(async_default) } @@ -340,17 +342,17 @@ macro_rules! default_catcher_fn { ($($code:expr, $reason:expr, $description:expr),+) => ( use std::borrow::Cow; use crate::http::Status; - use crate::response::{content, status, Responder}; - pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Result<'r> { - if req.accept().map(|a| a.preferred().is_json()).unwrap_or(false) { + pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Response<'r> { + let preferred = req.accept().map(|a| a.preferred()); + let (mime, text) = if preferred.map_or(false, |a| a.is_json()) { let json: Cow<'_, str> = match status.code { $($code => json_error_template!($code, $reason, $description).into(),)* code => format!(json_error_fmt_template!("{}", "Unknown Error", "An unknown error has occurred."), code).into() }; - status::Custom(status, content::Json(json)).respond_to(req) + (ContentType::JSON, json) } else { let html: Cow<'_, str> = match status.code { $($code => html_error_template!($code, $reason, $description).into(),)* @@ -358,8 +360,16 @@ macro_rules! default_catcher_fn { "An unknown error has occurred."), code, code).into(), }; - status::Custom(status, content::Html(html)).respond_to(req) - } + (ContentType::HTML, html) + }; + + let mut r = Response::build().status(status).header(mime).finalize(); + match text { + Cow::Owned(v) => r.set_sized_body(v.len(), Cursor::new(v)), + Cow::Borrowed(v) => r.set_sized_body(v.len(), Cursor::new(v)), + }; + + r } ) } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 718a97ae..dc1247b4 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -1,5 +1,6 @@ use std::io; use std::sync::Arc; +use std::panic::AssertUnwindSafe; use futures::stream::StreamExt; use futures::future::{Future, FutureExt, BoxFuture}; @@ -21,6 +22,20 @@ use crate::http::uri::Origin; // A token returned to force the execution of one method before another. pub(crate) struct Token; +macro_rules! info_panic { + ($e:expr) => {{ + error_!("A handler has panicked. This is an application bug."); + info_!("A panic in Rust must be treated as an exceptional event."); + info_!("Panicking is not a suitable error handling mechanism."); + info_!("Unwinding, the result of a panic, is an expensive operation."); + info_!("Panics will severely degrade application performance."); + info_!("Instead of panicking, return `Option` and/or `Result`."); + info_!("Values of either type can be returned directly from handlers."); + warn_!("Forwarding to the {} error catcher.", Paint::blue(500).bold()); + $e + }} +} + // This function tries to hide all of the Hyper-ness from Rocket. It essentially // converts Hyper types into Rocket types, then calls the `dispatch` function, // which knows nothing about Hyper. Because responding depends on the @@ -266,14 +281,9 @@ impl Rocket { request.set_route(route); // Dispatch the request to the handler. - let outcome = std::panic::AssertUnwindSafe(route.handler.handle(request, data)) - .catch_unwind() - .await - .unwrap_or_else(|_| { - error_!("A request handler panicked."); - warn_!("Handling as a 500 error."); - Outcome::Failure(Status::InternalServerError) - }); + let outcome = AssertUnwindSafe(route.handler.handle(request, data)) + .catch_unwind().await + .unwrap_or_else(|_| info_panic!(Outcome::Failure(Status::InternalServerError))); // Check if the request processing completed (Some) or if the // request needs to be forwarded. If it does, continue the loop @@ -290,53 +300,69 @@ impl Rocket { } } - // Finds the error catcher for the status `status` and executes it for the - // given request `req`. If a user has registered a catcher for `status`, the - // catcher is called. If the catcher fails to return a good response, the - // 500 catcher is executed. If there is no registered catcher for `status`, - // the default catcher is used. - pub(crate) fn handle_error<'s, 'r: 's>( + /// Invokes the handler with `req` for catcher with status `status`. + /// + /// In order of preference, invoked handler is: + /// * the user's registered handler for `status` + /// * the user's registered `default` handler + /// * Rocket's default handler for `status` + /// + /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` + /// if the handler ran to completion but failed. Returns `Ok(None)` if the + /// handler panicked while executing. + async fn invoke_catcher<'s, 'r: 's>( &'s self, status: Status, req: &'r Request<'s> - ) -> impl Future> + 's { - async move { - warn_!("Responding with {} catcher.", Paint::red(&status)); + ) -> Result, Option> { + // For now, we reset the delta state to prevent any modifications + // from earlier, unsuccessful paths from being reflected in error + // response. We may wish to relax this in the future. + req.cookies().reset_delta(); - // For now, we reset the delta state to prevent any modifications - // from earlier, unsuccessful paths from being reflected in error - // response. We may wish to relax this in the future. - req.cookies().reset_delta(); + // Try to get the active catcher + let catcher = self.catchers.get(&status.code) + .or_else(|| self.default_catcher.as_ref()); - // Try to get the active catcher but fallback to user's 500 catcher. - let code = Paint::red(status.code); - let response = if let Some(catcher) = self.catchers.get(&status.code) { - std::panic::AssertUnwindSafe(catcher.handler.handle(status, req)).catch_unwind().await - } else if let Some(ref default) = self.default_catcher { - warn_!("No {} catcher found. Using default catcher.", code); - std::panic::AssertUnwindSafe(default.handler.handle(status, req)).catch_unwind().await - } else { - warn_!("No {} or default catcher found. Using Rocket default catcher.", code); - Ok(crate::catcher::default(status, req)) - }; + if let Some(catcher) = catcher { + warn_!("Responding with registered {} catcher.", catcher); + let handler = AssertUnwindSafe(catcher.handler.handle(status, req)); + handler.catch_unwind().await + .map(|result| result.map_err(Some)) + .unwrap_or_else(|_| info_panic!(Err(None))) + } else { + let code = Paint::blue(status.code).bold(); + warn_!("No {} catcher registered. Using Rocket default.", code); + Ok(crate::catcher::default(status, req)) + } + } - // Dispatch to the catcher. If it fails, use the Rocket default 500. - match response { - Ok(Ok(r)) => r, - Ok(Err(err_status)) => { - error_!("Catcher unexpectedly failed with {}.", err_status); - warn_!("Using Rocket's default 500 error catcher."); - let default = crate::catcher::default(Status::InternalServerError, req); - default.expect("Rocket has default 500 response") - } - Err(_) => { - error_!("Catcher panicked!"); - warn_!("Using Rocket's default 500 error catcher."); - let default = crate::catcher::default(Status::InternalServerError, req); - default.expect("Rocket has default 500 response") - } + // Invokes the catcher for `status`. Returns the response on success. + // + // On catcher failure, the 500 error catcher is attempted. If _that_ fails, + // the (infallible) default 500 error cather is used. + pub(crate) async fn handle_error<'s, 'r: 's>( + &'s self, + mut status: Status, + req: &'r Request<'s> + ) -> Response<'r> { + // Dispatch to the `status` catcher. + if let Ok(r) = self.invoke_catcher(status, req).await { + return r; + } + + // If it fails and it's not a 500, try the 500 catcher. + if status != Status::InternalServerError { + error_!("Catcher failed. Attemping 500 error catcher."); + status = Status::InternalServerError; + if let Ok(r) = self.invoke_catcher(status, req).await { + return r; } } + + // If it failed again or if it was already a 500, use Rocket's default. + error_!("{} catcher failed. Using Rocket default 500.", status.code); + crate::catcher::default(Status::InternalServerError, req) } // TODO.async: Solidify the Listener APIs and make this function public diff --git a/core/lib/tests/panic-handling.rs b/core/lib/tests/panic-handling.rs index f2675a9c..1c41915d 100644 --- a/core/lib/tests/panic-handling.rs +++ b/core/lib/tests/panic-handling.rs @@ -14,23 +14,43 @@ fn panic_catcher() -> &'static str { panic!("Panic in catcher") } +#[catch(500)] +fn ise() -> &'static str { + "Hey, sorry! :(" +} + +#[catch(500)] +fn double_panic() { + panic!("so, so sorry...") +} + fn rocket() -> Rocket { rocket::ignite() .mount("/", routes![panic_route]) - .register(catchers![panic_catcher]) + .register(catchers![panic_catcher, ise]) } #[test] fn catches_route_panic() { - let client = Client::tracked(rocket()).unwrap(); + let client = Client::untracked(rocket()).unwrap(); let response = client.get("/panic").dispatch(); assert_eq!(response.status(), Status::InternalServerError); + assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); } +#[test] +fn catches_catcher_panic() { + let client = Client::untracked(rocket()).unwrap(); + let response = client.get("/noroute").dispatch(); + assert_eq!(response.status(), Status::InternalServerError); + assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); +} #[test] -fn catches_catcher_panic() { - let client = Client::tracked(rocket()).unwrap(); +fn catches_double_panic() { + let rocket = rocket().register(catchers![double_panic]); + let client = Client::untracked(rocket).unwrap(); let response = client.get("/noroute").dispatch(); assert_eq!(response.status(), Status::InternalServerError); + assert!(!response.into_string().unwrap().contains(":(")); }