Introduce scoped catchers.

Catchers can now be scoped to paths, with preference given to the
longest-prefix, then the status code. This a breaking change for all
applications that register catchers:

  * `Rocket::register()` takes a base path to scope catchers under.
    - The previous behavior is recovered with `::register("/", ...)`.
  * Catchers now fallibly, instead of silently, collide.
  * `ErrorKind::Collision` is now `ErrorKind::Collisions`.

Related changes:

  * `Origin` implements `TryFrom<String>`, `TryFrom<&str>`.
  * All URI variants implement `TryFrom<Uri>`.
  * Added `Segments::prefix_of()`.
  * `Rocket::mount()` takes a  `TryInto<Origin<'_>>` instead of `&str`
    for the base mount point.
  * Extended `errors` example with scoped catchers.
  * Added scoped sections to catchers guide.

Internal changes:

  * Moved router code to `router/router.rs`.
This commit is contained in:
Sergio Benitez 2021-03-25 21:36:00 -07:00
parent c3bad3a287
commit 2893ce754d
25 changed files with 1166 additions and 695 deletions

View File

@ -46,7 +46,7 @@ fn test_rank_collision() {
let rocket = rocket::ignite().mount("/", routes![get0, get0b]); let rocket = rocket::ignite().mount("/", routes![get0, get0b]);
let client_result = Client::debug(rocket); let client_result = Client::debug(rocket);
match client_result.as_ref().map_err(|e| e.kind()) { match client_result.as_ref().map_err(|e| e.kind()) {
Err(ErrorKind::Collision(..)) => { /* o.k. */ }, Err(ErrorKind::Collisions(..)) => { /* o.k. */ },
Ok(_) => panic!("client succeeded unexpectedly"), Ok(_) => panic!("client succeeded unexpectedly"),
Err(e) => panic!("expected collision, got {}", e) Err(e) => panic!("expected collision, got {}", e)
} }

View File

@ -23,7 +23,7 @@ fn catch(r#raw: &rocket::Request) -> String {
fn test_raw_ident() { fn test_raw_ident() {
let rocket = rocket::ignite() let rocket = rocket::ignite()
.mount("/", routes![get, swap]) .mount("/", routes![get, swap])
.register(catchers![catch]); .register("/", catchers![catch]);
let client = Client::debug(rocket).unwrap(); let client = Client::debug(rocket).unwrap();

View File

@ -1,5 +1,6 @@
use std::fmt::{self, Display};
use std::borrow::Cow; use std::borrow::Cow;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use crate::ext::IntoOwned; use crate::ext::IntoOwned;
use crate::parse::{Indexed, Extent, IndexedStr}; use crate::parse::{Indexed, Extent, IndexedStr};
@ -608,6 +609,22 @@ impl<'a> Origin<'a> {
} }
} }
impl TryFrom<String> for Origin<'static> {
type Error = Error<'static>;
fn try_from(value: String) -> Result<Self, Self::Error> {
Origin::parse_owned(value)
}
}
impl<'a> TryFrom<&'a str> for Origin<'a> {
type Error = Error<'a>;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Origin::parse(value)
}
}
impl Display for Origin<'_> { impl Display for Origin<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.path())?; write!(f, "{}", self.path())?;

View File

@ -75,6 +75,16 @@ impl<'o> Segments<'o> {
.map(|i| i.from_source(Some(self.source.as_str()))) .map(|i| i.from_source(Some(self.source.as_str())))
} }
/// Returns `true` if `self` is a prefix of `other`.
#[inline]
pub fn prefix_of<'b>(self, other: Segments<'b>) -> bool {
if self.len() > other.len() {
return false;
}
self.zip(other).all(|(a, b)| a == b)
}
/// Creates a `PathBuf` from `self`. The returned `PathBuf` is /// Creates a `PathBuf` from `self`. The returned `PathBuf` is
/// percent-decoded. If a segment is equal to "..", the previous segment (if /// percent-decoded. If a segment is equal to "..", the previous segment (if
/// any) is skipped. /// any) is skipped.

View File

@ -282,6 +282,16 @@ impl Display for Uri<'_> {
} }
} }
/// The error type returned when a URI conversion fails.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TryFromUriError(());
impl fmt::Display for TryFromUriError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"invalid conversion from general to specific URI variant".fmt(f)
}
}
macro_rules! impl_uri_from { macro_rules! impl_uri_from {
($type:ident) => ( ($type:ident) => (
impl<'a> From<$type<'a>> for Uri<'a> { impl<'a> From<$type<'a>> for Uri<'a> {
@ -289,6 +299,17 @@ macro_rules! impl_uri_from {
Uri::$type(other) Uri::$type(other)
} }
} }
impl<'a> TryFrom<Uri<'a>> for $type<'a> {
type Error = TryFromUriError;
fn try_from(uri: Uri<'a>) -> Result<Self, Self::Error> {
match uri {
Uri::$type(inner) => Ok(inner),
_ => Err(TryFromUriError(()))
}
}
}
) )
} }

View File

