mirror of https://github.com/rwf2/Rocket.git
Add support for checking which route routed a request
- Adds RouteType and CatcherType traits to identify routes and catchers - RouteType and CatcherType are implemented via codegen for attribute macros - Adds routed_by and caught_by methods to local client response - Adds catcher to RequestState - Updates route in RequestState to None if a catcher is run - examples/hello tests now also check which route generated the reponse - Adds DefaultCatcher type to represent Rocket's default catcher - FileServer now implements RouteType
This commit is contained in:
parent
07fe79796f
commit
9aa14d0e24
|
@ -62,6 +62,8 @@ pub fn _catch(
|
|||
/// Rocket code generated proxy structure.
|
||||
#deprecated #vis struct #user_catcher_fn_name { }
|
||||
|
||||
impl #CatcherType for #user_catcher_fn_name { }
|
||||
|
||||
/// Rocket code generated proxy static conversion implementations.
|
||||
#[allow(nonstandard_style, deprecated, clippy::style)]
|
||||
impl #user_catcher_fn_name {
|
||||
|
@ -83,6 +85,7 @@ pub fn _catch(
|
|||
name: stringify!(#user_catcher_fn_name),
|
||||
code: #status_code,
|
||||
handler: monomorphized_function,
|
||||
route_type: #_Box::new(self),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -342,6 +342,8 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
|
|||
/// Rocket code generated proxy structure.
|
||||
#deprecated #vis struct #handler_fn_name { }
|
||||
|
||||
impl #RouteType for #handler_fn_name {}
|
||||
|
||||
/// Rocket code generated proxy static conversion implementations.
|
||||
#[allow(nonstandard_style, deprecated, clippy::style)]
|
||||
impl #handler_fn_name {
|
||||
|
@ -368,6 +370,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
|
|||
format: #format,
|
||||
rank: #rank,
|
||||
sentinels: #sentinels,
|
||||
route_type: #_Box::new(self),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -98,7 +98,9 @@ define_exported_paths! {
|
|||
StaticRouteInfo => ::rocket::StaticRouteInfo,
|
||||
StaticCatcherInfo => ::rocket::StaticCatcherInfo,
|
||||
Route => ::rocket::Route,
|
||||
RouteType => ::rocket::route::RouteType,
|
||||
Catcher => ::rocket::Catcher,
|
||||
CatcherType => ::rocket::catcher::CatcherType,
|
||||
SmallVec => ::rocket::http::private::SmallVec,
|
||||
Status => ::rocket::http::Status,
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::any::{TypeId, Any};
|
||||
use std::fmt;
|
||||
use std::io::Cursor;
|
||||
|
||||
|
@ -10,6 +11,14 @@ use crate::catcher::{Handler, BoxFuture};
|
|||
|
||||
use yansi::Paint;
|
||||
|
||||
// We could also choose to require a Debug impl?
|
||||
/// A generic trait for route types. This should be automatically implemented on the structs
|
||||
/// generated by the codegen for each route.
|
||||
///
|
||||
/// It may also be desirable to add an option for other routes to define a RouteType. This
|
||||
/// would likely just be a case of adding an alternate constructor to the Route type.
|
||||
pub trait CatcherType: Any + Send + Sync + 'static { }
|
||||
|
||||
/// An error catching route.
|
||||
///
|
||||
/// Catchers are routes that run when errors are produced by the application.
|
||||
|
@ -127,6 +136,9 @@ pub struct Catcher {
|
|||
///
|
||||
/// This is -(number of nonempty segments in base).
|
||||
pub(crate) rank: isize,
|
||||
|
||||
/// A unique route type to identify this route
|
||||
pub(crate) catcher_type: Option<TypeId>,
|
||||
}
|
||||
|
||||
// The rank is computed as -(number of nonempty segments in base) => catchers
|
||||
|
@ -185,7 +197,8 @@ impl Catcher {
|
|||
base: uri::Origin::ROOT,
|
||||
handler: Box::new(handler),
|
||||
rank: rank(uri::Origin::ROOT.path()),
|
||||
code
|
||||
code,
|
||||
catcher_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -307,6 +320,10 @@ impl Catcher {
|
|||
}
|
||||
}
|
||||
|
||||
/// Catcher type of the default catcher created by Rocket
|
||||
pub struct DefaultCatcher { _priv: () }
|
||||
impl CatcherType for DefaultCatcher {}
|
||||
|
||||
impl Default for Catcher {
|
||||
fn default() -> Self {
|
||||
fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
|
||||
|
@ -315,6 +332,7 @@ impl Default for Catcher {
|
|||
|
||||
let mut catcher = Catcher::new(None, handler);
|
||||
catcher.name = Some("<Rocket Catcher>".into());
|
||||
catcher.catcher_type = Some(TypeId::of::<DefaultCatcher>());
|
||||
catcher
|
||||
}
|
||||
}
|
||||
|
@ -328,6 +346,8 @@ pub struct StaticInfo {
|
|||
pub code: Option<u16>,
|
||||
/// The catcher's handler, i.e, the annotated function.
|
||||
pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>,
|
||||
/// A unique route type to identify this route
|
||||
pub catcher_type: Box<dyn CatcherType>,
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -336,6 +356,7 @@ impl From<StaticInfo> for Catcher {
|
|||
fn from(info: StaticInfo) -> Catcher {
|
||||
let mut catcher = Catcher::new(info.code, info.handler);
|
||||
catcher.name = Some(info.name.into());
|
||||
catcher.catcher_type = Some(info.catcher_type.as_ref().type_id());
|
||||
catcher
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::path::{PathBuf, Path};
|
|||
|
||||
use crate::{Request, Data};
|
||||
use crate::http::{Method, uri::Segments, ext::IntoOwned};
|
||||
use crate::route::{Route, Handler, Outcome};
|
||||
use crate::route::{Route, Handler, Outcome, RouteType};
|
||||
use crate::response::Redirect;
|
||||
use crate::fs::NamedFile;
|
||||
|
||||
|
@ -180,10 +180,13 @@ impl FileServer {
|
|||
}
|
||||
}
|
||||
|
||||
impl RouteType for FileServer {}
|
||||
|
||||
impl From<FileServer> for Vec<Route> {
|
||||
fn from(server: FileServer) -> Self {
|
||||
let source = figment::Source::File(server.root.clone());
|
||||
let mut route = Route::ranked(server.rank, Method::Get, "/<path..>", server);
|
||||
let mut route = Route::ranked(server.rank, Method::Get, "/<path..>", server)
|
||||
.with_type::<FileServer>();
|
||||
route.name = Some(format!("FileServer: {}", source).into());
|
||||
vec![route]
|
||||
}
|
||||
|
|
|
@ -98,6 +98,12 @@ impl<'c> LocalResponse<'c> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'r> LocalResponse<'r> {
|
||||
pub(crate) fn _request(&self) -> &Request<'r> {
|
||||
&self._request
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalResponse<'_> {
|
||||
pub(crate) fn _response(&self) -> &Response<'_> {
|
||||
&self.response
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::io;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::{Response, local::asynchronous, http::CookieJar};
|
||||
use crate::{Response, local::asynchronous, http::CookieJar, Request};
|
||||
|
||||
use super::Client;
|
||||
|
||||
|
@ -54,6 +54,12 @@ pub struct LocalResponse<'c> {
|
|||
pub(in super) client: &'c Client,
|
||||
}
|
||||
|
||||
impl<'r> LocalResponse<'r> {
|
||||
pub(crate) fn _request(&self) -> &Request<'r> {
|
||||
&self.inner._request()
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalResponse<'_> {
|
||||
fn _response(&self) -> &Response<'_> {
|
||||
&self.inner._response()
|
||||
|
|
|
@ -180,6 +180,50 @@ macro_rules! pub_response_impl {
|
|||
self._into_msgpack() $(.$suffix)?
|
||||
}
|
||||
|
||||
/// Checks if a route was routed by a specific route type
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::get;
|
||||
/// #[get("/")]
|
||||
/// fn index() -> &'static str { "Hello World" }
|
||||
#[doc = $doc_prelude]
|
||||
/// # Client::_test(|_, _, response| {
|
||||
/// let response: LocalResponse = response;
|
||||
/// assert!(response.routed_by::<index>())
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn routed_by<T: crate::route::RouteType>(&self) -> bool {
|
||||
if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() {
|
||||
route_type == std::any::TypeId::of::<T>()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a route was caught by a specific route type
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::get;
|
||||
/// #[get("/")]
|
||||
/// fn index() -> &'static str { "Hello World" }
|
||||
#[doc = $doc_prelude]
|
||||
/// # Client::_test(|_, _, response| {
|
||||
/// let response: LocalResponse = response;
|
||||
/// assert!(response.routed_by::<index>())
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn caught_by<T: crate::catcher::CatcherType>(&self) -> bool {
|
||||
if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() {
|
||||
catcher_type == std::any::TypeId::of::<T>()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
fn _ensure_impls_exist() {
|
||||
|
|
|
@ -8,7 +8,7 @@ use state::{TypeMap, InitCell};
|
|||
use futures::future::BoxFuture;
|
||||
use atomic::{Atomic, Ordering};
|
||||
|
||||
use crate::{Rocket, Route, Orbit};
|
||||
use crate::{Rocket, Route, Orbit, Catcher};
|
||||
use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
|
||||
use crate::form::{self, ValueField, FromForm};
|
||||
use crate::data::Limits;
|
||||
|
@ -45,6 +45,7 @@ pub(crate) struct ConnectionMeta {
|
|||
pub(crate) struct RequestState<'r> {
|
||||
pub rocket: &'r Rocket<Orbit>,
|
||||
pub route: Atomic<Option<&'r Route>>,
|
||||
pub catcher: Atomic<Option<&'r Catcher>>,
|
||||
pub cookies: CookieJar<'r>,
|
||||
pub accept: InitCell<Option<Accept>>,
|
||||
pub content_type: InitCell<Option<ContentType>>,
|
||||
|
@ -69,6 +70,7 @@ impl RequestState<'_> {
|
|||
RequestState {
|
||||
rocket: self.rocket,
|
||||
route: Atomic::new(self.route.load(Ordering::Acquire)),
|
||||
catcher: Atomic::new(self.catcher.load(Ordering::Acquire)),
|
||||
cookies: self.cookies.clone(),
|
||||
accept: self.accept.clone(),
|
||||
content_type: self.content_type.clone(),
|
||||
|
@ -97,6 +99,7 @@ impl<'r> Request<'r> {
|
|||
state: RequestState {
|
||||
rocket,
|
||||
route: Atomic::new(None),
|
||||
catcher: Atomic::new(None),
|
||||
cookies: CookieJar::new(rocket.config()),
|
||||
accept: InitCell::new(),
|
||||
content_type: InitCell::new(),
|
||||
|
@ -691,6 +694,22 @@ impl<'r> Request<'r> {
|
|||
self.state.route.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Get the presently matched catcher, if any.
|
||||
///
|
||||
/// This method returns `Some` while a catcher is running.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
|
||||
/// # let request = c.get("/");
|
||||
/// let catcher = request.catcher();
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn catcher(&self) -> Option<&'r Catcher> {
|
||||
self.state.catcher.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Invokes the request guard implementation for `T`, returning its outcome.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -969,8 +988,15 @@ impl<'r> Request<'r> {
|
|||
/// Set `self`'s parameters given that the route used to reach this request
|
||||
/// was `route`. Use during routing when attempting a given route.
|
||||
#[inline(always)]
|
||||
pub(crate) fn set_route(&self, route: &'r Route) {
|
||||
self.state.route.store(Some(route), Ordering::Release)
|
||||
pub(crate) fn set_route(&self, route: Option<&'r Route>) {
|
||||
self.state.route.store(route, Ordering::Release)
|
||||
}
|
||||
|
||||
/// Set `self`'s parameters given that the route used to reach this request
|
||||
/// was `catcher`. Use during routing when attempting a given catcher.
|
||||
#[inline(always)]
|
||||
pub(crate) fn set_catcher(&self, catcher: Option<&'r Catcher>) {
|
||||
self.state.catcher.store(catcher, Ordering::Release)
|
||||
}
|
||||
|
||||
/// Set the method of `self`, even when `self` is a shared reference. Used
|
||||
|
|
|
@ -1,12 +1,22 @@
|
|||
use std::fmt;
|
||||
use std::any::{Any, TypeId};
|
||||
use std::borrow::Cow;
|
||||
use std::convert::From;
|
||||
|
||||
use yansi::Paint;
|
||||
|
||||
use crate::http::{uri, Method, MediaType};
|
||||
use crate::route::{Handler, RouteUri, BoxFuture};
|
||||
use crate::http::{uri, MediaType, Method};
|
||||
use crate::route::{BoxFuture, Handler, RouteUri};
|
||||
use crate::sentinel::Sentry;
|
||||
|
||||
// We could also choose to require a Debug impl?
|
||||
/// A generic trait for route types. This should be automatically implemented on the structs
|
||||
/// generated by the codegen for each route.
|
||||
///
|
||||
/// It may also be desirable to add an option for other routes to define a RouteType. This
|
||||
/// would likely just be a case of adding an alternate constructor to the Route type.
|
||||
pub trait RouteType: Any + Send + Sync + 'static { }
|
||||
|
||||
/// A request handling route.
|
||||
///
|
||||
/// A route consists of exactly the information in its fields. While a `Route`
|
||||
|
@ -16,7 +26,8 @@ use crate::sentinel::Sentry;
|
|||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// # use std::path::PathBuf;
|
||||
/// #
|
||||
/// use std::path::PathBuf;
|
||||
/// #[get("/route/<path..>?query", rank = 2, format = "json")]
|
||||
/// fn route_name(path: PathBuf) { /* handler procedure */ }
|
||||
///
|
||||
|
@ -178,6 +189,8 @@ pub struct Route {
|
|||
pub format: Option<MediaType>,
|
||||
/// The discovered sentinels.
|
||||
pub(crate) sentinels: Vec<Sentry>,
|
||||
/// A unique route type to identify this route
|
||||
pub(crate) route_type: Option<TypeId>,
|
||||
}
|
||||
|
||||
impl Route {
|
||||
|
@ -243,7 +256,9 @@ impl Route {
|
|||
/// ```
|
||||
#[track_caller]
|
||||
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route
|
||||
where H: Handler + 'static, R: Into<Option<isize>>,
|
||||
where
|
||||
H: Handler + 'static,
|
||||
R: Into<Option<isize>>,
|
||||
{
|
||||
let uri = RouteUri::new("/", uri);
|
||||
let rank = rank.into().unwrap_or_else(|| uri.default_rank());
|
||||
|
@ -252,7 +267,10 @@ impl Route {
|
|||
format: None,
|
||||
sentinels: Vec::new(),
|
||||
handler: Box::new(handler),
|
||||
rank, uri, method,
|
||||
rank,
|
||||
uri,
|
||||
method,
|
||||
route_type: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,6 +315,12 @@ impl Route {
|
|||
self
|
||||
}
|
||||
|
||||
/// Marks this route with the specified type
|
||||
pub fn with_type<T: RouteType>(mut self) -> Self {
|
||||
self.route_type = Some(TypeId::of::<T>());
|
||||
self
|
||||
}
|
||||
|
||||
/// Maps the `base` of this route using `mapper`, returning a new `Route`
|
||||
/// with the returned base.
|
||||
///
|
||||
|
@ -335,7 +359,8 @@ impl Route {
|
|||
/// assert_eq!(rebased.uri.path(), "/boo/foo/bar");
|
||||
/// ```
|
||||
pub fn map_base<'a, F>(mut self, mapper: F) -> Result<Self, uri::Error<'static>>
|
||||
where F: FnOnce(uri::Origin<'a>) -> String
|
||||
where
|
||||
F: FnOnce(uri::Origin<'a>) -> String,
|
||||
{
|
||||
let base = mapper(self.uri.base);
|
||||
self.uri = RouteUri::try_new(&base, &self.uri.unmounted_origin.to_string())?;
|
||||
|
@ -394,6 +419,8 @@ pub struct StaticInfo {
|
|||
/// Route-derived sentinels, if any.
|
||||
/// This isn't `&'static [SentryInfo]` because `type_name()` isn't `const`.
|
||||
pub sentinels: Vec<Sentry>,
|
||||
/// A unique route type to identify this route
|
||||
pub route_type: Box<dyn RouteType>,
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -410,6 +437,9 @@ impl From<StaticInfo> for Route {
|
|||
format: info.format,
|
||||
sentinels: info.sentinels.into_iter().collect(),
|
||||
uri,
|
||||
// Uses `.as_ref()` to get the type id if the internal type, rather than the type id of
|
||||
// the box
|
||||
route_type: Some(info.route_type.as_ref().type_id()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ use crate::router::Collide;
|
|||
pub(crate) struct Router {
|
||||
routes: HashMap<Method, Vec<Route>>,
|
||||
catchers: HashMap<Option<u16>, Vec<Catcher>>,
|
||||
pub default_catcher: Catcher,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -325,7 +325,7 @@ impl Rocket<Orbit> {
|
|||
for route in self.router.route(request) {
|
||||
// Retrieve and set the requests parameters.
|
||||
info_!("Matched: {}", route);
|
||||
request.set_route(route);
|
||||
request.set_route(Some(route));
|
||||
|
||||
let name = route.name.as_deref();
|
||||
let outcome = handle(name, || route.handler.handle(request, data)).await
|
||||
|
@ -364,8 +364,10 @@ impl Rocket<Orbit> {
|
|||
// from earlier, unsuccessful paths from being reflected in error
|
||||
// response. We may wish to relax this in the future.
|
||||
req.cookies().reset_delta();
|
||||
req.set_route(None);
|
||||
|
||||
if let Some(catcher) = self.router.catch(status, req) {
|
||||
req.set_catcher(Some(catcher));
|
||||
warn_!("Responding with registered {} catcher.", catcher);
|
||||
let name = catcher.name.as_deref();
|
||||
handle(name, || catcher.handler.handle(status, req)).await
|
||||
|
@ -374,6 +376,7 @@ impl Rocket<Orbit> {
|
|||
} else {
|
||||
let code = status.code.blue().bold();
|
||||
warn_!("No {} catcher registered. Using Rocket default.", code);
|
||||
req.set_catcher(Some(&self.router.default_catcher));
|
||||
Ok(crate::catcher::default_handler(status, req))
|
||||
}
|
||||
}
|
||||
|
@ -401,6 +404,7 @@ impl Rocket<Orbit> {
|
|||
}
|
||||
}
|
||||
|
||||
req.set_catcher(Some(&self.router.default_catcher));
|
||||
// If it failed again or if it was already a 500, use Rocket's default.
|
||||
error_!("{} catcher failed. Using Rocket default 500.", status.code);
|
||||
crate::catcher::default_handler(Status::InternalServerError, req)
|
||||
|
|
|
@ -30,10 +30,12 @@ fn hello() {
|
|||
|
||||
let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name));
|
||||
let response = client.get(uri).dispatch();
|
||||
assert!(response.routed_by::<super::hello>(), "Response was not generated by the `hello` route");
|
||||
assert_eq!(response.into_string().unwrap(), expected);
|
||||
|
||||
let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang));
|
||||
let response = client.get(uri).dispatch();
|
||||
assert!(response.routed_by::<super::hello>(), "Response was not generated by the `hello` route");
|
||||
assert_eq!(response.into_string().unwrap(), expected);
|
||||
}
|
||||
}
|
||||
|
@ -42,6 +44,7 @@ fn hello() {
|
|||
fn hello_world() {
|
||||
let client = Client::tracked(super::rocket()).unwrap();
|
||||
let response = client.get("/hello/world").dispatch();
|
||||
assert!(response.routed_by::<super::world>(), "Response was not generated by the `world` route");
|
||||
assert_eq!(response.into_string(), Some("Hello, world!".into()));
|
||||
}
|
||||
|
||||
|
@ -49,6 +52,7 @@ fn hello_world() {
|
|||
fn hello_mir() {
|
||||
let client = Client::tracked(super::rocket()).unwrap();
|
||||
let response = client.get("/hello/%D0%BC%D0%B8%D1%80").dispatch();
|
||||
assert!(response.routed_by::<super::mir>(), "Response was not generated by the `mir` route");
|
||||
assert_eq!(response.into_string(), Some("Привет, мир!".into()));
|
||||
}
|
||||
|
||||
|
@ -60,11 +64,13 @@ fn wave() {
|
|||
let real_name = RawStr::new(name).percent_decode_lossy();
|
||||
let expected = format!("👋 Hello, {} year old named {}!", age, real_name);
|
||||
let response = client.get(uri).dispatch();
|
||||
assert!(response.routed_by::<super::wave>(), "Response was not generated by the `wave` route");
|
||||
assert_eq!(response.into_string().unwrap(), expected);
|
||||
|
||||
for bad_age in &["1000", "-1", "bird", "?"] {
|
||||
let bad_uri = format!("/wave/{}/{}", name, bad_age);
|
||||
let response = client.get(bad_uri).dispatch();
|
||||
assert!(response.caught_by::<rocket::catcher::DefaultCatcher>(), "Response was not generated by the default catcher");
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue