diff --git a/benchmarks/src/routing.rs b/benchmarks/src/routing.rs index 37b9273b..d84437fa 100644 --- a/benchmarks/src/routing.rs +++ b/benchmarks/src/routing.rs @@ -58,9 +58,9 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec TokenStream { // Generate a unique macro name based on the route's metadata. let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX); let inner_macro_name = macro_name.uniqueify_with(|mut hasher| { - route.attr.method.0.hash(&mut hasher); + route.attr.method.as_ref().map(|m| m.0.hash(&mut hasher)); route.attr.uri.path().hash(&mut hasher); route.attr.uri.query().hash(&mut hasher); route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher)); @@ -395,7 +395,7 @@ fn codegen_route(route: Route) -> Result { let internal_uri_macro = internal_uri_macro_decl(&route); let responder_outcome = responder_outcome_expr(&route); - let method = &route.attr.method; + let method = Optional(route.attr.method.clone()); let uri = route.attr.uri.to_string(); let rank = Optional(route.attr.rank); let format = Optional(route.attr.format.as_ref()); @@ -480,9 +480,12 @@ fn incomplete_route( let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?; let attribute = Attribute { - method: SpanWrapped { - full_span: method_span, key_span: None, span: method_span, value: Method(method) - }, + method: Some(SpanWrapped { + full_span: method_span, + key_span: None, + span: method_span, + value: Method(method), + }), uri: method_attribute.uri, data: method_attribute.data, format: method_attribute.format, diff --git a/core/codegen/src/attribute/route/parse.rs b/core/codegen/src/attribute/route/parse.rs index 13f3b93d..782f3420 100644 --- a/core/codegen/src/attribute/route/parse.rs +++ b/core/codegen/src/attribute/route/parse.rs @@ -43,8 +43,8 @@ pub struct Arguments { #[derive(Debug, FromMeta)] pub struct Attribute { #[meta(naked)] - pub method: SpanWrapped, pub uri: RouteUri, + pub method: Option>, pub data: Option>, pub format: Option, pub rank: Option, @@ -129,17 +129,23 @@ impl Route { // Emit a warning if a `data` param was supplied for non-payload methods. if let Some(ref data) = attr.data { let lint = Lint::DubiousPayload; - match attr.method.0.allows_request_body() { - None if lint.enabled(handler.span()) => { - data.full_span.warning("`data` used with non-payload-supporting method") - .note(format!("'{}' does not typically support payloads", attr.method.0)) - .note(lint.how_to_suppress()) - .emit_as_item_tokens(); - } - Some(false) => { + match attr.method.as_ref() { + Some(m) if m.0.allows_request_body() == Some(false) => { diags.push(data.full_span .error("`data` cannot be used on this route") - .span_note(attr.method.span, "method does not support request payloads")) + .span_note(m.span, "method does not support request payloads")) + }, + Some(m) if m.0.allows_request_body().is_none() && lint.enabled(handler.span()) => { + data.full_span.warning("`data` used with non-payload-supporting method") + .span_note(m.span, format!("'{}' does not typically support payloads", m.0)) + .note(lint.how_to_suppress()) + .emit_as_item_tokens(); + }, + None if lint.enabled(handler.span()) => { + data.full_span.warning("`data` used on route with wildcard method") + .note("some methods may not support request payloads") + .note(lint.how_to_suppress()) + .emit_as_item_tokens(); } _ => { /* okay */ }, } diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 07d4fc80..3b41af8b 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -119,15 +119,20 @@ macro_rules! route_attribute { /// * [`patch`] - `PATCH` specific route /// /// Additionally, [`route`] allows the method and uri to be explicitly - /// specified: + /// specified, and for the method to be omitted entirely, to match any + /// method: /// /// ```rust /// # #[macro_use] extern crate rocket; - /// # - /// #[route(GET, uri = "/")] - /// fn index() -> &'static str { - /// "Hello, world!" - /// } + /// + /// #[route("/", method = GET)] + /// fn get_index() { /* ... */ } + /// + /// #[route("/", method = "VERSION-CONTROL")] + /// fn versioned_index() { /* ... */ } + /// + /// #[route("/")] + /// fn index() { /* ... */ } /// ``` /// /// [`get`]: attr.get.html @@ -171,7 +176,9 @@ macro_rules! route_attribute { /// The generic route attribute is defined as: /// /// ```text - /// generic-route := METHOD ',' 'uri' '=' route + /// generic-route := route (',' method)? + /// + /// method := 'method' '=' METHOD /// ``` /// /// # Typing Requirements @@ -1161,12 +1168,12 @@ pub fn derive_uri_display_path(input: TokenStream) -> TokenStream { /// assert_eq!(my_routes.len(), 2); /// /// let index_route = &my_routes[0]; -/// assert_eq!(index_route.method, Method::Get); +/// assert_eq!(index_route.method, Some(Method::Get)); /// assert_eq!(index_route.name.as_ref().unwrap(), "index"); /// assert_eq!(index_route.uri.path(), "/"); /// /// let hello_route = &my_routes[1]; -/// assert_eq!(hello_route.method, Method::Post); +/// assert_eq!(hello_route.method, Some(Method::Post)); /// assert_eq!(hello_route.name.as_ref().unwrap(), "hello"); /// assert_eq!(hello_route.uri.path(), "/hi/"); /// ``` diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index de65f2fa..f9a1bf66 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -54,8 +54,8 @@ fn post1( } #[route( - POST, - uri = "///name/?sky=blue&&", + "///name/?sky=blue&&", + method = POST, format = "json", data = "", rank = 138 diff --git a/core/http/src/method.rs b/core/http/src/method.rs index 2b273181..bbfaa9f4 100644 --- a/core/http/src/method.rs +++ b/core/http/src/method.rs @@ -124,6 +124,10 @@ macro_rules! define_methods { #[doc(hidden)] pub const ALL: &'static [&'static str] = &[$($name),*]; + /// A slice containing every defined method variant. + #[doc(hidden)] + pub const ALL_VARIANTS: &'static [Method] = &[$(Self::$V),*]; + /// Whether the method is considered "safe". /// /// From [RFC9110 ยง9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1): diff --git a/core/lib/fuzz/targets/collision-matching.rs b/core/lib/fuzz/targets/collision-matching.rs index 1fb035ef..44ffd8af 100644 --- a/core/lib/fuzz/targets/collision-matching.rs +++ b/core/lib/fuzz/targets/collision-matching.rs @@ -16,7 +16,7 @@ struct ArbitraryRequestData<'a> { #[derive(Arbitrary)] struct ArbitraryRouteData<'a> { - method: ArbitraryMethod, + method: Option, uri: ArbitraryRouteUri<'a>, format: Option, } @@ -24,7 +24,7 @@ struct ArbitraryRouteData<'a> { impl std::fmt::Debug for ArbitraryRouteData<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ArbitraryRouteData") - .field("method", &self.method.0) + .field("method", &self.method.map(|v| v.0)) .field("base", &self.uri.0.base()) .field("unmounted", &self.uri.0.unmounted().to_string()) .field("uri", &self.uri.0.to_string()) @@ -59,12 +59,14 @@ impl<'c, 'a: 'c> ArbitraryRequestData<'a> { impl<'a> ArbitraryRouteData<'a> { fn into_route(self) -> Route { - let mut r = Route::ranked(0, self.method.0, &self.uri.0.to_string(), dummy_handler); + let method = self.method.map(|m| m.0); + let mut r = Route::ranked(0, method, &self.uri.0.to_string(), dummy_handler); r.format = self.format.map(|f| f.0); r } } +#[derive(Clone, Copy)] struct ArbitraryMethod(Method); struct ArbitraryOrigin<'a>(Origin<'a>); @@ -79,12 +81,7 @@ struct ArbitraryRouteUri<'a>(RouteUri<'a>); impl<'a> Arbitrary<'a> for ArbitraryMethod { fn arbitrary(u: &mut Unstructured<'a>) -> Result { - let all_methods = &[ - Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options, - Method::Head, Method::Trace, Method::Connect, Method::Patch - ]; - - Ok(ArbitraryMethod(*u.choose(all_methods)?)) + Ok(ArbitraryMethod(*u.choose(Method::ALL_VARIANTS)?)) } fn size_hint(_: usize) -> (usize, Option) { diff --git a/core/lib/src/phase.rs b/core/lib/src/phase.rs index 19986382..8e05411b 100644 --- a/core/lib/src/phase.rs +++ b/core/lib/src/phase.rs @@ -4,7 +4,7 @@ use figment::Figment; use crate::listener::Endpoint; use crate::shutdown::Stages; use crate::{Catcher, Config, Rocket, Route}; -use crate::router::Router; +use crate::router::{Router, Finalized}; use crate::fairing::Fairings; mod private { @@ -100,7 +100,7 @@ phases! { /// represents a fully built and finalized application server ready for /// launch into orbit. See [`Rocket#ignite`] for full details. Ignite (#[derive(Debug)] Igniting) { - pub(crate) router: Router, + pub(crate) router: Router, pub(crate) fairings: Fairings, pub(crate) figment: Figment, pub(crate) config: Config, @@ -114,7 +114,7 @@ phases! { /// An instance of `Rocket` in this phase is typed as [`Rocket`] and /// represents a running application. Orbit (#[derive(Debug)] Orbiting) { - pub(crate) router: Router, + pub(crate) router: Router, pub(crate) fairings: Fairings, pub(crate) figment: Figment, pub(crate) config: Config, diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index aabd76dc..5762ed40 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -557,9 +557,10 @@ impl Rocket { // Initialize the router; check for collisions. let mut router = Router::new(); - self.routes.clone().into_iter().for_each(|r| router.add_route(r)); - self.catchers.clone().into_iter().for_each(|c| router.add_catcher(c)); - router.finalize().map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?; + self.routes.clone().into_iter().for_each(|r| router.routes.push(r)); + self.catchers.clone().into_iter().for_each(|c| router.catchers.push(c)); + let router = router.finalize() + .map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?; // Finally, freeze managed state for faster access later. self.state.freeze(); @@ -840,8 +841,8 @@ impl Rocket

{ pub fn routes(&self) -> impl Iterator { match self.0.as_ref() { StateRef::Build(p) => Either::Left(p.routes.iter()), - StateRef::Ignite(p) => Either::Right(p.router.routes()), - StateRef::Orbit(p) => Either::Right(p.router.routes()), + StateRef::Ignite(p) => Either::Right(p.router.routes.iter()), + StateRef::Orbit(p) => Either::Right(p.router.routes.iter()), } } @@ -871,8 +872,8 @@ impl Rocket

{ pub fn catchers(&self) -> impl Iterator { match self.0.as_ref() { StateRef::Build(p) => Either::Left(p.catchers.iter()), - StateRef::Ignite(p) => Either::Right(p.router.catchers()), - StateRef::Orbit(p) => Either::Right(p.router.catchers()), + StateRef::Ignite(p) => Either::Right(p.router.catchers.iter()), + StateRef::Orbit(p) => Either::Right(p.router.catchers.iter()), } } diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index 2305ea2c..8c28395b 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -22,7 +22,7 @@ use crate::sentinel::Sentry; /// /// let route = routes![route_name].remove(0); /// assert_eq!(route.name.unwrap(), "route_name"); -/// assert_eq!(route.method, Method::Get); +/// assert_eq!(route.method, Some(Method::Get)); /// assert_eq!(route.uri, "/route/?query"); /// assert_eq!(route.rank, 2); /// assert_eq!(route.format.unwrap(), MediaType::JSON); @@ -164,8 +164,8 @@ use crate::sentinel::Sentry; pub struct Route { /// The name of this route, if one was given. pub name: Option>, - /// The method this route matches against. - pub method: Method, + /// The method this route matches, or `None` to match any method. + pub method: Option, /// The function that should be called when the route matches. pub handler: Box, /// The route URI. @@ -203,12 +203,12 @@ impl Route { /// // this is a route matching requests to `GET /` /// let index = Route::new(Method::Get, "/", handler); /// assert_eq!(index.rank, -9); - /// assert_eq!(index.method, Method::Get); + /// assert_eq!(index.method, Some(Method::Get)); /// assert_eq!(index.uri, "/"); /// ``` #[track_caller] - pub fn new(method: Method, uri: &str, handler: H) -> Route { - Route::ranked(None, method, uri, handler) + pub fn new>, H: Handler>(method: M, uri: &str, handler: H) -> Route { + Route::ranked(None, method.into(), uri, handler) } /// Creates a new route with the given rank, method, path, and handler with @@ -233,17 +233,19 @@ impl Route { /// /// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler); /// assert_eq!(foo.rank, 1); - /// assert_eq!(foo.method, Method::Post); + /// assert_eq!(foo.method, Some(Method::Post)); /// assert_eq!(foo.uri, "/foo?bar"); /// /// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler); /// assert_eq!(foo.rank, -12); - /// assert_eq!(foo.method, Method::Post); + /// assert_eq!(foo.method, Some(Method::Post)); /// assert_eq!(foo.uri, "/foo?bar"); /// ``` #[track_caller] - pub fn ranked(rank: R, method: Method, uri: &str, handler: H) -> Route - where H: Handler + 'static, R: Into>, + pub fn ranked(rank: R, method: M, uri: &str, handler: H) -> Route + where M: Into>, + H: Handler + 'static, + R: Into>, { let uri = RouteUri::new("/", uri); let rank = rank.into().unwrap_or_else(|| uri.default_rank()); @@ -253,7 +255,9 @@ impl Route { sentinels: Vec::new(), handler: Box::new(handler), location: None, - rank, uri, method, + method: method.into(), + rank, + uri, } } @@ -362,7 +366,7 @@ pub struct StaticInfo { /// The route's name, i.e, the name of the function. pub name: &'static str, /// The route's method. - pub method: Method, + pub method: Option, /// The route's URi, without the base mount point. pub uri: &'static str, /// The route's format, if any. diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index d0e15ae4..d55ded8e 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -1,7 +1,7 @@ use crate::catcher::Catcher; use crate::route::{Route, Segment, RouteUri}; -use crate::http::MediaType; +use crate::http::{MediaType, Method}; pub trait Collide { fn collides_with(&self, other: &T) -> bool; @@ -87,7 +87,7 @@ impl Route { /// assert!(a.collides_with(&b)); /// ``` pub fn collides_with(&self, other: &Route) -> bool { - self.method == other.method + methods_collide(self, other) && self.rank == other.rank && self.uri.collides_with(&other.uri) && formats_collide(self, other) @@ -190,8 +190,16 @@ impl Collide for MediaType { } } +fn methods_collide(route: &Route, other: &Route) -> bool { + match (route.method, other.method) { + (Some(a), Some(b)) => a == b, + (None, _) | (_, None) => true, + } +} + fn formats_collide(route: &Route, other: &Route) -> bool { - match (route.method.allows_request_body(), other.method.allows_request_body()) { + let payload_support = |m: &Option| m.and_then(|m| m.allows_request_body()); + match (payload_support(&route.method), payload_support(&other.method)) { // Payload supporting methods match against `Content-Type` which must be // fully specified, so the request cannot contain a format that matches // more than one route format as long as those formats don't collide. diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index e0dd66d3..5cb5b918 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -67,8 +67,7 @@ impl Route { /// ``` #[tracing::instrument(level = "trace", name = "matching", skip_all, ret)] pub fn matches(&self, request: &Request<'_>) -> bool { - trace!(route.method = %self.method, request.method = %request.method()); - self.method == request.method() + methods_match(self, request) && paths_match(self, request) && queries_match(self, request) && formats_match(self, request) @@ -140,6 +139,11 @@ impl Catcher { } } +fn methods_match(route: &Route, req: &Request<'_>) -> bool { + trace!(?route.method, request.method = %req.method()); + route.method.map_or(true, |method| method == req.method()) +} + fn paths_match(route: &Route, req: &Request<'_>) -> bool { trace!(route.uri = %route.uri, request.uri = %req.uri()); let route_segments = &route.uri.metadata.uri_segments; @@ -208,7 +212,7 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool { None => return true, }; - match route.method.allows_request_body() { + match route.method.and_then(|m| m.allows_request_body()) { Some(true) => match req.format() { Some(f) if f.specificity() == 2 => route_format.collides_with(f), _ => false diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 84da1b98..017486a6 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -1,64 +1,120 @@ +use std::ops::{Deref, DerefMut}; use std::collections::HashMap; use crate::request::Request; use crate::http::{Method, Status}; - use crate::{Route, Catcher}; use crate::router::Collide; +#[derive(Debug)] +pub(crate) struct Router(T); + #[derive(Debug, Default)] -pub(crate) struct Router { - routes: HashMap>, - catchers: HashMap, Vec>, +pub struct Pending { + pub routes: Vec, + pub catchers: Vec, } -pub type Collisions = Vec<(T, T)>; +#[derive(Debug, Default)] +pub struct Finalized { + pub routes: Vec, + pub catchers: Vec, + route_map: HashMap>, + catcher_map: HashMap, Vec>, +} -impl Router { +pub type Pair = (T, T); + +pub type Collisions = (Vec>, Vec>); + +pub type Result = std::result::Result; + +impl Router { pub fn new() -> Self { - Self::default() + Router(Pending::default()) } - pub fn add_route(&mut self, route: Route) { - let routes = self.routes.entry(route.method).or_default(); - routes.push(route); - routes.sort_by_key(|r| r.rank); - } + pub fn finalize(self) -> Result, Collisions> { + fn collisions<'a, T>(items: &'a [T]) -> impl Iterator + 'a + where T: Collide + Clone + 'a, + { + items.iter() + .enumerate() + .flat_map(move |(i, a)| { + items.iter() + .skip(i + 1) + .filter(move |b| a.collides_with(b)) + .map(move |b| (a.clone(), b.clone())) + }) + } - pub fn add_catcher(&mut self, catcher: Catcher) { - let catchers = self.catchers.entry(catcher.code).or_default(); - catchers.push(catcher); - catchers.sort_by_key(|c| c.rank); - } + let route_collisions: Vec<_> = collisions(&self.routes).collect(); + let catcher_collisions: Vec<_> = collisions(&self.catchers).collect(); - #[inline] - pub fn routes(&self) -> impl Iterator + Clone { - self.routes.values().flat_map(|v| v.iter()) - } + if !route_collisions.is_empty() || !catcher_collisions.is_empty() { + return Err((route_collisions, catcher_collisions)) + } - #[inline] - pub fn catchers(&self) -> impl Iterator + Clone { - self.catchers.values().flat_map(|v| v.iter()) - } + // create the route map + let mut route_map: HashMap> = HashMap::new(); + for (i, route) in self.routes.iter().enumerate() { + match route.method { + Some(method) => route_map.entry(method).or_default().push(i), + None => for method in Method::ALL_VARIANTS { + route_map.entry(*method).or_default().push(i); + } + } + } + // create the catcher map + let mut catcher_map: HashMap, Vec> = HashMap::new(); + for (i, catcher) in self.catchers.iter().enumerate() { + catcher_map.entry(catcher.code).or_default().push(i); + } + + // sort routes by rank + for routes in route_map.values_mut() { + routes.sort_by_key(|&i| &self.routes[i].rank); + } + + // sort catchers by rank + for catchers in catcher_map.values_mut() { + catchers.sort_by_key(|&i| &self.catchers[i].rank); + } + + Ok(Router(Finalized { + routes: self.0.routes, + catchers: self.0.catchers, + route_map, catcher_map + })) + } +} + +impl Router { + #[track_caller] pub fn route<'r, 'a: 'r>( &'a self, req: &'r Request<'r> ) -> impl Iterator + 'r { - // Note that routes are presorted by ascending rank on each `add`. - self.routes.get(&req.method()) + // Note that routes are presorted by ascending rank on each `add` and + // that all routes with `None` methods have been cloned into all methods. + self.route_map.get(&req.method()) .into_iter() - .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) + .flat_map(move |routes| routes.iter().map(move |&i| &self.routes[i])) + .filter(move |r| r.matches(req)) } // For many catchers, using aho-corasick or similar should be much faster. + #[track_caller] pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { // Note that catchers are presorted by descending base length. - let explicit = self.catchers.get(&Some(status.code)) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); + let explicit = self.catcher_map.get(&Some(status.code)) + .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) + .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); - let default = self.catchers.get(&None) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); + let default = self.catcher_map.get(&None) + .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) + .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); match (explicit, default) { (None, None) => None, @@ -67,28 +123,19 @@ impl Router { (Some(_), Some(b)) => Some(b), } } +} - fn collisions<'a, I, T>(&self, items: I) -> impl Iterator + 'a - where I: Iterator + Clone + 'a, T: Collide + Clone + 'a, - { - items.clone().enumerate() - .flat_map(move |(i, a)| { - items.clone() - .skip(i + 1) - .filter(move |b| a.collides_with(b)) - .map(move |b| (a.clone(), b.clone())) - }) +impl Deref for Router { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 } +} - pub fn finalize(&self) -> Result<(), (Collisions, Collisions)> { - let routes: Vec<_> = self.collisions(self.routes()).collect(); - let catchers: Vec<_> = self.collisions(self.catchers()).collect(); - - if !routes.is_empty() || !catchers.is_empty() { - return Err((routes, catchers)) - } - - Ok(()) +impl DerefMut for Router { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -100,50 +147,32 @@ mod test { use crate::local::blocking::Client; use crate::http::{Method::*, uri::Origin}; - impl Router { - fn has_collisions(&self) -> bool { - self.finalize().is_err() - } - } - - fn router_with_routes(routes: &[&'static str]) -> Router { + fn make_router(routes: I) -> Result, Collisions> + where I: Iterator, &'static str)> + { let mut router = Router::new(); - for route in routes { - let route = Route::new(Get, route, dummy_handler); - router.add_route(route); - } - - router - } - - fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { - let mut router = Router::new(); - for &(rank, route) in routes { + for (rank, route) in routes { let route = Route::ranked(rank, Get, route, dummy_handler); - router.add_route(route); + router.routes.push(route); } - router + router.finalize() } - fn router_with_rankless_routes(routes: &[&'static str]) -> Router { - let mut router = Router::new(); - for route in routes { - let route = Route::ranked(0, Get, route, dummy_handler); - router.add_route(route); - } + fn router_with_routes(routes: &[&'static str]) -> Router { + make_router(routes.iter().map(|r| (None, *r))).unwrap() + } - router + fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { + make_router(routes.iter().map(|r| (Some(r.0), r.1))).unwrap() } fn rankless_route_collisions(routes: &[&'static str]) -> bool { - let router = router_with_rankless_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (Some(0), *r))).is_err() } fn default_rank_route_collisions(routes: &[&'static str]) -> bool { - let router = router_with_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (None, *r))).is_err() } #[test] @@ -280,13 +309,15 @@ mod test { assert!(!default_rank_route_collisions(&["/?a=b", "/?c=d&"])); } - fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { + #[track_caller] + fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { let client = Client::debug_with(vec![]).expect("client"); let request = client.req(method, Origin::parse(uri).unwrap()); router.route(&request).collect() } - fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { + #[track_caller] + fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { matches(router, method, uri).into_iter().next() } @@ -309,9 +340,10 @@ mod test { assert!(route(&router, Get, "/a/").is_some()); let mut router = Router::new(); - router.add_route(Route::new(Put, "/hello", dummy_handler)); - router.add_route(Route::new(Post, "/hello", dummy_handler)); - router.add_route(Route::new(Delete, "/hello", dummy_handler)); + router.routes.push(Route::new(Put, "/hello", dummy_handler)); + router.routes.push(Route::new(Post, "/hello", dummy_handler)); + router.routes.push(Route::new(Delete, "/hello", dummy_handler)); + let router = router.finalize().unwrap(); assert!(route(&router, Put, "/hello").is_some()); assert!(route(&router, Post, "/hello").is_some()); assert!(route(&router, Delete, "/hello").is_some()); @@ -368,7 +400,6 @@ mod test { macro_rules! assert_ranked_match { ($routes:expr, $to:expr => $want:expr) => ({ let router = router_with_routes($routes); - assert!(!router.has_collisions()); let route_path = route(&router, Get, $to).unwrap().uri.to_string(); assert_eq!(route_path, $want.to_string(), "\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router); @@ -401,8 +432,7 @@ mod test { } fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool { - let router = router_with_ranked_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (Some(r.0), r.1))).is_err() } #[test] @@ -429,7 +459,7 @@ mod test { let router = router_with_ranked_routes(&$routes); let routed_to = matches(&router, Get, $to); let expected = &[$($want),+]; - assert!(routed_to.len() == expected.len()); + assert_eq!(routed_to.len(), expected.len()); for (got, expected) in routed_to.iter().zip(expected.iter()) { assert_eq!(got.rank, expected.0); assert_eq!(got.uri.to_string(), expected.1.to_string()); @@ -545,20 +575,21 @@ mod test { ); } - fn router_with_catchers(catchers: &[(Option, &str)]) -> Router { + fn router_with_catchers(catchers: &[(Option, &str)]) -> Result> { let mut router = Router::new(); for (code, base) in catchers { let catcher = Catcher::new(*code, crate::catcher::dummy_handler); - router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap()); + router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap()); } - router + router.finalize() } - fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { + #[track_caller] + fn catcher<'a>(r: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { let client = Client::debug_with(vec![]).expect("client"); let request = client.get(Origin::parse(uri).unwrap()); - router.catch(status, &request) + r.catch(status, &request) } macro_rules! assert_catcher_routing { @@ -571,7 +602,7 @@ mod test { let requests = vec![$($r),+]; let expected = vec![$(($ecode.into(), $euri)),+]; - let router = router_with_catchers(&catchers); + let router = router_with_catchers(&catchers).expect("valid router"); for (req, expected) in requests.iter().zip(expected.iter()) { let req_status = Status::from_code(req.0).expect("valid status"); let catcher = catcher(&router, req_status, req.1).expect("some catcher"); diff --git a/core/lib/src/trace/traceable.rs b/core/lib/src/trace/traceable.rs index fa3f6850..9ef1d428 100644 --- a/core/lib/src/trace/traceable.rs +++ b/core/lib/src/trace/traceable.rs @@ -142,7 +142,10 @@ impl Trace for Route { event! { level, "route", name = self.name.as_ref().map(|n| &**n), rank = self.rank, - method = %self.method, + method = %Formatter(|f| match self.method { + Some(method) => write!(f, "{}", method), + None => write!(f, "[any]"), + }), uri = %self.uri, uri.base = %self.uri.base(), uri.unmounted = %self.uri.unmounted(), diff --git a/core/lib/tests/form_method-issue-45.rs b/core/lib/tests/form_method-issue-45.rs index 009aa164..4e57a48b 100644 --- a/core/lib/tests/form_method-issue-45.rs +++ b/core/lib/tests/form_method-issue-45.rs @@ -13,13 +13,13 @@ fn patch(form_data: Form) -> &'static str { "PATCH OK" } -#[route(UPDATEREDIRECTREF, uri = "/", data = "")] +#[route("/", method = UPDATEREDIRECTREF, data = "")] fn urr(form_data: Form) -> &'static str { assert_eq!("Form data", form_data.into_inner().form_data); "UPDATEREDIRECTREF OK" } -#[route("VERSION-CONTROL", uri = "/", data = "")] +#[route("/", method = "VERSION-CONTROL", data = "")] fn vc(form_data: Form) -> &'static str { assert_eq!("Form data", form_data.into_inner().form_data); "VERSION-CONTROL OK" diff --git a/docs/guide/05-requests.md b/docs/guide/05-requests.md index 08a0329d..d3ad9792 100644 --- a/docs/guide/05-requests.md +++ b/docs/guide/05-requests.md @@ -37,21 +37,55 @@ these properties and more. ## Methods -A Rocket route attribute can be any one of `get`, `put`, `post`, `delete`, -`head`, `patch`, or `options`, each corresponding to the HTTP method to match -against. For example, the following attribute will match against `POST` requests -to the root path: +A Rocket route attribute can either be method-specific, any one of `get`, `put`, +`post`, `delete`, `head`, `patch`, or `options`, or the generic [`route`], which +allows explicitly specifying any valid HTTP [`Method`] or no method at all, to +match again _any_ method. Consider the following examples: -```rust -# #[macro_use] extern crate rocket; -# fn main() {} + * Match a `POST` request to `/`: -#[post("/")] -# fn handler() {} -``` + ```rust + # use rocket::post; + #[post("/")] + # fn handler() {} + ``` + + * Match a `PATCH` request to `/fix`: + + ```rust + # use rocket::patch; + #[patch("/fix")] + # fn handler() {} + ``` + + * Match a `PROPFIND` request to `/collection`: + + ```rust + # use rocket::route; + #[route("/collection", method = PROPFIND)] + # fn handler() {} + ``` + + * Match a `VERSION-CONTROL` request to `/collection`: + + ```rust + # use rocket::route; + #[route("/resource", method = "VERSION-CONTROL")] + # fn handler() {} + ``` + + * Match a request to `/page` with _any_ method: + + ```rust + # use rocket::route; + #[route("/page")] + # fn handler() {} + ``` The grammar for these attributes is defined formally in the [`route`] API docs. +[`Method`]: @api/master/rocket/http/enum.Method.html + ### HEAD Requests Rocket handles `HEAD` requests automatically when there exists a `GET` route diff --git a/examples/hello/src/main.rs b/examples/hello/src/main.rs index 0f8c55cb..7cd60a50 100644 --- a/examples/hello/src/main.rs +++ b/examples/hello/src/main.rs @@ -33,7 +33,7 @@ fn mir() -> &'static str { // Try visiting: // http://127.0.0.1:8000/wave/Rocketeer/100 -#[get("//")] +#[get("//", rank = 2)] fn wave(name: &str, age: u8) -> String { format!("๐Ÿ‘‹ Hello, {} year old named {}!", age, name) } diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index ae08185a..b8155317 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -2,9 +2,9 @@ use std::net::SocketAddr; -use rocket::http::Status; +use rocket::http::uri::{Origin, Host}; use rocket::tracing::{self, Instrument}; -use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; +use rocket::{Rocket, Ignite, Orbit, State, Error}; use rocket::fairing::{Fairing, Info, Kind}; use rocket::response::Redirect; use rocket::listener::tcp::TcpListener; @@ -19,43 +19,33 @@ pub struct Config { tls_addr: SocketAddr, } +#[route("/<_..>")] +fn redirect(config: &State, uri: &Origin<'_>, host: &Host<'_>) -> Redirect { + // FIXME: Check the host against a whitelist! + let domain = host.domain(); + let https_uri = match config.tls_addr.port() { + 443 => format!("https://{domain}{uri}"), + port => format!("https://{domain}:{port}{uri}"), + }; + + Redirect::permanent(https_uri) +} + impl Redirector { pub fn on(port: u16) -> Self { Redirector(port) } - // Route function that gets called on every single request. - fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { - // FIXME: Check the host against a whitelist! - let config = req.rocket().state::().expect("managed Self"); - if let Some(host) = req.host() { - let domain = host.domain(); - let https_uri = match config.tls_addr.port() { - 443 => format!("https://{domain}{}", req.uri()), - port => format!("https://{domain}:{port}{}", req.uri()), - }; - - route::Outcome::from(req, Redirect::permanent(https_uri)).pin() - } else { - route::Outcome::from(req, Status::BadRequest).pin() - } - } - // Launch an instance of Rocket than handles redirection on `self.port`. pub async fn try_launch(self, config: Config) -> Result, Error> { - use rocket::http::Method::*; + rocket::span_info!("HTTP -> HTTPS Redirector" => { + info!(from = self.0, to = config.tls_addr.port(), "redirecting"); + }); - // Build a vector of routes to `redirect` on `` for each method. - let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch] - .into_iter() - .map(|m| Route::new(m, "/", Self::redirect)) - .collect::>(); - - info!(from = self.0, to = config.tls_addr.port(), "redirecting"); let addr = SocketAddr::new(config.tls_addr.ip(), self.0); rocket::custom(&config.server) .manage(config) - .mount("/", redirects) + .mount("/", routes![redirect]) .try_launch_on(TcpListener::bind(addr)) .await } diff --git a/testbench/src/servers/http_extensions.rs b/testbench/src/servers/http_extensions.rs index 42b990c7..86fd69b4 100644 --- a/testbench/src/servers/http_extensions.rs +++ b/testbench/src/servers/http_extensions.rs @@ -4,7 +4,7 @@ use crate::prelude::*; use rocket::http::Method; -#[route(PROPFIND, uri = "/")] +#[route("/", method = PROPFIND)] fn route() -> &'static str { "Hello, World!" }