Introduce scoped catchers.

Catchers can now be scoped to paths, with preference given to the
longest-prefix, then the status code. This a breaking change for all
applications that register catchers:

  * `Rocket::register()` takes a base path to scope catchers under.
    - The previous behavior is recovered with `::register("/", ...)`.
  * Catchers now fallibly, instead of silently, collide.
  * `ErrorKind::Collision` is now `ErrorKind::Collisions`.

Related changes:

  * `Origin` implements `TryFrom<String>`, `TryFrom<&str>`.
  * All URI variants implement `TryFrom<Uri>`.
  * Added `Segments::prefix_of()`.
  * `Rocket::mount()` takes a  `TryInto<Origin<'_>>` instead of `&str`
    for the base mount point.
  * Extended `errors` example with scoped catchers.
  * Added scoped sections to catchers guide.

Internal changes:

  * Moved router code to `router/router.rs`.
This commit is contained in:
Sergio Benitez 2021-03-25 21:36:00 -07:00
parent c3bad3a287
commit 2893ce754d
25 changed files with 1166 additions and 695 deletions

View File

@ -46,7 +46,7 @@ fn test_rank_collision() {
let rocket = rocket::ignite().mount("/", routes![get0, get0b]);
let client_result = Client::debug(rocket);
match client_result.as_ref().map_err(|e| e.kind()) {
Err(ErrorKind::Collision(..)) => { /* o.k. */ },
Err(ErrorKind::Collisions(..)) => { /* o.k. */ },
Ok(_) => panic!("client succeeded unexpectedly"),
Err(e) => panic!("expected collision, got {}", e)
}

View File

@ -23,7 +23,7 @@ fn catch(r#raw: &rocket::Request) -> String {
fn test_raw_ident() {
let rocket = rocket::ignite()
.mount("/", routes![get, swap])
.register(catchers![catch]);
.register("/", catchers![catch]);
let client = Client::debug(rocket).unwrap();

View File

@ -1,5 +1,6 @@
use std::fmt::{self, Display};
use std::borrow::Cow;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use crate::ext::IntoOwned;
use crate::parse::{Indexed, Extent, IndexedStr};
@ -608,6 +609,22 @@ impl<'a> Origin<'a> {
}
}
impl TryFrom<String> for Origin<'static> {
type Error = Error<'static>;
fn try_from(value: String) -> Result<Self, Self::Error> {
Origin::parse_owned(value)
}
}
impl<'a> TryFrom<&'a str> for Origin<'a> {
type Error = Error<'a>;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Origin::parse(value)
}
}
impl Display for Origin<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.path())?;

View File

@ -75,6 +75,16 @@ impl<'o> Segments<'o> {
.map(|i| i.from_source(Some(self.source.as_str())))
}
/// Returns `true` if `self` is a prefix of `other`.
#[inline]
pub fn prefix_of<'b>(self, other: Segments<'b>) -> bool {
if self.len() > other.len() {
return false;
}
self.zip(other).all(|(a, b)| a == b)
}
/// Creates a `PathBuf` from `self`. The returned `PathBuf` is
/// percent-decoded. If a segment is equal to "..", the previous segment (if
/// any) is skipped.

View File

@ -282,6 +282,16 @@ impl Display for Uri<'_> {
}
}
/// The error type returned when a URI conversion fails.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct TryFromUriError(());
impl fmt::Display for TryFromUriError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"invalid conversion from general to specific URI variant".fmt(f)
}
}
macro_rules! impl_uri_from {
($type:ident) => (
impl<'a> From<$type<'a>> for Uri<'a> {
@ -289,6 +299,17 @@ macro_rules! impl_uri_from {
Uri::$type(other)
}
}
impl<'a> TryFrom<Uri<'a>> for $type<'a> {
type Error = TryFromUriError;
fn try_from(uri: Uri<'a>) -> Result<Self, Self::Error> {
match uri {
Uri::$type(inner) => Ok(inner),
_ => Err(TryFromUriError(()))
}
}
}
)
}

View File

