Add support for checking which route routed a request

- Adds RouteType and CatcherType traits to identify routes and catchers
- RouteType and CatcherType are implemented via codegen for attribute
macros
- Adds routed_by and caught_by methods to local client response
- Adds catcher to RequestState
- Updates route in RequestState to None if a catcher is run
- examples/hello tests now also check which route generated the reponse
- Adds DefaultCatcher type to represent Rocket's default catcher
- FileServer now implements RouteType
This commit is contained in:
Matthew Pomes 2022-05-09 11:38:49 -05:00 committed by Sergio Benitez
parent 07fe79796f
commit 9aa14d0e24
13 changed files with 169 additions and 14 deletions

View File

@ -62,6 +62,8 @@ pub fn _catch(
/// Rocket code generated proxy structure.
#deprecated #vis struct #user_catcher_fn_name { }
impl #CatcherType for #user_catcher_fn_name { }
/// Rocket code generated proxy static conversion implementations.
#[allow(nonstandard_style, deprecated, clippy::style)]
impl #user_catcher_fn_name {
@ -83,6 +85,7 @@ pub fn _catch(
name: stringify!(#user_catcher_fn_name),
code: #status_code,
handler: monomorphized_function,
route_type: #_Box::new(self),
}
}

View File

@ -342,6 +342,8 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
/// Rocket code generated proxy structure.
#deprecated #vis struct #handler_fn_name { }
impl #RouteType for #handler_fn_name {}
/// Rocket code generated proxy static conversion implementations.
#[allow(nonstandard_style, deprecated, clippy::style)]
impl #handler_fn_name {
@ -368,6 +370,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
format: #format,
rank: #rank,
sentinels: #sentinels,
route_type: #_Box::new(self),
}
}

View File

@ -98,7 +98,9 @@ define_exported_paths! {
StaticRouteInfo => ::rocket::StaticRouteInfo,
StaticCatcherInfo => ::rocket::StaticCatcherInfo,
Route => ::rocket::Route,
RouteType => ::rocket::route::RouteType,
Catcher => ::rocket::Catcher,
CatcherType => ::rocket::catcher::CatcherType,
SmallVec => ::rocket::http::private::SmallVec,
Status => ::rocket::http::Status,
}

View File

@ -1,3 +1,4 @@
use std::any::{TypeId, Any};
use std::fmt;
use std::io::Cursor;
@ -10,6 +11,14 @@ use crate::catcher::{Handler, BoxFuture};
use yansi::Paint;
// We could also choose to require a Debug impl?
/// A generic trait for route types. This should be automatically implemented on the structs
/// generated by the codegen for each route.
///
/// It may also be desirable to add an option for other routes to define a RouteType. This
/// would likely just be a case of adding an alternate constructor to the Route type.
pub trait CatcherType: Any + Send + Sync + 'static { }
/// An error catching route.
///
/// Catchers are routes that run when errors are produced by the application.
@ -127,6 +136,9 @@ pub struct Catcher {
///
/// This is -(number of nonempty segments in base).
pub(crate) rank: isize,
/// A unique route type to identify this route
pub(crate) catcher_type: Option<TypeId>,
}
// The rank is computed as -(number of nonempty segments in base) => catchers
@ -185,7 +197,8 @@ impl Catcher {
base: uri::Origin::ROOT,
handler: Box::new(handler),
rank: rank(uri::Origin::ROOT.path()),
code
code,
catcher_type: None,
}
}
@ -307,6 +320,10 @@ impl Catcher {
}
}
/// Catcher type of the default catcher created by Rocket
pub struct DefaultCatcher { _priv: () }
impl CatcherType for DefaultCatcher {}
impl Default for Catcher {
fn default() -> Self {
fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
@ -315,6 +332,7 @@ impl Default for Catcher {
let mut catcher = Catcher::new(None, handler);
catcher.name = Some("<Rocket Catcher>".into());
catcher.catcher_type = Some(TypeId::of::<DefaultCatcher>());
catcher
}
}
@ -328,6 +346,8 @@ pub struct StaticInfo {
pub code: Option<u16>,
/// The catcher's handler, i.e, the annotated function.
pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>,
/// A unique route type to identify this route
pub catcher_type: Box<dyn CatcherType>,
}
#[doc(hidden)]
@ -336,6 +356,7 @@ impl From<StaticInfo> for Catcher {
fn from(info: StaticInfo) -> Catcher {
let mut catcher = Catcher::new(info.code, info.handler);
catcher.name = Some(info.name.into());
catcher.catcher_type = Some(info.catcher_type.as_ref().type_id());
catcher
}
}

View File

@ -2,7 +2,7 @@ use std::path::{PathBuf, Path};
use crate::{Request, Data};
use crate::http::{Method, uri::Segments, ext::IntoOwned};
use crate::route::{Route, Handler, Outcome};
use crate::route::{Route, Handler, Outcome, RouteType};
use crate::response::Redirect;
use crate::fs::NamedFile;
@ -180,10 +180,13 @@ impl FileServer {
}
}
impl RouteType for FileServer {}
impl From<FileServer> for Vec<Route> {
fn from(server: FileServer) -> Self {
let source = figment::Source::File(server.root.clone());
let mut route = Route::ranked(server.rank, Method::Get, "/<path..>", server);
let mut route = Route::ranked(server.rank, Method::Get, "/<path..>", server)
.with_type::<FileServer>();
route.name = Some(format!("FileServer: {}", source).into());
vec![route]
}

View File

@ -98,6 +98,12 @@ impl<'c> LocalResponse<'c> {
}
}
impl<'r> LocalResponse<'r> {
pub(crate) fn _request(&self) -> &Request<'r> {
&self._request
}
}
impl LocalResponse<'_> {
pub(crate) fn _response(&self) -> &Response<'_> {
&self.response

