From a0784b4b1502a915e78b19b544cc58c64fa21bd4 Mon Sep 17 00:00:00 2001 From: Jeb Rosen Date: Sun, 26 Jan 2020 13:52:37 -0800 Subject: [PATCH] Catch and gracefully handle panics in routes and catchers. --- core/lib/src/server.rs | 27 +++++++++++++++++------- core/lib/tests/panic-handling.rs | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) create mode 100644 core/lib/tests/panic-handling.rs diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index ca7e8420..718a97ae 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -2,7 +2,7 @@ use std::io; use std::sync::Arc; use futures::stream::StreamExt; -use futures::future::{Future, BoxFuture}; +use futures::future::{Future, FutureExt, BoxFuture}; use tokio::sync::oneshot; use yansi::Paint; @@ -266,7 +266,14 @@ impl Rocket { request.set_route(route); // Dispatch the request to the handler. - let outcome = route.handler.handle(request, data).await; + let outcome = std::panic::AssertUnwindSafe(route.handler.handle(request, data)) + .catch_unwind() + .await + .unwrap_or_else(|_| { + error_!("A request handler panicked."); + warn_!("Handling as a 500 error."); + Outcome::Failure(Status::InternalServerError) + }); // Check if the request processing completed (Some) or if the // request needs to be forwarded. If it does, continue the loop @@ -304,24 +311,30 @@ impl Rocket { // Try to get the active catcher but fallback to user's 500 catcher. let code = Paint::red(status.code); let response = if let Some(catcher) = self.catchers.get(&status.code) { - catcher.handler.handle(status, req).await + std::panic::AssertUnwindSafe(catcher.handler.handle(status, req)).catch_unwind().await } else if let Some(ref default) = self.default_catcher { warn_!("No {} catcher found. Using default catcher.", code); - default.handler.handle(status, req).await + std::panic::AssertUnwindSafe(default.handler.handle(status, req)).catch_unwind().await } else { warn_!("No {} or default catcher found. Using Rocket default catcher.", code); - crate::catcher::default(status, req) + Ok(crate::catcher::default(status, req)) }; // Dispatch to the catcher. If it fails, use the Rocket default 500. match response { - Ok(r) => r, - Err(err_status) => { + Ok(Ok(r)) => r, + Ok(Err(err_status)) => { error_!("Catcher unexpectedly failed with {}.", err_status); warn_!("Using Rocket's default 500 error catcher."); let default = crate::catcher::default(Status::InternalServerError, req); default.expect("Rocket has default 500 response") } + Err(_) => { + error_!("Catcher panicked!"); + warn_!("Using Rocket's default 500 error catcher."); + let default = crate::catcher::default(Status::InternalServerError, req); + default.expect("Rocket has default 500 response") + } } } } diff --git a/core/lib/tests/panic-handling.rs b/core/lib/tests/panic-handling.rs new file mode 100644 index 00000000..f2675a9c --- /dev/null +++ b/core/lib/tests/panic-handling.rs @@ -0,0 +1,36 @@ +#[macro_use] extern crate rocket; + +use rocket::Rocket; +use rocket::http::Status; +use rocket::local::blocking::Client; + +#[get("/panic")] +fn panic_route() -> &'static str { + panic!("Panic in route") +} + +#[catch(404)] +fn panic_catcher() -> &'static str { + panic!("Panic in catcher") +} + +fn rocket() -> Rocket { + rocket::ignite() + .mount("/", routes![panic_route]) + .register(catchers![panic_catcher]) +} + +#[test] +fn catches_route_panic() { + let client = Client::tracked(rocket()).unwrap(); + let response = client.get("/panic").dispatch(); + assert_eq!(response.status(), Status::InternalServerError); + +} + +#[test] +fn catches_catcher_panic() { + let client = Client::tracked(rocket()).unwrap(); + let response = client.get("/noroute").dispatch(); + assert_eq!(response.status(), Status::InternalServerError); +}