@ -1,5 +1,5 @@
//! Types and traits for error catchers, error handlers, and their return
//! values.
//! types.
use std::fmt;
use std::io::Cursor;
@ -7,7 +7,7 @@ use std::io::Cursor;
use crate::response::Response;
use crate::codegen::StaticCatcherInfo;
use crate::request::Request;
use crate::http::ContentType;
use crate::http::{Status, ContentType, uri};
use futures::future::BoxFuture;
use yansi::Paint;
@ -19,6 +19,12 @@ pub type Result<'r> = std::result::Result<Response<'r>, crate::http::Status>;
/// Type alias for the unwieldy [`ErrorHandler::handle()`] return type.
pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>;
// A handler to use when one is needed temporarily. Don't use outside of Rocket!
#[cfg(test)]
pub(crate) fn dummy<'r>(_: Status, _: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
Box::pin(async move { Ok(Response::new()) })
}
/// An error catching route.
///
/// # Overview
@ -77,12 +83,12 @@ pub type ErrorHandlerFuture<'r> = BoxFuture<'r, Result<'r>>;
///
/// #[catch(default)]
/// fn default(status: Status, req: &Request) -> String {
/// format!("{} - {} ({})", status.code, status.reason, req.uri())
/// format!("{} ({})", status, req.uri())
/// }
///
/// #[launch]
/// fn rocket() -> rocket::Rocket {
/// rocket::ignite().register(catchers![internal_error, not_found, default])
/// rocket::ignite().register("/", catchers![internal_error, not_found, default])
/// }
/// ```
///
@ -104,7 +110,10 @@ pub struct Catcher {
/// The name of this catcher, if one was given.
pub name: Option<Cow<'static, str>>,
/// The HTTP status code to match against if this route is not `default`.
/// The mount point.
pub base: uri::Origin<'static>,
/// The HTTP status to match against if this route is not `default`.
pub code: Option<u16>,
/// The catcher's associated error handler.
@ -112,8 +121,8 @@ pub struct Catcher {
}
impl Catcher {
/// Creates a catcher for the given status code, or a default catcher if
/// `code` is `None`, using the given error handler. This should only be
/// Creates a catcher for the given `status`, or a default catcher if
/// `status` is `None`, using the given error handler. This should only be
/// used when routing manually.
///
/// # Examples
@ -121,11 +130,11 @@ impl Catcher {
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, ErrorHandlerFuture};
/// use rocket::response::{Result, Responder, status::Custom};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = Custom(status, format!("404: {}", req.uri()));
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
///
@ -134,7 +143,7 @@ impl Catcher {
/// }
///
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = Custom(status, format!("{}: {}", status, req.uri()));
/// let res = (status, format!("{}: {}", status, req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
///
@ -142,22 +151,82 @@ impl Catcher {
/// let internal_server_error_catcher = Catcher::new(500, handle_500);
/// let default_error_catcher = Catcher::new(None, handle_default);
/// ```
///
/// # Panics
///
/// Panics if `code` is not in the HTTP status code error range `[400,
/// 600)`.
#[inline(always)]
pub fn new<C, H>(code: C, handler: H) -> Catcher
where C: Into<Option<u16>>, H: ErrorHandler
pub fn new<S, H>(code: S, handler: H) -> Catcher
where S: Into<Option<u16>>, H: ErrorHandler
{
Catcher { name: None, code: code.into(), handler: Box::new(handler) }
let code = code.into();
if let Some(code) = code {
assert!(code >= 400 && code < 600);
}
Catcher {
name: None,
base: uri::Origin::new("/", None::<&str>),
handler: Box::new(handler),
code,
}
}
/// Maps the `base` of this catcher using `mapper`, returning a new
/// `Catcher` with the returned base.
///
/// `mapper` is called with the current base. The returned `String` is used
/// as the new base if it is a valid URI. If the returned base URI contains
/// a query, it is ignored. Returns an error if the base produced by
/// `mapper` is not a valid origin URI.
///
/// # Example
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::catcher::{Catcher, ErrorHandlerFuture};
/// use rocket::response::Responder;
/// use rocket::http::Status;
///
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
/// let res = (status, format!("404: {}", req.uri()));
/// Box::pin(async move { res.respond_to(req) })
/// }
///
/// let catcher = Catcher::new(404, handle_404);
/// assert_eq!(catcher.base.path(), "/");
///
/// let catcher = catcher.map_base(|_| format!("/bar")).unwrap();
/// assert_eq!(catcher.base.path(), "/bar");
///
/// let catcher = catcher.map_base(|base| format!("/foo{}", base)).unwrap();
/// assert_eq!(catcher.base.path(), "/foo/bar");
///
/// let catcher = catcher.map_base(|base| format!("/foo ? {}", base));
/// assert!(catcher.is_err());
/// ```
pub fn map_base<'a, F>(
mut self,
mapper: F
) -> std::result::Result<Self, uri::Error<'static>>
where F: FnOnce(uri::Origin<'a>) -> String
{
self.base = uri::Origin::parse_owned(mapper(self.base))?.into_normalized();
self.base.clear_query();
Ok(self)
}
}
impl Default for Catcher {
fn default() -> Self {
fn async_default<'r>(status: Status, request: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
Box::pin(async move { Ok(default(status, request)) })
fn handler<'r>(s: Status, req: &'r Request<'_>) -> ErrorHandlerFuture<'r> {
Box::pin(async move { Ok(default(s, req)) })
}
let name = Some("<Rocket Catcher>".into());
Catcher { name, code: None, handler: Box::new(async_default) }
let mut catcher = Catcher::new(None, handler);
catcher.name = Some("<Rocket Catcher>".into());
catcher
}
}
@ -226,9 +295,9 @@ impl Default for Catcher {
/// fn rocket() -> rocket::Rocket {
/// rocket::ignite()
/// // to handle only `404`
/// .register(CustomHandler::catch(Status::NotFound, Kind::Simple))
/// .register("/", CustomHandler::catch(Status::NotFound, Kind::Simple))
/// // or to register as the default
/// .register(CustomHandler::default(Kind::Simple))
/// .register("/", CustomHandler::default(Kind::Simple))
/// }
/// ```
///
@ -239,7 +308,7 @@ impl Default for Catcher {
/// trait serves no other purpose but to ensure that every `ErrorHandler`
/// can be cloned, allowing `Catcher`s to be cloned.
/// 2. `CustomHandler`'s methods return `Vec<Route>`, allowing for use
/// directly as the parameter to `rocket.register()`.
/// directly as the parameter to `rocket.register("/", )`.
/// 3. Unlike static-function-based handlers, this custom handler can make use
/// of internal state.
#[crate::async_trait]
@ -288,6 +357,10 @@ impl fmt::Display for Catcher {
write!(f, "{}{}{} ", Paint::cyan("("), Paint::white(n), Paint::cyan(")"))?;
}
if self.base.path() != "/" {
write!(f, "{} ", Paint::green(self.base.path()))?;
}
match self.code {
Some(code) => write!(f, "{}", Paint::blue(code)),
None => write!(f, "{}", Paint::blue("default"))
@ -298,6 +371,8 @@ impl fmt::Display for Catcher {
impl fmt::Debug for Catcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Catcher")
.field("name", &self.name)
.field("base", &self.base)
.field("code", &self.code)
.finish()
}
@ -359,7 +434,6 @@ r#"{{
macro_rules! default_catcher_fn {
($($code:expr, $reason:expr, $description:expr),+) => (
use std::borrow::Cow;
use crate::http::Status;
pub(crate) fn default<'r>(status: Status, req: &'r Request<'_>) -> Response<'r> {
let preferred = req.accept().map(|a| a.preferred());

View File

@ -6,8 +6,6 @@ use std::sync::atomic::{Ordering, AtomicBool};
use yansi::Paint;
use figment::Profile;
use crate::router::Route;
/// An error that occurs during launch.
///
/// An `Error` is returned by [`launch()`](crate::Rocket::launch()) when
@ -80,7 +78,7 @@ pub enum ErrorKind {
/// An I/O error occurred in the runtime.
Runtime(Box<dyn std::error::Error + Send + Sync>),
/// Route collisions were detected.
Collision(Vec<(Route, Route)>),
Collisions(crate::router::Collisions),
/// Launch fairing(s) failed.
FailedFairings(Vec<crate::fairing::Info>),
/// The configuration profile is not debug but not secret key is configured.
@ -140,7 +138,7 @@ impl fmt::Display for ErrorKind {
match self {
ErrorKind::Bind(e) => write!(f, "binding failed: {}", e),
ErrorKind::Io(e) => write!(f, "I/O error: {}", e),
ErrorKind::Collision(_) => "route collisions detected".fmt(f),
ErrorKind::Collisions(_) => "collisions detected".fmt(f),
ErrorKind::FailedFairings(_) => "a launch fairing failed".fmt(f),
ErrorKind::Runtime(e) => write!(f, "runtime error: {}", e),
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
@ -181,14 +179,21 @@ impl Drop for Error {
info_!("{}", e);
panic!("aborting due to i/o error");
}
ErrorKind::Collision(ref collisions) => {
error!("Rocket failed to launch due to the following routing collisions:");
for &(ref a, ref b) in collisions {
info_!("{} {} {}", a, Paint::red("collides with").italic(), b)
ErrorKind::Collisions(ref collisions) => {
fn log_collisions<T: fmt::Display>(kind: &str, collisions: &[(T, T)]) {
if collisions.is_empty() { return }
error!("Rocket failed to launch due to the following {} collisions:", kind);
for &(ref a, ref b) in collisions {
info_!("{} {} {}", a, Paint::red("collides with").italic(), b)
}
}
info_!("Note: Collisions can usually be resolved by ranking routes.");
panic!("route collisions detected");
log_collisions("route", &collisions.routes);
log_collisions("catcher", &collisions.catchers);
info_!("Note: Route collisions can usually be resolved by ranking routes.");
panic!("routing collisions detected");
}
ErrorKind::FailedFairings(ref failures) => {
error!("Rocket failed to launch due to failing fairings:");

View File

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::fmt::Display;
use std::convert::TryInto;
use yansi::Paint;
use state::Container;
@ -13,7 +14,7 @@ use crate::router::{Router, Route};
use crate::fairing::{Fairing, Fairings};
use crate::logger::PaintExt;
use crate::shutdown::Shutdown;
use crate::http::uri::Origin;
use crate::http::{uri::Origin, ext::IntoOwned};
use crate::error::{Error, ErrorKind};
/// The main `Rocket` type: used to mount routes and catchers and launch the
@ -24,8 +25,6 @@ pub struct Rocket {
pub(crate) figment: Figment,
pub(crate) managed_state: Container![Send + Sync],
pub(crate) router: Router,
pub(crate) default_catcher: Option<Catcher>,
pub(crate) catchers: HashMap<u16, Catcher>,
pub(crate) fairings: Fairings,
pub(crate) shutdown_receiver: Option<mpsc::Receiver<()>>,
pub(crate) shutdown_handle: Shutdown,
@ -95,8 +94,6 @@ impl Rocket {
config, figment, managed_state,
shutdown_handle: Shutdown(shutdown_sender),
router: Router::new(),
default_catcher: None,
catchers: HashMap::new(),
fairings: Fairings::new(),
shutdown_receiver: Some(shutdown_receiver),
}
@ -203,10 +200,15 @@ impl Rocket {
/// # .launch().await;
/// # };
/// ```
pub fn mount<R: Into<Vec<Route>>>(mut self, base: &str, routes: R) -> Self {
let base_uri = Origin::parse_owned(base.to_string())
pub fn mount<'a, B, R>(mut self, base: B, routes: R) -> Self
where B: TryInto<Origin<'a>> + Clone + Display,
B::Error: Display,
R: Into<Vec<Route>>
{
let base_uri = base.clone().try_into()
.map(|origin| origin.into_owned())
.unwrap_or_else(|e| {
error!("Invalid mount point URI: {}.", Paint::white(base));
error!("Invalid route base: {}.", Paint::white(&base));
panic!("Error: {}", e);
});
@ -215,11 +217,11 @@ impl Rocket {
panic!("Invalid mount point.");
}
info!("{}{} {}{}",
info!("{}{} {} {}",
Paint::emoji("🛰 "),
Paint::magenta("Mounting"),
Paint::blue(&base_uri),
Paint::magenta(":"));
Paint::magenta("routes:"));
for route in routes.into() {
let mounted_route = route.clone()
@ -231,13 +233,18 @@ impl Rocket {
});
info_!("{}", mounted_route);
self.router.add(mounted_route);
self.router.add_route(mounted_route);
}
self
}
/// Registers all of the catchers in the supplied vector.
/// Registers all of the catchers in the supplied vector, scoped to `base`.
///
/// # Panics
///
/// Panics if `base` is not a valid static path: a valid origin URI without
/// dynamic parameters.
///
/// # Examples
///
@ -257,23 +264,31 @@ impl Rocket {
///
/// #[launch]
/// fn rocket() -> rocket::Rocket {
/// rocket::ignite().register(catchers![internal_error, not_found])
/// rocket::ignite().register("/", catchers![internal_error, not_found])
/// }
/// ```
pub fn register(mut self, catchers: Vec<Catcher>) -> Self {
info!("{}{}", Paint::emoji("👾 "), Paint::magenta("Catchers:"));
pub fn register<'a, B, C>(mut self, base: B, catchers: C) -> Self
where B: TryInto<Origin<'a>> + Clone + Display,
B::Error: Display,
C: Into<Vec<Catcher>>
{
info!("{}{} {} {}",
Paint::emoji("👾 "),
Paint::magenta("Registering"),
Paint::blue(&base),
Paint::magenta("catchers:"));
for catcher in catchers {
info_!("{}", catcher);
for catcher in catchers.into() {
let mounted_catcher = catcher.clone()
.map_base(|old| format!("{}{}", base, old))
.unwrap_or_else(|e| {
error_!("Catcher `{}` has a malformed URI.", catcher);
error_!("{}", e);
panic!("Invalid catcher URI.");
});
let existing = match catcher.code {
Some(code) => self.catchers.insert(code, catcher),
None => self.default_catcher.replace(catcher)
};
if let Some(existing) = existing {
warn_!("Replacing existing '{}' catcher.", existing);
}
info_!("{}", mounted_catcher);
self.router.add_catcher(mounted_catcher);
}
self
@ -444,7 +459,7 @@ impl Rocket {
/// }
/// ```
#[inline(always)]
pub fn routes(&self) -> impl Iterator<Item = &Route> + '_ {
pub fn routes(&self) -> impl Iterator<Item = &Route> {
self.router.routes()
}
@ -464,7 +479,7 @@ impl Rocket {
///
/// fn main() {
/// let mut rocket = rocket::ignite()
/// .register(catchers![not_found, just_500, some_default]);
/// .register("/", catchers![not_found, just_500, some_default]);
///
/// let mut codes: Vec<_> = rocket.catchers().map(|c| c.code).collect();
/// codes.sort();
@ -473,8 +488,8 @@ impl Rocket {
/// }
/// ```
#[inline(always)]
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + '_ {
self.catchers.values().chain(self.default_catcher.as_ref())
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
self.router.catchers()
}
/// Returns `Some` of the managed state value for the type `T` if it is
@ -525,10 +540,8 @@ impl Rocket {
/// * there were no fairing failures
/// * a secret key, if needed, is securely configured
pub(crate) async fn prelaunch_check(&mut self) -> Result<(), Error> {
let collisions: Vec<_> = self.router.collisions().collect();
if !collisions.is_empty() {
let owned = collisions.into_iter().map(|(a, b)| (a.clone(), b.clone()));
return Err(Error::new(ErrorKind::Collision(owned.collect())));
if let Err(collisions) = self.router.finalize() {
return Err(Error::new(ErrorKind::Collisions(collisions)));
}
if let Some(failures) = self.fairings.failures() {

View File

@ -1,46 +1,23 @@
use super::{Route, uri::Color};
use crate::catcher::Catcher;
use crate::http::MediaType;
use crate::http::{MediaType, Status};
use crate::request::Request;
impl Route {
/// Determines if two routes can match against some request. That is, if two
/// routes `collide`, there exists a request that can match against both
/// routes.
///
/// This implementation is used at initialization to check if two user
/// routes collide before launching. Format collisions works like this:
///
/// * If route specifies a format, it only gets requests for that format.
/// * If route doesn't specify a format, it gets requests for any format.
///
/// Because query parsing is lenient, and dynamic query parameters can be
/// missing, queries do not impact whether two routes collide.
pub(crate) fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
&& self.rank == other.rank
&& paths_collide(self, other)
&& formats_collide(self, other)
}
pub trait Collide<T = Self> {
fn collides_with(&self, other: &T) -> bool;
}
/// Determines if this route matches against the given request.
///
/// This means that:
///
/// * The route's method matches that of the incoming request.
/// * The route's format (if any) matches that of the incoming request.
/// - If route specifies format, it only gets requests for that format.
/// - If route doesn't specify format, it gets requests for any format.
/// * All static components in the route's path match the corresponding
/// components in the same position in the incoming request.
/// * All static components in the route's query string are also in the
/// request query string, though in any position. If there is no query
/// in the route, requests with/without queries match.
pub(crate) fn matches(&self, req: &Request<'_>) -> bool {
self.method == req.method()
&& paths_match(self, req)
&& queries_match(self, req)
&& formats_match(self, req)
impl<'a, 'b, T: Collide> Collide<&T> for &T {
fn collides_with(&self, other: &&T) -> bool {
T::collides_with(*self, *other)
}
}
impl Collide for MediaType {
fn collides_with(&self, other: &Self) -> bool {
let collide = |a, b| a == "*" || b == "*" || a == b;
collide(self.top(), other.top()) && collide(self.sub(), other.sub())
}
}
@ -66,6 +43,68 @@ fn paths_collide(route: &Route, other: &Route) -> bool {
|| a_segments.len() == b_segments.len()
}
fn formats_collide(route: &Route, other: &Route) -> bool {
// When matching against the `Accept` header, the client can always
// provide a media type that will cause a collision through
// non-specificity.
if !route.method.supports_payload() {
return true;
}
// When matching against the `Content-Type` header, we'll only
// consider requests as having a `Content-Type` if they're fully
// specified. If a route doesn't have a `format`, it accepts all
// `Content-Type`s. If a request doesn't have a format, it only
// matches routes without a format.
match (route.format.as_ref(), other.format.as_ref()) {
(Some(a), Some(b)) => a.collides_with(b),
_ => true
}
}
impl Collide for Route {
/// Determines if two routes can match against some request. That is, if two
/// routes `collide`, there exists a request that can match against both
/// routes.
///
/// This implementation is used at initialization to check if two user
/// routes collide before launching. Format collisions works like this:
///
/// * If route specifies a format, it only gets requests for that format.
/// * If route doesn't specify a format, it gets requests for any format.
///
/// Because query parsing is lenient, and dynamic query parameters can be
/// missing, queries do not impact whether two routes collide.
fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
&& self.rank == other.rank
&& paths_collide(self, other)
&& formats_collide(self, other)
}
}
impl Route {
/// Determines if this route matches against the given request.
///
/// This means that:
///
/// * The route's method matches that of the incoming request.
/// * The route's format (if any) matches that of the incoming request.
/// - If route specifies format, it only gets requests for that format.
/// - If route doesn't specify format, it gets requests for any format.
/// * All static components in the route's path match the corresponding
/// components in the same position in the incoming request.
/// * All static components in the route's query string are also in the
/// request query string, though in any position. If there is no query
/// in the route, requests with/without queries match.
pub(crate) fn matches(&self, req: &Request<'_>) -> bool {
self.method == req.method()
&& paths_match(self, req)
&& queries_match(self, req)
&& formats_match(self, req)
}
}
fn paths_match(route: &Route, req: &Request<'_>) -> bool {
let route_segments = &route.uri.metadata.path_segs;
let req_segments = req.uri().path_segments();
@ -90,7 +129,7 @@ fn paths_match(route: &Route, req: &Request<'_>) -> bool {
return true;
}
if !route_seg.dynamic && route_seg.value != req_seg {
if !(route_seg.dynamic || route_seg.value == req_seg) {
return false;
}
}
@ -116,33 +155,16 @@ fn queries_match(route: &Route, req: &Request<'_>) -> bool {
true
}
fn formats_collide(route: &Route, other: &Route) -> bool {
// When matching against the `Accept` header, the client can always provide
// a media type that will cause a collision through non-specificity.
if !route.method.supports_payload() {
return true;
}
// When matching against the `Content-Type` header, we'll only consider
// requests as having a `Content-Type` if they're fully specified. If a
// route doesn't have a `format`, it accepts all `Content-Type`s. If a
// request doesn't have a format, it only matches routes without a format.
match (route.format.as_ref(), other.format.as_ref()) {
(Some(a), Some(b)) => media_types_collide(a, b),
_ => true
}
}
fn formats_match(route: &Route, request: &Request<'_>) -> bool {
if !route.method.supports_payload() {
route.format.as_ref()
.and_then(|a| request.format().map(|b| (a, b)))
.map(|(a, b)| media_types_collide(a, b))
.map(|(a, b)| a.collides_with(b))
.unwrap_or(true)
} else {
match route.format.as_ref() {
Some(a) => match request.format() {
Some(b) if b.specificity() == 2 => media_types_collide(a, b),
Some(b) if b.specificity() == 2 => a.collides_with(b),
_ => false
}
None => true
@ -150,9 +172,30 @@ fn formats_match(route: &Route, request: &Request<'_>) -> bool {
}
}
fn media_types_collide(first: &MediaType, other: &MediaType) -> bool {
let collide = |a, b| a == "*" || b == "*" || a == b;
collide(first.top(), other.top()) && collide(first.sub(), other.sub())
impl Collide for Catcher {
/// Determines if two catchers are in conflict: there exists a request for
/// which there exist no rule to determine _which_ of the two catchers to
/// use. This means that the catchers:
///
/// * Have the same base.
/// * Have the same status code or are both defaults.
fn collides_with(&self, other: &Self) -> bool {
self.code == other.code
&& self.base.path_segments().eq(other.base.path_segments())
}
}
impl Catcher {
/// Determines if this catcher is responsible for handling the error with
/// `status` that occurred during request `req`. A catcher matches if:
///
/// * It is a default catcher _or_ has a code of `status`.
/// * Its base is a prefix of the normalized/decoded `req.path()`.
pub(crate) fn matches(&self, status: Status, req: &Request<'_>) -> bool {
self.code.map_or(true, |code| code == status.code)
&& self.base.path_segments().prefix_of(req.uri().path_segments())
}
}
#[cfg(test)]
@ -335,7 +378,7 @@ mod tests {
fn mt_mt_collide(mt1: &str, mt2: &str) -> bool {
let mt_a = MediaType::from_str(mt1).expect(mt1);
let mt_b = MediaType::from_str(mt2).expect(mt2);
media_types_collide(&mt_a, &mt_b)
mt_a.collides_with(&mt_b)
}
#[test]
@ -525,4 +568,35 @@ mod tests {
assert!(!req_route_path_match("/a/b", "/a/b?foo&<rest..>"));
assert!(!req_route_path_match("/a/b", "/a/b?<a>&b&<rest..>"));
}
fn catchers_collide<A, B>(a: A, ap: &str, b: B, bp: &str) -> bool
where A: Into<Option<u16>>, B: Into<Option<u16>>
{
let a = Catcher::new(a, crate::catcher::dummy).map_base(|_| ap.into()).unwrap();
let b = Catcher::new(b, crate::catcher::dummy).map_base(|_| bp.into()).unwrap();
a.collides_with(&b)
}
#[test]
fn catcher_collisions() {
for path in &["/a", "/foo", "/a/b/c", "/a/b/c/d/e"] {
assert!(catchers_collide(404, path, 404, path));
assert!(catchers_collide(500, path, 500, path));
assert!(catchers_collide(None, path, None, path));
}
}
#[test]
fn catcher_non_collisions() {
assert!(!catchers_collide(404, "/foo", 405, "/foo"));
assert!(!catchers_collide(404, "/", None, "/foo"));
assert!(!catchers_collide(404, "/", None, "/"));
assert!(!catchers_collide(404, "/a/b", None, "/a/b"));
assert!(!catchers_collide(404, "/a/b", 404, "/a/b/c"));
assert!(!catchers_collide(None, "/a/b", None, "/a/b/c"));
assert!(!catchers_collide(None, "/b", None, "/a/b/c"));
assert!(!catchers_collide(None, "/", None, "/a/b/c"));
}
}

View File

@ -1,514 +1,13 @@
//! Routing types: [`Route`] and [`RouteUri`].
mod collider;
mod route;
mod segment;
mod uri;
mod router;
mod collider;
use std::collections::HashMap;
pub(crate) use router::*;
use crate::request::Request;
use crate::http::Method;
pub use self::route::Route;
pub use self::uri::RouteUri;
// type Selector = (Method, usize);
type Selector = Method;
#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Selector, Vec<Route>>,
}
impl Router {
pub fn new() -> Router {
Router { routes: HashMap::new() }
}
pub fn add(&mut self, route: Route) {
let selector = route.method;
let entries = self.routes.entry(selector).or_insert_with(|| vec![]);
let i = entries.binary_search_by_key(&route.rank, |r| r.rank)
.unwrap_or_else(|i| i);
entries.insert(i, route);
}
pub fn route<'r, 'a: 'r>(&'a self, req: &'r Request<'r>) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by rank on each `add`.
self.routes.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
}
pub(crate) fn collisions(&self) -> impl Iterator<Item = (&Route, &Route)> {
let all_routes = self.routes.values().flat_map(|v| v.iter());
all_routes.clone().enumerate()
.flat_map(move |(i, a)| {
all_routes.clone()
.skip(i + 1)
.filter(move |b| b.collides_with(a))
.map(move |b| (a, b))
})
}
#[inline]
pub fn routes(&self) -> impl Iterator<Item = &Route> {
self.routes.values().flat_map(|v| v.iter())
}
// This is slow. Don't expose this publicly; only for tests.
#[cfg(test)]
fn has_collisions(&self) -> bool {
self.collisions().next().is_some()
}
}
#[cfg(test)]
mod test {
use super::{Router, Route};
use crate::rocket::Rocket;
use crate::config::Config;
use crate::http::{Method, Method::*};
use crate::http::uri::Origin;
use crate::request::Request;
use crate::handler::dummy;
fn router_with_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = dbg!(Route::new(Get, route, dummy));
router.add(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy);
router.add(route);
}
router
}
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::ranked(0, Get, route, dummy);
router.add(route);
}
router
}
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
router.has_collisions()
}
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
router.has_collisions()
}
#[test]
fn test_rankless_collisions() {
assert!(rankless_route_collisions(&["/hello", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>"]));
assert!(rankless_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/a/b/<c>/d", "/<a>/<b>/c/d"]));
assert!(rankless_route_collisions(&["/a/b", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a>/b", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<b>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<c>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c/d", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/a/d/<b..>", "/a/d"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/a"]));
assert!(rankless_route_collisions(&["/<a>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/b"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<b>"]));
assert!(rankless_route_collisions(&["/<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/<_>/b", "/a/b"]));
assert!(rankless_route_collisions(&["/", "/<foo..>"]));
}
#[test]
fn test_collisions_normalize() {
assert!(rankless_route_collisions(&["/hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello///"]));
assert!(rankless_route_collisions(&["/hello///bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/<a..>//", "/a//<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/"]));
assert!(rankless_route_collisions(&["/a/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["//", "/<foo..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["/a//<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["///<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "///a//b"]));
assert!(rankless_route_collisions(&["//a///<_>", "/a//<b>"]));
assert!(rankless_route_collisions(&["//<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["//<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["///<a>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["///<a..>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/hello"]));
}
#[test]
fn test_collisions_query() {
// Query shouldn't affect things when rankless.
assert!(rankless_route_collisions(&["/hello?<foo>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>?<foo>"]));
assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/<b>?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b", "/foo?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e&<c>"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e"]));
}
#[test]
fn test_no_collisions() {
assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c/<d>/e"]));
assert!(!rankless_route_collisions(&["/a/d/<b..>", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/<_>", "/"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/a"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"]));
}
#[test]
fn test_no_collision_when_ranked() {
assert!(!default_rank_route_collisions(&["/<a>", "/hello"]));
assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/a/b/c/d", "/<a>/<b>/c/d"]));
assert!(!default_rank_route_collisions(&["/hi", "/<hi>"]));
assert!(!default_rank_route_collisions(&["/a", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/", "/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/b", "/a/b/<c..>"]));
assert!(!default_rank_route_collisions(&["/<_>", "/static"]));
assert!(!default_rank_route_collisions(&["/<_..>", "/static"]));
assert!(!default_rank_route_collisions(&["/<path..>", "/"]));
assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/<a>/<b>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a>/<b..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello"]));
assert!(!default_rank_route_collisions(&["/<a>", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/<b>/c", "/<d>/<c..>"]));
}
#[test]
fn test_collision_when_ranked() {
assert!(default_rank_route_collisions(&["/a/<b>/<c..>", "/a/<c>"]));
assert!(default_rank_route_collisions(&["/<a>/b", "/a/<b>"]));
}
#[test]
fn test_collision_when_ranked_query() {
assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b>", "/a?<c>&c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b..>", "/a?<c>&c=d"]));
}
#[test]
fn test_no_collision_when_ranked_query() {
assert!(!default_rank_route_collisions(&["/", "/?<c..>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?<c>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/hi?<c>", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
}
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
let route = router.route(&request).next();
route
}
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
router.route(&request).collect()
}
#[test]
fn test_ok_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Get, "/hello").is_some());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Get, "/hello").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/bobbbbbbbbbby").is_some());
assert!(route(&router, Get, "/dsfhjasdf").is_some());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/jdlk/asdij").is_some());
let mut router = Router::new();
router.add(Route::new(Put, "/hello", dummy));
router.add(Route::new(Post, "/hello", dummy));
router.add(Route::new(Delete, "/hello", dummy));
assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some());
let router = router_with_routes(&["/<a..>"]);
assert!(route(&router, Get, "/").is_some());
assert!(route(&router, Get, "//").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/a/b/c/d/e/f").is_some());
let router = router_with_routes(&["/foo/<a..>"]);
assert!(route(&router, Get, "/foo").is_some());
assert!(route(&router, Get, "/foo/").is_some());
assert!(route(&router, Get, "/foo///bar").is_some());
}
#[test]
fn test_err_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hell").is_none());
assert!(route(&router, Get, "/hi").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
assert!(route(&router, Get, "/hillo").is_none());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/a/b/c").is_none());
assert!(route(&router, Get, "/a").is_none());
assert!(route(&router, Get, "/a/").is_none());
assert!(route(&router, Get, "/a/b/c/d").is_none());
assert!(route(&router, Put, "/hello/hi").is_none());
assert!(route(&router, Put, "/a/b").is_none());
assert!(route(&router, Put, "/a/b").is_none());
let router = router_with_routes(&["/prefix/<a..>"]);
assert!(route(&router, Get, "/").is_none());
assert!(route(&router, Get, "/prefi/").is_none());
}
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string());
})
}
#[test]
fn test_default_ranking() {
assert_ranked_match!(&["/hello", "/<name>"], "/hello" => "/hello");
assert_ranked_match!(&["/<name>", "/hello"], "/hello" => "/hello");
assert_ranked_match!(&["/<a>", "/hi", "/hi/<b>"], "/hi" => "/hi");
assert_ranked_match!(&["/<a>/b", "/hi/c"], "/hi/c" => "/hi/c");
assert_ranked_match!(&["/<a>/<b>", "/hi/a"], "/hi/c" => "/<a>/<b>");
assert_ranked_match!(&["/hi/a", "/hi/<c>"], "/hi/c" => "/hi/<c>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b?v=1" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/<a>", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/a", "/<a>?b"], "/a?b" => "/a");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?b" => "/a?<c>&b");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?c" => "/a?<b>");
assert_ranked_match!(&["/", "/<foo..>"], "/" => "/");
assert_ranked_match!(&["/", "/<foo..>"], "/hi" => "/<foo..>");
assert_ranked_match!(&["/hi", "/<foo..>"], "/hi" => "/hi");
}
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
router.has_collisions()
}
#[test]
fn test_no_manual_ranked_collisions() {
assert!(!ranked_collisions(&[(1, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(0, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(5, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b>"), (1, "/b/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(0, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(5, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(1, "/<a..>"), (2, "/<a..>")]));
}
#[test]
fn test_ranked_collisions() {
assert!(ranked_collisions(&[(2, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/a/c/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/<b..>"), (2, "/a/<b..>")]));
}
macro_rules! assert_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string());
}
})
}
#[test]
fn test_ranked_routing() {
assert_ranked_routing!(
to: "/a/b",
with: [(1, "/a/<b>"), (2, "/a/<b>")],
expect: (1, "/a/<b>"), (2, "/a/<b>")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(2, "/b/<b>"), (1, "/a/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(3, "/b/b"), (2, "/b/<b>"), (1, "/a/<b>")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (0, "/b/b")],
expect: (0, "/b/b"), (2, "/b/<b>")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(1, "/<a>/<b>/edit"), (2, "/profile/<d>"), (0, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/<c>"), (1, "/<a>/<b>/edit")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(0, "/<a>/<b>/edit"), (2, "/profile/<d>"), (5, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/edit"), (5, "/<a>/<b>/<c>")
);
assert_ranked_routing!(
to: "/a/b",
with: [(0, "/a/b"), (1, "/a/<b..>")],
expect: (0, "/a/b"), (1, "/a/<b..>")
);
assert_ranked_routing!(
to: "/a/b/c/d/e/f",
with: [(1, "/a/<b..>"), (2, "/a/b/<c..>")],
expect: (1, "/a/<b..>"), (2, "/a/b/<c..>")
);
assert_ranked_routing!(
to: "/hi",
with: [(1, "/hi/<foo..>"), (0, "/hi/<foo>")],
expect: (1, "/hi/<foo..>")
);
}
macro_rules! assert_default_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.uri.to_string(), expected.to_string());
}
})
}
#[test]
fn test_default_ranked_routing() {
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b"],
expect: "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?c",
with: ["/a/b", "/a/b?<c>", "/a/b?c", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"],
expect: "/a/b?c", "/a/b?<c>", "/a/b", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"
);
}
}
pub use route::Route;
pub use collider::Collide;
pub use uri::RouteUri;

View File

@ -162,7 +162,6 @@ impl Route {
}
}
/// Maps the `base` of this route using `mapper`, returning a new `Route`
/// with the returned base.
///

View File

@ -0,0 +1,670 @@
use std::collections::HashMap;
use crate::request::Request;
use crate::http::{Method, Status};
pub use crate::router::{Route, RouteUri};
pub use crate::router::collider::Collide;
pub use crate::catcher::Catcher;
#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Method, Vec<Route>>,
catchers: HashMap<Option<u16>, Vec<Catcher>>,
}
#[derive(Debug)]
pub struct Collisions {
pub routes: Vec<(Route, Route)>,
pub catchers: Vec<(Catcher, Catcher)>,
}
impl Router {
pub fn new() -> Self {
Self::default()
}
pub fn add_route(&mut self, route: Route) {
let routes = self.routes.entry(route.method).or_default();
routes.push(route);
routes.sort_by_key(|r| r.rank);
}
pub fn add_catcher(&mut self, catcher: Catcher) {
let catchers = self.catchers.entry(catcher.code).or_default();
catchers.push(catcher);
catchers.sort_by(|a, b| b.base.path_segments().len().cmp(&a.base.path_segments().len()))
}
#[inline]
pub fn routes(&self) -> impl Iterator<Item = &Route> + Clone {
self.routes.values().flat_map(|v| v.iter())
}
#[inline]
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + Clone {
self.catchers.values().flat_map(|v| v.iter())
}
pub fn route<'r, 'a: 'r>(
&'a self,
req: &'r Request<'r>
) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by ascending rank on each `add`.
self.routes.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
}
// For many catchers, using aho-corasick or similar should be much faster.
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> {
// Note that catchers are presorted by descending base length.
let explicit = self.catchers.get(&Some(status.code))
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
let default = self.catchers.get(&None)
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
match (explicit, default) {
(None, None) => None,
(None, c@Some(_)) | (c@Some(_), None) => c,
(Some(a), Some(b)) => {
if b.base.path_segments().len() > a.base.path_segments().len() {
Some(b)
} else {
Some(a)
}
}
}
}
fn collisions<'a, I: 'a, T: 'a>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a
where I: Iterator<Item = &'a T> + Clone, T: Collide + Clone,
{
items.clone().enumerate()
.flat_map(move |(i, a)| {
items.clone()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
}
pub fn finalize(&self) -> Result<(), Collisions> {
let routes: Vec<_> = self.collisions(self.routes()).collect();
let catchers: Vec<_> = self.collisions(self.catchers()).collect();
if !routes.is_empty() || !catchers.is_empty() {
return Err(Collisions { routes, catchers })
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::rocket::Rocket;
use crate::config::Config;
use crate::http::{Method, Method::*};
use crate::http::uri::Origin;
use crate::request::Request;
use crate::handler::dummy;
impl Router {
fn has_collisions(&self) -> bool {
self.finalize().is_err()
}
}
fn router_with_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::new(Get, route, dummy);
router.add_route(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy);
router.add_route(route);
}
router
}
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::ranked(0, Get, route, dummy);
router.add_route(route);
}
router
}
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
router.has_collisions()
}
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
router.has_collisions()
}
#[test]
fn test_rankless_collisions() {
assert!(rankless_route_collisions(&["/hello", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>"]));
assert!(rankless_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/a/b/<c>/d", "/<a>/<b>/c/d"]));
assert!(rankless_route_collisions(&["/a/b", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a>/b", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<b>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<c>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/<a..>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/b/c/d", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/", "/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/<_>", "/a/<_..>"]));
assert!(rankless_route_collisions(&["/foo", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/foo/bar/baz", "/foo/<_..>"]));
assert!(rankless_route_collisions(&["/a/d/<b..>", "/a/d"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_..>", "/a"]));
assert!(rankless_route_collisions(&["/<a>", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/b"]));
assert!(rankless_route_collisions(&["/a/<_>", "/a/<b>"]));
assert!(rankless_route_collisions(&["/<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["/<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["/<_>/b", "/a/b"]));
assert!(rankless_route_collisions(&["/", "/<foo..>"]));
}
#[test]
fn test_collisions_normalize() {
assert!(rankless_route_collisions(&["/hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello"]));
assert!(rankless_route_collisions(&["//hello/", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello//"]));
assert!(rankless_route_collisions(&["/<a>", "/hello///"]));
assert!(rankless_route_collisions(&["/hello///bob", "/hello/<b>"]));
assert!(rankless_route_collisions(&["/<a..>//", "/a//<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/"]));
assert!(rankless_route_collisions(&["/a/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["/<a..>/", "/a/bd/e/"]));
assert!(rankless_route_collisions(&["//", "/<foo..>"]));
assert!(rankless_route_collisions(&["/a/<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["/a//<a..>//", "/a/b//c//d/e/"]));
assert!(rankless_route_collisions(&["///<_>", "/<_>"]));
assert!(rankless_route_collisions(&["/a/<_>", "///a//b"]));
assert!(rankless_route_collisions(&["//a///<_>", "/a//<b>"]));
assert!(rankless_route_collisions(&["//<_..>", "/a/b"]));
assert!(rankless_route_collisions(&["//<_..>", "/<_>"]));
assert!(rankless_route_collisions(&["///<a>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["///<a..>/", "/a/<a..>"]));
assert!(rankless_route_collisions(&["/<a..>", "/hello"]));
}
#[test]
fn test_collisions_query() {
// Query shouldn't affect things when rankless.
assert!(rankless_route_collisions(&["/hello?<foo>", "/hello"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>?foo=bar", "/hello?foo=bar&cat=fat"]));
assert!(rankless_route_collisions(&["/<a>", "/<b>?<foo>"]));
assert!(rankless_route_collisions(&["/hello/bob?a=b", "/hello/<b>?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b", "/foo?d=e"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e&<c>"]));
assert!(rankless_route_collisions(&["/<foo>?a=b&<c>", "/<foo>?d=e"]));
}
#[test]
fn test_no_collisions() {
assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c/<d>/e"]));
assert!(!rankless_route_collisions(&["/a/d/<b..>", "/a/b/c"]));
assert!(!rankless_route_collisions(&["/<_>", "/"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/a"]));
assert!(!rankless_route_collisions(&["/a/<_>", "/<_>"]));
}
#[test]
fn test_no_collision_when_ranked() {
assert!(!default_rank_route_collisions(&["/<a>", "/hello"]));
assert!(!default_rank_route_collisions(&["/hello/bob", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/a/b/c/d", "/<a>/<b>/c/d"]));
assert!(!default_rank_route_collisions(&["/hi", "/<hi>"]));
assert!(!default_rank_route_collisions(&["/a", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/", "/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/b", "/a/b/<c..>"]));
assert!(!default_rank_route_collisions(&["/<_>", "/static"]));
assert!(!default_rank_route_collisions(&["/<_..>", "/static"]));
assert!(!default_rank_route_collisions(&["/<path..>", "/"]));
assert!(!default_rank_route_collisions(&["/<_>/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/foo/<_>", "/foo/bar"]));
assert!(!default_rank_route_collisions(&["/<a>/<b>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a>/<b..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello/<b>"]));
assert!(!default_rank_route_collisions(&["/<a..>", "/hello"]));
assert!(!default_rank_route_collisions(&["/<a>", "/a/<path..>"]));
assert!(!default_rank_route_collisions(&["/a/<b>/c", "/<d>/<c..>"]));
}
#[test]
fn test_collision_when_ranked() {
assert!(default_rank_route_collisions(&["/a/<b>/<c..>", "/a/<c>"]));
assert!(default_rank_route_collisions(&["/<a>/b", "/a/<b>"]));
}
#[test]
fn test_collision_when_ranked_query() {
assert!(default_rank_route_collisions(&["/a?a=b", "/a?c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b>", "/a?<c>&c=d"]));
assert!(default_rank_route_collisions(&["/a?a=b&<b..>", "/a?<c>&c=d"]));
}
#[test]
fn test_no_collision_when_ranked_query() {
assert!(!default_rank_route_collisions(&["/", "/?<c..>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?<c>"]));
assert!(!default_rank_route_collisions(&["/hi", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/hi?<c>", "/hi?c"]));
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
}
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
let route = router.route(&request).next();
route
}
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, method, Origin::parse(uri).unwrap());
router.route(&request).collect()
}
#[test]
fn test_ok_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Get, "/hello").is_some());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Get, "/hello").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/bobbbbbbbbbby").is_some());
assert!(route(&router, Get, "/dsfhjasdf").is_some());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/jdlk/asdij").is_some());
let mut router = Router::new();
router.add_route(Route::new(Put, "/hello", dummy));
router.add_route(Route::new(Post, "/hello", dummy));
router.add_route(Route::new(Delete, "/hello", dummy));
assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some());
let router = router_with_routes(&["/<a..>"]);
assert!(route(&router, Get, "/").is_some());
assert!(route(&router, Get, "//").is_some());
assert!(route(&router, Get, "/hi").is_some());
assert!(route(&router, Get, "/hello/hi").is_some());
assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/a/b/c/d/e/f").is_some());
let router = router_with_routes(&["/foo/<a..>"]);
assert!(route(&router, Get, "/foo").is_some());
assert!(route(&router, Get, "/foo/").is_some());
assert!(route(&router, Get, "/foo///bar").is_some());
}
#[test]
fn test_err_routing() {
let router = router_with_routes(&["/hello"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hell").is_none());
assert!(route(&router, Get, "/hi").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
assert!(route(&router, Get, "/hillo").is_none());
let router = router_with_routes(&["/<a>"]);
assert!(route(&router, Put, "/hello").is_none());
assert!(route(&router, Post, "/hello").is_none());
assert!(route(&router, Options, "/hello").is_none());
assert!(route(&router, Get, "/hello/there").is_none());
assert!(route(&router, Get, "/hello/i").is_none());
let router = router_with_routes(&["/<a>/<b>"]);
assert!(route(&router, Get, "/a/b/c").is_none());
assert!(route(&router, Get, "/a").is_none());
assert!(route(&router, Get, "/a/").is_none());
assert!(route(&router, Get, "/a/b/c/d").is_none());
assert!(route(&router, Put, "/hello/hi").is_none());
assert!(route(&router, Put, "/a/b").is_none());
assert!(route(&router, Put, "/a/b").is_none());
let router = router_with_routes(&["/prefix/<a..>"]);
assert!(route(&router, Get, "/").is_none());
assert!(route(&router, Get, "/prefi/").is_none());
}
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string());
})
}
#[test]
fn test_default_ranking() {
assert_ranked_match!(&["/hello", "/<name>"], "/hello" => "/hello");
assert_ranked_match!(&["/<name>", "/hello"], "/hello" => "/hello");
assert_ranked_match!(&["/<a>", "/hi", "/hi/<b>"], "/hi" => "/hi");
assert_ranked_match!(&["/<a>/b", "/hi/c"], "/hi/c" => "/hi/c");
assert_ranked_match!(&["/<a>/<b>", "/hi/a"], "/hi/c" => "/<a>/<b>");
assert_ranked_match!(&["/hi/a", "/hi/<c>"], "/hi/c" => "/hi/<c>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/b?v=1" => "/<a>?<b>");
assert_ranked_match!(&["/a", "/<a>", "/a?<b>", "/<a>?<b>"], "/a?b=c" => "/a?<b>");
assert_ranked_match!(&["/a", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/<a>", "/a?b"], "/a?b" => "/a?b");
assert_ranked_match!(&["/a", "/<a>?b"], "/a?b" => "/a");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a" => "/a?<b>");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?b" => "/a?<c>&b");
assert_ranked_match!(&["/a?<c>&b", "/a?<b>"], "/a?c" => "/a?<b>");
assert_ranked_match!(&["/", "/<foo..>"], "/" => "/");
assert_ranked_match!(&["/", "/<foo..>"], "/hi" => "/<foo..>");
assert_ranked_match!(&["/hi", "/<foo..>"], "/hi" => "/hi");
}
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
router.has_collisions()
}
#[test]
fn test_no_manual_ranked_collisions() {
assert!(!ranked_collisions(&[(1, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(0, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(5, "/a/<b>"), (2, "/a/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b>"), (1, "/b/<b>")]));
assert!(!ranked_collisions(&[(1, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(0, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(5, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(!ranked_collisions(&[(1, "/<a..>"), (2, "/<a..>")]));
}
#[test]
fn test_ranked_collisions() {
assert!(ranked_collisions(&[(2, "/a/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/a/c/<b..>"), (2, "/a/<b..>")]));
assert!(ranked_collisions(&[(2, "/<b..>"), (2, "/a/<b..>")]));
}
macro_rules! assert_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string());
}
})
}
#[test]
fn test_ranked_routing() {
assert_ranked_routing!(
to: "/a/b",
with: [(1, "/a/<b>"), (2, "/a/<b>")],
expect: (1, "/a/<b>"), (2, "/a/<b>")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(2, "/b/<b>"), (1, "/a/<b>"), (3, "/b/b")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(3, "/b/b"), (2, "/b/<b>"), (1, "/a/<b>")],
expect: (2, "/b/<b>"), (3, "/b/b")
);
assert_ranked_routing!(
to: "/b/b",
with: [(1, "/a/<b>"), (2, "/b/<b>"), (0, "/b/b")],
expect: (0, "/b/b"), (2, "/b/<b>")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(1, "/<a>/<b>/edit"), (2, "/profile/<d>"), (0, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/<c>"), (1, "/<a>/<b>/edit")
);
assert_ranked_routing!(
to: "/profile/sergio/edit",
with: [(0, "/<a>/<b>/edit"), (2, "/profile/<d>"), (5, "/<a>/<b>/<c>")],
expect: (0, "/<a>/<b>/edit"), (5, "/<a>/<b>/<c>")
);
assert_ranked_routing!(
to: "/a/b",
with: [(0, "/a/b"), (1, "/a/<b..>")],
expect: (0, "/a/b"), (1, "/a/<b..>")
);
assert_ranked_routing!(
to: "/a/b/c/d/e/f",
with: [(1, "/a/<b..>"), (2, "/a/b/<c..>")],
expect: (1, "/a/<b..>"), (2, "/a/b/<c..>")
);
assert_ranked_routing!(
to: "/hi",
with: [(1, "/hi/<foo..>"), (0, "/hi/<foo>")],
expect: (1, "/hi/<foo..>")
);
}
macro_rules! assert_default_ranked_routing {
(to: $to:expr, with: $routes:expr, expect: $($want:expr),+) => ({
let router = router_with_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.uri.to_string(), expected.to_string());
}
})
}
#[test]
fn test_default_ranked_routing() {
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b"],
expect: "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?v=1",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b",
with: ["/a/<b>", "/a/b", "/a/b?<v>", "/a/<b>?<v>"],
expect: "/a/b?<v>", "/a/b", "/a/<b>?<v>", "/a/<b>"
);
assert_default_ranked_routing!(
to: "/a/b?c",
with: ["/a/b", "/a/b?<c>", "/a/b?c", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"],
expect: "/a/b?c", "/a/b?<c>", "/a/b", "/a/<b>?c", "/a/<b>?<c>", "/<a>/<b>"
);
}
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Router {
let mut router = Router::new();
for (code, base) in catchers {
let catcher = Catcher::new(*code, crate::catcher::dummy);
router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap());
}
router
}
fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> {
let rocket = Rocket::custom(Config::default());
let request = Request::new(&rocket, Method::Get, Origin::parse(uri).unwrap());
router.catch(status, &request)
}
macro_rules! assert_catcher_routing {
(
catch: [$(($code:expr, $uri:expr)),+],
reqs: [$($r:expr),+],
with: [$(($ecode:expr, $euri:expr)),+]
) => ({
let catchers = vec![$(($code.into(), $uri)),+];
let requests = vec![$($r),+];
let expected = vec![$(($ecode.into(), $euri)),+];
let router = router_with_catchers(&catchers);
for (req, expected) in requests.iter().zip(expected.iter()) {
let req_status = Status::from_code(req.0).expect("valid status");
let catcher = catcher(&router, req_status, req.1).expect("some catcher");
assert_eq!(catcher.code, expected.0, "<- got, expected ->");
assert_eq!(catcher.base.path(), expected.1, "<- got, expected ->");
}
})
}
#[test]
fn test_catcher_routing() {
// Check that the default `/` catcher catches everything.
assert_catcher_routing! {
catch: [(None, "/")],
reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")],
with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")]
}
// Check prefixes when they're exact.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a"), (None, "/a/b")],
reqs: [
(404, "/"), (500, "/"),
(404, "/a"), (500, "/a"),
(404, "/a/b"), (500, "/a/b")
],
with: [
(None, "/"), (None, "/"),
(None, "/a"), (None, "/a"),
(None, "/a/b"), (None, "/a/b")
]
}
// Check prefixes when they're not exact.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a"), (None, "/a/b")],
reqs: [
(404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"),
(404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"),
(404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes")
],
with: [
(None, "/"), (None, "/"), (None, "/"), (None, "/"),
(None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"),
(None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b")
]
}
// Check that we prefer specific to default.
assert_catcher_routing! {
catch: [(400, "/"), (404, "/"), (None, "/")],
reqs: [
(400, "/"), (400, "/bar"), (400, "/foo/bar"),
(404, "/"), (404, "/bar"), (404, "/foo/bar"),
(405, "/"), (405, "/bar"), (406, "/foo/bar")
],
with: [
(400, "/"), (400, "/"), (400, "/"),
(404, "/"), (404, "/"), (404, "/"),
(None, "/"), (None, "/"), (None, "/")
]
}
// Check that we prefer longer prefixes over specific.
assert_catcher_routing! {
catch: [(None, "/a/b"), (404, "/a"), (422, "/a")],
reqs: [
(404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"),
(404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"),
(422, "/a"), (422, "/a/c"), (422, "/a/cat/bar")
],
with: [
(None, "/a/b"), (None, "/a/b"), (None, "/a/b"),
(404, "/a"), (404, "/a"), (404, "/a"),
(422, "/a"), (422, "/a"), (422, "/a")
]
}
// Just a fun one.
assert_catcher_routing! {
catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")],
reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")],
with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")]
}
}
}

View File

@ -329,11 +329,7 @@ impl Rocket {
// response. We may wish to relax this in the future.
req.cookies().reset_delta();
// Try to get the active catcher
let catcher = self.catchers.get(&status.code)
.or_else(|| self.default_catcher.as_ref());
if let Some(catcher) = catcher {
if let Some(catcher) = self.router.catch(status, req) {
warn_!("Responding with registered {} catcher.", catcher);
let name = catcher.name.as_deref();
handle(name, || catcher.handler.handle(status, req)).await

View File

@ -24,7 +24,7 @@ mod tests {
fn error_catcher_sets_cookies() {
let rocket = rocket::ignite()
.mount("/", routes![index])
.register(catchers![not_found])
.register("/", catchers![not_found])
.attach(AdHoc::on_request("Add Cookie", |req, _| Box::pin(async move {
req.cookies().add(Cookie::new("fairing", "woo"));
})));

View File

@ -22,30 +22,20 @@ fn ise() -> &'static str {
"Hey, sorry! :("
}
#[catch(500)]
fn double_panic() {
panic!("so, so sorry...")
}
fn pre_future_route<'r>(_: &'r Request<'_>, _: Data) -> HandlerFuture<'r> {
panic!("hey now...");
}
fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> {
panic!("a panicking pre-future catcher")
}
fn rocket() -> Rocket {
let pre_future_panic = Route::new(Method::Get, "/pre", pre_future_route);
rocket::ignite()
.mount("/", routes![panic_route])
.mount("/", vec![pre_future_panic])
.register(catchers![panic_catcher, ise])
.mount("/", vec![Route::new(Method::Get, "/pre", pre_future_route)])
}
#[test]
fn catches_route_panic() {
let client = Client::debug(rocket()).unwrap();
let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/panic").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -53,7 +43,8 @@ fn catches_route_panic() {
#[test]
fn catches_catcher_panic() {
let client = Client::debug(rocket()).unwrap();
let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/noroute").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -61,7 +52,12 @@ fn catches_catcher_panic() {
#[test]
fn catches_double_panic() {
let rocket = rocket().register(catchers![double_panic]);
#[catch(500)]
fn double_panic() {
panic!("so, so sorry...")
}
let rocket = rocket().register("/", catchers![panic_catcher, double_panic]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/noroute").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
@ -70,7 +66,8 @@ fn catches_double_panic() {
#[test]
fn catches_early_route_panic() {
let client = Client::debug(rocket()).unwrap();
let rocket = rocket().register("/", catchers![panic_catcher, ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/pre").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");
@ -78,9 +75,15 @@ fn catches_early_route_panic() {
#[test]
fn catches_early_catcher_panic() {
let panic_catcher = Catcher::new(404, pre_future_catcher);
fn pre_future_catcher<'r>(_: Status, _: &'r Request) -> ErrorHandlerFuture<'r> {
panic!("a panicking pre-future catcher")
}
let client = Client::debug(rocket().register(vec![panic_catcher])).unwrap();
let rocket = rocket()
.register("/", vec![Catcher::new(404, pre_future_catcher)])
.register("/", catchers![ise]);
let client = Client::debug(rocket).unwrap();
let response = client.get("/idontexist").dispatch();
assert_eq!(response.status(), Status::InternalServerError);
assert_eq!(response.into_string().unwrap(), "Hey, sorry! :(");

View File

@ -14,7 +14,7 @@ mod tests {
#[test]
fn error_catcher_redirect() {
let client = Client::debug(rocket::ignite().register(catchers![not_found])).unwrap();
let client = Client::debug(rocket::ignite().register("/", catchers![not_found])).unwrap();
let response = client.get("/unknown").dispatch();
let location: Vec<_> = response.headers().get("location").collect();

View File

@ -63,5 +63,5 @@ fn not_found(request: &Request<'_>) -> Html<String> {
fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/hello", routes![get_hello, post_hello])
.register(catchers![not_found])
.register("/", catchers![not_found])
}

View File

@ -7,7 +7,7 @@ use rocket::response::{content, status};
use rocket::http::Status;
#[get("/hello/<name>/<age>")]
fn hello(name: String, age: i8) -> String {
fn hello(name: &str, age: i8) -> String {
format!("Hello, {} year old named {}!", age, name)
}
@ -17,10 +17,24 @@ fn forced_error(code: u16) -> Status {
}
#[catch(404)]
fn not_found(req: &Request<'_>) -> content::Html<String> {
content::Html(format!("<p>Sorry, but '{}' is not a valid path!</p>
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
req.uri()))
fn general_not_found() -> content::Html<&'static str> {
content::Html(r#"
<p>Hmm... What are you looking for?</p>
Say <a href="/hello/Sergio/100">hello!</a>
"#)
}
#[catch(404)]
fn hello_not_found(req: &Request<'_>) -> content::Html<String> {
content::Html(format!("\
<p>Sorry, but '{}' is not a valid path!</p>\
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
req.uri()))
}
#[catch(default)]
fn sergio_error() -> &'static str {
"I...don't know what to say."
}
#[catch(default)]
@ -33,7 +47,9 @@ fn rocket() -> rocket::Rocket {
rocket::ignite()
// .mount("/", routes![hello, hello]) // uncoment this to get an error
.mount("/", routes![hello, forced_error])
.register(catchers![not_found, default_catcher])
.register("/", catchers![general_not_found, default_catcher])
.register("/hello", catchers![hello_not_found])
.register("/hello/Sergio", catchers![sergio_error])
}
#[rocket::main]

View File

@ -14,11 +14,11 @@ fn test_hello() {
}
#[test]
fn forced_error_and_default_catcher() {
fn forced_error() {
let client = Client::tracked(super::rocket()).unwrap();
let request = client.get("/404");
let expected = super::not_found(request.inner());
let expected = super::general_not_found();
let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected.0);
@ -46,11 +46,24 @@ fn forced_error_and_default_catcher() {
fn test_hello_invalid_age() {
let client = Client::tracked(super::rocket()).unwrap();
for &(name, age) in &[("Ford", -129), ("Trillian", 128)] {
let request = client.get(format!("/hello/{}/{}", name, age));
let expected = super::not_found(request.inner());
for path in &["Ford/-129", "Trillian/128", "foo/bar/baz"] {
let request = client.get(format!("/hello/{}", path));
let expected = super::hello_not_found(request.inner());
let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected.0);
}
}
#[test]
fn test_hello_sergio() {
let client = Client::tracked(super::rocket()).unwrap();
for path in &["oops", "-129", "foo/bar", "/foo/bar/baz"] {
let request = client.get(format!("/hello/Sergio/{}", path));
let expected = super::sergio_error();
let response = request.dispatch();
assert_eq!(response.status(), Status::NotFound);
assert_eq!(response.into_string().unwrap(), expected);
}
}

View File

@ -69,7 +69,7 @@ fn wow_helper(
fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", routes![index, hello, about])
.register(catchers![not_found])
.register("/", catchers![not_found])
.attach(Template::custom(|engines| {
engines.handlebars.register_helper("wow", Box::new(wow_helper));
}))

View File

@ -10,7 +10,7 @@ fn hello() -> &'static str {
fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", rocket::routes![hello])
.register(rocket::catchers![not_found])
.register("/", rocket::catchers![not_found])
}
#[rocket::catch(404)]

View File

@ -75,6 +75,6 @@ fn not_found() -> JsonValue {
fn rocket() -> _ {
rocket::ignite()
.mount("/message", routes![new, update, get, echo])
.register(catchers![not_found])
.register("/", catchers![not_found])
.manage(Mutex::new(HashMap::<Id, String>::new()))
}

View File

@ -110,5 +110,5 @@ fn rocket() -> rocket::Rocket {
.mount("/hello", vec![name.clone()])
.mount("/hi", vec![name])
.mount("/custom", CustomHandler::new("some data here"))
.register(vec![not_found_catcher])
.register("/", vec![not_found_catcher])
}

View File

@ -37,5 +37,5 @@ fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", routes![index, get])
.attach(Template::fairing())
.register(catchers![not_found])
.register("/", catchers![not_found])
}

View File

@ -1716,8 +1716,8 @@ Application processing is fallible. Errors arise from the following sources:
* A routing failure.
If any of these occur, Rocket returns an error to the client. To generate the
error, Rocket invokes the _catcher_ corresponding to the error's status code.
Catchers are similar to routes except in that:
error, Rocket invokes the _catcher_ corresponding to the error's status code and
scope. Catchers are similar to routes except in that:
1. Catchers are only invoked on error conditions.
2. Catchers are declared with the `catch` attribute.
@ -1725,6 +1725,7 @@ Catchers are similar to routes except in that:
4. Any modifications to cookies are cleared before a catcher is invoked.
5. Error catchers cannot invoke guards.
6. Error catchers should not fail to produce a response.
7. Catchers are scoped to a path prefix.
To declare a catcher for a given status code, use the [`catch`] attribute, which
takes a single integer corresponding to the HTTP status code to catch. For
@ -1770,36 +1771,96 @@ looks like:
# #[catch(404)] fn not_found(req: &Request) { /* .. */ }
fn main() {
rocket::ignite().register(catchers![not_found]);
rocket::ignite().register("/", catchers![not_found]);
}
```
### Default Catchers
### Scoping
If no catcher for a given status code has been registered, Rocket calls the
_default_ catcher. Rocket provides a default catcher for all applications
automatically, so providing one is usually unnecessary. Rocket's built-in
default catcher can handle all errors. It produces HTML or JSON, depending on
the value of the `Accept` header. As such, a default catcher, or catchers in
general, only need to be registered if an error needs to be handled in a custom
fashion.
The first argument to `register()` is a path to scope the catcher under called
the catcher's _base_. A catcher's base determines which requests it will handle
errors for. Specifically, a catcher's base must be a prefix of the erroring
request for it to be invoked. When multiple catchers can be invoked, the catcher
with the longest base takes precedence.
Declaring a default catcher is done with `#[catch(default)]`:
As an example, consider the following application:
```rust
# #[macro_use] extern crate rocket;
#[catch(404)]
fn general_not_found() -> &'static str {
"General 404"
}
#[catch(404)]
fn foo_not_found() -> &'static str {
"Foo 404"
}
#[launch]
fn rocket() -> _ {
rocket::ignite()
.register("/", catchers![general_not_found])
.register("/foo", catchers![foo_not_found])
}
# let client = rocket::local::blocking::Client::debug(rocket()).unwrap();
#
# let response = client.get("/").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/bar").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/bar/baz").dispatch();
# assert_eq!(response.into_string().unwrap(), "General 404");
#
# let response = client.get("/foo").dispatch();
# assert_eq!(response.into_string().unwrap(), "Foo 404");
#
# let response = client.get("/foo/bar").dispatch();
# assert_eq!(response.into_string().unwrap(), "Foo 404");
```
Since there are no mounted routes, all requests will `404`. Any request whose
path begins with `/foo` (i.e, `GET /foo`, `GET /foo/bar`, etc) will be handled
by the `foo_not_found` catcher while all other requests will be handled by the
`general_not_found` catcher.
### Default Catchers
A _default_ catcher is a catcher that handles _all_ status codes. They are
invoked as a fallback if no status-specific catcher is registered for a given
error. Declaring a default catcher is done with `#[catch(default)]` and must
similarly be registered with [`register()`]:
```rust
# #[macro_use] extern crate rocket;
# fn main() {}
use rocket::Request;
use rocket::http::Status;
#[catch(default)]
fn default_catcher(status: Status, request: &Request) { /* .. */ }
#[launch]
fn rocket() -> _ {
rocket::ignite().register("/", catchers![default_catcher])
}
```
It must similarly be registered with [`register()`].
Catchers with longer bases are preferred, even when there is a status-specific
catcher. In other words, a default catcher with a longer matching base than a
status-specific catcher takes precedence.
The [error catcher example](@example/errors) illustrates their use in full,
### Built-In Catcher
Rocket provides a built-in default catcher. It produces HTML or JSON, depending
on the value of the `Accept` header. As such, custom catchers only need to be
registered for custom error handling.
The [error catcher example](@example/errors) illustrates catcher use in full,
while the [`Catcher`] API documentation provides further details.
[`catch`]: @api/rocket/attr.catch.html