View File

@ -1,7 +1,7 @@
use std::io;
use tokio::io::AsyncReadExt;
use crate::{Response, local::asynchronous, http::CookieJar};
use crate::{Response, local::asynchronous, http::CookieJar, Request};
use super::Client;
@ -54,6 +54,12 @@ pub struct LocalResponse<'c> {
pub(in super) client: &'c Client,
}
impl<'r> LocalResponse<'r> {
pub(crate) fn _request(&self) -> &Request<'r> {
&self.inner._request()
}
}
impl LocalResponse<'_> {
fn _response(&self) -> &Response<'_> {
&self.inner._response()

View File

@ -180,6 +180,50 @@ macro_rules! pub_response_impl {
self._into_msgpack() $(.$suffix)?
}
/// Checks if a route was routed by a specific route type
///
/// # Example
///
/// ```rust
/// # use rocket::get;
/// #[get("/")]
/// fn index() -> &'static str { "Hello World" }
#[doc = $doc_prelude]
/// # Client::_test(|_, _, response| {
/// let response: LocalResponse = response;
/// assert!(response.routed_by::<index>())
/// # });
/// ```
pub fn routed_by<T: crate::route::RouteType>(&self) -> bool {
if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() {
route_type == std::any::TypeId::of::<T>()
} else {
false
}
}
/// Checks if a route was caught by a specific route type
///
/// # Example
///
/// ```rust
/// # use rocket::get;
/// #[get("/")]
/// fn index() -> &'static str { "Hello World" }
#[doc = $doc_prelude]
/// # Client::_test(|_, _, response| {
/// let response: LocalResponse = response;
/// assert!(response.routed_by::<index>())
/// # });
/// ```
pub fn caught_by<T: crate::catcher::CatcherType>(&self) -> bool {
if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() {
catcher_type == std::any::TypeId::of::<T>()
} else {
false
}
}
#[cfg(test)]
#[allow(dead_code)]
fn _ensure_impls_exist() {

View File

@ -8,7 +8,7 @@ use state::{TypeMap, InitCell};
use futures::future::BoxFuture;
use atomic::{Atomic, Ordering};
use crate::{Rocket, Route, Orbit};
use crate::{Rocket, Route, Orbit, Catcher};
use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
use crate::form::{self, ValueField, FromForm};
use crate::data::Limits;
@ -45,6 +45,7 @@ pub(crate) struct ConnectionMeta {
pub(crate) struct RequestState<'r> {
pub rocket: &'r Rocket<Orbit>,
pub route: Atomic<Option<&'r Route>>,
pub catcher: Atomic<Option<&'r Catcher>>,
pub cookies: CookieJar<'r>,
pub accept: InitCell<Option<Accept>>,
pub content_type: InitCell<Option<ContentType>>,
@ -69,6 +70,7 @@ impl RequestState<'_> {
RequestState {
rocket: self.rocket,
route: Atomic::new(self.route.load(Ordering::Acquire)),
catcher: Atomic::new(self.catcher.load(Ordering::Acquire)),
cookies: self.cookies.clone(),
accept: self.accept.clone(),
content_type: self.content_type.clone(),
@ -97,6 +99,7 @@ impl<'r> Request<'r> {
state: RequestState {
rocket,
route: Atomic::new(None),
catcher: Atomic::new(None),
cookies: CookieJar::new(rocket.config()),
accept: InitCell::new(),
content_type: InitCell::new(),
@ -691,6 +694,22 @@ impl<'r> Request<'r> {
self.state.route.load(Ordering::Acquire)
}
/// Get the presently matched catcher, if any.
///
/// This method returns `Some` while a catcher is running.
///
/// # Example
///
/// ```rust
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let request = c.get("/");
/// let catcher = request.catcher();
/// ```
#[inline(always)]
pub fn catcher(&self) -> Option<&'r Catcher> {
self.state.catcher.load(Ordering::Acquire)
}
/// Invokes the request guard implementation for `T`, returning its outcome.
///
/// # Example
@ -969,8 +988,15 @@ impl<'r> Request<'r> {
/// Set `self`'s parameters given that the route used to reach this request
/// was `route`. Use during routing when attempting a given route.
#[inline(always)]
pub(crate) fn set_route(&self, route: &'r Route) {
self.state.route.store(Some(route), Ordering::Release)
pub(crate) fn set_route(&self, route: Option<&'r Route>) {
self.state.route.store(route, Ordering::Release)
}
/// Set `self`'s parameters given that the route used to reach this request
/// was `catcher`. Use during routing when attempting a given catcher.
#[inline(always)]
pub(crate) fn set_catcher(&self, catcher: Option<&'r Catcher>) {
self.state.catcher.store(catcher, Ordering::Release)
}
/// Set the method of `self`, even when `self` is a shared reference. Used

View File

@ -1,12 +1,22 @@
use std::fmt;
use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::convert::From;
use yansi::Paint;
use crate::http::{uri, Method, MediaType};
use crate::route::{Handler, RouteUri, BoxFuture};
use crate::http::{uri, MediaType, Method};
use crate::route::{BoxFuture, Handler, RouteUri};
use crate::sentinel::Sentry;
// We could also choose to require a Debug impl?
/// A generic trait for route types. This should be automatically implemented on the structs
/// generated by the codegen for each route.
///
/// It may also be desirable to add an option for other routes to define a RouteType. This
/// would likely just be a case of adding an alternate constructor to the Route type.
pub trait RouteType: Any + Send + Sync + 'static { }
/// A request handling route.
///
/// A route consists of exactly the information in its fields. While a `Route`
@ -16,7 +26,8 @@ use crate::sentinel::Sentry;
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// # use std::path::PathBuf;
/// #
/// use std::path::PathBuf;
/// #[get("/route/<path..>?query", rank = 2, format = "json")]
/// fn route_name(path: PathBuf) { /* handler procedure */ }
///
@ -178,6 +189,8 @@ pub struct Route {
pub format: Option<MediaType>,
/// The discovered sentinels.
pub(crate) sentinels: Vec<Sentry>,
/// A unique route type to identify this route
pub(crate) route_type: Option<TypeId>,
}
impl Route {
@ -243,7 +256,9 @@ impl Route {
/// ```
#[track_caller]
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route
where H: Handler + 'static, R: Into<Option<isize>>,
where
H: Handler + 'static,
R: Into<Option<isize>>,
{
let uri = RouteUri::new("/", uri);
let rank = rank.into().unwrap_or_else(|| uri.default_rank());
@ -252,7 +267,10 @@ impl Route {
format: None,
sentinels: Vec::new(),
handler: Box::new(handler),
rank, uri, method,
rank,
uri,
method,
route_type: None,
}
}
@ -297,6 +315,12 @@ impl Route {
self
}
/// Marks this route with the specified type
pub fn with_type<T: RouteType>(mut self) -> Self {
self.route_type = Some(TypeId::of::<T>());
self
}
/// Maps the `base` of this route using `mapper`, returning a new `Route`
/// with the returned base.
///
@ -335,7 +359,8 @@ impl Route {
/// assert_eq!(rebased.uri.path(), "/boo/foo/bar");
/// ```
pub fn map_base<'a, F>(mut self, mapper: F) -> Result<Self, uri::Error<'static>>
where F: FnOnce(uri::Origin<'a>) -> String
where
F: FnOnce(uri::Origin<'a>) -> String,
{
let base = mapper(self.uri.base);
self.uri = RouteUri::try_new(&base, &self.uri.unmounted_origin.to_string())?;
@ -394,6 +419,8 @@ pub struct StaticInfo {
/// Route-derived sentinels, if any.
/// This isn't `&'static [SentryInfo]` because `type_name()` isn't `const`.
pub sentinels: Vec<Sentry>,
/// A unique route type to identify this route
pub route_type: Box<dyn RouteType>,
}
#[doc(hidden)]
@ -410,6 +437,9 @@ impl From<StaticInfo> for Route {
format: info.format,
sentinels: info.sentinels.into_iter().collect(),
uri,
// Uses `.as_ref()` to get the type id if the internal type, rather than the type id of
// the box
route_type: Some(info.route_type.as_ref().type_id()),
}
}
}

View File

@ -10,6 +10,7 @@ use crate::router::Collide;
pub(crate) struct Router {
routes: HashMap<Method, Vec<Route>>,
catchers: HashMap<Option<u16>, Vec<Catcher>>,
pub default_catcher: Catcher,
}
#[derive(Debug)]

View File

@ -325,7 +325,7 @@ impl Rocket<Orbit> {
for route in self.router.route(request) {
// Retrieve and set the requests parameters.
info_!("Matched: {}", route);
request.set_route(route);
request.set_route(Some(route));
let name = route.name.as_deref();
let outcome = handle(name, || route.handler.handle(request, data)).await
@ -364,8 +364,10 @@ impl Rocket<Orbit> {
// from earlier, unsuccessful paths from being reflected in error
// response. We may wish to relax this in the future.
req.cookies().reset_delta();
req.set_route(None);
if let Some(catcher) = self.router.catch(status, req) {
req.set_catcher(Some(catcher));
warn_!("Responding with registered {} catcher.", catcher);
let name = catcher.name.as_deref();
handle(name, || catcher.handler.handle(status, req)).await
@ -374,6 +376,7 @@ impl Rocket<Orbit> {
} else {
let code = status.code.blue().bold();
warn_!("No {} catcher registered. Using Rocket default.", code);
req.set_catcher(Some(&self.router.default_catcher));
Ok(crate::catcher::default_handler(status, req))
}
}
@ -401,6 +404,7 @@ impl Rocket<Orbit> {
}
}
req.set_catcher(Some(&self.router.default_catcher));
// If it failed again or if it was already a 500, use Rocket's default.
error_!("{} catcher failed. Using Rocket default 500.", status.code);
crate::catcher::default_handler(Status::InternalServerError, req)

View File

@ -30,10 +30,12 @@ fn hello() {
let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name));
let response = client.get(uri).dispatch();
assert!(response.routed_by::<super::hello>(), "Response was not generated by the `hello` route");
assert_eq!(response.into_string().unwrap(), expected);
let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang));
let response = client.get(uri).dispatch();
assert!(response.routed_by::<super::hello>(), "Response was not generated by the `hello` route");
assert_eq!(response.into_string().unwrap(), expected);
}
}
@ -42,6 +44,7 @@ fn hello() {
fn hello_world() {
let client = Client::tracked(super::rocket()).unwrap();
let response = client.get("/hello/world").dispatch();
assert!(response.routed_by::<super::world>(), "Response was not generated by the `world` route");
assert_eq!(response.into_string(), Some("Hello, world!".into()));
}
@ -49,6 +52,7 @@ fn hello_world() {
fn hello_mir() {
let client = Client::tracked(super::rocket()).unwrap();
let response = client.get("/hello/%D0%BC%D0%B8%D1%80").dispatch();
assert!(response.routed_by::<super::mir>(), "Response was not generated by the `mir` route");
assert_eq!(response.into_string(), Some("Привет, мир!".into()));
}
@ -60,11 +64,13 @@ fn wave() {
let real_name = RawStr::new(name).percent_decode_lossy();
let expected = format!("👋 Hello, {} year old named {}!", age, real_name);
let response = client.get(uri).dispatch();
assert!(response.routed_by::<super::wave>(), "Response was not generated by the `wave` route");
assert_eq!(response.into_string().unwrap(), expected);
for bad_age in &["1000", "-1", "bird", "?"] {
let bad_uri = format!("/wave/{}/{}", name, bad_age);
let response = client.get(bad_uri).dispatch();
assert!(response.caught_by::<rocket::catcher::DefaultCatcher>(), "Response was not generated by the default catcher");
assert_eq!(response.status(), Status::NotFound);
}
}