From 2893ce754d6535e0a752586e60d7e292343016c0 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 25 Mar 2021 21:36:00 -0700 Subject: [PATCH] Introduce scoped catchers. Catchers can now be scoped to paths, with preference given to the longest-prefix, then the status code. This a breaking change for all applications that register catchers: * `Rocket::register()` takes a base path to scope catchers under. - The previous behavior is recovered with `::register("/", ...)`. * Catchers now fallibly, instead of silently, collide. * `ErrorKind::Collision` is now `ErrorKind::Collisions`. Related changes: * `Origin` implements `TryFrom`, `TryFrom<&str>`. * All URI variants implement `TryFrom`. * Added `Segments::prefix_of()`. * `Rocket::mount()` takes a `TryInto>` instead of `&str` for the base mount point. * Extended `errors` example with scoped catchers. * Added scoped sections to catchers guide. Internal changes: * Moved router code to `router/router.rs`. --- core/codegen/tests/route-ranking.rs | 2 +- core/codegen/tests/route-raw.rs | 2 +- core/http/src/uri/origin.rs | 19 +- core/http/src/uri/segments.rs | 10 + core/http/src/uri/uri.rs | 21 + core/lib/src/catcher.rs | 116 ++- core/lib/src/error.rs | 25 +- core/lib/src/rocket.rs | 81 ++- core/lib/src/router/collider.rs | 198 ++++-- core/lib/src/router/mod.rs | 513 +------------- core/lib/src/router/route.rs | 1 - core/lib/src/router/router.rs | 670 ++++++++++++++++++ core/lib/src/server.rs | 6 +- core/lib/tests/catcher-cookies-1213.rs | 2 +- core/lib/tests/panic-handling.rs | 39 +- .../tests/redirect_from_catcher-issue-113.rs | 2 +- examples/content_types/src/main.rs | 2 +- examples/errors/src/main.rs | 28 +- examples/errors/src/tests.rs | 23 +- examples/handlebars_templates/src/main.rs | 2 +- examples/hello_2018/src/main.rs | 2 +- examples/json/src/main.rs | 2 +- examples/manual_routes/src/main.rs | 2 +- examples/tera_templates/src/main.rs | 2 +- site/guide/4-requests.md | 91 ++- 25 files changed, 1166 insertions(+), 695 deletions(-) create mode 100644 core/lib/src/router/router.rs diff --git a/core/codegen/tests/route-ranking.rs b/core/codegen/tests/route-ranking.rs index 20b8ca43..585fb243 100644 --- a/core/codegen/tests/route-ranking.rs +++ b/core/codegen/tests/route-ranking.rs @@ -46,7 +46,7 @@ fn test_rank_collision() { let rocket = rocket::ignite().mount("/", routes![get0, get0b]); let client_result = Client::debug(rocket); match client_result.as_ref().map_err(|e| e.kind()) { - Err(ErrorKind::Collision(..)) => { /* o.k. */ }, + Err(ErrorKind::Collisions(..)) => { /* o.k. */ }, Ok(_) => panic!("client succeeded unexpectedly"), Err(e) => panic!("expected collision, got {}", e) } diff --git a/core/codegen/tests/route-raw.rs b/core/codegen/tests/route-raw.rs index d191973d..043afa3c 100644 --- a/core/codegen/tests/route-raw.rs +++ b/core/codegen/tests/route-raw.rs @@ -23,7 +23,7 @@ fn catch(r#raw: &rocket::Request) -> String { fn test_raw_ident() { let rocket = rocket::ignite() .mount("/", routes![get, swap]) - .register(catchers![catch]); + .register("/", catchers![catch]); let client = Client::debug(rocket).unwrap(); diff --git a/core/http/src/uri/origin.rs b/core/http/src/uri/origin.rs index c39cbd73..4c9273c9 100644 --- a/core/http/src/uri/origin.rs +++ b/core/http/src/uri/origin.rs @@ -1,5 +1,6 @@ -use std::fmt::{self, Display}; use std::borrow::Cow; +use std::convert::TryFrom; +use std::fmt::{self, Display}; use crate::ext::IntoOwned; use crate::parse::{Indexed, Extent, IndexedStr}; @@ -608,6 +609,22 @@ impl<'a> Origin<'a> { } } +impl TryFrom for Origin<'static> { + type Error = Error<'static>; + + fn try_from(value: String) -> Result { + Origin::parse_owned(value) + } +} + +impl<'a> TryFrom<&'a str> for Origin<'a> { + type Error = Error<'a>; + + fn try_from(value: &'a str) -> Result { + Origin::parse(value) + } +} + impl Display for Origin<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.path())?; diff --git a/core/http/src/uri/segments.rs b/core/http/src/uri/segments.rs index 86368262..cab8148b 100644 --- a/core/http/src/uri/segments.rs +++ b/core/http/src/uri/segments.rs @@ -75,6 +75,16 @@ impl<'o> Segments<'o> { .map(|i| i.from_source(Some(self.source.as_str()))) } + /// Returns `true` if `self` is a prefix of `other`. + #[inline] + pub fn prefix_of<'b>(self, other: Segments<'b>) -> bool { + if self.len() > other.len() { + return false; + } + + self.zip(other).all(|(a, b)| a == b) + } + /// Creates a `PathBuf` from `self`. The returned `PathBuf` is /// percent-decoded. If a segment is equal to "..", the previous segment (if /// any) is skipped. diff --git a/core/http/src/uri/uri.rs b/core/http/src/uri/uri.rs index a4bc2b9d..421d4d2b 100644 --- a/core/http/src/uri/uri.rs +++ b/core/http/src/uri/uri.rs @@ -282,6 +282,16 @@ impl Display for Uri<'_> { } } +/// The error type returned when a URI conversion fails. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct TryFromUriError(()); + +impl fmt::Display for TryFromUriError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "invalid conversion from general to specific URI variant".fmt(f) + } +} + macro_rules! impl_uri_from { ($type:ident) => ( impl<'a> From<$type<'a>> for Uri<'a> { @@ -289,6 +299,17 @@ macro_rules! impl_uri_from { Uri::$type(other) } } + + impl<'a> TryFrom> for $type<'a> { + type Error = TryFromUriError; + + fn try_from(uri: Uri<'a>) -> Result { + match uri { + Uri::$type(inner) => Ok(inner), + _ => Err(TryFromUriError(())) + } + } + } ) } diff --git a/core/lib/src/catcher.rs b/core/lib/src/catcher.rs index 070e7545..0d29d619 100644 --- a/core/lib/src/catcher.rs +++ b/core/lib/src/catcher.rs @@ -1,5 +1,5 @@ //! Types and traits for error catchers, error handlers, and their return -//! values. +//! types. use std::fmt; use std::io::Cursor; @@ -7,7 +7,7 @@ use std::io::Cursor; use crate::response::Response; use crate::codegen::StaticCatcherInfo; use crate::request::Request; -use crate::http::ContentType; +use crate::http::{Status, ContentType, uri}; use futures::future::BoxFuture; use yansi::Paint; @@ -19,6 +19,12 @@ pub type Result<'r> = std::result::Result, crate::http::Status>; /// Type alias for the unwieldy [`ErrorHandler::handle()`] return type. pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>; +// A handler to use when one is needed temporarily. Don't use outside of Rocket! +#[cfg(test)] +pub(crate) fn dummy<'r>(_: Status, _: &'r Request<'_>) -> ErrorHandlerFuture<'r> { + Box::pin(async move { Ok(Response::new()) }) +} + /// An error catching route. /// /// # Overview @@ -77,12 +83,12 @@ pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>; /// /// #[catch(default)] /// fn default(status: Status, req: &Request) -> String { -/// format!("{} - {} ({})", status.code, status.reason, req.uri()) +/// format!("{} ({})", status, req.uri()) /// } /// /// #[launch] /// fn rocket() -> rocket::Rocket { -/// rocket::ignite().register(catchers![internal_error, not_found, default]) +/// rocket::ignite().register("/", catchers![internal_error, not_found, default]) /// } /// ``` /// @@ -104,7 +110,10 @@ pub struct Catcher { /// The name of this catcher, if one was given. pub name: Option>, - /// The HTTP status code to match against if this route is not `default`. + /// The mount point. + pub base: uri::Origin<'static>, + + /// The HTTP status to match against if this route is not `default`. pub code: Option, /// The catcher's associated error handler. @@ -112,8 +121,8 @@ pub struct Catcher { } impl Catcher { - /// Creates a catcher for the given status code, or a default catcher if - /// `code` is `None`, using the given error handler. This should only be + /// Creates a catcher for the given `status`, or a default catcher if + /// `status` is `None`, using the given error handler. This should only be /// used when routing manually. /// /// # Examples @@ -121,11 +130,11 @@ impl Catcher { /// ```rust /// use rocket::request::Request; /// use rocket::catcher::{Catcher, ErrorHandlerFuture}; - /// use rocket::response::{Result, Responder, status::Custom}; + /// use rocket::response::Responder; /// use rocket::http::Status; /// /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { - /// let res = Custom(status, format!("404: {}", req.uri())); + /// let res = (status, format!("404: {}", req.uri())); /// Box::pin(async move { res.respond_to(req) }) /// } /// @@ -134,7 +143,7 @@ impl Catcher { /// } /// /// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { - /// let res = Custom(status, format!("{}: {}", status, req.uri())); + /// let res = (status, format!("{}: {}", status, req.uri())); /// Box::pin(async move { res.respond_to(req) }) /// } /// @@ -142,22 +151,82 @@ impl Catcher { /// let internal_server_error_catcher = Catcher::new(500, handle_500); /// let default_error_catcher = Catcher::new(None, handle_default); /// ``` + /// + /// # Panics + /// + /// Panics if `code` is not in the HTTP status code error range `[400, + /// 600)`. #[inline(always)] - pub fn new(code: C, handler: H) -> Catcher - where C: Into>, H: ErrorHandler + pub fn new(code: S, handler: H) -> Catcher + where S: Into>, H: ErrorHandler { - Catcher { name: None, code: code.into(), handler: Box::new(handler) } + let code = code.into(); + if let Some(code) = code { + assert!(code >= 400 && code < 600); + } + + Catcher { + name: None, + base: uri::Origin::new("/", None::<&str>), + handler: Box::new(handler), + code, + } + } + + /// Maps the `base` of this catcher using `mapper`, returning a new + /// `Catcher` with the returned base. + /// + /// `mapper` is called with the current base. The returned `String` is used + /// as the new base if it is a valid URI. If the returned base URI contains + /// a query, it is ignored. Returns an error if the base produced by + /// `mapper` is not a valid origin URI. + /// + /// # Example + /// + /// ```rust + /// use rocket::request::Request; + /// use rocket::catcher::{Catcher, ErrorHandlerFuture}; + /// use rocket::response::Responder; + /// use rocket::http::Status; + /// + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { + /// let res = (status, format!("404: {}", req.uri())); + /// Box::pin(async move { res.respond_to(req) }) + /// } + /// + /// let catcher = Catcher::new(404, handle_404); + /// assert_eq!(catcher.base.path(), "/"); + /// + /// let catcher = catcher.map_base(|_| format!("/bar")).unwrap(); + /// assert_eq!(catcher.base.path(), "/bar"); + /// + /// let catcher = catcher.map_base(|base| format!("/foo{}", base)).unwrap(); + /// assert_eq!(catcher.base.path(), "/foo/bar"); + /// + /// let catcher = catcher.map_base(|base| format!("/foo ? {}", base)); + /// assert!(catcher.is_err()); + /// ``` + pub fn map_base<'a, F>( + mut self, + mapper: F + ) -> std::result::Result> + where F: FnOnce(uri::Origin<'a>) -> String + { + self.base = uri::Origin::parse_owned(mapper(self.base))?.into_normalized(); + self.base.clear_query(); + Ok(self) } } impl Default for Catcher { fn default() -> Self { - fn async_default<'r>(status: Status, request: &'r Request<'_>) -> ErrorHandlerFuture<'r> { - Box::pin(async move { Ok(default(status, request)) }) + fn handler<'r>(s: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> { + Box::pin(async move { Ok(default(s, req)) }) } - let name = Some("".into()); - Catcher { name, code: None, handler: Box::new(async_default) } + let mut catcher = Catcher::new(None, handler); + catcher.name = Some("".into()); + catcher } } @@ -226,9 +295,9 @@ impl Default for Catcher { /// fn rocket() -> rocket::Rocket { /// rocket::ignite() /// // to handle only `404` -/// .register(CustomHandler::catch(Status::NotFound, Kind::Simple)) +/// .register("/", CustomHandler::catch(Status::NotFound, Kind::Simple)) /// // or to register as the default -/// .register(CustomHandler::default(Kind::Simple)) +/// .register("/", CustomHandler::default(Kind::Simple)) /// } /// ``` /// @@ -239,7 +308,7 @@ impl Default for Catcher { /// trait serves no other purpose but to ensure that every `ErrorHandler` /// can be cloned, allowing `Catcher`s to be cloned. /// 2. `CustomHandler`'s methods return `Vec`, allowing for use -/// directly as the parameter to `rocket.register()`. +/// directly as the parameter to `rocket.register("/", )`. /// 3. Unlike static-function-based handlers, this custom handler can make use /// of internal state. #[crate::async_trait] @@ -288,6 +357,10 @@ impl fmt::Display for Catcher { write!(f, "{}{}{} ", Paint::cyan("("), Paint::white(n), Paint::cyan(")"))?; } + if self.base.path() != "/" { + write!(f, "{} ", Paint::green(self.base.path()))?; + } + match self.code { Some(code) => write!(f, "{}", Paint::blue(code)), None => write!(f, "{}", Paint::blue("default")) @@ -298,6 +371,8 @@ impl fmt::Display for Catcher { impl fmt::Debug for Catcher { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Catcher") + .field("name", &self.name) + .field("base", &self.base) .field("code", &self.code) .finish() } @@ -359,7 +434,6 @@ r#"{{ macro_rules! default_catcher_fn { ($($code:expr, $reason:expr, $description:expr),+) => ( use std::borrow::Cow; - use crate::http::Status; pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Response<'r> { let preferred = req.accept().map(|a| a.preferred()); diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index adcac955..09a5be22 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -6,8 +6,6 @@ use std::sync::atomic::{Ordering, AtomicBool}; use yansi::Paint; use figment::Profile; -use crate::router::Route; - /// An error that occurs during launch. /// /// An `Error` is returned by [`launch()`](crate::Rocket::launch()) when @@ -80,7 +78,7 @@ pub enum ErrorKind { /// An I/O error occurred in the runtime. Runtime(Box), /// Route collisions were detected. - Collision(Vec<(Route, Route)>), + Collisions(crate::router::Collisions), /// Launch fairing(s) failed. FailedFairings(Vec), /// The configuration profile is not debug but not secret key is configured. @@ -140,7 +138,7 @@ impl fmt::Display for ErrorKind { match self { ErrorKind::Bind(e) => write!(f, "binding failed: {}", e), ErrorKind::Io(e) => write!(f, "I/O error: {}", e), - ErrorKind::Collision(_) => "route collisions detected".fmt(f), + ErrorKind::Collisions(_) => "collisions detected".fmt(f), ErrorKind::FailedFairings(_) => "a launch fairing failed".fmt(f), ErrorKind::Runtime(e) => write!(f, "runtime error: {}", e), ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), @@ -181,14 +179,21 @@ impl Drop for Error { info_!("{}", e); panic!("aborting due to i/o error"); } - ErrorKind::Collision(ref collisions) => { - error!("Rocket failed to launch due to the following routing collisions:"); - for &(ref a, ref b) in collisions { - info_!("{} {} {}", a, Paint::red("collides with").italic(), b) + ErrorKind::Collisions(ref collisions) => { + fn log_collisions(kind: &str, collisions: &[(T, T)]) { + if collisions.is_empty() { return } + + error!("Rocket failed to launch due to the following {} collisions:", kind); + for &(ref a, ref b) in collisions { + info_!("{} {} {}", a, Paint::red("collides with").italic(), b) + } } - info_!("Note: Collisions can usually be resolved by ranking routes."); - panic!("route collisions detected"); + log_collisions("route", &collisions.routes); + log_collisions("catcher", &collisions.catchers); + + info_!("Note: Route collisions can usually be resolved by ranking routes."); + panic!("routing collisions detected"); } ErrorKind::FailedFairings(ref failures) => { error!("Rocket failed to launch due to failing fairings:"); diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 6c8e0148..e4444acf 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,4 +1,5 @@ -use std::collections::HashMap; +use std::fmt::Display; +use std::convert::TryInto; use yansi::Paint; use state::Container; @@ -13,7 +14,7 @@ use crate::router::{Router, Route}; use crate::fairing::{Fairing, Fairings}; use crate::logger::PaintExt; use crate::shutdown::Shutdown; -use crate::http::uri::Origin; +use crate::http::{uri::Origin, ext::IntoOwned}; use crate::error::{Error, ErrorKind}; /// The main `Rocket` type: used to mount routes and catchers and launch the @@ -24,8 +25,6 @@ pub struct Rocket { pub(crate) figment: Figment, pub(crate) managed_state: Container![Send + Sync], pub(crate) router: Router, - pub(crate) default_catcher: Option, - pub(crate) catchers: HashMap, pub(crate) fairings: Fairings, pub(crate) shutdown_receiver: Option>, pub(crate) shutdown_handle: Shutdown, @@ -95,8 +94,6 @@ impl Rocket { config, figment, managed_state, shutdown_handle: Shutdown(shutdown_sender), router: Router::new(), - default_catcher: None, - catchers: HashMap::new(), fairings: Fairings::new(), shutdown_receiver: Some(shutdown_receiver), } @@ -203,10 +200,15 @@ impl Rocket { /// # .launch().await; /// # }; /// ``` - pub fn mount>>(mut self, base: &str, routes: R) -> Self { - let base_uri = Origin::parse_owned(base.to_string()) + pub fn mount<'a, B, R>(mut self, base: B, routes: R) -> Self + where B: TryInto> + Clone + Display, + B::Error: Display, + R: Into> + { + let base_uri = base.clone().try_into() + .map(|origin| origin.into_owned()) .unwrap_or_else(|e| { - error!("Invalid mount point URI: {}.", Paint::white(base)); + error!("Invalid route base: {}.", Paint::white(&base)); panic!("Error: {}", e); }); @@ -215,11 +217,11 @@ impl Rocket { panic!("Invalid mount point."); } - info!("{}{} {}{}", + info!("{}{} {} {}", Paint::emoji("🛰 "), Paint::magenta("Mounting"), Paint::blue(&base_uri), - Paint::magenta(":")); + Paint::magenta("routes:")); for route in routes.into() { let mounted_route = route.clone() @@ -231,13 +233,18 @@ impl Rocket { }); info_!("{}", mounted_route); - self.router.add(mounted_route); + self.router.add_route(mounted_route); } self } - /// Registers all of the catchers in the supplied vector. + /// Registers all of the catchers in the supplied vector, scoped to `base`. + /// + /// # Panics + /// + /// Panics if `base` is not a valid static path: a valid origin URI without + /// dynamic parameters. /// /// # Examples /// @@ -257,23 +264,31 @@ impl Rocket { /// /// #[launch] /// fn rocket() -> rocket::Rocket { - /// rocket::ignite().register(catchers![internal_error, not_found]) + /// rocket::ignite().register("/", catchers![internal_error, not_found]) /// } /// ``` - pub fn register(mut self, catchers: Vec) -> Self { - info!("{}{}", Paint::emoji("👾 "), Paint::magenta("Catchers:")); + pub fn register<'a, B, C>(mut self, base: B, catchers: C) -> Self + where B: TryInto> + Clone + Display, + B::Error: Display, + C: Into> + { + info!("{}{} {} {}", + Paint::emoji("👾 "), + Paint::magenta("Registering"), + Paint::blue(&base), + Paint::magenta("catchers:")); - for catcher in catchers { - info_!("{}", catcher); + for catcher in catchers.into() { + let mounted_catcher = catcher.clone() + .map_base(|old| format!("{}{}", base, old)) + .unwrap_or_else(|e| { + error_!("Catcher `{}` has a malformed URI.", catcher); + error_!("{}", e); + panic!("Invalid catcher URI."); + }); - let existing = match catcher.code { - Some(code) => self.catchers.insert(code, catcher), - None => self.default_catcher.replace(catcher) - }; - - if let Some(existing) = existing { - warn_!("Replacing existing '{}' catcher.", existing); - } + info_!("{}", mounted_catcher); + self.router.add_catcher(mounted_catcher); } self @@ -444,7 +459,7 @@ impl Rocket { /// } /// ``` #[inline(always)] - pub fn routes(&self) -> impl Iterator + '_ { + pub fn routes(&self) -> impl Iterator { self.router.routes() } @@ -464,7 +479,7 @@ impl Rocket { /// /// fn main() { /// let mut rocket = rocket::ignite() - /// .register(catchers![not_found, just_500, some_default]); + /// .register("/", catchers![not_found, just_500, some_default]); /// /// let mut codes: Vec<_> = rocket.catchers().map(|c| c.code).collect(); /// codes.sort(); @@ -473,8 +488,8 @@ impl Rocket { /// } /// ``` #[inline(always)] - pub fn catchers(&self) -> impl Iterator + '_ { - self.catchers.values().chain(self.default_catcher.as_ref()) + pub fn catchers(&self) -> impl Iterator { + self.router.catchers() } /// Returns `Some` of the managed state value for the type `T` if it is @@ -525,10 +540,8 @@ impl Rocket { /// * there were no fairing failures /// * a secret key, if needed, is securely configured pub(crate) async fn prelaunch_check(&mut self) -> Result<(), Error> { - let collisions: Vec<_> = self.router.collisions().collect(); - if !collisions.is_empty() { - let owned = collisions.into_iter().map(|(a, b)| (a.clone(), b.clone())); - return Err(Error::new(ErrorKind::Collision(owned.collect()))); + if let Err(collisions) = self.router.finalize() { + return Err(Error::new(ErrorKind::Collisions(collisions))); } if let Some(failures) = self.fairings.failures() { diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index 6cb9d2df..05f80a21 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -1,46 +1,23 @@ use super::{Route, uri::Color}; +use crate::catcher::Catcher; -use crate::http::MediaType; +use crate::http::{MediaType, Status}; use crate::request::Request; -impl Route { - /// Determines if two routes can match against some request. That is, if two - /// routes `collide`, there exists a request that can match against both - /// routes. - /// - /// This implementation is used at initialization to check if two user - /// routes collide before launching. Format collisions works like this: - /// - /// * If route specifies a format, it only gets requests for that format. - /// * If route doesn't specify a format, it gets requests for any format. - /// - /// Because query parsing is lenient, and dynamic query parameters can be - /// missing, queries do not impact whether two routes collide. - pub(crate) fn collides_with(&self, other: &Route) -> bool { - self.method == other.method - && self.rank == other.rank - && paths_collide(self, other) - && formats_collide(self, other) - } +pub trait Collide { + fn collides_with(&self, other: &T) -> bool; +} - /// Determines if this route matches against the given request. - /// - /// This means that: - /// - /// * The route's method matches that of the incoming request. - /// * The route's format (if any) matches that of the incoming request. - /// - If route specifies format, it only gets requests for that format. - /// - If route doesn't specify format, it gets requests for any format. - /// * All static components in the route's path match the corresponding - /// components in the same position in the incoming request. - /// * All static components in the route's query string are also in the - /// request query string, though in any position. If there is no query - /// in the route, requests with/without queries match. - pub(crate) fn matches(&self, req: &Request<'_>) -> bool { - self.method == req.method() - && paths_match(self, req) - && queries_match(self, req) - && formats_match(self, req) +impl<'a, 'b, T: Collide> Collide<&T> for &T { + fn collides_with(&self, other: &&T) -> bool { + T::collides_with(*self, *other) + } +} + +impl Collide for MediaType { + fn collides_with(&self, other: &Self) -> bool { + let collide = |a, b| a == "*" || b == "*" || a == b; + collide(self.top(), other.top()) && collide(self.sub(), other.sub()) } } @@ -66,6 +43,68 @@ fn paths_collide(route: &Route, other: &Route) -> bool { || a_segments.len() == b_segments.len() } +fn formats_collide(route: &Route, other: &Route) -> bool { + // When matching against the `Accept` header, the client can always + // provide a media type that will cause a collision through + // non-specificity. + if !route.method.supports_payload() { + return true; + } + + // When matching against the `Content-Type` header, we'll only + // consider requests as having a `Content-Type` if they're fully + // specified. If a route doesn't have a `format`, it accepts all + // `Content-Type`s. If a request doesn't have a format, it only + // matches routes without a format. + match (route.format.as_ref(), other.format.as_ref()) { + (Some(a), Some(b)) => a.collides_with(b), + _ => true + } +} + +impl Collide for Route { + /// Determines if two routes can match against some request. That is, if two + /// routes `collide`, there exists a request that can match against both + /// routes. + /// + /// This implementation is used at initialization to check if two user + /// routes collide before launching. Format collisions works like this: + /// + /// * If route specifies a format, it only gets requests for that format. + /// * If route doesn't specify a format, it gets requests for any format. + /// + /// Because query parsing is lenient, and dynamic query parameters can be + /// missing, queries do not impact whether two routes collide. + fn collides_with(&self, other: &Route) -> bool { + self.method == other.method + && self.rank == other.rank + && paths_collide(self, other) + && formats_collide(self, other) + } +} + +impl Route { + /// Determines if this route matches against the given request. + /// + /// This means that: + /// + /// * The route's method matches that of the incoming request. + /// * The route's format (if any) matches that of the incoming request. + /// - If route specifies format, it only gets requests for that format. + /// - If route doesn't specify format, it gets requests for any format. + /// * All static components in the route's path match the corresponding + /// components in the same position in the incoming request. + /// * All static components in the route's query string are also in the + /// request query string, though in any position. If there is no query + /// in the route, requests with/without queries match. + pub(crate) fn matches(&self, req: &Request<'_>) -> bool { + self.method == req.method() + && paths_match(self, req) + && queries_match(self, req) + && formats_match(self, req) + } +} + fn paths_match(route: &Route, req: &Request<'_>) -> bool { let route_segments = &route.uri.metadata.path_segs; let req_segments = req.uri().path_segments(); @@ -90,7 +129,7 @@ fn paths_match(route: &Route, req: &Request<'_>) -> bool { return true; } - if !route_seg.dynamic && route_seg.value != req_seg { + if !(route_seg.dynamic || route_seg.value == req_seg) { return false; } } @@ -116,33 +155,16 @@ fn queries_match(route: &Route, req: &Request<'_>) -> bool { true } -fn formats_collide(route: &Route, other: &Route) -> bool { - // When matching against the `Accept` header, the client can always provide - // a media type that will cause a collision through non-specificity. - if !route.method.supports_payload() { - return true; - } - - // When matching against the `Content-Type` header, we'll only consider - // requests as having a `Content-Type` if they're fully specified. If a - // route doesn't have a `format`, it accepts all `Content-Type`s. If a - // request doesn't have a format, it only matches routes without a format. - match (route.format.as_ref(), other.format.as_ref()) { - (Some(a), Some(b)) => media_types_collide(a, b), - _ => true - } -} - fn formats_match(route: &Route, request: &Request<'_>) -> bool { if !route.method.supports_payload() { route.format.as_ref() .and_then(|a| request.format().map(|b| (a, b))) - .map(|(a, b)| media_types_collide(a, b)) + .map(|(a, b)| a.collides_with(b)) .unwrap_or(true) } else { match route.format.as_ref() { Some(a) => match request.format() { - Some(b) if b.specificity() == 2 => media_types_collide(a, b), + Some(b) if b.specificity() == 2 => a.collides_with(b), _ => false } None => true @@ -150,9 +172,30 @@ fn formats_match(route: &Route, request: &Request<'_>) -> bool { } } -fn media_types_collide(first: &MediaType, other: &MediaType) -> bool { - let collide = |a, b| a == "*" || b == "*" || a == b; - collide(first.top(), other.top()) && collide(first.sub(), other.sub()) + +impl Collide for Catcher { + /// Determines if two catchers are in conflict: there exists a request for + /// which there exist no rule to determine _which_ of the two catchers to + /// use. This means that the catchers: + /// + /// * Have the same base. + /// * Have the same status code or are both defaults. + fn collides_with(&self, other: &Self) -> bool { + self.code == other.code + && self.base.path_segments().eq(other.base.path_segments()) + } +} + +impl Catcher { + /// Determines if this catcher is responsible for handling the error with + /// `status` that occurred during request `req`. A catcher matches if: + /// + /// * It is a default catcher _or_ has a code of `status`. + /// * Its base is a prefix of the normalized/decoded `req.path()`. + pub(crate) fn matches(&self, status: Status, req: &Request<'_>) -> bool { + self.code.map_or(true, |code| code == status.code) + && self.base.path_segments().prefix_of(req.uri().path_segments()) + } } #[cfg(test)] @@ -335,7 +378,7 @@ mod tests { fn mt_mt_collide(mt1: &str, mt2: &str) -> bool { let mt_a = MediaType::from_str(mt1).expect(mt1); let mt_b = MediaType::from_str(mt2).expect(mt2); - media_types_collide(&mt_a, &mt_b) + mt_a.collides_with(&mt_b) } #[test] @@ -525,4 +568,35 @@ mod tests { assert!(!req_route_path_match("/a/b", "/a/b?foo&")); assert!(!req_route_path_match("/a/b", "/a/b?&b&")); } + + + fn catchers_collide(a: A, ap: &str, b: B, bp: &str) -> bool + where A: Into>, B: Into> + { + let a = Catcher::new(a, crate::catcher::dummy).map_base(|_| ap.into()).unwrap(); + let b = Catcher::new(b, crate::catcher::dummy).map_base(|_| bp.into()).unwrap(); + a.collides_with(&b) + } + + #[test] + fn catcher_collisions() { + for path in &["/a", "/foo", "/a/b/c", "/a/b/c/d/e"] { + assert!(catchers_collide(404, path, 404, path)); + assert!(catchers_collide(500, path, 500, path)); + assert!(catchers_collide(None, path, None, path)); + } + } + + #[test] + fn catcher_non_collisions() { + assert!(!catchers_collide(404, "/foo", 405, "/foo")); + assert!(!catchers_collide(404, "/", None, "/foo")); + assert!(!catchers_collide(404, "/", None, "/")); + assert!(!catchers_collide(404, "/a/b", None, "/a/b")); + assert!(!catchers_collide(404, "/a/b", 404, "/a/b/c")); + + assert!(!catchers_collide(None, "/a/b", None, "/a/b/c")); + assert!(!catchers_collide(None, "/b", None, "/a/b/c")); + assert!(!catchers_collide(None, "/", None, "/a/b/c")); + } } diff --git a/core/lib/src/router/mod.rs b/core/lib/src/router/mod.rs index faa0945e..77b672b3 100644 --- a/core/lib/src/router/mod.rs +++ b/core/lib/src/router/mod.rs @@ -1,514 +1,13 @@ //! Routing types: [`Route`] and [`RouteUri`]. -mod collider; mod route; mod segment; mod uri; +mod router; +mod collider; -use std::collections::HashMap; +pub(crate) use router::*; -use crate::request::Request; -use crate::http::Method; - -pub use self::route::Route; -pub use self::uri::RouteUri; - -// type Selector = (Method, usize); -type Selector = Method; - -#[derive(Debug, Default)] -pub(crate) struct Router { - routes: HashMap>, -} - -impl Router { - pub fn new() -> Router { - Router { routes: HashMap::new() } - } - - pub fn add(&mut self, route: Route) { - let selector = route.method; - let entries = self.routes.entry(selector).or_insert_with(|| vec![]); - let i = entries.binary_search_by_key(&route.rank, |r| r.rank) - .unwrap_or_else(|i| i); - - entries.insert(i, route); - } - - pub fn route<'r, 'a: 'r>(&'a self, req: &'r Request<'r>) -> impl Iterator + 'r { - // Note that routes are presorted by rank on each `add`. - self.routes.get(&req.method()) - .into_iter() - .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) - } - - pub(crate) fn collisions(&self) -> impl Iterator { - let all_routes = self.routes.values().flat_map(|v| v.iter()); - all_routes.clone().enumerate() - .flat_map(move |(i, a)| { - all_routes.clone() - .skip(i + 1) - .filter(move |b| b.collides_with(a)) - .map(move |b| (a, b)) - }) - } - - #[inline] - pub fn routes(&self) -> impl Iterator { - self.routes.values().flat_map(|v| v.iter()) - } - - // This is slow. Don't expose this publicly; only for tests. - #[cfg(test)] - fn has_collisions(&self) -> bool { - self.collisions().next().is_some() - } -} - -#[cfg(test)] -mod test { - use super::{Router, Route}; - - use crate::rocket::Rocket; - use crate::config::Config; - use crate::http::{Method, Method::*}; - use crate::http::uri::Origin; - use crate::request::Request; - use crate::handler::dummy; - - fn router_with_routes(routes: &[&'static str]) -> Router { - let mut router = Router::new(); - for route in routes { - let route = dbg!(Route::new(Get, route, dummy)); - router.add(route); - } - - router - } - - fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { - let mut router = Router::new(); - for &(rank, route) in routes { - let route = Route::ranked(rank, Get, route, dummy); - router.add(route); - } - - router - } - - fn router_with_rankless_routes(routes: &[&'static str]) -> Router { - let mut router = Router::new(); - for route in routes { - let route = Route::ranked(0, Get, route, dummy); - router.add(route); - } - - router - } - - fn rankless_route_collisions(routes: &[&'static str]) -> bool { - let router = router_with_rankless_routes(routes); - router.has_collisions() - } - - fn default_rank_route_collisions(routes: &[&'static str]) -> bool { - let router = router_with_routes(routes); - router.has_collisions() - } - - #[test] - fn test_rankless_collisions() { - assert!(rankless_route_collisions(&["/hello", "/hello"])); - assert!(rankless_route_collisions(&["/", "/hello"])); - assert!(rankless_route_collisions(&["/", "/"])); - assert!(rankless_route_collisions(&["/hello/bob", "/hello/"])); - assert!(rankless_route_collisions(&["/a/b//d", "///c/d"])); - - assert!(rankless_route_collisions(&["/a/b", "/"])); - assert!(rankless_route_collisions(&["/a/b/c", "/a/"])); - assert!(rankless_route_collisions(&["//b", "/a/"])); - assert!(rankless_route_collisions(&["/a/", "/a/"])); - assert!(rankless_route_collisions(&["/a/b/", "/a/"])); - assert!(rankless_route_collisions(&["/", "/a/"])); - assert!(rankless_route_collisions(&["/a/", "/a/"])); - assert!(rankless_route_collisions(&["/a/b/", "/a/"])); - assert!(rankless_route_collisions(&["/a/b/c/d", "/a/"])); - assert!(rankless_route_collisions(&["/", "/"])); - assert!(rankless_route_collisions(&["/a/<_>", "/a/"])); - assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"])); - assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"])); - assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"])); - assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"])); - assert!(rankless_route_collisions(&["/a/d/", "/a/d"])); - assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"])); - assert!(rankless_route_collisions(&["/a/<_..>", "/a"])); - assert!(rankless_route_collisions(&["/", "/a/"])); - - assert!(rankless_route_collisions(&["/<_>", "/<_>"])); - assert!(rankless_route_collisions(&["/a/<_>", "/a/b"])); - assert!(rankless_route_collisions(&["/a/<_>", "/a/"])); - assert!(rankless_route_collisions(&["/<_..>", "/a/b"])); - assert!(rankless_route_collisions(&["/<_..>", "/<_>"])); - assert!(rankless_route_collisions(&["/<_>/b", "/a/b"])); - assert!(rankless_route_collisions(&["/", "/"])); - } - - #[test] - fn test_collisions_normalize() { - assert!(rankless_route_collisions(&["/hello/", "/hello"])); - assert!(rankless_route_collisions(&["//hello/", "/hello"])); - assert!(rankless_route_collisions(&["//hello/", "/hello//"])); - assert!(rankless_route_collisions(&["/", "/hello//"])); - assert!(rankless_route_collisions(&["/", "/hello///"])); - assert!(rankless_route_collisions(&["/hello///bob", "/hello/"])); - assert!(rankless_route_collisions(&["///", "/a//"])); - assert!(rankless_route_collisions(&["/a///", "/a/"])); - assert!(rankless_route_collisions(&["/a///", "/a/b//c//d/"])); - assert!(rankless_route_collisions(&["/a//", "/a/bd/e/"])); - assert!(rankless_route_collisions(&["//", "/a/bd/e/"])); - assert!(rankless_route_collisions(&["//", "/"])); - assert!(rankless_route_collisions(&["/a///", "/a/b//c//d/e/"])); - assert!(rankless_route_collisions(&["/a////", "/a/b//c//d/e/"])); - assert!(rankless_route_collisions(&["///<_>", "/<_>"])); - assert!(rankless_route_collisions(&["/a/<_>", "///a//b"])); - assert!(rankless_route_collisions(&["//a///<_>", "/a//"])); - assert!(rankless_route_collisions(&["//<_..>", "/a/b"])); - assert!(rankless_route_collisions(&["//<_..>", "/<_>"])); - assert!(rankless_route_collisions(&["////", "/a/"])); - assert!(rankless_route_collisions(&["////", "/a/"])); - assert!(rankless_route_collisions(&["/", "/hello"])); - } - - #[test] - fn test_collisions_query() { - // Query shouldn't affect things when rankless. - assert!(rankless_route_collisions(&["/hello?", "/hello"])); - assert!(rankless_route_collisions(&["/?foo=bar", "/hello?foo=bar&cat=fat"])); - assert!(rankless_route_collisions(&["/?foo=bar", "/hello?foo=bar&cat=fat"])); - assert!(rankless_route_collisions(&["/", "/?"])); - assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/?d=e"])); - assert!(rankless_route_collisions(&["/?a=b", "/foo?d=e"])); - assert!(rankless_route_collisions(&["/?a=b&", "/?d=e&"])); - assert!(rankless_route_collisions(&["/?a=b&", "/?d=e"])); - } - - #[test] - fn test_no_collisions() { - assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"])); - assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c//e"])); - assert!(!rankless_route_collisions(&["/a/d/", "/a/b/c"])); - assert!(!rankless_route_collisions(&["/<_>", "/"])); - assert!(!rankless_route_collisions(&["/a/<_>", "/a"])); - assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"])); - } - - #[test] - fn test_no_collision_when_ranked() { - assert!(!default_rank_route_collisions(&["/", "/hello"])); - assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/"])); - assert!(!default_rank_route_collisions(&["/a/b/c/d", "///c/d"])); - assert!(!default_rank_route_collisions(&["/hi", "/"])); - assert!(!default_rank_route_collisions(&["/a", "/a/"])); - assert!(!default_rank_route_collisions(&["/", "/"])); - assert!(!default_rank_route_collisions(&["/a/b", "/a/b/"])); - assert!(!default_rank_route_collisions(&["/<_>", "/static"])); - assert!(!default_rank_route_collisions(&["/<_..>", "/static"])); - assert!(!default_rank_route_collisions(&["/", "/"])); - assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"])); - assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"])); - - assert!(!default_rank_route_collisions(&["//", "/hello/"])); - assert!(!default_rank_route_collisions(&["//", "/hello/"])); - assert!(!default_rank_route_collisions(&["/", "/hello/"])); - assert!(!default_rank_route_collisions(&["/", "/hello"])); - assert!(!default_rank_route_collisions(&["/", "/a/"])); - assert!(!default_rank_route_collisions(&["/a//c", "//"])); - } - - #[test] - fn test_collision_when_ranked() { - assert!(default_rank_route_collisions(&["/a//", "/a/"])); - assert!(default_rank_route_collisions(&["//b", "/a/"])); - } - - #[test] - fn test_collision_when_ranked_query() { - assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"])); - assert!(default_rank_route_collisions(&["/a?a=b&", "/a?&c=d"])); - assert!(default_rank_route_collisions(&["/a?a=b&", "/a?&c=d"])); - } - - #[test] - fn test_no_collision_when_ranked_query() { - assert!(!default_rank_route_collisions(&["/", "/?"])); - assert!(!default_rank_route_collisions(&["/hi", "/hi?"])); - assert!(!default_rank_route_collisions(&["/hi", "/hi?c"])); - assert!(!default_rank_route_collisions(&["/hi?", "/hi?c"])); - assert!(!default_rank_route_collisions(&["/?a=b", "/?c=d&"])); - } - - fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { - let rocket = Rocket::custom(Config::default()); - let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - let route = router.route(&request).next(); - route - } - - fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { - let rocket = Rocket::custom(Config::default()); - let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - router.route(&request).collect() - } - - #[test] - fn test_ok_routing() { - let router = router_with_routes(&["/hello"]); - assert!(route(&router, Get, "/hello").is_some()); - - let router = router_with_routes(&["/"]); - assert!(route(&router, Get, "/hello").is_some()); - assert!(route(&router, Get, "/hi").is_some()); - assert!(route(&router, Get, "/bobbbbbbbbbby").is_some()); - assert!(route(&router, Get, "/dsfhjasdf").is_some()); - - let router = router_with_routes(&["//"]); - assert!(route(&router, Get, "/hello/hi").is_some()); - assert!(route(&router, Get, "/a/b/").is_some()); - assert!(route(&router, Get, "/i/a").is_some()); - assert!(route(&router, Get, "/jdlk/asdij").is_some()); - - let mut router = Router::new(); - router.add(Route::new(Put, "/hello", dummy)); - router.add(Route::new(Post, "/hello", dummy)); - router.add(Route::new(Delete, "/hello", dummy)); - assert!(route(&router, Put, "/hello").is_some()); - assert!(route(&router, Post, "/hello").is_some()); - assert!(route(&router, Delete, "/hello").is_some()); - - let router = router_with_routes(&["/"]); - assert!(route(&router, Get, "/").is_some()); - assert!(route(&router, Get, "//").is_some()); - assert!(route(&router, Get, "/hi").is_some()); - assert!(route(&router, Get, "/hello/hi").is_some()); - assert!(route(&router, Get, "/a/b/").is_some()); - assert!(route(&router, Get, "/i/a").is_some()); - assert!(route(&router, Get, "/a/b/c/d/e/f").is_some()); - - let router = router_with_routes(&["/foo/"]); - assert!(route(&router, Get, "/foo").is_some()); - assert!(route(&router, Get, "/foo/").is_some()); - assert!(route(&router, Get, "/foo///bar").is_some()); - } - - #[test] - fn test_err_routing() { - let router = router_with_routes(&["/hello"]); - assert!(route(&router, Put, "/hello").is_none()); - assert!(route(&router, Post, "/hello").is_none()); - assert!(route(&router, Options, "/hello").is_none()); - assert!(route(&router, Get, "/hell").is_none()); - assert!(route(&router, Get, "/hi").is_none()); - assert!(route(&router, Get, "/hello/there").is_none()); - assert!(route(&router, Get, "/hello/i").is_none()); - assert!(route(&router, Get, "/hillo").is_none()); - - let router = router_with_routes(&["/"]); - assert!(route(&router, Put, "/hello").is_none()); - assert!(route(&router, Post, "/hello").is_none()); - assert!(route(&router, Options, "/hello").is_none()); - assert!(route(&router, Get, "/hello/there").is_none()); - assert!(route(&router, Get, "/hello/i").is_none()); - - let router = router_with_routes(&["//"]); - assert!(route(&router, Get, "/a/b/c").is_none()); - assert!(route(&router, Get, "/a").is_none()); - assert!(route(&router, Get, "/a/").is_none()); - assert!(route(&router, Get, "/a/b/c/d").is_none()); - assert!(route(&router, Put, "/hello/hi").is_none()); - assert!(route(&router, Put, "/a/b").is_none()); - assert!(route(&router, Put, "/a/b").is_none()); - - let router = router_with_routes(&["/prefix/"]); - assert!(route(&router, Get, "/").is_none()); - assert!(route(&router, Get, "/prefi/").is_none()); - } - - macro_rules! assert_ranked_match { - ($routes:expr, $to:expr => $want:expr) => ({ - let router = router_with_routes($routes); - assert!(!router.has_collisions()); - let route_path = route(&router, Get, $to).unwrap().uri.to_string(); - assert_eq!(route_path, $want.to_string()); - }) - } - - #[test] - fn test_default_ranking() { - assert_ranked_match!(&["/hello", "/"], "/hello" => "/hello"); - assert_ranked_match!(&["/", "/hello"], "/hello" => "/hello"); - assert_ranked_match!(&["/", "/hi", "/hi/"], "/hi" => "/hi"); - assert_ranked_match!(&["//b", "/hi/c"], "/hi/c" => "/hi/c"); - assert_ranked_match!(&["//", "/hi/a"], "/hi/c" => "//"); - assert_ranked_match!(&["/hi/a", "/hi/"], "/hi/c" => "/hi/"); - assert_ranked_match!(&["/a", "/a?"], "/a?b=c" => "/a?"); - assert_ranked_match!(&["/a", "/a?"], "/a" => "/a?"); - assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/a" => "/a?"); - assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/b" => "/?"); - assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/b?v=1" => "/?"); - assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/a?b=c" => "/a?"); - assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b"); - assert_ranked_match!(&["/", "/a?b"], "/a?b" => "/a?b"); - assert_ranked_match!(&["/a", "/?b"], "/a?b" => "/a"); - assert_ranked_match!(&["/a?&b", "/a?"], "/a" => "/a?"); - assert_ranked_match!(&["/a?&b", "/a?"], "/a?b" => "/a?&b"); - assert_ranked_match!(&["/a?&b", "/a?"], "/a?c" => "/a?"); - assert_ranked_match!(&["/", "/"], "/" => "/"); - assert_ranked_match!(&["/", "/"], "/hi" => "/"); - assert_ranked_match!(&["/hi", "/"], "/hi" => "/hi"); - } - - fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool { - let router = router_with_ranked_routes(routes); - router.has_collisions() - } - - #[test] - fn test_no_manual_ranked_collisions() { - assert!(!ranked_collisions(&[(1, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(0, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(5, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(1, "/a/"), (1, "/b/")])); - assert!(!ranked_collisions(&[(1, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(0, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(5, "/a/"), (2, "/a/")])); - assert!(!ranked_collisions(&[(1, "/"), (2, "/")])); - } - - #[test] - fn test_ranked_collisions() { - assert!(ranked_collisions(&[(2, "/a/"), (2, "/a/")])); - assert!(ranked_collisions(&[(2, "/a/c/"), (2, "/a/")])); - assert!(ranked_collisions(&[(2, "/"), (2, "/a/")])); - } - - macro_rules! assert_ranked_routing { - (to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({ - let router = router_with_ranked_routes(&$routes); - let routed_to = matches(&router, Get, $to); - let expected = &[$($want),+]; - assert!(routed_to.len() == expected.len()); - for (got, expected) in routed_to.iter().zip(expected.iter()) { - assert_eq!(got.rank, expected.0); - assert_eq!(got.uri.to_string(), expected.1.to_string()); - } - }) - } - - #[test] - fn test_ranked_routing() { - assert_ranked_routing!( - to: "/a/b", - with: [(1, "/a/"), (2, "/a/")], - expect: (1, "/a/"), (2, "/a/") - ); - - assert_ranked_routing!( - to: "/b/b", - with: [(1, "/a/"), (2, "/b/"), (3, "/b/b")], - expect: (2, "/b/"), (3, "/b/b") - ); - - assert_ranked_routing!( - to: "/b/b", - with: [(2, "/b/"), (1, "/a/"), (3, "/b/b")], - expect: (2, "/b/"), (3, "/b/b") - ); - - assert_ranked_routing!( - to: "/b/b", - with: [(3, "/b/b"), (2, "/b/"), (1, "/a/")], - expect: (2, "/b/"), (3, "/b/b") - ); - - assert_ranked_routing!( - to: "/b/b", - with: [(1, "/a/"), (2, "/b/"), (0, "/b/b")], - expect: (0, "/b/b"), (2, "/b/") - ); - - assert_ranked_routing!( - to: "/profile/sergio/edit", - with: [(1, "///edit"), (2, "/profile/"), (0, "///")], - expect: (0, "///"), (1, "///edit") - ); - - assert_ranked_routing!( - to: "/profile/sergio/edit", - with: [(0, "///edit"), (2, "/profile/"), (5, "///")], - expect: (0, "///edit"), (5, "///") - ); - - assert_ranked_routing!( - to: "/a/b", - with: [(0, "/a/b"), (1, "/a/")], - expect: (0, "/a/b"), (1, "/a/") - ); - - assert_ranked_routing!( - to: "/a/b/c/d/e/f", - with: [(1, "/a/"), (2, "/a/b/")], - expect: (1, "/a/"), (2, "/a/b/") - ); - - assert_ranked_routing!( - to: "/hi", - with: [(1, "/hi/"), (0, "/hi/")], - expect: (1, "/hi/") - ); - } - - macro_rules! assert_default_ranked_routing { - (to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({ - let router = router_with_routes(&$routes); - let routed_to = matches(&router, Get, $to); - let expected = &[$($want),+]; - assert!(routed_to.len() == expected.len()); - for (got, expected) in routed_to.iter().zip(expected.iter()) { - assert_eq!(got.uri.to_string(), expected.to_string()); - } - }) - } - - #[test] - fn test_default_ranked_routing() { - assert_default_ranked_routing!( - to: "/a/b?v=1", - with: ["/a/", "/a/b"], - expect: "/a/b", "/a/" - ); - - assert_default_ranked_routing!( - to: "/a/b?v=1", - with: ["/a/", "/a/b", "/a/b?"], - expect: "/a/b?", "/a/b", "/a/" - ); - - assert_default_ranked_routing!( - to: "/a/b?v=1", - with: ["/a/", "/a/b", "/a/b?", "/a/?"], - expect: "/a/b?", "/a/b", "/a/?", "/a/" - ); - - assert_default_ranked_routing!( - to: "/a/b", - with: ["/a/", "/a/b", "/a/b?", "/a/?"], - expect: "/a/b?", "/a/b", "/a/?", "/a/" - ); - - assert_default_ranked_routing!( - to: "/a/b?c", - with: ["/a/b", "/a/b?", "/a/b?c", "/a/?c", "/a/?", "//"], - expect: "/a/b?c", "/a/b?", "/a/b", "/a/?c", "/a/?", "//" - ); - } -} +pub use route::Route; +pub use collider::Collide; +pub use uri::RouteUri; diff --git a/core/lib/src/router/route.rs b/core/lib/src/router/route.rs index 67f96768..4800d3c9 100644 --- a/core/lib/src/router/route.rs +++ b/core/lib/src/router/route.rs @@ -162,7 +162,6 @@ impl Route { } } - /// Maps the `base` of this route using `mapper`, returning a new `Route` /// with the returned base. /// diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs new file mode 100644 index 00000000..01c3a3e7 --- /dev/null +++ b/core/lib/src/router/router.rs @@ -0,0 +1,670 @@ +use std::collections::HashMap; + +use crate::request::Request; +use crate::http::{Method, Status}; + +pub use crate::router::{Route, RouteUri}; +pub use crate::router::collider::Collide; +pub use crate::catcher::Catcher; + +#[derive(Debug, Default)] +pub(crate) struct Router { + routes: HashMap>, + catchers: HashMap, Vec>, +} + +#[derive(Debug)] +pub struct Collisions { + pub routes: Vec<(Route, Route)>, + pub catchers: Vec<(Catcher, Catcher)>, +} + +impl Router { + pub fn new() -> Self { + Self::default() + } + + pub fn add_route(&mut self, route: Route) { + let routes = self.routes.entry(route.method).or_default(); + routes.push(route); + routes.sort_by_key(|r| r.rank); + } + + pub fn add_catcher(&mut self, catcher: Catcher) { + let catchers = self.catchers.entry(catcher.code).or_default(); + catchers.push(catcher); + catchers.sort_by(|a, b| b.base.path_segments().len().cmp(&a.base.path_segments().len())) + } + + #[inline] + pub fn routes(&self) -> impl Iterator + Clone { + self.routes.values().flat_map(|v| v.iter()) + } + + #[inline] + pub fn catchers(&self) -> impl Iterator + Clone { + self.catchers.values().flat_map(|v| v.iter()) + } + + pub fn route<'r, 'a: 'r>( + &'a self, + req: &'r Request<'r> + ) -> impl Iterator + 'r { + // Note that routes are presorted by ascending rank on each `add`. + self.routes.get(&req.method()) + .into_iter() + .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) + } + + // For many catchers, using aho-corasick or similar should be much faster. + pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { + // Note that catchers are presorted by descending base length. + let explicit = self.catchers.get(&Some(status.code)) + .and_then(|c| c.iter().find(|c| c.matches(status, req))); + + let default = self.catchers.get(&None) + .and_then(|c| c.iter().find(|c| c.matches(status, req))); + + match (explicit, default) { + (None, None) => None, + (None, c@Some(_)) | (c@Some(_), None) => c, + (Some(a), Some(b)) => { + if b.base.path_segments().len() > a.base.path_segments().len() { + Some(b) + } else { + Some(a) + } + } + } + } + + fn collisions<'a, I: 'a, T: 'a>(&self, items: I) -> impl Iterator + 'a + where I: Iterator + Clone, T: Collide + Clone, + { + items.clone().enumerate() + .flat_map(move |(i, a)| { + items.clone() + .skip(i + 1) + .filter(move |b| a.collides_with(b)) + .map(move |b| (a.clone(), b.clone())) + }) + } + + pub fn finalize(&self) -> Result<(), Collisions> { + let routes: Vec<_> = self.collisions(self.routes()).collect(); + let catchers: Vec<_> = self.collisions(self.catchers()).collect(); + + if !routes.is_empty() || !catchers.is_empty() { + return Err(Collisions { routes, catchers }) + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::rocket::Rocket; + use crate::config::Config; + use crate::http::{Method, Method::*}; + use crate::http::uri::Origin; + use crate::request::Request; + use crate::handler::dummy; + + impl Router { + fn has_collisions(&self) -> bool { + self.finalize().is_err() + } + } + + fn router_with_routes(routes: &[&'static str]) -> Router { + let mut router = Router::new(); + for route in routes { + let route = Route::new(Get, route, dummy); + router.add_route(route); + } + + router + } + + fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { + let mut router = Router::new(); + for &(rank, route) in routes { + let route = Route::ranked(rank, Get, route, dummy); + router.add_route(route); + } + + router + } + + fn router_with_rankless_routes(routes: &[&'static str]) -> Router { + let mut router = Router::new(); + for route in routes { + let route = Route::ranked(0, Get, route, dummy); + router.add_route(route); + } + + router + } + + fn rankless_route_collisions(routes: &[&'static str]) -> bool { + let router = router_with_rankless_routes(routes); + router.has_collisions() + } + + fn default_rank_route_collisions(routes: &[&'static str]) -> bool { + let router = router_with_routes(routes); + router.has_collisions() + } + + #[test] + fn test_rankless_collisions() { + assert!(rankless_route_collisions(&["/hello", "/hello"])); + assert!(rankless_route_collisions(&["/", "/hello"])); + assert!(rankless_route_collisions(&["/", "/"])); + assert!(rankless_route_collisions(&["/hello/bob", "/hello/"])); + assert!(rankless_route_collisions(&["/a/b//d", "///c/d"])); + + assert!(rankless_route_collisions(&["/a/b", "/"])); + assert!(rankless_route_collisions(&["/a/b/c", "/a/"])); + assert!(rankless_route_collisions(&["//b", "/a/"])); + assert!(rankless_route_collisions(&["/a/", "/a/"])); + assert!(rankless_route_collisions(&["/a/b/", "/a/"])); + assert!(rankless_route_collisions(&["/", "/a/"])); + assert!(rankless_route_collisions(&["/a/", "/a/"])); + assert!(rankless_route_collisions(&["/a/b/", "/a/"])); + assert!(rankless_route_collisions(&["/a/b/c/d", "/a/"])); + assert!(rankless_route_collisions(&["/", "/"])); + assert!(rankless_route_collisions(&["/a/<_>", "/a/"])); + assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"])); + assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"])); + assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"])); + assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"])); + assert!(rankless_route_collisions(&["/a/d/", "/a/d"])); + assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"])); + assert!(rankless_route_collisions(&["/a/<_..>", "/a"])); + assert!(rankless_route_collisions(&["/", "/a/"])); + + assert!(rankless_route_collisions(&["/<_>", "/<_>"])); + assert!(rankless_route_collisions(&["/a/<_>", "/a/b"])); + assert!(rankless_route_collisions(&["/a/<_>", "/a/"])); + assert!(rankless_route_collisions(&["/<_..>", "/a/b"])); + assert!(rankless_route_collisions(&["/<_..>", "/<_>"])); + assert!(rankless_route_collisions(&["/<_>/b", "/a/b"])); + assert!(rankless_route_collisions(&["/", "/"])); + } + + #[test] + fn test_collisions_normalize() { + assert!(rankless_route_collisions(&["/hello/", "/hello"])); + assert!(rankless_route_collisions(&["//hello/", "/hello"])); + assert!(rankless_route_collisions(&["//hello/", "/hello//"])); + assert!(rankless_route_collisions(&["/", "/hello//"])); + assert!(rankless_route_collisions(&["/", "/hello///"])); + assert!(rankless_route_collisions(&["/hello///bob", "/hello/"])); + assert!(rankless_route_collisions(&["///", "/a//"])); + assert!(rankless_route_collisions(&["/a///", "/a/"])); + assert!(rankless_route_collisions(&["/a///", "/a/b//c//d/"])); + assert!(rankless_route_collisions(&["/a//", "/a/bd/e/"])); + assert!(rankless_route_collisions(&["//", "/a/bd/e/"])); + assert!(rankless_route_collisions(&["//", "/"])); + assert!(rankless_route_collisions(&["/a///", "/a/b//c//d/e/"])); + assert!(rankless_route_collisions(&["/a////", "/a/b//c//d/e/"])); + assert!(rankless_route_collisions(&["///<_>", "/<_>"])); + assert!(rankless_route_collisions(&["/a/<_>", "///a//b"])); + assert!(rankless_route_collisions(&["//a///<_>", "/a//"])); + assert!(rankless_route_collisions(&["//<_..>", "/a/b"])); + assert!(rankless_route_collisions(&["//<_..>", "/<_>"])); + assert!(rankless_route_collisions(&["////", "/a/"])); + assert!(rankless_route_collisions(&["////", "/a/"])); + assert!(rankless_route_collisions(&["/", "/hello"])); + } + + #[test] + fn test_collisions_query() { + // Query shouldn't affect things when rankless. + assert!(rankless_route_collisions(&["/hello?", "/hello"])); + assert!(rankless_route_collisions(&["/?foo=bar", "/hello?foo=bar&cat=fat"])); + assert!(rankless_route_collisions(&["/?foo=bar", "/hello?foo=bar&cat=fat"])); + assert!(rankless_route_collisions(&["/", "/?"])); + assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/?d=e"])); + assert!(rankless_route_collisions(&["/?a=b", "/foo?d=e"])); + assert!(rankless_route_collisions(&["/?a=b&", "/?d=e&"])); + assert!(rankless_route_collisions(&["/?a=b&", "/?d=e"])); + } + + #[test] + fn test_no_collisions() { + assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"])); + assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c//e"])); + assert!(!rankless_route_collisions(&["/a/d/", "/a/b/c"])); + assert!(!rankless_route_collisions(&["/<_>", "/"])); + assert!(!rankless_route_collisions(&["/a/<_>", "/a"])); + assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"])); + } + + #[test] + fn test_no_collision_when_ranked() { + assert!(!default_rank_route_collisions(&["/", "/hello"])); + assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/"])); + assert!(!default_rank_route_collisions(&["/a/b/c/d", "///c/d"])); + assert!(!default_rank_route_collisions(&["/hi", "/"])); + assert!(!default_rank_route_collisions(&["/a", "/a/"])); + assert!(!default_rank_route_collisions(&["/", "/"])); + assert!(!default_rank_route_collisions(&["/a/b", "/a/b/"])); + assert!(!default_rank_route_collisions(&["/<_>", "/static"])); + assert!(!default_rank_route_collisions(&["/<_..>", "/static"])); + assert!(!default_rank_route_collisions(&["/", "/"])); + assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"])); + assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"])); + + assert!(!default_rank_route_collisions(&["//", "/hello/"])); + assert!(!default_rank_route_collisions(&["//", "/hello/"])); + assert!(!default_rank_route_collisions(&["/", "/hello/"])); + assert!(!default_rank_route_collisions(&["/", "/hello"])); + assert!(!default_rank_route_collisions(&["/", "/a/"])); + assert!(!default_rank_route_collisions(&["/a//c", "//"])); + } + + #[test] + fn test_collision_when_ranked() { + assert!(default_rank_route_collisions(&["/a//", "/a/"])); + assert!(default_rank_route_collisions(&["//b", "/a/"])); + } + + #[test] + fn test_collision_when_ranked_query() { + assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"])); + assert!(default_rank_route_collisions(&["/a?a=b&", "/a?&c=d"])); + assert!(default_rank_route_collisions(&["/a?a=b&", "/a?&c=d"])); + } + + #[test] + fn test_no_collision_when_ranked_query() { + assert!(!default_rank_route_collisions(&["/", "/?"])); + assert!(!default_rank_route_collisions(&["/hi", "/hi?"])); + assert!(!default_rank_route_collisions(&["/hi", "/hi?c"])); + assert!(!default_rank_route_collisions(&["/hi?", "/hi?c"])); + assert!(!default_rank_route_collisions(&["/?a=b", "/?c=d&"])); + } + + fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { + let rocket = Rocket::custom(Config::default()); + let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); + let route = router.route(&request).next(); + route + } + + fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { + let rocket = Rocket::custom(Config::default()); + let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); + router.route(&request).collect() + } + + #[test] + fn test_ok_routing() { + let router = router_with_routes(&["/hello"]); + assert!(route(&router, Get, "/hello").is_some()); + + let router = router_with_routes(&["/"]); + assert!(route(&router, Get, "/hello").is_some()); + assert!(route(&router, Get, "/hi").is_some()); + assert!(route(&router, Get, "/bobbbbbbbbbby").is_some()); + assert!(route(&router, Get, "/dsfhjasdf").is_some()); + + let router = router_with_routes(&["//"]); + assert!(route(&router, Get, "/hello/hi").is_some()); + assert!(route(&router, Get, "/a/b/").is_some()); + assert!(route(&router, Get, "/i/a").is_some()); + assert!(route(&router, Get, "/jdlk/asdij").is_some()); + + let mut router = Router::new(); + router.add_route(Route::new(Put, "/hello", dummy)); + router.add_route(Route::new(Post, "/hello", dummy)); + router.add_route(Route::new(Delete, "/hello", dummy)); + assert!(route(&router, Put, "/hello").is_some()); + assert!(route(&router, Post, "/hello").is_some()); + assert!(route(&router, Delete, "/hello").is_some()); + + let router = router_with_routes(&["/"]); + assert!(route(&router, Get, "/").is_some()); + assert!(route(&router, Get, "//").is_some()); + assert!(route(&router, Get, "/hi").is_some()); + assert!(route(&router, Get, "/hello/hi").is_some()); + assert!(route(&router, Get, "/a/b/").is_some()); + assert!(route(&router, Get, "/i/a").is_some()); + assert!(route(&router, Get, "/a/b/c/d/e/f").is_some()); + + let router = router_with_routes(&["/foo/"]); + assert!(route(&router, Get, "/foo").is_some()); + assert!(route(&router, Get, "/foo/").is_some()); + assert!(route(&router, Get, "/foo///bar").is_some()); + } + + #[test] + fn test_err_routing() { + let router = router_with_routes(&["/hello"]); + assert!(route(&router, Put, "/hello").is_none()); + assert!(route(&router, Post, "/hello").is_none()); + assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Get, "/hell").is_none()); + assert!(route(&router, Get, "/hi").is_none()); + assert!(route(&router, Get, "/hello/there").is_none()); + assert!(route(&router, Get, "/hello/i").is_none()); + assert!(route(&router, Get, "/hillo").is_none()); + + let router = router_with_routes(&["/"]); + assert!(route(&router, Put, "/hello").is_none()); + assert!(route(&router, Post, "/hello").is_none()); + assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Get, "/hello/there").is_none()); + assert!(route(&router, Get, "/hello/i").is_none()); + + let router = router_with_routes(&["//"]); + assert!(route(&router, Get, "/a/b/c").is_none()); + assert!(route(&router, Get, "/a").is_none()); + assert!(route(&router, Get, "/a/").is_none()); + assert!(route(&router, Get, "/a/b/c/d").is_none()); + assert!(route(&router, Put, "/hello/hi").is_none()); + assert!(route(&router, Put, "/a/b").is_none()); + assert!(route(&router, Put, "/a/b").is_none()); + + let router = router_with_routes(&["/prefix/"]); + assert!(route(&router, Get, "/").is_none()); + assert!(route(&router, Get, "/prefi/").is_none()); + } + + macro_rules! assert_ranked_match { + ($routes:expr, $to:expr => $want:expr) => ({ + let router = router_with_routes($routes); + assert!(!router.has_collisions()); + let route_path = route(&router, Get, $to).unwrap().uri.to_string(); + assert_eq!(route_path, $want.to_string()); + }) + } + + #[test] + fn test_default_ranking() { + assert_ranked_match!(&["/hello", "/"], "/hello" => "/hello"); + assert_ranked_match!(&["/", "/hello"], "/hello" => "/hello"); + assert_ranked_match!(&["/", "/hi", "/hi/"], "/hi" => "/hi"); + assert_ranked_match!(&["//b", "/hi/c"], "/hi/c" => "/hi/c"); + assert_ranked_match!(&["//", "/hi/a"], "/hi/c" => "//"); + assert_ranked_match!(&["/hi/a", "/hi/"], "/hi/c" => "/hi/"); + assert_ranked_match!(&["/a", "/a?"], "/a?b=c" => "/a?"); + assert_ranked_match!(&["/a", "/a?"], "/a" => "/a?"); + assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/a" => "/a?"); + assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/b" => "/?"); + assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/b?v=1" => "/?"); + assert_ranked_match!(&["/a", "/", "/a?", "/?"], "/a?b=c" => "/a?"); + assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b"); + assert_ranked_match!(&["/", "/a?b"], "/a?b" => "/a?b"); + assert_ranked_match!(&["/a", "/?b"], "/a?b" => "/a"); + assert_ranked_match!(&["/a?&b", "/a?"], "/a" => "/a?"); + assert_ranked_match!(&["/a?&b", "/a?"], "/a?b" => "/a?&b"); + assert_ranked_match!(&["/a?&b", "/a?"], "/a?c" => "/a?"); + assert_ranked_match!(&["/", "/"], "/" => "/"); + assert_ranked_match!(&["/", "/"], "/hi" => "/"); + assert_ranked_match!(&["/hi", "/"], "/hi" => "/hi"); + } + + fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool { + let router = router_with_ranked_routes(routes); + router.has_collisions() + } + + #[test] + fn test_no_manual_ranked_collisions() { + assert!(!ranked_collisions(&[(1, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(0, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(5, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(1, "/a/"), (1, "/b/")])); + assert!(!ranked_collisions(&[(1, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(0, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(5, "/a/"), (2, "/a/")])); + assert!(!ranked_collisions(&[(1, "/"), (2, "/")])); + } + + #[test] + fn test_ranked_collisions() { + assert!(ranked_collisions(&[(2, "/a/"), (2, "/a/")])); + assert!(ranked_collisions(&[(2, "/a/c/"), (2, "/a/")])); + assert!(ranked_collisions(&[(2, "/"), (2, "/a/")])); + } + + macro_rules! assert_ranked_routing { + (to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({ + let router = router_with_ranked_routes(&$routes); + let routed_to = matches(&router, Get, $to); + let expected = &[$($want),+]; + assert!(routed_to.len() == expected.len()); + for (got, expected) in routed_to.iter().zip(expected.iter()) { + assert_eq!(got.rank, expected.0); + assert_eq!(got.uri.to_string(), expected.1.to_string()); + } + }) + } + + #[test] + fn test_ranked_routing() { + assert_ranked_routing!( + to: "/a/b", + with: [(1, "/a/"), (2, "/a/")], + expect: (1, "/a/"), (2, "/a/") + ); + + assert_ranked_routing!( + to: "/b/b", + with: [(1, "/a/"), (2, "/b/"), (3, "/b/b")], + expect: (2, "/b/"), (3, "/b/b") + ); + + assert_ranked_routing!( + to: "/b/b", + with: [(2, "/b/"), (1, "/a/"), (3, "/b/b")], + expect: (2, "/b/"), (3, "/b/b") + ); + + assert_ranked_routing!( + to: "/b/b", + with: [(3, "/b/b"), (2, "/b/"), (1, "/a/")], + expect: (2, "/b/"), (3, "/b/b") + ); + + assert_ranked_routing!( + to: "/b/b", + with: [(1, "/a/"), (2, "/b/"), (0, "/b/b")], + expect: (0, "/b/b"), (2, "/b/") + ); + + assert_ranked_routing!( + to: "/profile/sergio/edit", + with: [(1, "///edit"), (2, "/profile/"), (0, "///")], + expect: (0, "///"), (1, "///edit") + ); + + assert_ranked_routing!( + to: "/profile/sergio/edit", + with: [(0, "///edit"), (2, "/profile/"), (5, "///")], + expect: (0, "///edit"), (5, "///") + ); + + assert_ranked_routing!( + to: "/a/b", + with: [(0, "/a/b"), (1, "/a/")], + expect: (0, "/a/b"), (1, "/a/") + ); + + assert_ranked_routing!( + to: "/a/b/c/d/e/f", + with: [(1, "/a/"), (2, "/a/b/")], + expect: (1, "/a/"), (2, "/a/b/") + ); + + assert_ranked_routing!( + to: "/hi", + with: [(1, "/hi/"), (0, "/hi/")], + expect: (1, "/hi/") + ); + } + + macro_rules! assert_default_ranked_routing { + (to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({ + let router = router_with_routes(&$routes); + let routed_to = matches(&router, Get, $to); + let expected = &[$($want),+]; + assert!(routed_to.len() == expected.len()); + for (got, expected) in routed_to.iter().zip(expected.iter()) { + assert_eq!(got.uri.to_string(), expected.to_string()); + } + }) + } + + #[test] + fn test_default_ranked_routing() { + assert_default_ranked_routing!( + to: "/a/b?v=1", + with: ["/a/", "/a/b"], + expect: "/a/b", "/a/" + ); + + assert_default_ranked_routing!( + to: "/a/b?v=1", + with: ["/a/", "/a/b", "/a/b?"], + expect: "/a/b?", "/a/b", "/a/" + ); + + assert_default_ranked_routing!( + to: "/a/b?v=1", + with: ["/a/", "/a/b", "/a/b?", "/a/?"], + expect: "/a/b?", "/a/b", "/a/?", "/a/" + ); + + assert_default_ranked_routing!( + to: "/a/b", + with: ["/a/", "/a/b", "/a/b?", "/a/?"], + expect: "/a/b?", "/a/b", "/a/?", "/a/" + ); + + assert_default_ranked_routing!( + to: "/a/b?c", + with: ["/a/b", "/a/b?", "/a/b?c", "/a/?c", "/a/?", "//"], + expect: "/a/b?c", "/a/b?", "/a/b", "/a/?c", "/a/?", "//" + ); + } + + fn router_with_catchers(catchers: &[(Option, &str)]) -> Router { + let mut router = Router::new(); + for (code, base) in catchers { + let catcher = Catcher::new(*code, crate::catcher::dummy); + router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap()); + } + + router + } + + fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { + let rocket = Rocket::custom(Config::default()); + let request = Request::new(&rocket, Method::Get, Origin::parse(uri).unwrap()); + router.catch(status, &request) + } + + macro_rules! assert_catcher_routing { + ( + catch: [$(($code:expr, $uri:expr)),+], + reqs: [$($r:expr),+], + with: [$(($ecode:expr, $euri:expr)),+] + ) => ({ + let catchers = vec![$(($code.into(), $uri)),+]; + let requests = vec![$($r),+]; + let expected = vec![$(($ecode.into(), $euri)),+]; + + let router = router_with_catchers(&catchers); + for (req, expected) in requests.iter().zip(expected.iter()) { + let req_status = Status::from_code(req.0).expect("valid status"); + let catcher = catcher(&router, req_status, req.1).expect("some catcher"); + assert_eq!(catcher.code, expected.0, "<- got, expected ->"); + assert_eq!(catcher.base.path(), expected.1, "<- got, expected ->"); + } + }) + } + + #[test] + fn test_catcher_routing() { + // Check that the default `/` catcher catches everything. + assert_catcher_routing! { + catch: [(None, "/")], + reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")], + with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")] + } + + // Check prefixes when they're exact. + assert_catcher_routing! { + catch: [(None, "/"), (None, "/a"), (None, "/a/b")], + reqs: [ + (404, "/"), (500, "/"), + (404, "/a"), (500, "/a"), + (404, "/a/b"), (500, "/a/b") + ], + with: [ + (None, "/"), (None, "/"), + (None, "/a"), (None, "/a"), + (None, "/a/b"), (None, "/a/b") + ] + } + + // Check prefixes when they're not exact. + assert_catcher_routing! { + catch: [(None, "/"), (None, "/a"), (None, "/a/b")], + reqs: [ + (404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"), + (404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"), + (404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes") + ], + with: [ + (None, "/"), (None, "/"), (None, "/"), (None, "/"), + (None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"), + (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b") + ] + } + + // Check that we prefer specific to default. + assert_catcher_routing! { + catch: [(400, "/"), (404, "/"), (None, "/")], + reqs: [ + (400, "/"), (400, "/bar"), (400, "/foo/bar"), + (404, "/"), (404, "/bar"), (404, "/foo/bar"), + (405, "/"), (405, "/bar"), (406, "/foo/bar") + ], + with: [ + (400, "/"), (400, "/"), (400, "/"), + (404, "/"), (404, "/"), (404, "/"), + (None, "/"), (None, "/"), (None, "/") + ] + } + + // Check that we prefer longer prefixes over specific. + assert_catcher_routing! { + catch: [(None, "/a/b"), (404, "/a"), (422, "/a")], + reqs: [ + (404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"), + (404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"), + (422, "/a"), (422, "/a/c"), (422, "/a/cat/bar") + ], + with: [ + (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), + (404, "/a"), (404, "/a"), (404, "/a"), + (422, "/a"), (422, "/a"), (422, "/a") + ] + } + + // Just a fun one. + assert_catcher_routing! { + catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")], + reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")], + with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")] + } + } +} diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 35060755..da4590fe 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -329,11 +329,7 @@ impl Rocket { // response. We may wish to relax this in the future. req.cookies().reset_delta(); - // Try to get the active catcher - let catcher = self.catchers.get(&status.code) - .or_else(|| self.default_catcher.as_ref()); - - if let Some(catcher) = catcher { + if let Some(catcher) = self.router.catch(status, req) { warn_!("Responding with registered {} catcher.", catcher); let name = catcher.name.as_deref(); handle(name, || catcher.handler.handle(status, req)).await diff --git a/core/lib/tests/catcher-cookies-1213.rs b/core/lib/tests/catcher-cookies-1213.rs index 3f491873..58919089 100644 --- a/core/lib/tests/catcher-cookies-1213.rs +++ b/core/lib/tests/catcher-cookies-1213.rs @@ -24,7 +24,7 @@ mod tests { fn error_catcher_sets_cookies() { let rocket = rocket::ignite() .mount("/", routes![index]) - .register(catchers![not_found]) + .register("/", catchers![not_found]) .attach(AdHoc::on_request("Add Cookie", |req, _| Box::pin(async move { req.cookies().add(Cookie::new("fairing", "woo")); }))); diff --git a/core/lib/tests/panic-handling.rs b/core/lib/tests/panic-handling.rs index 339ee398..a089da82 100644 --- a/core/lib/tests/panic-handling.rs +++ b/core/lib/tests/panic-handling.rs @@ -22,30 +22,20 @@ fn ise() -> &'static str { "Hey, sorry! :(" } -#[catch(500)] -fn double_panic() { - panic!("so, so sorry...") -} - fn pre_future_route<'r>(_: &'r Request<'_>, _: Data) -> HandlerFuture<'r> { panic!("hey now..."); } -fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> { - panic!("a panicking pre-future catcher") -} - fn rocket() -> Rocket { - let pre_future_panic = Route::new(Method::Get, "/pre", pre_future_route); rocket::ignite() .mount("/", routes![panic_route]) - .mount("/", vec![pre_future_panic]) - .register(catchers![panic_catcher, ise]) + .mount("/", vec![Route::new(Method::Get, "/pre", pre_future_route)]) } #[test] fn catches_route_panic() { - let client = Client::debug(rocket()).unwrap(); + let rocket = rocket().register("/", catchers![panic_catcher, ise]); + let client = Client::debug(rocket).unwrap(); let response = client.get("/panic").dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); @@ -53,7 +43,8 @@ fn catches_route_panic() { #[test] fn catches_catcher_panic() { - let client = Client::debug(rocket()).unwrap(); + let rocket = rocket().register("/", catchers![panic_catcher, ise]); + let client = Client::debug(rocket).unwrap(); let response = client.get("/noroute").dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); @@ -61,7 +52,12 @@ fn catches_catcher_panic() { #[test] fn catches_double_panic() { - let rocket = rocket().register(catchers![double_panic]); + #[catch(500)] + fn double_panic() { + panic!("so, so sorry...") + } + + let rocket = rocket().register("/", catchers![panic_catcher, double_panic]); let client = Client::debug(rocket).unwrap(); let response = client.get("/noroute").dispatch(); assert_eq!(response.status(), Status::InternalServerError); @@ -70,7 +66,8 @@ fn catches_double_panic() { #[test] fn catches_early_route_panic() { - let client = Client::debug(rocket()).unwrap(); + let rocket = rocket().register("/", catchers![panic_catcher, ise]); + let client = Client::debug(rocket).unwrap(); let response = client.get("/pre").dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); @@ -78,9 +75,15 @@ fn catches_early_route_panic() { #[test] fn catches_early_catcher_panic() { - let panic_catcher = Catcher::new(404, pre_future_catcher); + fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> { + panic!("a panicking pre-future catcher") + } - let client = Client::debug(rocket().register(vec![panic_catcher])).unwrap(); + let rocket = rocket() + .register("/", vec![Catcher::new(404, pre_future_catcher)]) + .register("/", catchers![ise]); + + let client = Client::debug(rocket).unwrap(); let response = client.get("/idontexist").dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), "Hey, sorry! :("); diff --git a/core/lib/tests/redirect_from_catcher-issue-113.rs b/core/lib/tests/redirect_from_catcher-issue-113.rs index c34c2a6f..a09c923b 100644 --- a/core/lib/tests/redirect_from_catcher-issue-113.rs +++ b/core/lib/tests/redirect_from_catcher-issue-113.rs @@ -14,7 +14,7 @@ mod tests { #[test] fn error_catcher_redirect() { - let client = Client::debug(rocket::ignite().register(catchers![not_found])).unwrap(); + let client = Client::debug(rocket::ignite().register("/", catchers![not_found])).unwrap(); let response = client.get("/unknown").dispatch(); let location: Vec<_> = response.headers().get("location").collect(); diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index ea694f3c..44d2e889 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -63,5 +63,5 @@ fn not_found(request: &Request<'_>) -> Html { fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/hello", routes![get_hello, post_hello]) - .register(catchers![not_found]) + .register("/", catchers![not_found]) } diff --git a/examples/errors/src/main.rs b/examples/errors/src/main.rs index 1332b8a5..652ab253 100644 --- a/examples/errors/src/main.rs +++ b/examples/errors/src/main.rs @@ -7,7 +7,7 @@ use rocket::response::{content, status}; use rocket::http::Status; #[get("/hello//")] -fn hello(name: String, age: i8) -> String { +fn hello(name: &str, age: i8) -> String { format!("Hello, {} year old named {}!", age, name) } @@ -17,10 +17,24 @@ fn forced_error(code: u16) -> Status { } #[catch(404)] -fn not_found(req: &Request<'_>) -> content::Html { - content::Html(format!("

Sorry, but '{}' is not a valid path!

-

Try visiting /hello/<name>/<age> instead.

", - req.uri())) +fn general_not_found() -> content::Html<&'static str> { + content::Html(r#" +

Hmm... What are you looking for?

+ Say
hello! + "#) +} + +#[catch(404)] +fn hello_not_found(req: &Request<'_>) -> content::Html { + content::Html(format!("\ +

Sorry, but '{}' is not a valid path!

\ +

Try visiting /hello/<name>/<age> instead.

", + req.uri())) +} + +#[catch(default)] +fn sergio_error() -> &'static str { + "I...don't know what to say." } #[catch(default)] @@ -33,7 +47,9 @@ fn rocket() -> rocket::Rocket { rocket::ignite() // .mount("/", routes![hello, hello]) // uncoment this to get an error .mount("/", routes![hello, forced_error]) - .register(catchers![not_found, default_catcher]) + .register("/", catchers![general_not_found, default_catcher]) + .register("/hello", catchers![hello_not_found]) + .register("/hello/Sergio", catchers![sergio_error]) } #[rocket::main] diff --git a/examples/errors/src/tests.rs b/examples/errors/src/tests.rs index 7d5e6c24..3c5195f4 100644 --- a/examples/errors/src/tests.rs +++ b/examples/errors/src/tests.rs @@ -14,11 +14,11 @@ fn test_hello() { } #[test] -fn forced_error_and_default_catcher() { +fn forced_error() { let client = Client::tracked(super::rocket()).unwrap(); let request = client.get("/404"); - let expected = super::not_found(request.inner()); + let expected = super::general_not_found(); let response = request.dispatch(); assert_eq!(response.status(), Status::NotFound); assert_eq!(response.into_string().unwrap(), expected.0); @@ -46,11 +46,24 @@ fn forced_error_and_default_catcher() { fn test_hello_invalid_age() { let client = Client::tracked(super::rocket()).unwrap(); - for &(name, age) in &[("Ford", -129), ("Trillian", 128)] { - let request = client.get(format!("/hello/{}/{}", name, age)); - let expected = super::not_found(request.inner()); + for path in &["Ford/-129", "Trillian/128", "foo/bar/baz"] { + let request = client.get(format!("/hello/{}", path)); + let expected = super::hello_not_found(request.inner()); let response = request.dispatch(); assert_eq!(response.status(), Status::NotFound); assert_eq!(response.into_string().unwrap(), expected.0); } } + +#[test] +fn test_hello_sergio() { + let client = Client::tracked(super::rocket()).unwrap(); + + for path in &["oops", "-129", "foo/bar", "/foo/bar/baz"] { + let request = client.get(format!("/hello/Sergio/{}", path)); + let expected = super::sergio_error(); + let response = request.dispatch(); + assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.into_string().unwrap(), expected); + } +} diff --git a/examples/handlebars_templates/src/main.rs b/examples/handlebars_templates/src/main.rs index 04d3cafe..a2421f5f 100644 --- a/examples/handlebars_templates/src/main.rs +++ b/examples/handlebars_templates/src/main.rs @@ -69,7 +69,7 @@ fn wow_helper( fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![index, hello, about]) - .register(catchers![not_found]) + .register("/", catchers![not_found]) .attach(Template::custom(|engines| { engines.handlebars.register_helper("wow", Box::new(wow_helper)); })) diff --git a/examples/hello_2018/src/main.rs b/examples/hello_2018/src/main.rs index f4720ca5..51651475 100644 --- a/examples/hello_2018/src/main.rs +++ b/examples/hello_2018/src/main.rs @@ -10,7 +10,7 @@ fn hello() -> &'static str { fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", rocket::routes![hello]) - .register(rocket::catchers![not_found]) + .register("/", rocket::catchers![not_found]) } #[rocket::catch(404)] diff --git a/examples/json/src/main.rs b/examples/json/src/main.rs index f58c8331..72e849c4 100644 --- a/examples/json/src/main.rs +++ b/examples/json/src/main.rs @@ -75,6 +75,6 @@ fn not_found() -> JsonValue { fn rocket() -> _ { rocket::ignite() .mount("/message", routes![new, update, get, echo]) - .register(catchers![not_found]) + .register("/", catchers![not_found]) .manage(Mutex::new(HashMap::::new())) } diff --git a/examples/manual_routes/src/main.rs b/examples/manual_routes/src/main.rs index 2a67bf6a..d24cdce1 100644 --- a/examples/manual_routes/src/main.rs +++ b/examples/manual_routes/src/main.rs @@ -110,5 +110,5 @@ fn rocket() -> rocket::Rocket { .mount("/hello", vec![name.clone()]) .mount("/hi", vec![name]) .mount("/custom", CustomHandler::new("some data here")) - .register(vec![not_found_catcher]) + .register("/", vec![not_found_catcher]) } diff --git a/examples/tera_templates/src/main.rs b/examples/tera_templates/src/main.rs index 3f617773..0ca25378 100644 --- a/examples/tera_templates/src/main.rs +++ b/examples/tera_templates/src/main.rs @@ -37,5 +37,5 @@ fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![index, get]) .attach(Template::fairing()) - .register(catchers![not_found]) + .register("/", catchers![not_found]) } diff --git a/site/guide/4-requests.md b/site/guide/4-requests.md index b7e0aa74..09d3de04 100644 --- a/site/guide/4-requests.md +++ b/site/guide/4-requests.md @@ -1716,8 +1716,8 @@ Application processing is fallible. Errors arise from the following sources: * A routing failure. If any of these occur, Rocket returns an error to the client. To generate the -error, Rocket invokes the _catcher_ corresponding to the error's status code. -Catchers are similar to routes except in that: +error, Rocket invokes the _catcher_ corresponding to the error's status code and +scope. Catchers are similar to routes except in that: 1. Catchers are only invoked on error conditions. 2. Catchers are declared with the `catch` attribute. @@ -1725,6 +1725,7 @@ Catchers are similar to routes except in that: 4. Any modifications to cookies are cleared before a catcher is invoked. 5. Error catchers cannot invoke guards. 6. Error catchers should not fail to produce a response. + 7. Catchers are scoped to a path prefix. To declare a catcher for a given status code, use the [`catch`] attribute, which takes a single integer corresponding to the HTTP status code to catch. For @@ -1770,36 +1771,96 @@ looks like: # #[catch(404)] fn not_found(req: &Request) { /* .. */ } fn main() { - rocket::ignite().register(catchers![not_found]); + rocket::ignite().register("/", catchers![not_found]); } ``` -### Default Catchers +### Scoping -If no catcher for a given status code has been registered, Rocket calls the -_default_ catcher. Rocket provides a default catcher for all applications -automatically, so providing one is usually unnecessary. Rocket's built-in -default catcher can handle all errors. It produces HTML or JSON, depending on -the value of the `Accept` header. As such, a default catcher, or catchers in -general, only need to be registered if an error needs to be handled in a custom -fashion. +The first argument to `register()` is a path to scope the catcher under called +the catcher's _base_. A catcher's base determines which requests it will handle +errors for. Specifically, a catcher's base must be a prefix of the erroring +request for it to be invoked. When multiple catchers can be invoked, the catcher +with the longest base takes precedence. -Declaring a default catcher is done with `#[catch(default)]`: +As an example, consider the following application: + +```rust +# #[macro_use] extern crate rocket; + +#[catch(404)] +fn general_not_found() -> &'static str { + "General 404" +} + +#[catch(404)] +fn foo_not_found() -> &'static str { + "Foo 404" +} + +#[launch] +fn rocket() -> _ { + rocket::ignite() + .register("/", catchers![general_not_found]) + .register("/foo", catchers![foo_not_found]) +} + +# let client = rocket::local::blocking::Client::debug(rocket()).unwrap(); +# +# let response = client.get("/").dispatch(); +# assert_eq!(response.into_string().unwrap(), "General 404"); +# +# let response = client.get("/bar").dispatch(); +# assert_eq!(response.into_string().unwrap(), "General 404"); +# +# let response = client.get("/bar/baz").dispatch(); +# assert_eq!(response.into_string().unwrap(), "General 404"); +# +# let response = client.get("/foo").dispatch(); +# assert_eq!(response.into_string().unwrap(), "Foo 404"); +# +# let response = client.get("/foo/bar").dispatch(); +# assert_eq!(response.into_string().unwrap(), "Foo 404"); +``` + +Since there are no mounted routes, all requests will `404`. Any request whose +path begins with `/foo` (i.e, `GET /foo`, `GET /foo/bar`, etc) will be handled +by the `foo_not_found` catcher while all other requests will be handled by the +`general_not_found` catcher. + +### Default Catchers + +A _default_ catcher is a catcher that handles _all_ status codes. They are +invoked as a fallback if no status-specific catcher is registered for a given +error. Declaring a default catcher is done with `#[catch(default)]` and must +similarly be registered with [`register()`]: ```rust # #[macro_use] extern crate rocket; -# fn main() {} use rocket::Request; use rocket::http::Status; #[catch(default)] fn default_catcher(status: Status, request: &Request) { /* .. */ } + +#[launch] +fn rocket() -> _ { + rocket::ignite().register("/", catchers![default_catcher]) +} ``` -It must similarly be registered with [`register()`]. +Catchers with longer bases are preferred, even when there is a status-specific +catcher. In other words, a default catcher with a longer matching base than a +status-specific catcher takes precedence. -The [error catcher example](@example/errors) illustrates their use in full, +### Built-In Catcher + +Rocket provides a built-in default catcher. It produces HTML or JSON, depending +on the value of the `Accept` header. As such, custom catchers only need to be +registered for custom error handling. + +The [error catcher example](@example/errors) illustrates catcher use in full, while the [`Catcher`] API documentation provides further details. [`catch`]: @api/rocket/attr.catch.html