Catch and gracefully handle panics in routes and catchers.

This commit is contained in:
Jeb Rosen 2020-01-26 13:52:37 -08:00 committed by Sergio Benitez
parent 7784cc982a
commit a0784b4b15
2 changed files with 56 additions and 7 deletions

View File

@ -2,7 +2,7 @@ use std::io;
use std::sync::Arc; use std::sync::Arc;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::future::{Future, BoxFuture}; use futures::future::{Future, FutureExt, BoxFuture};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use yansi::Paint; use yansi::Paint;
@ -266,7 +266,14 @@ impl Rocket {
request.set_route(route); request.set_route(route);
// Dispatch the request to the handler. // Dispatch the request to the handler.
let outcome = route.handler.handle(request, data).await; let outcome = std::panic::AssertUnwindSafe(route.handler.handle(request, data))
.catch_unwind()
.await
.unwrap_or_else(|_| {
error_!("A request handler panicked.");
warn_!("Handling as a 500 error.");
Outcome::Failure(Status::InternalServerError)
});
// Check if the request processing completed (Some) or if the // Check if the request processing completed (Some) or if the
// request needs to be forwarded. If it does, continue the loop // request needs to be forwarded. If it does, continue the loop
@ -304,24 +311,30 @@ impl Rocket {
// Try to get the active catcher but fallback to user's 500 catcher. // Try to get the active catcher but fallback to user's 500 catcher.
let code = Paint::red(status.code); let code = Paint::red(status.code);
let response = if let Some(catcher) = self.catchers.get(&status.code) { let response = if let Some(catcher) = self.catchers.get(&status.code) {
catcher.handler.handle(status, req).await std::panic::AssertUnwindSafe(catcher.handler.handle(status, req)).catch_unwind().await
} else if let Some(ref default) = self.default_catcher { } else if let Some(ref default) = self.default_catcher {
warn_!("No {} catcher found. Using default catcher.", code); warn_!("No {} catcher found. Using default catcher.", code);
default.handler.handle(status, req).await std::panic::AssertUnwindSafe(default.handler.handle(status, req)).catch_unwind().await
} else { } else {
warn_!("No {} or default catcher found. Using Rocket default catcher.", code); warn_!("No {} or default catcher found. Using Rocket default catcher.", code);
crate::catcher::default(status, req) Ok(crate::catcher::default(status, req))
}; };
// Dispatch to the catcher. If it fails, use the Rocket default 500. // Dispatch to the catcher. If it fails, use the Rocket default 500.
match response { match response {
Ok(r) => r, Ok(Ok(r)) => r,
Err(err_status) => { Ok(Err(err_status)) => {
error_!("Catcher unexpectedly failed with {}.", err_status); error_!("Catcher unexpectedly failed with {}.", err_status);
warn_!("Using Rocket's default 500 error catcher."); warn_!("Using Rocket's default 500 error catcher.");
let default = crate::catcher::default(Status::InternalServerError, req); let default = crate::catcher::default(Status::InternalServerError, req);
default.expect("Rocket has default 500 response") default.expect("Rocket has default 500 response")
} }
Err(_) => {
error_!("Catcher panicked!");
warn_!("Using Rocket's default 500 error catcher.");
let default = crate::catcher::default(Status::InternalServerError, req);
default.expect("Rocket has default 500 response")
}
} }
} }
} }

View File

@ -0,0 +1,36 @@
#[macro_use] extern crate rocket;
use rocket::Rocket;
use rocket::http::Status;
use rocket::local::blocking::Client;
#[get("/panic")]
fn panic_route() -> &'static str {
panic!("Panic in route")
}
#[catch(404)]
fn panic_catcher() -> &'static str {
panic!("Panic in catcher")
}
fn rocket() -> Rocket {
rocket::ignite()
.mount("/", routes![panic_route])
.register(catchers![panic_catcher])
}
#[test]
fn catches_route_panic() {
let client = Client::tracked(rocket()).unwrap();
let response = client.get("/panic").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
}
#[test]
fn catches_catcher_panic() {
let client = Client::tracked(rocket()).unwrap();
let response = client.get("/noroute").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
}