@ -1,5 +1,5 @@
//! Types and traits for error catchers, error handlers, and their return //! Types and traits for error catchers, error handlers, and their return
//! values. //! types.
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -7,7 +7,7 @@ use std::io::Cursor;
use crate::response::Response; use crate::response::Response;
use crate::codegen::StaticCatcherInfo; use crate::codegen::StaticCatcherInfo;
use crate::request::Request; use crate::request::Request;
use crate::http::ContentType; use crate::http::{Status, ContentType, uri};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use yansi::Paint; use yansi::Paint;
@ -19,6 +19,12 @@ pub type Result<'r> = std::result::Result<Response<'r>, crate::http::Status>;
/// Type alias for the unwieldy [`ErrorHandler::handle()`] return type. /// Type alias for the unwieldy [`ErrorHandler::handle()`] return type.
pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>; pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>;
// A handler to use when one is needed temporarily. Don't use outside of Rocket!
#[cfg(test)]
pub(crate) fn dummy<'r>(_: Status, _: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
Box::pin(async move { Ok(Response::new()) })
}
/// An error catching route. /// An error catching route.
/// ///
/// # Overview /// # Overview
@ -77,12 +83,12 @@ pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>;
/// ///
/// #[catch(default)] /// #[catch(default)]
/// fn default(status: Status, req: &Request) -> String { /// fn default(status: Status, req: &Request) -> String {
/// format!("{} - {} ({})", status.code, status.reason, req.uri()) /// format!("{} ({})", status, req.uri())
/// } /// }
/// ///
/// #[launch] /// #[launch]
/// fn rocket() -> rocket::Rocket { /// fn rocket() -> rocket::Rocket {
/// rocket::ignite().register(catchers![internal_error, not_found, default]) /// rocket::ignite().register("/", catchers![internal_error, not_found, default])
/// } /// }
/// ``` /// ```
/// ///
@ -104,7 +110,10 @@ pub struct Catcher {
/// The name of this catcher, if one was given. /// The name of this catcher, if one was given.
pub name: Option<Cow<'static, str>>, pub name: Option<Cow<'static, str>>,
/// The HTTP status code to match against if this route is not `default`. /// The mount point.
pub base: uri::Origin<'static>,
/// The HTTP status to match against if this route is not `default`.
pub code: Option<u16>, pub code: Option<u16>,
/// The catcher's associated error handler. /// The catcher's associated error handler.
@ -112,8 +121,8 @@ pub struct Catcher {
} }
impl Catcher { impl Catcher {
/// Creates a catcher for the given status code, or a default catcher if /// Creates a catcher for the given `status`, or a default catcher if
/// `code` is `None`, using the given error handler. This should only be /// `status` is `None`, using the given error handler. This should only be
/// used when routing manually. /// used when routing manually.
/// ///
/// # Examples /// # Examples
@ -121,11 +130,11 @@ impl Catcher {
/// ```rust /// ```rust
/// use rocket::request::Request; /// use rocket::request::Request;
/// use rocket::catcher::{Catcher, ErrorHandlerFuture}; /// use rocket::catcher::{Catcher, ErrorHandlerFuture};
/// use rocket::response::{Result, Responder, status::Custom}; /// use rocket::response::Responder;
/// use rocket::http::Status; /// use rocket::http::Status;
/// ///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = Custom(status, format!("404: {}", req.uri())); /// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) }) /// Box::pin(async move { res.respond_to(req) })
/// } /// }
/// ///
@ -134,7 +143,7 @@ impl Catcher {
/// } /// }
/// ///
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { /// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = Custom(status, format!("{}: {}", status, req.uri())); /// let res = (status, format!("{}: {}", status, req.uri()));
/// Box::pin(async move { res.respond_to(req) }) /// Box::pin(async move { res.respond_to(req) })
/// } /// }
/// ///
@ -142,22 +151,82 @@ impl Catcher {
/// let internal_server_error_catcher = Catcher::new(500, handle_500); /// let internal_server_error_catcher = Catcher::new(500, handle_500);
/// let default_error_catcher = Catcher::new(None, handle_default); /// let default_error_catcher = Catcher::new(None, handle_default);
/// ``` /// ```
///
/// # Panics
///
/// Panics if `code` is not in the HTTP status code error range `[400,
/// 600)`.
#[inline(always)] #[inline(always)]
pub fn new<C, H>(code: C, handler: H) -> Catcher pub fn new<S, H>(code: S, handler: H) -> Catcher
where C: Into<Option<u16>>, H: ErrorHandler where S: Into<Option<u16>>, H: ErrorHandler
{ {
Catcher { name: None, code: code.into(), handler: Box::new(handler) } let code = code.into();
if let Some(code) = code {
assert!(code >= 400 && code < 600);
}
Catcher {
name: None,
base: uri::Origin::new("/", None::<&str>),
handler: Box::new(handler),
code,
}
}
/// Maps the `base` of this catcher using `mapper`, returning a new
/// `Catcher` with the returned base.
///
/// `mapper` is called with the current base. The returned `String` is used
/// as the new base if it is a valid URI. If the returned base URI contains
/// a query, it is ignored. Returns an error if the base produced by
/// `mapper` is not a valid origin URI.
///
/// # Example
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, ErrorHandlerFuture};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
///
/// let catcher = Catcher::new(404, handle_404);
/// assert_eq!(catcher.base.path(), "/");
///
/// let catcher = catcher.map_base(|_| format!("/bar")).unwrap();
/// assert_eq!(catcher.base.path(), "/bar");
///
/// let catcher = catcher.map_base(|base| format!("/foo{}", base)).unwrap();
/// assert_eq!(catcher.base.path(), "/foo/bar");
///
/// let catcher = catcher.map_base(|base| format!("/foo ? {}", base));
/// assert!(catcher.is_err());
/// ```
pub fn map_base<'a, F>(
mut self,
mapper: F
) -> std::result::Result<Self, uri::Error<'static>>
where F: FnOnce(uri::Origin<'a>) -> String
{
self.base = uri::Origin::parse_owned(mapper(self.base))?.into_normalized();
self.base.clear_query();
Ok(self)
} }
} }
impl Default for Catcher { impl Default for Catcher {
fn default() -> Self { fn default() -> Self {
fn async_default<'r>(status: Status, request: &'r Request<'_>) -> ErrorHandlerFuture<'r> { fn handler<'r>(s: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
Box::pin(async move { Ok(default(status, request)) }) Box::pin(async move { Ok(default(s, req)) })
} }
let name = Some("<Rocket Catcher>".into()); let mut catcher = Catcher::new(None, handler);
Catcher { name, code: None, handler: Box::new(async_default) } catcher.name = Some("<Rocket Catcher>".into());
catcher
} }
} }
@ -226,9 +295,9 @@ impl Default for Catcher {
/// fn rocket() -> rocket::Rocket { /// fn rocket() -> rocket::Rocket {
/// rocket::ignite() /// rocket::ignite()
/// // to handle only `404` /// // to handle only `404`
/// .register(CustomHandler::catch(Status::NotFound, Kind::Simple)) /// .register("/", CustomHandler::catch(Status::NotFound, Kind::Simple))
/// // or to register as the default /// // or to register as the default
/// .register(CustomHandler::default(Kind::Simple)) /// .register("/", CustomHandler::default(Kind::Simple))
/// } /// }
/// ``` /// ```
/// ///
@ -239,7 +308,7 @@ impl Default for Catcher {
/// trait serves no other purpose but to ensure that every `ErrorHandler` /// trait serves no other purpose but to ensure that every `ErrorHandler`
/// can be cloned, allowing `Catcher`s to be cloned. /// can be cloned, allowing `Catcher`s to be cloned.
/// 2. `CustomHandler`'s methods return `Vec<Route>`, allowing for use /// 2. `CustomHandler`'s methods return `Vec<Route>`, allowing for use
/// directly as the parameter to `rocket.register()`. /// directly as the parameter to `rocket.register("/", )`.
/// 3. Unlike static-function-based handlers, this custom handler can make use /// 3. Unlike static-function-based handlers, this custom handler can make use
/// of internal state. /// of internal state.
#[crate::async_trait] #[crate::async_trait]
@ -288,6 +357,10 @@ impl fmt::Display for Catcher {
write!(f, "{}{}{} ", Paint::cyan("("), Paint::white(n), Paint::cyan(")"))?; write!(f, "{}{}{} ", Paint::cyan("("), Paint::white(n), Paint::cyan(")"))?;
} }
if self.base.path() != "/" {
write!(f, "{} ", Paint::green(self.base.path()))?;
}
match self.code { match self.code {
Some(code) => write!(f, "{}", Paint::blue(code)), Some(code) => write!(f, "{}", Paint::blue(code)),
None => write!(f, "{}", Paint::blue("default")) None => write!(f, "{}", Paint::blue("default"))
@ -298,6 +371,8 @@ impl fmt::Display for Catcher {
impl fmt::Debug for Catcher { impl fmt::Debug for Catcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Catcher") f.debug_struct("Catcher")
.field("name", &self.name)
.field("base", &self.base)
.field("code", &self.code) .field("code", &self.code)
.finish() .finish()
} }
@ -359,7 +434,6 @@ r#"{{
macro_rules! default_catcher_fn { macro_rules! default_catcher_fn {
($($code:expr, $reason:expr, $description:expr),+) => ( ($($code:expr, $reason:expr, $description:expr),+) => (
use std::borrow::Cow; use std::borrow::Cow;
use crate::http::Status;
pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Response<'r> { pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Response<'r> {
let preferred = req.accept().map(|a| a.preferred()); let preferred = req.accept().map(|a| a.preferred());

View File

@ -6,8 +6,6 @@ use std::sync::atomic::{Ordering, AtomicBool};
use yansi::Paint; use yansi::Paint;
use figment::Profile; use figment::Profile;
use crate::router::Route;
/// An error that occurs during launch. /// An error that occurs during launch.
/// ///
/// An `Error` is returned by [`launch()`](crate::Rocket::launch()) when /// An `Error` is returned by [`launch()`](crate::Rocket::launch()) when
@ -80,7 +78,7 @@ pub enum ErrorKind {
/// An I/O error occurred in the runtime. /// An I/O error occurred in the runtime.
Runtime(Box<dyn std::error::Error + Send + Sync>), Runtime(Box<dyn std::error::Error + Send + Sync>),
/// Route collisions were detected. /// Route collisions were detected.
Collision(Vec<(Route, Route)>), Collisions(crate::router::Collisions),
/// Launch fairing(s) failed. /// Launch fairing(s) failed.
FailedFairings(Vec<crate::fairing::Info>), FailedFairings(Vec<crate::fairing::Info>),
/// The configuration profile is not debug but not secret key is configured. /// The configuration profile is not debug but not secret key is configured.
@ -140,7 +138,7 @@ impl fmt::Display for ErrorKind {
match self { match self {
ErrorKind::Bind(e) => write!(f, "binding failed: {}", e), ErrorKind::Bind(e) => write!(f, "binding failed: {}", e),
ErrorKind::Io(e) => write!(f, "I/O error: {}", e), ErrorKind::Io(e) => write!(f, "I/O error: {}", e),
ErrorKind::Collision(_) => "route collisions detected".fmt(f), ErrorKind::Collisions(_) => "collisions detected".fmt(f),
ErrorKind::FailedFairings(_) => "a launch fairing failed".fmt(f), ErrorKind::FailedFairings(_) => "a launch fairing failed".fmt(f),
ErrorKind::Runtime(e) => write!(f, "runtime error: {}", e), ErrorKind::Runtime(e) => write!(f, "runtime error: {}", e),
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
@ -181,14 +179,21 @@ impl Drop for Error {
info_!("{}", e); info_!("{}", e);
panic!("aborting due to i/o error"); panic!("aborting due to i/o error");
} }
ErrorKind::Collision(ref collisions) => { ErrorKind::Collisions(ref collisions) => {
error!("Rocket failed to launch due to the following routing collisions:"); fn log_collisions<T: fmt::Display>(kind: &str, collisions: &[(T, T)]) {
if collisions.is_empty() { return }
error!("Rocket failed to launch due to the following {} collisions:", kind);
for &(ref a, ref b) in collisions { for &(ref a, ref b) in collisions {
info_!("{} {} {}", a, Paint::red("collides with").italic(), b) info_!("{} {} {}", a, Paint::red("collides with").italic(), b)
} }
}
info_!("Note: Collisions can usually be resolved by ranking routes."); log_collisions("route", &collisions.routes);
panic!("route collisions detected"); log_collisions("catcher", &collisions.catchers);
info_!("Note: Route collisions can usually be resolved by ranking routes.");
panic!("routing collisions detected");
} }
ErrorKind::FailedFairings(ref failures) => { ErrorKind::FailedFairings(ref failures) => {
error!("Rocket failed to launch due to failing fairings:"); error!("Rocket failed to launch due to failing fairings:");

View File

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::fmt::Display;
use std::convert::TryInto;
use yansi::Paint; use yansi::Paint;
use state::Container; use state::Container;
@ -13,7 +14,7 @@ use crate::router::{Router, Route};
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::logger::PaintExt; use crate::logger::PaintExt;
use crate::shutdown::Shutdown; use crate::shutdown::Shutdown;
use crate::http::uri::Origin; use crate::http::{uri::Origin, ext::IntoOwned};
use crate::error::{Error, ErrorKind}; use crate::error::{Error, ErrorKind};
/// The main `Rocket` type: used to mount routes and catchers and launch the /// The main `Rocket` type: used to mount routes and catchers and launch the
@ -24,8 +25,6 @@ pub struct Rocket {
pub(crate) figment: Figment, pub(crate) figment: Figment,
pub(crate) managed_state: Container![Send + Sync], pub(crate) managed_state: Container![Send + Sync],
pub(crate) router: Router, pub(crate) router: Router,
pub(crate) default_catcher: Option<Catcher>,
pub(crate) catchers: HashMap<u16, Catcher>,
pub(crate) fairings: Fairings, pub(crate) fairings: Fairings,
pub(crate) shutdown_receiver: Option<mpsc::Receiver<()>>, pub(crate) shutdown_receiver: Option<mpsc::Receiver<()>>,
pub(crate) shutdown_handle: Shutdown, pub(crate) shutdown_handle: Shutdown,
@ -95,8 +94,6 @@ impl Rocket {
config, figment, managed_state, config, figment, managed_state,
shutdown_handle: Shutdown(shutdown_sender), shutdown_handle: Shutdown(shutdown_sender),
router: Router::new(), router: Router::new(),
default_catcher: None,
catchers: HashMap::new(),
fairings: Fairings::new(), fairings: Fairings::new(),
shutdown_receiver: Some(shutdown_receiver), shutdown_receiver: Some(shutdown_receiver),
} }
@ -203,10 +200,15 @@ impl Rocket {
/// # .launch().await; /// # .launch().await;
/// # }; /// # };
/// ``` /// ```
pub fn mount<R: Into<Vec<Route>>>(mut self, base: &str, routes: R) -> Self { pub fn mount<'a, B, R>(mut self, base: B, routes: R) -> Self
let base_uri = Origin::parse_owned(base.to_string()) where B: TryInto<Origin<'a>> + Clone + Display,
B::Error: Display,
R: Into<Vec<Route>>
{
let base_uri = base.clone().try_into()
.map(|origin| origin.into_owned())
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
error!("Invalid mount point URI: {}.", Paint::white(base)); error!("Invalid route base: {}.", Paint::white(&base));
panic!("Error: {}", e); panic!("Error: {}", e);
}); });
@ -215,11 +217,11 @@ impl Rocket {
panic!("Invalid mount point."); panic!("Invalid mount point.");
} }
info!("{}{} {}{}", info!("{}{} {} {}",
Paint::emoji("🛰 "), Paint::emoji("🛰 "),
Paint::magenta("Mounting"), Paint::magenta("Mounting"),
Paint::blue(&base_uri), Paint::blue(&base_uri),
Paint::magenta(":")); Paint::magenta("routes:"));
for route in routes.into() { for route in routes.into() {
let mounted_route = route.clone() let mounted_route = route.clone()
@ -231,13 +233,18 @@ impl Rocket {
}); });
info_!("{}", mounted_route); info_!("{}", mounted_route);
self.router.add(mounted_route); self.router.add_route(mounted_route);
} }
self self
} }
/// Registers all of the catchers in the supplied vector. /// Registers all of the catchers in the supplied vector, scoped to `base`.
///
/// # Panics
///
/// Panics if `base` is not a valid static path: a valid origin URI without
/// dynamic parameters.
/// ///
/// # Examples /// # Examples
/// ///
@ -257,23 +264,31 @@ impl Rocket {
/// ///
/// #[launch] /// #[launch]
/// fn rocket() -> rocket::Rocket { /// fn rocket() -> rocket::Rocket {
/// rocket::ignite().register(catchers![internal_error, not_found]) /// rocket::ignite().register("/", catchers![internal_error, not_found])
/// } /// }
/// ``` /// ```
pub fn register(mut self, catchers: Vec<Catcher>) -> Self { pub fn register<'a, B, C>(mut self, base: B, catchers: C) -> Self
info!("{}{}", Paint::emoji("👾 "), Paint::magenta("Catchers:")); where B: TryInto<Origin<'a>> + Clone + Display,
B::Error: Display,
C: Into<Vec<Catcher>>
{
info!("{}{} {} {}",
Paint::emoji("👾 "),
Paint::magenta("Registering"),
Paint::blue(&base),
Paint::magenta("catchers:"));
for catcher in catchers { for catcher in catchers.into() {
info_!("{}", catcher); let mounted_catcher = catcher.clone()
.map_base(|old| format!("{}{}", base, old))
.unwrap_or_else(|e| {
error_!("Catcher `{}` has a malformed URI.", catcher);
error_!("{}", e);
panic!("Invalid catcher URI.");
});
let existing = match catcher.code { info_!("{}", mounted_catcher);
Some(code) => self.catchers.insert(code, catcher), self.router.add_catcher(mounted_catcher);
None => self.default_catcher.replace(catcher)
};
if let Some(existing) = existing {
warn_!("Replacing existing '{}' catcher.", existing);
}
} }
self self
@ -444,7 +459,7 @@ impl Rocket {
/// } /// }
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn routes(&self) -> impl Iterator<Item = &Route> + '_ { pub fn routes(&self) -> impl Iterator<Item = &Route> {
self.router.routes() self.router.routes()
} }
@ -464,7 +479,7 @@ impl Rocket {
/// ///
/// fn main() { /// fn main() {
/// let mut rocket = rocket::ignite() /// let mut rocket = rocket::ignite()
/// .register(catchers![not_found, just_500, some_default]); /// .register("/", catchers![not_found, just_500, some_default]);
/// ///
/// let mut codes: Vec<_> = rocket.catchers().map(|c| c.code).collect(); /// let mut codes: Vec<_> = rocket.catchers().map(|c| c.code).collect();
/// codes.sort(); /// codes.sort();
@ -473,8 +488,8 @@ impl Rocket {
/// } /// }
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + '_ { pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
self.catchers.values().chain(self.default_catcher.as_ref()) self.router.catchers()
} }
/// Returns `Some` of the managed state value for the type `T` if it is /// Returns `Some` of the managed state value for the type `T` if it is
@ -525,10 +540,8 @@ impl Rocket {
/// * there were no fairing failures /// * there were no fairing failures
/// * a secret key, if needed, is securely configured /// * a secret key, if needed, is securely configured
pub(crate) async fn prelaunch_check(&mut self) -> Result<(), Error> { pub(crate) async fn prelaunch_check(&mut self) -> Result<(), Error> {
let collisions: Vec<_> = self.router.collisions().collect(); if let Err(collisions) = self.router.finalize() {
if !collisions.is_empty() { return Err(Error::new(ErrorKind::Collisions(collisions)));
let owned = collisions.into_iter().map(|(a, b)| (a.clone(), b.clone()));
return Err(Error::new(ErrorKind::Collision(owned.collect())));
} }
if let Some(failures) = self.fairings.failures() { if let Some(failures) = self.fairings.failures() {

View File

@ -1,46 +1,23 @@
use super::{Route, uri::Color}; use super::{Route, uri::Color};
use crate::catcher::Catcher;
use crate::http::MediaType; use crate::http::{MediaType, Status};
use crate::request::Request; use crate::request::Request;
impl Route { pub trait Collide<T = Self> {
/// Determines if two routes can match against some request. That is, if two fn collides_with(&self, other: &T) -> bool;
/// routes `collide`, there exists a request that can match against both }
/// routes.
///
/// This implementation is used at initialization to check if two user
/// routes collide before launching. Format collisions works like this:
///
/// * If route specifies a format, it only gets requests for that format.
/// * If route doesn't specify a format, it gets requests for any format.
///
/// Because query parsing is lenient, and dynamic query parameters can be
/// missing, queries do not impact whether two routes collide.
pub(crate) fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
&& self.rank == other.rank
&& paths_collide(self, other)
&& formats_collide(self, other)
}
/// Determines if this route matches against the given request. impl<'a, 'b, T: Collide> Collide<&T> for &T {
/// fn collides_with(&self, other: &&T) -> bool {
/// This means that: T::collides_with(*self, *other)
/// }
/// * The route's method matches that of the incoming request. }
/// * The route's format (if any) matches that of the incoming request.
/// - If route specifies format, it only gets requests for that format. impl Collide for MediaType {
/// - If route doesn't specify format, it gets requests for any format. fn collides_with(&self, other: &Self) -> bool {
/// * All static components in the route's path match the corresponding let collide = |a, b| a == "*" || b == "*" || a == b;
/// components in the same position in the incoming request. collide(self.top(), other.top()) && collide(self.sub(), other.sub())
/// * All static components in the route's query string are also in the
/// request query string, though in any position. If there is no query
/// in the route, requests with/without queries match.
pub(crate) fn matches(&self, req: &Request<'_>) -> bool {
self.method == req.method()
&& paths_match(self, req)
&& queries_match(self, req)
&& formats_match(self, req)
} }
} }
@ -66,6 +43,68 @@ fn paths_collide(route: &Route, other: &Route) -> bool {
|| a_segments.len() == b_segments.len() || a_segments.len() == b_segments.len()
} }
fn formats_collide(route: &Route, other: &Route) -> bool {
// When matching against the `Accept` header, the client can always
// provide a media type that will cause a collision through
// non-specificity.
if !route.method.supports_payload() {
return true;
}
// When matching against the `Content-Type` header, we'll only
// consider requests as having a `Content-Type` if they're fully
// specified. If a route doesn't have a `format`, it accepts all
// `Content-Type`s. If a request doesn't have a format, it only
// matches routes without a format.
match (route.format.as_ref(), other.format.as_ref()) {
(Some(a), Some(b)) => a.collides_with(b),
_ => true
}
}
impl Collide for Route {
/// Determines if two routes can match against some request. That is, if two
/// routes `collide`, there exists a request that can match against both
/// routes.
///
/// This implementation is used at initialization to check if two user
/// routes collide before launching. Format collisions works like this:
///
/// * If route specifies a format, it only gets requests for that format.
/// * If route doesn't specify a format, it gets requests for any format.
///
/// Because query parsing is lenient, and dynamic query parameters can be
/// missing, queries do not impact whether two routes collide.
fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
&& self.rank == other.rank
&& paths_collide(self, other)
&& formats_collide(self, other)
}
}
impl Route {
/// Determines if this route matches against the given request.
///
/// This means that:
///
/// * The route's method matches that of the incoming request.
/// * The route's format (if any) matches that of the incoming request.
/// - If route specifies format, it only gets requests for that format.
/// - If route doesn't specify format, it gets requests for any format.
/// * All static components in the route's path match the corresponding
/// components in the same position in the incoming request.
/// * All static components in the route's query string are also in the
/// request query string, though in any position. If there is no query
/// in the route, requests with/without queries match.
pub(crate) fn matches(&self, req: &Request<'_>) -> bool {
self.method == req.method()
&& paths_match(self, req)
&& queries_match(self, req)
&& formats_match(self, req)
}
}
fn paths_match(route: &Route, req: &Request<'_>) -> bool { fn paths_match(route: &Route, req: &Request<'_>) -> bool {
let route_segments = &route.uri.metadata.path_segs; let route_segments = &route.uri.metadata.path_segs;
let req_segments = req.uri().path_segments(); let req_segments = req.uri().path_segments();
@ -90,7 +129,7 @@ fn paths_match(route: &Route, req: &Request<'_>) -> bool {
return true; return true;
} }
if !route_seg.dynamic && route_seg.value != req_seg { if !(route_seg.dynamic || route_seg.value == req_seg) {
return false; return false;
} }
} }
@ -116,33 +155,16 @@ fn queries_match(route: &Route, req: &Request<'_>) -> bool {
true true
} }
fn formats_collide(route: &Route, other: &Route) -> bool {
// When matching against the `Accept` header, the client can always provide
// a media type that will cause a collision through non-specificity.
if !route.method.supports_payload() {
return true;
}
// When matching against the `Content-Type` header, we'll only consider
// requests as having a `Content-Type` if they're fully specified. If a
// route doesn't have a `format`, it accepts all `Content-Type`s. If a
// request doesn't have a format, it only matches routes without a format.
match (route.format.as_ref(), other.format.as_ref()) {
(Some(a), Some(b)) => media_types_collide(a, b),
_ => true
}
}
fn formats_match(route: &Route, request: &Request<'_>) -> bool { fn formats_match(route: &Route, request: &Request<'_>) -> bool {
if !route.method.supports_payload() { if !route.method.supports_payload() {
route.format.as_ref() route.format.as_ref()
.and_then(|a| request.format().map(|b| (a, b))) .and_then(|a| request.format().map(|b| (a, b)))
.map(|(a, b)| media_types_collide(a, b)) .map(|(a, b)| a.collides_with(b))
.unwrap_or(true) .unwrap_or(true)
} else { } else {
match route.format.as_ref() { match route.format.as_ref() {
Some(a) => match request.format() { Some(a) => match request.format() {
Some(b) if b.specificity() == 2 => media_types_collide(a, b), Some(b) if b.specificity() == 2 => a.collides_with(b),
_ => false _ => false
} }
None => true None => true
@ -150,9 +172,30 @@ fn formats_match(route: &Route, request: &Request<'_>) -> bool {
} }
} }
fn media_types_collide(first: &MediaType, other: &MediaType) -> bool {
let collide = |a, b| a == "*" || b == "*" || a == b; impl Collide for Catcher {
collide(first.top(), other.top()) && collide(first.sub(), other.sub()) /// Determines if two catchers are in conflict: there exists a request for
/// which there exist no rule to determine _which_ of the two catchers to
/// use. This means that the catchers:
///
/// * Have the same base.
/// * Have the same status code or are both defaults.
fn collides_with(&self, other: &Self) -> bool {
self.code == other.code
&& self.base.path_segments().eq(other.base.path_segments())
}
}
impl Catcher {
/// Determines if this catcher is responsible for handling the error with
/// `status` that occurred during request `req`. A catcher matches if:
///
/// * It is a default catcher _or_ has a code of `status`.
/// * Its base is a prefix of the normalized/decoded `req.path()`.
pub(crate) fn matches(&self, status: Status, req: &Request<'_>) -> bool {
self.code.map_or(true, |code| code == status.code)
&& self.base.path_segments().prefix_of(req.uri().path_segments())
}
} }
#[cfg(test)] #[cfg(test)]
@ -335,7 +378,7 @@ mod tests {
fn mt_mt_collide(mt1: &str, mt2: &str) -> bool { fn mt_mt_collide(mt1: &str, mt2: &str) -> bool {
let mt_a = MediaType::from_str(mt1).expect(mt1); let mt_a = MediaType::from_str(mt1).expect(mt1);
let mt_b = MediaType::from_str(mt2).expect(mt2); let mt_b = MediaType::from_str(mt2).expect(mt2);
media_types_collide(&mt_a, &mt_b) mt_a.collides_with(&mt_b)
} }
#[test] #[test]
@ -525,4 +568,35 @@ mod tests {
assert!(!req_route_path_match("/a/b", "/a/b?foo&<rest..>")); assert!(!req_route_path_match("/a/b", "/a/b?foo&<rest..>"));
assert!(!req_route_path_match("/a/b", "/a/b?<a>&b&<rest..>")); assert!(!req_route_path_match("/a/b", "/a/b?<a>&b&<rest..>"));
} }
fn catchers_collide<A, B>(a: A, ap: &str, b: B, bp: &str) -> bool
where A: Into<Option<u16>>, B: Into<Option<u16>>
{
let a = Catcher::new(a, crate::catcher::dummy).map_base(|_| ap.into()).unwrap();
let b = Catcher::new(b, crate::catcher::dummy).map_base(|_| bp.into()).unwrap();
a.collides_with(&b)
}
#[test]
fn catcher_collisions() {
for path in &["/a", "/foo", "/a/b/c", "/a/b/c/d/e"] {
assert!(catchers_collide(404, path, 404, path));
assert!(catchers_collide(500, path, 500, path));
assert!(catchers_collide(None, path, None, path));
}
}
#[test]
fn catcher_non_collisions() {
assert!(!catchers_collide(404, "/foo", 405, "/foo"));
assert!(!catchers_collide(404, "/", None, "/foo"));
assert!(!catchers_collide(404, "/", None, "/"));
assert!(!catchers_collide(404, "/a/b", None, "/a/b"));
assert!(!catchers_collide(404, "/a/b", 404, "/a/b/c"));
assert!(!catchers_collide(None, "/a/b", None, "/a/b/c"));
assert!(!catchers_collide(None, "/b", None, "/a/b/c"));
assert!(!catchers_collide(None, "/", None, "/a/b/c"));
}
} }

View File

@ -1,514 +1,13 @@
//! Routing types: [`Route`] and [`RouteUri`]. //! Routing types: [`Route`] and [`RouteUri`].
mod collider;
mod route; mod route;
mod segment; mod segment;
mod uri; mod uri;
mod router;
mod collider;
use std::collections::HashMap; pub(crate) use router::*;
use crate::request::Request; pub use route::Route;
use crate::http::Method; pub use collider::Collide;
pub use uri::RouteUri;
pub use self::route::Route;
pub use self::uri::RouteUri;
// type Selector = (Method, usize);
type Selector = Method;
#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Selector, Vec<Route>>,
}
impl Router {
pub fn new() -> Router {
Router { routes: HashMap::new() }
}
pub fn add(&mut self, route: Route) {
let selector = route.method;
let entries = self.routes.entry(selector).or_insert_with(|| vec![]);
let i = entries.binary_search_by_key(&route.rank, |r| r.rank)
.unwrap_or_else(|i| i);
entries.insert(i, route);
}
pub fn route<'r, 'a: 'r>(&'a self, req: &'r Request<'r>) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by rank on each `add`.
self.routes.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
}
pub(crate) fn collisions(&self) -> impl Iterator<Item = (&Route, &Route)> {
let all_routes = self.routes.values().flat_map(|v| v.iter());
all_routes.clone().enumerate()
.flat_map(move |(i, a)| {
all_routes.clone()
.skip(i + 1)
.filter(move |b| b.collides_with(a))
.map(move |b| (a, b))
})
}
#[inline]
pub fn routes(&self) -> impl Iterator<Item = &Route> {
self.routes.values().flat_map(|v| v.iter())
}
// This is slow. Don't expose this publicly; only for tests.
#[cfg(test)]
fn has_collisions(&self) -> bool {
self.collisions().next().is_some()
}
}
#[cfg(test)]
mod test {
use super::{Router, Route};
use crate::rocket::Rocket;
use crate::config::Config;
use crate::http::{Method, Method::*};
use crate::http::uri::Origin;
use crate::request::Request;
use crate::handler::dummy;
fn router_with_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = dbg!(Route::new(Get, route, dummy));
router.add(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy);
router.add(route);
}
router
}
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::ranked(0, Get, route, dummy);
router.add(route);
}
router
}
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
router.has_collisions()
}
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
router.has_collisions()
}
#[test]
fn test_rankless_collisions() {
assert!(rankless_route_collisions(&["/hello", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>"]));
assert!(rankless_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/a/b/<c>/d", "/<a>/<b>/c/d"]));
assert!(rankless_route_collisions(&["/a/b", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a>/b", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<b>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<c>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c/d", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/a/d/<b..>", "/a/d"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/a"]));
assert!(rankless_route_collisions(&["/<a>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/b"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<b>"]));
assert!(rankless_route_collisions(&["/<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/<_>/b", "/a/b"]));
assert!(rankless_route_collisions(&["/", "/<foo..>"]));
}
#[test]
fn test_collisions_normalize() {
assert!(rankless_route_collisions(&["/hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello///"]));
assert!(rankless_route_collisions(&["/hello///bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/<a..>//", "/a//<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/"]));
assert!(rankless_route_collisions(&["/a/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["//", "/<foo..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["/a//<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["///<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "///a//b"]));
assert!(rankless_route_collisions(&["//a///<_>", "/a//<b>"]));
assert!(rankless_route_collisions(&["//<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["//<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["///<a>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["///<a..>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/hello"]));
}
#[test]
fn test_collisions_query() {
// Query shouldn't affect things when rankless.
assert!(rankless_route_collisions(&["/hello?<foo>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>?<foo>"]));
assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/<b>?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b", "/foo?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e&<c>"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e"]));
}
#[test]
fn test_no_collisions() {
assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c/<d>/e"]));
assert!(!rankless_route_collisions(&["/a/d/<b..>", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/<_>", "/"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/a"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"]));
}
#[test]
fn test_no_collision_when_ranked() {
assert!(!default_rank_route_collisions(&["/<a>", "/hello"]));
assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/a/b/c/d", "/<a>/<b>/c/d"]));
assert!(!default_rank_route_collisions(&["/hi", "/<hi>"]));
assert!(!default_rank_route_collisions(&["/a", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/", "/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/b", "/a/b/<c..>"]));
assert!(!default_rank_route_collisions(&["/<_>", "/static"]));
assert!(!default_rank_route_collisions(&["/<_..>", "/static"]));
assert!(!default_rank_route_collisions(&["/<path..>", "/"]));
assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/<a>/<b>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a>/<b..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello"]));
assert!(!default_rank_route_collisions(&["/<a>", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/<b>/c", "/<d>/<c..>"]));
}
#[test]
fn test_collision_when_ranked() {
assert!(default_rank_route_collisions(&["/a/<b>/<c..>", "/a/<c>"]));
assert!(default_rank_route_collisions(&["/<a>/b", "/a/<b>"]));
}
#[test]
fn test_collision_when_ranked_query() {
assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b>", "/a?<c>&c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b..>", "/a?<c>&c=d"]));
}
#[test]
fn test_no_collision_when_ranked_query() {
assert!(!default_rank_route_collisions(&["/", "/?<c..>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?<c>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/hi?<c>", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
}
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
let route = router.route(&request).next();
route
}
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
router.route(&request).collect()
}
#[test]
fn test_ok_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Get, "/hello").is_some());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Get, "/hello").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/bobbbbbbbbbby").is_some());
assert!(route(&router, Get, "/dsfhjasdf").is_some());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/jdlk/asdij").is_some());
let mut router = Router::new();
router.add(Route::new(Put, "/hello", dummy));
router.add(Route::new(Post, "/hello", dummy));
router.add(Route::new(Delete, "/hello", dummy));
assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some());
let router = router_with_routes(&["/<a..>"]);
assert!(route(&router, Get, "/").is_some());
assert!(route(&router, Get, "//").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/a/b/c/d/e/f").is_some());
let router = router_with_routes(&["/foo/<a..>"]);
assert!(route(&router, Get, "/foo").is_some());
assert!(route(&router, Get, "/foo/").is_some());
assert!(route(&router, Get, "/foo///bar").is_some());
}
#[test]
fn test_err_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hell").is_none());
assert!(route(&router, Get, "/hi").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
assert!(route(&router, Get, "/hillo").is_none());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/a/b/c").is_none());
assert!(route(&router, Get, "/a").is_none());
assert!(route(&router, Get, "/a/").is_none());
assert!(route(&router, Get, "/a/b/c/d").is_none());
assert!(route(&router, Put, "/hello/hi").is_none());
assert!(route(&router, Put, "/a/b").is_none());
assert!(route(&router, Put, "/a/b").is_none());
let router = router_with_routes(&["/prefix/<a..>"]);
assert!(route(&router, Get, "/").is_none());
assert!(route(&router, Get, "/prefi/").is_none());
}
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string());
})
}
#[test]
fn test_default_ranking() {
assert_ranked_match!(&["/hello", "/<name>"], "/hello" => "/hello");
assert_ranked_match!(&["/<name>", "/hello"], "/hello" => "/hello");
assert_ranked_match!(&["/<a>", "/hi", "/hi/<b>"], "/hi" => "/hi");
assert_ranked_match!(&["/<a>/b", "/hi/c"], "/hi/c" => "/hi/c");
assert_ranked_match!(&["/<a>/<b>", "/hi/a"], "/hi/c" => "/<a>/<b>");
assert_ranked_match!(&["/hi/a", "/hi/<c>"], "/hi/c" => "/hi/<c>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b?v=1" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/<a>", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/a", "/<a>?b"], "/a?b" => "/a");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?b" => "/a?<c>&b");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?c" => "/a?<b>");
assert_ranked_match!(&["/", "/<foo..>"], "/" => "/");
assert_ranked_match!(&["/", "/<foo..>"], "/hi" => "/<foo..>");
assert_ranked_match!(&["/hi", "/<foo..>"], "/hi" => "/hi");
}
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
router.has_collisions()
}
#[test]
fn test_no_manual_ranked_collisions() {
assert!(!ranked_collisions(&[(1, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(0, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(5, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b>"), (1, "/b/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(0, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(5, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(1, "/<a..>"), (2, "/<a..>")]));
}
#[test]
fn test_ranked_collisions() {
assert!(ranked_collisions(&[(2, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/a/c/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/<b..>"), (2, "/a/<b..>")]));
}
macro_rules! assert_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string());
}
})
}
#[test]
fn test_ranked_routing() {
assert_ranked_routing!(
to: "/a/b",
with: [(1, "/a/<b>"), (2, "/a/<b>")],
expect: (1, "/a/<b>"), (2, "/a/<b>")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(2, "/b/<b>"), (1, "/a/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(3, "/b/b"), (2, "/b/<b>"), (1, "/a/<b>")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (0, "/b/b")],
expect: (0, "/b/b"), (2, "/b/<b>")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(1, "/<a>/<b>/edit"), (2, "/profile/<d>"), (0, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/<c>"), (1, "/<a>/<b>/edit")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(0, "/<a>/<b>/edit"), (2, "/profile/<d>"), (5, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/edit"), (5, "/<a>/<b>/<c>")
);
assert_ranked_routing!(
to: "/a/b",
with: [(0, "/a/b"), (1, "/a/<b..>")],
expect: (0, "/a/b"), (1, "/a/<b..>")
);
assert_ranked_routing!(
to: "/a/b/c/d/e/f",
with: [(1, "/a/<b..>"), (2, "/a/b/<c..>")],
expect: (1, "/a/<b..>"), (2, "/a/b/<c..>")
);
assert_ranked_routing!(
to: "/hi",
with: [(1, "/hi/<foo..>"), (0, "/hi/<foo>")],
expect: (1, "/hi/<foo..>")
);
}
macro_rules! assert_default_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.uri.to_string(), expected.to_string());
}
})
}
#[test]
fn test_default_ranked_routing() {
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b"],
expect: "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?c",
with: ["/a/b", "/a/b?<c>", "/a/b?c", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"],
expect: "/a/b?c", "/a/b?<c>", "/a/b", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"
);
}
}

View File

@ -162,7 +162,6 @@ impl Route {
} }
} }
/// Maps the `base` of this route using `mapper`, returning a new `Route` /// Maps the `base` of this route using `mapper`, returning a new `Route`
/// with the returned base. /// with the returned base.
/// ///

View File

@ -0,0 +1,670 @@
use std::collections::HashMap;
use crate::request::Request;
use crate::http::{Method, Status};
pub use crate::router::{Route, RouteUri};
pub use crate::router::collider::Collide;
pub use crate::catcher::Catcher;
#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Method, Vec<Route>>,
catchers: HashMap<Option<u16>, Vec<Catcher>>,
}
#[derive(Debug)]
pub struct Collisions {
pub routes: Vec<(Route, Route)>,
pub catchers: Vec<(Catcher, Catcher)>,
}
impl Router {
pub fn new() -> Self {
Self::default()
}
pub fn add_route(&mut self, route: Route) {
let routes = self.routes.entry(route.method).or_default();
routes.push(route);
routes.sort_by_key(|r| r.rank);
}
pub fn add_catcher(&mut self, catcher: Catcher) {
let catchers = self.catchers.entry(catcher.code).or_default();
catchers.push(catcher);
catchers.sort_by(|a, b| b.base.path_segments().len().cmp(&a.base.path_segments().len()))
}
#[inline]
pub fn routes(&self) -> impl Iterator<Item = &Route> + Clone {
self.routes.values().flat_map(|v| v.iter())
}
#[inline]
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + Clone {
self.catchers.values().flat_map(|v| v.iter())
}
pub fn route<'r, 'a: 'r>(
&'a self,
req: &'r Request<'r>
) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by ascending rank on each `add`.
self.routes.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
}
// For many catchers, using aho-corasick or similar should be much faster.
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> {
// Note that catchers are presorted by descending base length.
let explicit = self.catchers.get(&Some(status.code))
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
let default = self.catchers.get(&None)
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
match (explicit, default) {
(None, None) => None,
(None, c@Some(_)) | (c@Some(_), None) => c,
(Some(a), Some(b)) => {
if b.base.path_segments().len() > a.base.path_segments().len() {
Some(b)
} else {
Some(a)
}
}
}
}
fn collisions<'a, I: 'a, T: 'a>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a
where I: Iterator<Item = &'a T> + Clone, T: Collide + Clone,
{
items.clone().enumerate()
.flat_map(move |(i, a)| {
items.clone()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
}
pub fn finalize(&self) -> Result<(), Collisions> {
let routes: Vec<_> = self.collisions(self.routes()).collect();
let catchers: Vec<_> = self.collisions(self.catchers()).collect();
if !routes.is_empty() || !catchers.is_empty() {
return Err(Collisions { routes, catchers })
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::rocket::Rocket;
use crate::config::Config;
use crate::http::{Method, Method::*};
use crate::http::uri::Origin;
use crate::request::Request;
use crate::handler::dummy;
impl Router {
fn has_collisions(&self) -> bool {
self.finalize().is_err()
}
}
fn router_with_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::new(Get, route, dummy);
router.add_route(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy);
router.add_route(route);
}
router
}
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::ranked(0, Get, route, dummy);
router.add_route(route);
}
router
}
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
router.has_collisions()
}
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
router.has_collisions()
}
#[test]
fn test_rankless_collisions() {
assert!(rankless_route_collisions(&["/hello", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>"]));
assert!(rankless_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/a/b/<c>/d", "/<a>/<b>/c/d"]));
assert!(rankless_route_collisions(&["/a/b", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a>/b", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<b>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<c>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c/d", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/a/d/<b..>", "/a/d"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/a"]));
assert!(rankless_route_collisions(&["/<a>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/b"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<b>"]));
assert!(rankless_route_collisions(&["/<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/<_>/b", "/a/b"]));
assert!(rankless_route_collisions(&["/", "/<foo..>"]));
}
#[test]
fn test_collisions_normalize() {
assert!(rankless_route_collisions(&["/hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello///"]));
assert!(rankless_route_collisions(&["/hello///bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/<a..>//", "/a//<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/"]));
assert!(rankless_route_collisions(&["/a/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["//", "/<foo..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["/a//<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["///<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "///a//b"]));
assert!(rankless_route_collisions(&["//a///<_>", "/a//<b>"]));
assert!(rankless_route_collisions(&["//<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["//<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["///<a>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["///<a..>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/hello"]));
}
#[test]
fn test_collisions_query() {
// Query shouldn't affect things when rankless.
assert!(rankless_route_collisions(&["/hello?<foo>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>?<foo>"]));
assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/<b>?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b", "/foo?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e&<c>"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e"]));
}
#[test]
fn test_no_collisions() {
assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c/<d>/e"]));
assert!(!rankless_route_collisions(&["/a/d/<b..>", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/<_>", "/"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/a"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"]));
}
#[test]
fn test_no_collision_when_ranked() {
assert!(!default_rank_route_collisions(&["/<a>", "/hello"]));
assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/a/b/c/d", "/<a>/<b>/c/d"]));
assert!(!default_rank_route_collisions(&["/hi", "/<hi>"]));
assert!(!default_rank_route_collisions(&["/a", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/", "/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/b", "/a/b/<c..>"]));
assert!(!default_rank_route_collisions(&["/<_>", "/static"]));
assert!(!default_rank_route_collisions(&["/<_..>", "/static"]));
assert!(!default_rank_route_collisions(&["/<path..>", "/"]));
assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/<a>/<b>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a>/<b..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello"]));
assert!(!default_rank_route_collisions(&["/<a>", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/<b>/c", "/<d>/<c..>"]));
}
#[test]
fn test_collision_when_ranked() {
assert!(default_rank_route_collisions(&["/a/<b>/<c..>", "/a/<c>"]));
assert!(default_rank_route_collisions(&["/<a>/b", "/a/<b>"]));
}
#[test]
fn test_collision_when_ranked_query() {
assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b>", "/a?<c>&c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b..>", "/a?<c>&c=d"]));
}
#[test]
fn test_no_collision_when_ranked_query() {
assert!(!default_rank_route_collisions(&["/", "/?<c..>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?<c>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/hi?<c>", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
}
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
let route = router.route(&request).next();
route
}
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
router.route(&request).collect()
}
#[test]
fn test_ok_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Get, "/hello").is_some());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Get, "/hello").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/bobbbbbbbbbby").is_some());
assert!(route(&router, Get, "/dsfhjasdf").is_some());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/jdlk/asdij").is_some());
let mut router = Router::new();
router.add_route(Route::new(Put, "/hello", dummy));
router.add_route(Route::new(Post, "/hello", dummy));
router.add_route(Route::new(Delete, "/hello", dummy));
assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some());
let router = router_with_routes(&["/<a..>"]);
assert!(route(&router, Get, "/").is_some());
assert!(route(&router, Get, "//").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/a/b/c/d/e/f").is_some());
let router = router_with_routes(&["/foo/<a..>"]);
assert!(route(&router, Get, "/foo").is_some());
assert!(route(&router, Get, "/foo/").is_some());
assert!(route(&router, Get, "/foo///bar").is_some());
}
#[test]
fn test_err_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hell").is_none());
assert!(route(&router, Get, "/hi").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
assert!(route(&router, Get, "/hillo").is_none());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/a/b/c").is_none());
assert!(route(&router, Get, "/a").is_none());
assert!(route(&router, Get, "/a/").is_none());
assert!(route(&router, Get, "/a/b/c/d").is_none());
assert!(route(&router, Put, "/hello/hi").is_none());
assert!(route(&router, Put, "/a/b").is_none());
assert!(route(&router, Put, "/a/b").is_none());
let router = router_with_routes(&["/prefix/<a..>"]);
assert!(route(&router, Get, "/").is_none());
assert!(route(&router, Get, "/prefi/").is_none());
}
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string());
})
}
#[test]
fn test_default_ranking() {
assert_ranked_match!(&["/hello", "/<name>"], "/hello" => "/hello");
assert_ranked_match!(&["/<name>", "/hello"], "/hello" => "/hello");
assert_ranked_match!(&["/<a>", "/hi", "/hi/<b>"], "/hi" => "/hi");
assert_ranked_match!(&["/<a>/b", "/hi/c"], "/hi/c" => "/hi/c");
assert_ranked_match!(&["/<a>/<b>", "/hi/a"], "/hi/c" => "/<a>/<b>");
assert_ranked_match!(&["/hi/a", "/hi/<c>"], "/hi/c" => "/hi/<c>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b?v=1" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/<a>", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/a", "/<a>?b"], "/a?b" => "/a");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?b" => "/a?<c>&b");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?c" => "/a?<b>");
assert_ranked_match!(&["/", "/<foo..>"], "/" => "/");
assert_ranked_match!(&["/", "/<foo..>"], "/hi" => "/<foo..>");
assert_ranked_match!(&["/hi", "/<foo..>"], "/hi" => "/hi");
}
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
router.has_collisions()
}
#[test]
fn test_no_manual_ranked_collisions() {
assert!(!ranked_collisions(&[(1, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(0, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(5, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b>"), (1, "/b/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(0, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(5, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(1, "/<a..>"), (2, "/<a..>")]));
}
#[test]
fn test_ranked_collisions() {
assert!(ranked_collisions(&[(2, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/a/c/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/<b..>"), (2, "/a/<b..>")]));
}
macro_rules! assert_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string());
}
})
}
#[test]
fn test_ranked_routing() {
assert_ranked_routing!(
to: "/a/b",
with: [(1, "/a/<b>"), (2, "/a/<b>")],
expect: (1, "/a/<b>"), (2, "/a/<b>")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(2, "/b/<b>"), (1, "/a/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(3, "/b/b"), (2, "/b/<b>"), (1, "/a/<b>")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (0, "/b/b")],
expect: (0, "/b/b"), (2, "/b/<b>")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(1, "/<a>/<b>/edit"), (2, "/profile/<d>"), (0, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/<c>"), (1, "/<a>/<b>/edit")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(0, "/<a>/<b>/edit"), (2, "/profile/<d>"), (5, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/edit"), (5, "/<a>/<b>/<c>")
);
assert_ranked_routing!(
to: "/a/b",
with: [(0, "/a/b"), (1, "/a/<b..>")],
expect: (0, "/a/b"), (1, "/a/<b..>")
);
assert_ranked_routing!(
to: "/a/b/c/d/e/f",
with: [(1, "/a/<b..>"), (2, "/a/b/<c..>")],
expect: (1, "/a/<b..>"), (2, "/a/b/<c..>")
);
assert_ranked_routing!(
to: "/hi",
with: [(1, "/hi/<foo..>"), (0, "/hi/<foo>")],
expect: (1, "/hi/<foo..>")
);
}
macro_rules! assert_default_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.uri.to_string(), expected.to_string());
}
})
}
#[test]
fn test_default_ranked_routing() {
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b"],
expect: "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?c",
with: ["/a/b", "/a/b?<c>", "/a/b?c", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"],
expect: "/a/b?c", "/a/b?<c>", "/a/b", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"
);
}
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Router {
let mut router = Router::new();
for (code, base) in catchers {
let catcher = Catcher::new(*code, crate::catcher::dummy);
router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap());
}
router
}
fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, Method::Get, Origin::parse(uri).unwrap());
router.catch(status, &request)
}
macro_rules! assert_catcher_routing {
(
catch: [$(($code:expr, $uri:expr)),+],
reqs: [$($r:expr),+],
with: [$(($ecode:expr, $euri:expr)),+]
) => ({
let catchers = vec![$(($code.into(), $uri)),+];
let requests = vec![$($r),+];
let expected = vec![$(($ecode.into(), $euri)),+];
let router = router_with_catchers(&catchers);
for (req, expected) in requests.iter().zip(expected.iter()) {
let req_status = Status::from_code(req.0).expect("valid status");
let catcher = catcher(&router, req_status, req.1).expect("some catcher");
assert_eq!(catcher.code, expected.0, "<- got, expected ->");
assert_eq!(catcher.base.path(), expected.1, "<- got, expected ->");
}
})
}
#[test]
fn test_catcher_routing() {
// Check that the default `/` catcher catches everything.
assert_catcher_routing! {
catch: [(None, "/")],
reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")],
with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")]
}
// Check prefixes when they're exact.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a"), (None, "/a/b")],
reqs: [
(404, "/"), (500, "/"),
(404, "/a"), (500, "/a"),
(404, "/a/b"), (500, "/a/b")
],
with: [
(None, "/"), (None, "/"),
(None, "/a"), (None, "/a"),
(None, "/a/b"), (None, "/a/b")
]
}
// Check prefixes when they're not exact.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a"), (None, "/a/b")],
reqs: [
(404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"),
(404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"),
(404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes")
],
with: [
(None, "/"), (None, "/"), (None, "/"), (None, "/"),
(None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"),
(None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b")
]
}
// Check that we prefer specific to default.
assert_catcher_routing! {
catch: [(400, "/"), (404, "/"), (None, "/")],
reqs: [
(400, "/"), (400, "/bar"), (400, "/foo/bar"),
(404, "/"), (404, "/bar"), (404, "/foo/bar"),
(405, "/"), (405, "/bar"), (406, "/foo/bar")
],
with: [
(400, "/"), (400, "/"), (400, "/"),
(404, "/"), (404, "/"), (404, "/"),
(None, "/"), (None, "/"), (None, "/")
]
}
// Check that we prefer longer prefixes over specific.
assert_catcher_routing! {
catch: [(None, "/a/b"), (404, "/a"), (422, "/a")],
reqs: [
(404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"),
(404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"),
(422, "/a"), (422, "/a/c"), (422, "/a/cat/bar")
],
with: [
(None, "/a/b"), (None, "/a/b"), (None, "/a/b"),
(404, "/a"), (404, "/a"), (404, "/a"),
(422, "/a"), (422, "/a"), (422, "/a")
]
}
// Just a fun one.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")],
reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")],
with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")]
}
}
}

View File

@ -329,11 +329,7 @@ impl Rocket {
// response. We may wish to relax this in the future. // response. We may wish to relax this in the future.
req.cookies().reset_delta(); req.cookies().reset_delta();
// Try to get the active catcher if let Some(catcher) = self.router.catch(status, req) {
let catcher = self.catchers.get(&status.code)
.or_else(|| self.default_catcher.as_ref());
if let Some(catcher) = catcher {
warn_!("Responding with registered {} catcher.", catcher); warn_!("Responding with registered {} catcher.", catcher);
let name = catcher.name.as_deref(); let name = catcher.name.as_deref();
handle(name, || catcher.handler.handle(status, req)).await handle(name, || catcher.handler.handle(status, req)).await

View File

@ -24,7 +24,7 @@ mod tests {
fn error_catcher_sets_cookies() { fn error_catcher_sets_cookies() {
let rocket = rocket::ignite() let rocket = rocket::ignite()
.mount("/", routes![index]) .mount("/", routes![index])
.register(catchers![not_found]) .register("/", catchers![not_found])
.attach(AdHoc::on_request("Add Cookie", |req, _| Box::pin(async move { .attach(AdHoc::on_request("Add Cookie", |req, _| Box::pin(async move {
req.cookies().add(Cookie::new("fairing", "woo")); req.cookies().add(Cookie::new("fairing", "woo"));
}))); })));

View File

@ -22,30 +22,20 @@ fn ise() -> &'static str {
"Hey, sorry! :(" "Hey, sorry! :("
} }
#[catch(500)]
fn double_panic() {
panic!("so, so sorry...")
}
fn pre_future_route<'r>(_: &'r Request<'_>, _: Data) -> HandlerFuture<'r> { fn pre_future_route<'r>(_: &'r Request<'_>, _: Data) -> HandlerFuture<'r> {
panic!("hey now..."); panic!("hey now...");
} }
fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> {
panic!("a panicking pre-future catcher")
}
fn rocket() -> Rocket { fn rocket() -> Rocket {
let pre_future_panic = Route::new(Method::Get, "/pre", pre_future_route);
rocket::ignite() rocket::ignite()
.mount("/", routes![panic_route]) .mount("/", routes![panic_route])
.mount("/", vec![pre_future_panic]) .mount("/", vec![Route::new(Method::Get, "/pre", pre_future_route)])
.register(catchers![panic_catcher, ise])
} }
#[test] #[test]
fn catches_route_panic() { fn catches_route_panic() {
let client = Client::debug(rocket()).unwrap(); let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/panic").dispatch(); let response = client.get("/panic").dispatch();
assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -53,7 +43,8 @@ fn catches_route_panic() {
#[test] #[test]
fn catches_catcher_panic() { fn catches_catcher_panic() {
let client = Client::debug(rocket()).unwrap(); let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/noroute").dispatch(); let response = client.get("/noroute").dispatch();
assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -61,7 +52,12 @@ fn catches_catcher_panic() {
#[test] #[test]
fn catches_double_panic() { fn catches_double_panic() {
let rocket = rocket().register(catchers![double_panic]); #[catch(500)]
fn double_panic() {
panic!("so, so sorry...")
}
let rocket = rocket().register("/", catchers![panic_catcher, double_panic]);
let client = Client::debug(rocket).unwrap(); let client = Client::debug(rocket).unwrap();
let response = client.get("/noroute").dispatch(); let response = client.get("/noroute").dispatch();
assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.status(), Status::InternalServerError);
@ -70,7 +66,8 @@ fn catches_double_panic() {
#[test] #[test]
fn catches_early_route_panic() { fn catches_early_route_panic() {
let client = Client::debug(rocket()).unwrap(); let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/pre").dispatch(); let response = client.get("/pre").dispatch();
assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -78,9 +75,15 @@ fn catches_early_route_panic() {
#[test] #[test]
fn catches_early_catcher_panic() { fn catches_early_catcher_panic() {
let panic_catcher = Catcher::new(404, pre_future_catcher); fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> {
panic!("a panicking pre-future catcher")
}
let client = Client::debug(rocket().register(vec![panic_catcher])).unwrap(); let rocket = rocket()
.register("/", vec![Catcher::new(404, pre_future_catcher)])
.register("/", catchers![ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/idontexist").dispatch(); let response = client.get("/idontexist").dispatch();
assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");

View File

@ -14,7 +14,7 @@ mod tests {
#[test] #[test]
fn error_catcher_redirect() { fn error_catcher_redirect() {
let client = Client::debug(rocket::ignite().register(catchers![not_found])).unwrap(); let client = Client::debug(rocket::ignite().register("/", catchers![not_found])).unwrap();
let response = client.get("/unknown").dispatch(); let response = client.get("/unknown").dispatch();
let location: Vec<_> = response.headers().get("location").collect(); let location: Vec<_> = response.headers().get("location").collect();

View File

@ -63,5 +63,5 @@ fn not_found(request: &Request<'_>) -> Html<String> {
fn rocket() -> rocket::Rocket { fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/hello", routes![get_hello, post_hello]) .mount("/hello", routes![get_hello, post_hello])
.register(catchers![not_found]) .register("/", catchers![not_found])
} }

View File

@ -7,7 +7,7 @@ use rocket::response::{content, status};
use rocket::http::Status; use rocket::http::Status;
#[get("/hello/<name>/<age>")] #[get("/hello/<name>/<age>")]
fn hello(name: String, age: i8) -> String { fn hello(name: &str, age: i8) -> String {
format!("Hello, {} year old named {}!", age, name) format!("Hello, {} year old named {}!", age, name)
} }
@ -17,12 +17,26 @@ fn forced_error(code: u16) -> Status {
} }
#[catch(404)] #[catch(404)]
fn not_found(req: &Request<'_>) -> content::Html<String> { fn general_not_found() -> content::Html<&'static str> {
content::Html(format!("<p>Sorry, but '{}' is not a valid path!</p> content::Html(r#"
<p>Hmm... What are you looking for?</p>
Say <a href="/hello/Sergio/100">hello!</a>
"#)
}
#[catch(404)]
fn hello_not_found(req: &Request<'_>) -> content::Html<String> {
content::Html(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>", <p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
req.uri())) req.uri()))
} }
#[catch(default)]
fn sergio_error() -> &'static str {
"I...don't know what to say."
}
#[catch(default)] #[catch(default)]
fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom<String> { fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom<String> {
let msg = format!("{} - {} ({})", status.code, status.reason, req.uri()); let msg = format!("{} - {} ({})", status.code, status.reason, req.uri());
@ -33,7 +47,9 @@ fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
// .mount("/", routes![hello, hello]) // uncoment this to get an error // .mount("/", routes![hello, hello]) // uncoment this to get an error
.mount("/", routes![hello, forced_error]) .mount("/", routes![hello, forced_error])
.register(catchers![not_found, default_catcher]) .register("/", catchers![general_not_found, default_catcher])
.register("/hello", catchers![hello_not_found])
.register("/hello/Sergio", catchers![sergio_error])
} }
#[rocket::main] #[rocket::main]

View File

@ -14,11 +14,11 @@ fn test_hello() {
} }
#[test] #[test]
fn forced_error_and_default_catcher() { fn forced_error() {
let client = Client::tracked(super::rocket()).unwrap(); let client = Client::tracked(super::rocket()).unwrap();
let request = client.get("/404"); let request = client.get("/404");
let expected = super::not_found(request.inner()); let expected = super::general_not_found();
let response = request.dispatch(); let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound); assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected.0); assert_eq!(response.into_string().unwrap(), expected.0);
@ -46,11 +46,24 @@ fn forced_error_and_default_catcher() {
fn test_hello_invalid_age() { fn test_hello_invalid_age() {
let client = Client::tracked(super::rocket()).unwrap(); let client = Client::tracked(super::rocket()).unwrap();
for &(name, age) in &[("Ford", -129), ("Trillian", 128)] { for path in &["Ford/-129", "Trillian/128", "foo/bar/baz"] {
let request = client.get(format!("/hello/{}/{}", name, age)); let request = client.get(format!("/hello/{}", path));
let expected = super::not_found(request.inner()); let expected = super::hello_not_found(request.inner());
let response = request.dispatch(); let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound); assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected.0); assert_eq!(response.into_string().unwrap(), expected.0);
} }
} }
#[test]
fn test_hello_sergio() {
let client = Client::tracked(super::rocket()).unwrap();
for path in &["oops", "-129", "foo/bar", "/foo/bar/baz"] {
let request = client.get(format!("/hello/Sergio/{}", path));
let expected = super::sergio_error();
let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected);
}
}

View File

@ -69,7 +69,7 @@ fn wow_helper(
fn rocket() -> rocket::Rocket { fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/", routes![index, hello, about]) .mount("/", routes![index, hello, about])
.register(catchers![not_found]) .register("/", catchers![not_found])
.attach(Template::custom(|engines| { .attach(Template::custom(|engines| {
engines.handlebars.register_helper("wow", Box::new(wow_helper)); engines.handlebars.register_helper("wow", Box::new(wow_helper));
})) }))

View File

@ -10,7 +10,7 @@ fn hello() -> &'static str {
fn rocket() -> rocket::Rocket { fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/", rocket::routes![hello]) .mount("/", rocket::routes![hello])
.register(rocket::catchers![not_found]) .register("/", rocket::catchers![not_found])
} }
#[rocket::catch(404)] #[rocket::catch(404)]

View File

@ -75,6 +75,6 @@ fn not_found() -> JsonValue {
fn rocket() -> _ { fn rocket() -> _ {
rocket::ignite() rocket::ignite()
.mount("/message", routes![new, update, get, echo]) .mount("/message", routes![new, update, get, echo])
.register(catchers![not_found]) .register("/", catchers![not_found])
.manage(Mutex::new(HashMap::<Id, String>::new())) .manage(Mutex::new(HashMap::<Id, String>::new()))
} }

View File

@ -110,5 +110,5 @@ fn rocket() -> rocket::Rocket {
.mount("/hello", vec![name.clone()]) .mount("/hello", vec![name.clone()])
.mount("/hi", vec![name]) .mount("/hi", vec![name])
.mount("/custom", CustomHandler::new("some data here")) .mount("/custom", CustomHandler::new("some data here"))
.register(vec![not_found_catcher]) .register("/", vec![not_found_catcher])
} }

View File

@ -37,5 +37,5 @@ fn rocket() -> rocket::Rocket {
rocket::ignite() rocket::ignite()
.mount("/", routes![index, get]) .mount("/", routes![index, get])
.attach(Template::fairing()) .attach(Template::fairing())
.register(catchers![not_found]) .register("/", catchers![not_found])
} }

View File

@ -1716,8 +1716,8 @@ Application processing is fallible. Errors arise from the following sources:
* A routing failure. * A routing failure.
If any of these occur, Rocket returns an error to the client. To generate the If any of these occur, Rocket returns an error to the client. To generate the
error, Rocket invokes the _catcher_ corresponding to the error's status code. error, Rocket invokes the _catcher_ corresponding to the error's status code and
Catchers are similar to routes except in that: scope. Catchers are similar to routes except in that:
1. Catchers are only invoked on error conditions. 1. Catchers are only invoked on error conditions.
2. Catchers are declared with the `catch` attribute. 2. Catchers are declared with the `catch` attribute.
@ -1725,6 +1725,7 @@ Catchers are similar to routes except in that:
4. Any modifications to cookies are cleared before a catcher is invoked. 4. Any modifications to cookies are cleared before a catcher is invoked.
5. Error catchers cannot invoke guards. 5. Error catchers cannot invoke guards.
6. Error catchers should not fail to produce a response. 6. Error catchers should not fail to produce a response.
7. Catchers are scoped to a path prefix.
To declare a catcher for a given status code, use the [`catch`] attribute, which To declare a catcher for a given status code, use the [`catch`] attribute, which
takes a single integer corresponding to the HTTP status code to catch. For takes a single integer corresponding to the HTTP status code to catch. For
@ -1770,36 +1771,96 @@ looks like:
# #[catch(404)] fn not_found(req: &Request) { /* .. */ } # #[catch(404)] fn not_found(req: &Request) { /* .. */ }
fn main() { fn main() {
rocket::ignite().register(catchers![not_found]); rocket::ignite().register("/", catchers![not_found]);
} }
``` ```
### Default Catchers ### Scoping
If no catcher for a given status code has been registered, Rocket calls the The first argument to `register()` is a path to scope the catcher under called
_default_ catcher. Rocket provides a default catcher for all applications the catcher's _base_. A catcher's base determines which requests it will handle
automatically, so providing one is usually unnecessary. Rocket's built-in errors for. Specifically, a catcher's base must be a prefix of the erroring
default catcher can handle all errors. It produces HTML or JSON, depending on request for it to be invoked. When multiple catchers can be invoked, the catcher
the value of the `Accept` header. As such, a default catcher, or catchers in with the longest base takes precedence.
general, only need to be registered if an error needs to be handled in a custom
fashion.
Declaring a default catcher is done with `#[catch(default)]`: As an example, consider the following application:
```rust
# #[macro_use] extern crate rocket;
#[catch(404)]
fn general_not_found() -> &'static str {
"General 404"
}
#[catch(404)]
fn foo_not_found() -> &'static str {
"Foo 404"
}
#[launch]
fn rocket() -> _ {
rocket::ignite()
.register("/", catchers![general_not_found])
.register("/foo", catchers![foo_not_found])
}
# let client = rocket::local::blocking::Client::debug(rocket()).unwrap();
#
# let response = client.get("/").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/bar").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/bar/baz").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/foo").dispatch();
# assert_eq!(response.into_string().unwrap(), "Foo 404");
#
# let response = client.get("/foo/bar").dispatch();
# assert_eq!(response.into_string().unwrap(), "Foo 404");
```
Since there are no mounted routes, all requests will `404`. Any request whose
path begins with `/foo` (i.e, `GET /foo`, `GET /foo/bar`, etc) will be handled
by the `foo_not_found` catcher while all other requests will be handled by the
`general_not_found` catcher.
### Default Catchers
A _default_ catcher is a catcher that handles _all_ status codes. They are
invoked as a fallback if no status-specific catcher is registered for a given
error. Declaring a default catcher is done with `#[catch(default)]` and must
similarly be registered with [`register()`]:
```rust ```rust
# #[macro_use] extern crate rocket; # #[macro_use] extern crate rocket;
# fn main() {}
use rocket::Request; use rocket::Request;
use rocket::http::Status; use rocket::http::Status;
#[catch(default)] #[catch(default)]
fn default_catcher(status: Status, request: &Request) { /* .. */ } fn default_catcher(status: Status, request: &Request) { /* .. */ }
#[launch]
fn rocket() -> _ {
rocket::ignite().register("/", catchers![default_catcher])
}
``` ```
It must similarly be registered with [`register()`]. Catchers with longer bases are preferred, even when there is a status-specific
catcher. In other words, a default catcher with a longer matching base than a
status-specific catcher takes precedence.
The [error catcher example](@example/errors) illustrates their use in full, ### Built-In Catcher
Rocket provides a built-in default catcher. It produces HTML or JSON, depending
on the value of the `Accept` header. As such, custom catchers only need to be
registered for custom error handling.
The [error catcher example](@example/errors) illustrates catcher use in full,
while the [`Catcher`] API documentation provides further details. while the [`Catcher`] API documentation provides further details.
[`catch`]: @api/rocket/attr.catch.html [`catch`]: @api/rocket/attr.catch.html