From 9aa14d0e24dcd16c28c67031e7aec1882e4e6e26 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Mon, 9 May 2022 11:38:49 -0500 Subject: [PATCH] Add support for checking which route routed a request - Adds RouteType and CatcherType traits to identify routes and catchers - RouteType and CatcherType are implemented via codegen for attribute macros - Adds routed_by and caught_by methods to local client response - Adds catcher to RequestState - Updates route in RequestState to None if a catcher is run - examples/hello tests now also check which route generated the reponse - Adds DefaultCatcher type to represent Rocket's default catcher - FileServer now implements RouteType --- core/codegen/src/attribute/catch/mod.rs | 3 ++ core/codegen/src/attribute/route/mod.rs | 3 ++ core/codegen/src/exports.rs | 2 + core/lib/src/catcher/catcher.rs | 23 ++++++++++- core/lib/src/fs/server.rs | 7 +++- core/lib/src/local/asynchronous/response.rs | 6 +++ core/lib/src/local/blocking/response.rs | 8 +++- core/lib/src/local/response.rs | 44 +++++++++++++++++++++ core/lib/src/request/request.rs | 32 +++++++++++++-- core/lib/src/route/route.rs | 42 +++++++++++++++++--- core/lib/src/router/router.rs | 1 + core/lib/src/server.rs | 6 ++- examples/hello/src/tests.rs | 6 +++ 13 files changed, 169 insertions(+), 14 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 09528c71..0911936e 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -62,6 +62,8 @@ pub fn _catch( /// Rocket code generated proxy structure. #deprecated #vis struct #user_catcher_fn_name { } + impl #CatcherType for #user_catcher_fn_name { } + /// Rocket code generated proxy static conversion implementations. #[allow(nonstandard_style, deprecated, clippy::style)] impl #user_catcher_fn_name { @@ -83,6 +85,7 @@ pub fn _catch( name: stringify!(#user_catcher_fn_name), code: #status_code, handler: monomorphized_function, + route_type: #_Box::new(self), } } diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 1079f306..c4611a06 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -342,6 +342,8 @@ fn codegen_route(route: Route) -> Result { /// Rocket code generated proxy structure. #deprecated #vis struct #handler_fn_name { } + impl #RouteType for #handler_fn_name {} + /// Rocket code generated proxy static conversion implementations. #[allow(nonstandard_style, deprecated, clippy::style)] impl #handler_fn_name { @@ -368,6 +370,7 @@ fn codegen_route(route: Route) -> Result { format: #format, rank: #rank, sentinels: #sentinels, + route_type: #_Box::new(self), } } diff --git a/core/codegen/src/exports.rs b/core/codegen/src/exports.rs index d6c1f4d9..ee8d553e 100644 --- a/core/codegen/src/exports.rs +++ b/core/codegen/src/exports.rs @@ -98,7 +98,9 @@ define_exported_paths! { StaticRouteInfo => ::rocket::StaticRouteInfo, StaticCatcherInfo => ::rocket::StaticCatcherInfo, Route => ::rocket::Route, + RouteType => ::rocket::route::RouteType, Catcher => ::rocket::Catcher, + CatcherType => ::rocket::catcher::CatcherType, SmallVec => ::rocket::http::private::SmallVec, Status => ::rocket::http::Status, } diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 3ed8ff32..1407bd0f 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -1,3 +1,4 @@ +use std::any::{TypeId, Any}; use std::fmt; use std::io::Cursor; @@ -10,6 +11,14 @@ use crate::catcher::{Handler, BoxFuture}; use yansi::Paint; +// We could also choose to require a Debug impl? +/// A generic trait for route types. This should be automatically implemented on the structs +/// generated by the codegen for each route. +/// +/// It may also be desirable to add an option for other routes to define a RouteType. This +/// would likely just be a case of adding an alternate constructor to the Route type. +pub trait CatcherType: Any + Send + Sync + 'static { } + /// An error catching route. /// /// Catchers are routes that run when errors are produced by the application. @@ -127,6 +136,9 @@ pub struct Catcher { /// /// This is -(number of nonempty segments in base). pub(crate) rank: isize, + + /// A unique route type to identify this route + pub(crate) catcher_type: Option, } // The rank is computed as -(number of nonempty segments in base) => catchers @@ -185,7 +197,8 @@ impl Catcher { base: uri::Origin::ROOT, handler: Box::new(handler), rank: rank(uri::Origin::ROOT.path()), - code + code, + catcher_type: None, } } @@ -307,6 +320,10 @@ impl Catcher { } } +/// Catcher type of the default catcher created by Rocket +pub struct DefaultCatcher { _priv: () } +impl CatcherType for DefaultCatcher {} + impl Default for Catcher { fn default() -> Self { fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> { @@ -315,6 +332,7 @@ impl Default for Catcher { let mut catcher = Catcher::new(None, handler); catcher.name = Some("".into()); + catcher.catcher_type = Some(TypeId::of::()); catcher } } @@ -328,6 +346,8 @@ pub struct StaticInfo { pub code: Option, /// The catcher's handler, i.e, the annotated function. pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>, + /// A unique route type to identify this route + pub catcher_type: Box, } #[doc(hidden)] @@ -336,6 +356,7 @@ impl From for Catcher { fn from(info: StaticInfo) -> Catcher { let mut catcher = Catcher::new(info.code, info.handler); catcher.name = Some(info.name.into()); + catcher.catcher_type = Some(info.catcher_type.as_ref().type_id()); catcher } } diff --git a/core/lib/src/fs/server.rs b/core/lib/src/fs/server.rs index da78ec33..e1b66e04 100644 --- a/core/lib/src/fs/server.rs +++ b/core/lib/src/fs/server.rs @@ -2,7 +2,7 @@ use std::path::{PathBuf, Path}; use crate::{Request, Data}; use crate::http::{Method, uri::Segments, ext::IntoOwned}; -use crate::route::{Route, Handler, Outcome}; +use crate::route::{Route, Handler, Outcome, RouteType}; use crate::response::Redirect; use crate::fs::NamedFile; @@ -180,10 +180,13 @@ impl FileServer { } } +impl RouteType for FileServer {} + impl From for Vec { fn from(server: FileServer) -> Self { let source = figment::Source::File(server.root.clone()); - let mut route = Route::ranked(server.rank, Method::Get, "/", server); + let mut route = Route::ranked(server.rank, Method::Get, "/", server) + .with_type::(); route.name = Some(format!("FileServer: {}", source).into()); vec![route] } diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index cabbdccc..8ba00eb7 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -98,6 +98,12 @@ impl<'c> LocalResponse<'c> { } } +impl<'r> LocalResponse<'r> { + pub(crate) fn _request(&self) -> &Request<'r> { + &self._request + } +} + impl LocalResponse<'_> { pub(crate) fn _response(&self) -> &Response<'_> { &self.response diff --git a/core/lib/src/local/blocking/response.rs b/core/lib/src/local/blocking/response.rs index fc009398..86dbdfce 100644 --- a/core/lib/src/local/blocking/response.rs +++ b/core/lib/src/local/blocking/response.rs @@ -1,7 +1,7 @@ use std::io; use tokio::io::AsyncReadExt; -use crate::{Response, local::asynchronous, http::CookieJar}; +use crate::{Response, local::asynchronous, http::CookieJar, Request}; use super::Client; @@ -54,6 +54,12 @@ pub struct LocalResponse<'c> { pub(in super) client: &'c Client, } +impl<'r> LocalResponse<'r> { + pub(crate) fn _request(&self) -> &Request<'r> { + &self.inner._request() + } +} + impl LocalResponse<'_> { fn _response(&self) -> &Response<'_> { &self.inner._response() diff --git a/core/lib/src/local/response.rs b/core/lib/src/local/response.rs index 411be73f..38d972a9 100644 --- a/core/lib/src/local/response.rs +++ b/core/lib/src/local/response.rs @@ -180,6 +180,50 @@ macro_rules! pub_response_impl { self._into_msgpack() $(.$suffix)? } + /// Checks if a route was routed by a specific route type + /// + /// # Example + /// + /// ```rust + /// # use rocket::get; + /// #[get("/")] + /// fn index() -> &'static str { "Hello World" } + #[doc = $doc_prelude] + /// # Client::_test(|_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.routed_by::()) + /// # }); + /// ``` + pub fn routed_by(&self) -> bool { + if let Some(route_type) = self._request().route().map(|r| r.route_type).flatten() { + route_type == std::any::TypeId::of::() + } else { + false + } + } + + /// Checks if a route was caught by a specific route type + /// + /// # Example + /// + /// ```rust + /// # use rocket::get; + /// #[get("/")] + /// fn index() -> &'static str { "Hello World" } + #[doc = $doc_prelude] + /// # Client::_test(|_, _, response| { + /// let response: LocalResponse = response; + /// assert!(response.routed_by::()) + /// # }); + /// ``` + pub fn caught_by(&self) -> bool { + if let Some(catcher_type) = self._request().catcher().map(|r| r.catcher_type).flatten() { + catcher_type == std::any::TypeId::of::() + } else { + false + } + } + #[cfg(test)] #[allow(dead_code)] fn _ensure_impls_exist() { diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 7f7e50e7..324fd943 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -8,7 +8,7 @@ use state::{TypeMap, InitCell}; use futures::future::BoxFuture; use atomic::{Atomic, Ordering}; -use crate::{Rocket, Route, Orbit}; +use crate::{Rocket, Route, Orbit, Catcher}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::form::{self, ValueField, FromForm}; use crate::data::Limits; @@ -45,6 +45,7 @@ pub(crate) struct ConnectionMeta { pub(crate) struct RequestState<'r> { pub rocket: &'r Rocket, pub route: Atomic>, + pub catcher: Atomic>, pub cookies: CookieJar<'r>, pub accept: InitCell>, pub content_type: InitCell>, @@ -69,6 +70,7 @@ impl RequestState<'_> { RequestState { rocket: self.rocket, route: Atomic::new(self.route.load(Ordering::Acquire)), + catcher: Atomic::new(self.catcher.load(Ordering::Acquire)), cookies: self.cookies.clone(), accept: self.accept.clone(), content_type: self.content_type.clone(), @@ -97,6 +99,7 @@ impl<'r> Request<'r> { state: RequestState { rocket, route: Atomic::new(None), + catcher: Atomic::new(None), cookies: CookieJar::new(rocket.config()), accept: InitCell::new(), content_type: InitCell::new(), @@ -691,6 +694,22 @@ impl<'r> Request<'r> { self.state.route.load(Ordering::Acquire) } + /// Get the presently matched catcher, if any. + /// + /// This method returns `Some` while a catcher is running. + /// + /// # Example + /// + /// ```rust + /// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap(); + /// # let request = c.get("/"); + /// let catcher = request.catcher(); + /// ``` + #[inline(always)] + pub fn catcher(&self) -> Option<&'r Catcher> { + self.state.catcher.load(Ordering::Acquire) + } + /// Invokes the request guard implementation for `T`, returning its outcome. /// /// # Example @@ -969,8 +988,15 @@ impl<'r> Request<'r> { /// Set `self`'s parameters given that the route used to reach this request /// was `route`. Use during routing when attempting a given route. #[inline(always)] - pub(crate) fn set_route(&self, route: &'r Route) { - self.state.route.store(Some(route), Ordering::Release) + pub(crate) fn set_route(&self, route: Option<&'r Route>) { + self.state.route.store(route, Ordering::Release) + } + + /// Set `self`'s parameters given that the route used to reach this request + /// was `catcher`. Use during routing when attempting a given catcher. + #[inline(always)] + pub(crate) fn set_catcher(&self, catcher: Option<&'r Catcher>) { + self.state.catcher.store(catcher, Ordering::Release) } /// Set the method of `self`, even when `self` is a shared reference. Used diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index 24853d95..f51e4dc9 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -1,12 +1,22 @@ use std::fmt; +use std::any::{Any, TypeId}; use std::borrow::Cow; +use std::convert::From; use yansi::Paint; -use crate::http::{uri, Method, MediaType}; -use crate::route::{Handler, RouteUri, BoxFuture}; +use crate::http::{uri, MediaType, Method}; +use crate::route::{BoxFuture, Handler, RouteUri}; use crate::sentinel::Sentry; +// We could also choose to require a Debug impl? +/// A generic trait for route types. This should be automatically implemented on the structs +/// generated by the codegen for each route. +/// +/// It may also be desirable to add an option for other routes to define a RouteType. This +/// would likely just be a case of adding an alternate constructor to the Route type. +pub trait RouteType: Any + Send + Sync + 'static { } + /// A request handling route. /// /// A route consists of exactly the information in its fields. While a `Route` @@ -16,7 +26,8 @@ use crate::sentinel::Sentry; /// /// ```rust /// # #[macro_use] extern crate rocket; -/// # use std::path::PathBuf; +/// # +/// use std::path::PathBuf; /// #[get("/route/?query", rank = 2, format = "json")] /// fn route_name(path: PathBuf) { /* handler procedure */ } /// @@ -178,6 +189,8 @@ pub struct Route { pub format: Option, /// The discovered sentinels. pub(crate) sentinels: Vec, + /// A unique route type to identify this route + pub(crate) route_type: Option, } impl Route { @@ -243,7 +256,9 @@ impl Route { /// ``` #[track_caller] pub fn ranked(rank: R, method: Method, uri: &str, handler: H) -> Route - where H: Handler + 'static, R: Into>, + where + H: Handler + 'static, + R: Into>, { let uri = RouteUri::new("/", uri); let rank = rank.into().unwrap_or_else(|| uri.default_rank()); @@ -252,7 +267,10 @@ impl Route { format: None, sentinels: Vec::new(), handler: Box::new(handler), - rank, uri, method, + rank, + uri, + method, + route_type: None, } } @@ -297,6 +315,12 @@ impl Route { self } + /// Marks this route with the specified type + pub fn with_type(mut self) -> Self { + self.route_type = Some(TypeId::of::()); + self + } + /// Maps the `base` of this route using `mapper`, returning a new `Route` /// with the returned base. /// @@ -335,7 +359,8 @@ impl Route { /// assert_eq!(rebased.uri.path(), "/boo/foo/bar"); /// ``` pub fn map_base<'a, F>(mut self, mapper: F) -> Result> - where F: FnOnce(uri::Origin<'a>) -> String + where + F: FnOnce(uri::Origin<'a>) -> String, { let base = mapper(self.uri.base); self.uri = RouteUri::try_new(&base, &self.uri.unmounted_origin.to_string())?; @@ -394,6 +419,8 @@ pub struct StaticInfo { /// Route-derived sentinels, if any. /// This isn't `&'static [SentryInfo]` because `type_name()` isn't `const`. pub sentinels: Vec, + /// A unique route type to identify this route + pub route_type: Box, } #[doc(hidden)] @@ -410,6 +437,9 @@ impl From for Route { format: info.format, sentinels: info.sentinels.into_iter().collect(), uri, + // Uses `.as_ref()` to get the type id if the internal type, rather than the type id of + // the box + route_type: Some(info.route_type.as_ref().type_id()), } } } diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 5617f4fb..a0429307 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -10,6 +10,7 @@ use crate::router::Collide; pub(crate) struct Router { routes: HashMap>, catchers: HashMap, Vec>, + pub default_catcher: Catcher, } #[derive(Debug)] diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index e3836984..666685a1 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -325,7 +325,7 @@ impl Rocket { for route in self.router.route(request) { // Retrieve and set the requests parameters. info_!("Matched: {}", route); - request.set_route(route); + request.set_route(Some(route)); let name = route.name.as_deref(); let outcome = handle(name, || route.handler.handle(request, data)).await @@ -364,8 +364,10 @@ impl Rocket { // from earlier, unsuccessful paths from being reflected in error // response. We may wish to relax this in the future. req.cookies().reset_delta(); + req.set_route(None); if let Some(catcher) = self.router.catch(status, req) { + req.set_catcher(Some(catcher)); warn_!("Responding with registered {} catcher.", catcher); let name = catcher.name.as_deref(); handle(name, || catcher.handler.handle(status, req)).await @@ -374,6 +376,7 @@ impl Rocket { } else { let code = status.code.blue().bold(); warn_!("No {} catcher registered. Using Rocket default.", code); + req.set_catcher(Some(&self.router.default_catcher)); Ok(crate::catcher::default_handler(status, req)) } } @@ -401,6 +404,7 @@ impl Rocket { } } + req.set_catcher(Some(&self.router.default_catcher)); // If it failed again or if it was already a 500, use Rocket's default. error_!("{} catcher failed. Using Rocket default 500.", status.code); crate::catcher::default_handler(Status::InternalServerError, req) diff --git a/examples/hello/src/tests.rs b/examples/hello/src/tests.rs index fd5b628d..f089a1f9 100644 --- a/examples/hello/src/tests.rs +++ b/examples/hello/src/tests.rs @@ -30,10 +30,12 @@ fn hello() { let uri = format!("/?{}{}{}", q("lang", lang), q("emoji", emoji), q("name", name)); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `hello` route"); assert_eq!(response.into_string().unwrap(), expected); let uri = format!("/?{}{}{}", q("emoji", emoji), q("name", name), q("lang", lang)); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `hello` route"); assert_eq!(response.into_string().unwrap(), expected); } } @@ -42,6 +44,7 @@ fn hello() { fn hello_world() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/world").dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `world` route"); assert_eq!(response.into_string(), Some("Hello, world!".into())); } @@ -49,6 +52,7 @@ fn hello_world() { fn hello_mir() { let client = Client::tracked(super::rocket()).unwrap(); let response = client.get("/hello/%D0%BC%D0%B8%D1%80").dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `mir` route"); assert_eq!(response.into_string(), Some("Привет, мир!".into())); } @@ -60,11 +64,13 @@ fn wave() { let real_name = RawStr::new(name).percent_decode_lossy(); let expected = format!("👋 Hello, {} year old named {}!", age, real_name); let response = client.get(uri).dispatch(); + assert!(response.routed_by::(), "Response was not generated by the `wave` route"); assert_eq!(response.into_string().unwrap(), expected); for bad_age in &["1000", "-1", "bird", "?"] { let bad_uri = format!("/wave/{}/{}", name, bad_age); let response = client.get(bad_uri).dispatch(); + assert!(response.caught_by::(), "Response was not generated by the default catcher"); assert_eq!(response.status(), Status::NotFound); } }