Support routes that match any method.

This commit introduces support for method-less routes and route
attributes, which match _any_ valid method: `#[route("/")]`. The `Route`
structure's `method` field is now accordingly of type `Option<Route>`.

The syntax for the `route` attribute has changed in a breaking manner.
To set a method, a key/value of `method = NAME` must be introduced:

```rust
#[route("/", method = GET)]
```

If the method's name is a valid identifier, it can be used without
quotes. Otherwise it must be quoted:

```rust
// `GET` is a valid identifier, but `VERSION-CONTROL` is not
#[route("/", method = "VERSION-CONTROL")]
```

Closes #2731.
This commit is contained in:
Sergio Benitez 2024-06-03 21:54:07 -07:00
parent 9496b70e8c
commit 72c91958b7
19 changed files with 297 additions and 205 deletions

View File

@ -58,9 +58,9 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec<L
.join("&"); .join("&");
let uri = format!("/{}?{}", path, query); let uri = format!("/{}?{}", path, query);
let mut req = client.req(route.method, uri); let mut req = client.req(route.method.unwrap(), uri);
if let Some(ref format) = route.format { if let Some(ref format) = route.format {
if let Some(true) = route.method.allows_request_body() { if let Some(true) = route.method.and_then(|m| m.allows_request_body()) {
req.add_header(ContentType::from(format.clone())); req.add_header(ContentType::from(format.clone()));
} else { } else {
req.add_header(Accept::from(format.clone())); req.add_header(Accept::from(format.clone()));

View File

@ -263,7 +263,7 @@ fn internal_uri_macro_decl(route: &Route) -> TokenStream {
// Generate a unique macro name based on the route's metadata. // Generate a unique macro name based on the route's metadata.
let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX); let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX);
let inner_macro_name = macro_name.uniqueify_with(|mut hasher| { let inner_macro_name = macro_name.uniqueify_with(|mut hasher| {
route.attr.method.0.hash(&mut hasher); route.attr.method.as_ref().map(|m| m.0.hash(&mut hasher));
route.attr.uri.path().hash(&mut hasher); route.attr.uri.path().hash(&mut hasher);
route.attr.uri.query().hash(&mut hasher); route.attr.uri.query().hash(&mut hasher);
route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher)); route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher));
@ -395,7 +395,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
let internal_uri_macro = internal_uri_macro_decl(&route); let internal_uri_macro = internal_uri_macro_decl(&route);
let responder_outcome = responder_outcome_expr(&route); let responder_outcome = responder_outcome_expr(&route);
let method = &route.attr.method; let method = Optional(route.attr.method.clone());
let uri = route.attr.uri.to_string(); let uri = route.attr.uri.to_string();
let rank = Optional(route.attr.rank); let rank = Optional(route.attr.rank);
let format = Optional(route.attr.format.as_ref()); let format = Optional(route.attr.format.as_ref());
@ -480,9 +480,12 @@ fn incomplete_route(
let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?; let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?;
let attribute = Attribute { let attribute = Attribute {
method: SpanWrapped { method: Some(SpanWrapped {
full_span: method_span, key_span: None, span: method_span, value: Method(method) full_span: method_span,
}, key_span: None,
span: method_span,
value: Method(method),
}),
uri: method_attribute.uri, uri: method_attribute.uri,
data: method_attribute.data, data: method_attribute.data,
format: method_attribute.format, format: method_attribute.format,

View File

@ -43,8 +43,8 @@ pub struct Arguments {
#[derive(Debug, FromMeta)] #[derive(Debug, FromMeta)]
pub struct Attribute { pub struct Attribute {
#[meta(naked)] #[meta(naked)]
pub method: SpanWrapped<Method>,
pub uri: RouteUri, pub uri: RouteUri,
pub method: Option<SpanWrapped<Method>>,
pub data: Option<SpanWrapped<Dynamic>>, pub data: Option<SpanWrapped<Dynamic>>,
pub format: Option<MediaType>, pub format: Option<MediaType>,
pub rank: Option<isize>, pub rank: Option<isize>,
@ -129,17 +129,23 @@ impl Route {
// Emit a warning if a `data` param was supplied for non-payload methods. // Emit a warning if a `data` param was supplied for non-payload methods.
if let Some(ref data) = attr.data { if let Some(ref data) = attr.data {
let lint = Lint::DubiousPayload; let lint = Lint::DubiousPayload;
match attr.method.0.allows_request_body() { match attr.method.as_ref() {
None if lint.enabled(handler.span()) => { Some(m) if m.0.allows_request_body() == Some(false) => {
data.full_span.warning("`data` used with non-payload-supporting method")
.note(format!("'{}' does not typically support payloads", attr.method.0))
.note(lint.how_to_suppress())
.emit_as_item_tokens();
}
Some(false) => {
diags.push(data.full_span diags.push(data.full_span
.error("`data` cannot be used on this route") .error("`data` cannot be used on this route")
.span_note(attr.method.span, "method does not support request payloads")) .span_note(m.span, "method does not support request payloads"))
},
Some(m) if m.0.allows_request_body().is_none() && lint.enabled(handler.span()) => {
data.full_span.warning("`data` used with non-payload-supporting method")
.span_note(m.span, format!("'{}' does not typically support payloads", m.0))
.note(lint.how_to_suppress())
.emit_as_item_tokens();
},
None if lint.enabled(handler.span()) => {
data.full_span.warning("`data` used on route with wildcard method")
.note("some methods may not support request payloads")
.note(lint.how_to_suppress())
.emit_as_item_tokens();
} }
_ => { /* okay */ }, _ => { /* okay */ },
} }

View File

@ -119,15 +119,20 @@ macro_rules! route_attribute {
/// * [`patch`] - `PATCH` specific route /// * [`patch`] - `PATCH` specific route
/// ///
/// Additionally, [`route`] allows the method and uri to be explicitly /// Additionally, [`route`] allows the method and uri to be explicitly
/// specified: /// specified, and for the method to be omitted entirely, to match any
/// method:
/// ///
/// ```rust /// ```rust
/// # #[macro_use] extern crate rocket; /// # #[macro_use] extern crate rocket;
/// # ///
/// #[route(GET, uri = "/")] /// #[route("/", method = GET)]
/// fn index() -> &'static str { /// fn get_index() { /* ... */ }
/// "Hello, world!" ///
/// } /// #[route("/", method = "VERSION-CONTROL")]
/// fn versioned_index() { /* ... */ }
///
/// #[route("/")]
/// fn index() { /* ... */ }
/// ``` /// ```
/// ///
/// [`get`]: attr.get.html /// [`get`]: attr.get.html
@ -171,7 +176,9 @@ macro_rules! route_attribute {
/// The generic route attribute is defined as: /// The generic route attribute is defined as:
/// ///
/// ```text /// ```text
/// generic-route := METHOD ',' 'uri' '=' route /// generic-route := route (',' method)?
///
/// method := 'method' '=' METHOD
/// ``` /// ```
/// ///
/// # Typing Requirements /// # Typing Requirements
@ -1161,12 +1168,12 @@ pub fn derive_uri_display_path(input: TokenStream) -> TokenStream {
/// assert_eq!(my_routes.len(), 2); /// assert_eq!(my_routes.len(), 2);
/// ///
/// let index_route = &my_routes[0]; /// let index_route = &my_routes[0];
/// assert_eq!(index_route.method, Method::Get); /// assert_eq!(index_route.method, Some(Method::Get));
/// assert_eq!(index_route.name.as_ref().unwrap(), "index"); /// assert_eq!(index_route.name.as_ref().unwrap(), "index");
/// assert_eq!(index_route.uri.path(), "/"); /// assert_eq!(index_route.uri.path(), "/");
/// ///
/// let hello_route = &my_routes[1]; /// let hello_route = &my_routes[1];
/// assert_eq!(hello_route.method, Method::Post); /// assert_eq!(hello_route.method, Some(Method::Post));
/// assert_eq!(hello_route.name.as_ref().unwrap(), "hello"); /// assert_eq!(hello_route.name.as_ref().unwrap(), "hello");
/// assert_eq!(hello_route.uri.path(), "/hi/<person>"); /// assert_eq!(hello_route.uri.path(), "/hi/<person>");
/// ``` /// ```

View File

@ -54,8 +54,8 @@ fn post1(
} }
#[route( #[route(
POST, "/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
uri = "/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>", method = POST,
format = "json", format = "json",
data = "<simple>", data = "<simple>",
rank = 138 rank = 138

View File

@ -124,6 +124,10 @@ macro_rules! define_methods {
#[doc(hidden)] #[doc(hidden)]
pub const ALL: &'static [&'static str] = &[$($name),*]; pub const ALL: &'static [&'static str] = &[$($name),*];
/// A slice containing every defined method variant.
#[doc(hidden)]
pub const ALL_VARIANTS: &'static [Method] = &[$(Self::$V),*];
/// Whether the method is considered "safe". /// Whether the method is considered "safe".
/// ///
/// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1): /// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1):

View File

@ -16,7 +16,7 @@ struct ArbitraryRequestData<'a> {
#[derive(Arbitrary)] #[derive(Arbitrary)]
struct ArbitraryRouteData<'a> { struct ArbitraryRouteData<'a> {
method: ArbitraryMethod, method: Option<ArbitraryMethod>,
uri: ArbitraryRouteUri<'a>, uri: ArbitraryRouteUri<'a>,
format: Option<ArbitraryMediaType>, format: Option<ArbitraryMediaType>,
} }
@ -24,7 +24,7 @@ struct ArbitraryRouteData<'a> {
impl std::fmt::Debug for ArbitraryRouteData<'_> { impl std::fmt::Debug for ArbitraryRouteData<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArbitraryRouteData") f.debug_struct("ArbitraryRouteData")
.field("method", &self.method.0) .field("method", &self.method.map(|v| v.0))
.field("base", &self.uri.0.base()) .field("base", &self.uri.0.base())
.field("unmounted", &self.uri.0.unmounted().to_string()) .field("unmounted", &self.uri.0.unmounted().to_string())
.field("uri", &self.uri.0.to_string()) .field("uri", &self.uri.0.to_string())
@ -59,12 +59,14 @@ impl<'c, 'a: 'c> ArbitraryRequestData<'a> {
impl<'a> ArbitraryRouteData<'a> { impl<'a> ArbitraryRouteData<'a> {
fn into_route(self) -> Route { fn into_route(self) -> Route {
let mut r = Route::ranked(0, self.method.0, &self.uri.0.to_string(), dummy_handler); let method = self.method.map(|m| m.0);
let mut r = Route::ranked(0, method, &self.uri.0.to_string(), dummy_handler);
r.format = self.format.map(|f| f.0); r.format = self.format.map(|f| f.0);
r r
} }
} }
#[derive(Clone, Copy)]
struct ArbitraryMethod(Method); struct ArbitraryMethod(Method);
struct ArbitraryOrigin<'a>(Origin<'a>); struct ArbitraryOrigin<'a>(Origin<'a>);
@ -79,12 +81,7 @@ struct ArbitraryRouteUri<'a>(RouteUri<'a>);
impl<'a> Arbitrary<'a> for ArbitraryMethod { impl<'a> Arbitrary<'a> for ArbitraryMethod {
fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> { fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
let all_methods = &[ Ok(ArbitraryMethod(*u.choose(Method::ALL_VARIANTS)?))
Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options,
Method::Head, Method::Trace, Method::Connect, Method::Patch
];
Ok(ArbitraryMethod(*u.choose(all_methods)?))
} }
fn size_hint(_: usize) -> (usize, Option<usize>) { fn size_hint(_: usize) -> (usize, Option<usize>) {

View File

@ -4,7 +4,7 @@ use figment::Figment;
use crate::listener::Endpoint; use crate::listener::Endpoint;
use crate::shutdown::Stages; use crate::shutdown::Stages;
use crate::{Catcher, Config, Rocket, Route}; use crate::{Catcher, Config, Rocket, Route};
use crate::router::Router; use crate::router::{Router, Finalized};
use crate::fairing::Fairings; use crate::fairing::Fairings;
mod private { mod private {
@ -100,7 +100,7 @@ phases! {
/// represents a fully built and finalized application server ready for /// represents a fully built and finalized application server ready for
/// launch into orbit. See [`Rocket#ignite`] for full details. /// launch into orbit. See [`Rocket#ignite`] for full details.
Ignite (#[derive(Debug)] Igniting) { Ignite (#[derive(Debug)] Igniting) {
pub(crate) router: Router, pub(crate) router: Router<Finalized>,
pub(crate) fairings: Fairings, pub(crate) fairings: Fairings,
pub(crate) figment: Figment, pub(crate) figment: Figment,
pub(crate) config: Config, pub(crate) config: Config,
@ -114,7 +114,7 @@ phases! {
/// An instance of `Rocket` in this phase is typed as [`Rocket<Orbit>`] and /// An instance of `Rocket` in this phase is typed as [`Rocket<Orbit>`] and
/// represents a running application. /// represents a running application.
Orbit (#[derive(Debug)] Orbiting) { Orbit (#[derive(Debug)] Orbiting) {
pub(crate) router: Router, pub(crate) router: Router<Finalized>,
pub(crate) fairings: Fairings, pub(crate) fairings: Fairings,
pub(crate) figment: Figment, pub(crate) figment: Figment,
pub(crate) config: Config, pub(crate) config: Config,

View File

@ -557,9 +557,10 @@ impl Rocket<Build> {
// Initialize the router; check for collisions. // Initialize the router; check for collisions.
let mut router = Router::new(); let mut router = Router::new();
self.routes.clone().into_iter().for_each(|r| router.add_route(r)); self.routes.clone().into_iter().for_each(|r| router.routes.push(r));
self.catchers.clone().into_iter().for_each(|c| router.add_catcher(c)); self.catchers.clone().into_iter().for_each(|c| router.catchers.push(c));
router.finalize().map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?; let router = router.finalize()
.map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?;
// Finally, freeze managed state for faster access later. // Finally, freeze managed state for faster access later.
self.state.freeze(); self.state.freeze();
@ -840,8 +841,8 @@ impl<P: Phase> Rocket<P> {
pub fn routes(&self) -> impl Iterator<Item = &Route> { pub fn routes(&self) -> impl Iterator<Item = &Route> {
match self.0.as_ref() { match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.routes.iter()), StateRef::Build(p) => Either::Left(p.routes.iter()),
StateRef::Ignite(p) => Either::Right(p.router.routes()), StateRef::Ignite(p) => Either::Right(p.router.routes.iter()),
StateRef::Orbit(p) => Either::Right(p.router.routes()), StateRef::Orbit(p) => Either::Right(p.router.routes.iter()),
} }
} }
@ -871,8 +872,8 @@ impl<P: Phase> Rocket<P> {
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> { pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
match self.0.as_ref() { match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.catchers.iter()), StateRef::Build(p) => Either::Left(p.catchers.iter()),
StateRef::Ignite(p) => Either::Right(p.router.catchers()), StateRef::Ignite(p) => Either::Right(p.router.catchers.iter()),
StateRef::Orbit(p) => Either::Right(p.router.catchers()), StateRef::Orbit(p) => Either::Right(p.router.catchers.iter()),
} }
} }

View File

@ -22,7 +22,7 @@ use crate::sentinel::Sentry;
/// ///
/// let route = routes![route_name].remove(0); /// let route = routes![route_name].remove(0);
/// assert_eq!(route.name.unwrap(), "route_name"); /// assert_eq!(route.name.unwrap(), "route_name");
/// assert_eq!(route.method, Method::Get); /// assert_eq!(route.method, Some(Method::Get));
/// assert_eq!(route.uri, "/route/<path..>?query"); /// assert_eq!(route.uri, "/route/<path..>?query");
/// assert_eq!(route.rank, 2); /// assert_eq!(route.rank, 2);
/// assert_eq!(route.format.unwrap(), MediaType::JSON); /// assert_eq!(route.format.unwrap(), MediaType::JSON);
@ -164,8 +164,8 @@ use crate::sentinel::Sentry;
pub struct Route { pub struct Route {
/// The name of this route, if one was given. /// The name of this route, if one was given.
pub name: Option<Cow<'static, str>>, pub name: Option<Cow<'static, str>>,
/// The method this route matches against. /// The method this route matches, or `None` to match any method.
pub method: Method, pub method: Option<Method>,
/// The function that should be called when the route matches. /// The function that should be called when the route matches.
pub handler: Box<dyn Handler>, pub handler: Box<dyn Handler>,
/// The route URI. /// The route URI.
@ -203,12 +203,12 @@ impl Route {
/// // this is a route matching requests to `GET /` /// // this is a route matching requests to `GET /`
/// let index = Route::new(Method::Get, "/", handler); /// let index = Route::new(Method::Get, "/", handler);
/// assert_eq!(index.rank, -9); /// assert_eq!(index.rank, -9);
/// assert_eq!(index.method, Method::Get); /// assert_eq!(index.method, Some(Method::Get));
/// assert_eq!(index.uri, "/"); /// assert_eq!(index.uri, "/");
/// ``` /// ```
#[track_caller] #[track_caller]
pub fn new<H: Handler>(method: Method, uri: &str, handler: H) -> Route { pub fn new<M: Into<Option<Method>>, H: Handler>(method: M, uri: &str, handler: H) -> Route {
Route::ranked(None, method, uri, handler) Route::ranked(None, method.into(), uri, handler)
} }
/// Creates a new route with the given rank, method, path, and handler with /// Creates a new route with the given rank, method, path, and handler with
@ -233,17 +233,19 @@ impl Route {
/// ///
/// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler); /// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler);
/// assert_eq!(foo.rank, 1); /// assert_eq!(foo.rank, 1);
/// assert_eq!(foo.method, Method::Post); /// assert_eq!(foo.method, Some(Method::Post));
/// assert_eq!(foo.uri, "/foo?bar"); /// assert_eq!(foo.uri, "/foo?bar");
/// ///
/// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler); /// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler);
/// assert_eq!(foo.rank, -12); /// assert_eq!(foo.rank, -12);
/// assert_eq!(foo.method, Method::Post); /// assert_eq!(foo.method, Some(Method::Post));
/// assert_eq!(foo.uri, "/foo?bar"); /// assert_eq!(foo.uri, "/foo?bar");
/// ``` /// ```
#[track_caller] #[track_caller]
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route pub fn ranked<M, H, R>(rank: R, method: M, uri: &str, handler: H) -> Route
where H: Handler + 'static, R: Into<Option<isize>>, where M: Into<Option<Method>>,
H: Handler + 'static,
R: Into<Option<isize>>,
{ {
let uri = RouteUri::new("/", uri); let uri = RouteUri::new("/", uri);
let rank = rank.into().unwrap_or_else(|| uri.default_rank()); let rank = rank.into().unwrap_or_else(|| uri.default_rank());
@ -253,7 +255,9 @@ impl Route {
sentinels: Vec::new(), sentinels: Vec::new(),
handler: Box::new(handler), handler: Box::new(handler),
location: None, location: None,
rank, uri, method, method: method.into(),
rank,
uri,
} }
} }
@ -362,7 +366,7 @@ pub struct StaticInfo {
/// The route's name, i.e, the name of the function. /// The route's name, i.e, the name of the function.
pub name: &'static str, pub name: &'static str,
/// The route's method. /// The route's method.
pub method: Method, pub method: Option<Method>,
/// The route's URi, without the base mount point. /// The route's URi, without the base mount point.
pub uri: &'static str, pub uri: &'static str,
/// The route's format, if any. /// The route's format, if any.

View File

@ -1,7 +1,7 @@
use crate::catcher::Catcher; use crate::catcher::Catcher;
use crate::route::{Route, Segment, RouteUri}; use crate::route::{Route, Segment, RouteUri};
use crate::http::MediaType; use crate::http::{MediaType, Method};
pub trait Collide<T = Self> { pub trait Collide<T = Self> {
fn collides_with(&self, other: &T) -> bool; fn collides_with(&self, other: &T) -> bool;
@ -87,7 +87,7 @@ impl Route {
/// assert!(a.collides_with(&b)); /// assert!(a.collides_with(&b));
/// ``` /// ```
pub fn collides_with(&self, other: &Route) -> bool { pub fn collides_with(&self, other: &Route) -> bool {
self.method == other.method methods_collide(self, other)
&& self.rank == other.rank && self.rank == other.rank
&& self.uri.collides_with(&other.uri) && self.uri.collides_with(&other.uri)
&& formats_collide(self, other) && formats_collide(self, other)
@ -190,8 +190,16 @@ impl Collide for MediaType {
} }
} }
fn methods_collide(route: &Route, other: &Route) -> bool {
match (route.method, other.method) {
(Some(a), Some(b)) => a == b,
(None, _) | (_, None) => true,
}
}
fn formats_collide(route: &Route, other: &Route) -> bool { fn formats_collide(route: &Route, other: &Route) -> bool {
match (route.method.allows_request_body(), other.method.allows_request_body()) { let payload_support = |m: &Option<Method>| m.and_then(|m| m.allows_request_body());
match (payload_support(&route.method), payload_support(&other.method)) {
// Payload supporting methods match against `Content-Type` which must be // Payload supporting methods match against `Content-Type` which must be
// fully specified, so the request cannot contain a format that matches // fully specified, so the request cannot contain a format that matches
// more than one route format as long as those formats don't collide. // more than one route format as long as those formats don't collide.

View File

@ -67,8 +67,7 @@ impl Route {
/// ``` /// ```
#[tracing::instrument(level = "trace", name = "matching", skip_all, ret)] #[tracing::instrument(level = "trace", name = "matching", skip_all, ret)]
pub fn matches(&self, request: &Request<'_>) -> bool { pub fn matches(&self, request: &Request<'_>) -> bool {
trace!(route.method = %self.method, request.method = %request.method()); methods_match(self, request)
self.method == request.method()
&& paths_match(self, request) && paths_match(self, request)
&& queries_match(self, request) && queries_match(self, request)
&& formats_match(self, request) && formats_match(self, request)
@ -140,6 +139,11 @@ impl Catcher {
} }
} }
fn methods_match(route: &Route, req: &Request<'_>) -> bool {
trace!(?route.method, request.method = %req.method());
route.method.map_or(true, |method| method == req.method())
}
fn paths_match(route: &Route, req: &Request<'_>) -> bool { fn paths_match(route: &Route, req: &Request<'_>) -> bool {
trace!(route.uri = %route.uri, request.uri = %req.uri()); trace!(route.uri = %route.uri, request.uri = %req.uri());
let route_segments = &route.uri.metadata.uri_segments; let route_segments = &route.uri.metadata.uri_segments;
@ -208,7 +212,7 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool {
None => return true, None => return true,
}; };
match route.method.allows_request_body() { match route.method.and_then(|m| m.allows_request_body()) {
Some(true) => match req.format() { Some(true) => match req.format() {
Some(f) if f.specificity() == 2 => route_format.collides_with(f), Some(f) if f.specificity() == 2 => route_format.collides_with(f),
_ => false _ => false

View File

@ -1,64 +1,120 @@
use std::ops::{Deref, DerefMut};
use std::collections::HashMap; use std::collections::HashMap;
use crate::request::Request; use crate::request::Request;
use crate::http::{Method, Status}; use crate::http::{Method, Status};
use crate::{Route, Catcher}; use crate::{Route, Catcher};
use crate::router::Collide; use crate::router::Collide;
#[derive(Debug)]
pub(crate) struct Router<T>(T);
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct Router { pub struct Pending {
routes: HashMap<Method, Vec<Route>>, pub routes: Vec<Route>,
catchers: HashMap<Option<u16>, Vec<Catcher>>, pub catchers: Vec<Catcher>,
} }
pub type Collisions<T> = Vec<(T, T)>; #[derive(Debug, Default)]
pub struct Finalized {
pub routes: Vec<Route>,
pub catchers: Vec<Catcher>,
route_map: HashMap<Method, Vec<usize>>,
catcher_map: HashMap<Option<u16>, Vec<usize>>,
}
impl Router { pub type Pair<T> = (T, T);
pub type Collisions = (Vec<Pair<Route>>, Vec<Pair<Catcher>>);
pub type Result<T, E = Collisions> = std::result::Result<T, E>;
impl Router<Pending> {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Router(Pending::default())
} }
pub fn add_route(&mut self, route: Route) { pub fn finalize(self) -> Result<Router<Finalized>, Collisions> {
let routes = self.routes.entry(route.method).or_default(); fn collisions<'a, T>(items: &'a [T]) -> impl Iterator<Item = (T, T)> + 'a
routes.push(route); where T: Collide + Clone + 'a,
routes.sort_by_key(|r| r.rank); {
} items.iter()
.enumerate()
.flat_map(move |(i, a)| {
items.iter()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
}
pub fn add_catcher(&mut self, catcher: Catcher) { let route_collisions: Vec<_> = collisions(&self.routes).collect();
let catchers = self.catchers.entry(catcher.code).or_default(); let catcher_collisions: Vec<_> = collisions(&self.catchers).collect();
catchers.push(catcher);
catchers.sort_by_key(|c| c.rank);
}
#[inline] if !route_collisions.is_empty() || !catcher_collisions.is_empty() {
pub fn routes(&self) -> impl Iterator<Item = &Route> + Clone { return Err((route_collisions, catcher_collisions))
self.routes.values().flat_map(|v| v.iter()) }
}
#[inline] // create the route map
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + Clone { let mut route_map: HashMap<Method, Vec<usize>> = HashMap::new();
self.catchers.values().flat_map(|v| v.iter()) for (i, route) in self.routes.iter().enumerate() {
} match route.method {
Some(method) => route_map.entry(method).or_default().push(i),
None => for method in Method::ALL_VARIANTS {
route_map.entry(*method).or_default().push(i);
}
}
}
// create the catcher map
let mut catcher_map: HashMap<Option<u16>, Vec<usize>> = HashMap::new();
for (i, catcher) in self.catchers.iter().enumerate() {
catcher_map.entry(catcher.code).or_default().push(i);
}
// sort routes by rank
for routes in route_map.values_mut() {
routes.sort_by_key(|&i| &self.routes[i].rank);
}
// sort catchers by rank
for catchers in catcher_map.values_mut() {
catchers.sort_by_key(|&i| &self.catchers[i].rank);
}
Ok(Router(Finalized {
routes: self.0.routes,
catchers: self.0.catchers,
route_map, catcher_map
}))
}
}
impl Router<Finalized> {
#[track_caller]
pub fn route<'r, 'a: 'r>( pub fn route<'r, 'a: 'r>(
&'a self, &'a self,
req: &'r Request<'r> req: &'r Request<'r>
) -> impl Iterator<Item = &'a Route> + 'r { ) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by ascending rank on each `add`. // Note that routes are presorted by ascending rank on each `add` and
self.routes.get(&req.method()) // that all routes with `None` methods have been cloned into all methods.
self.route_map.get(&req.method())
.into_iter() .into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) .flat_map(move |routes| routes.iter().map(move |&i| &self.routes[i]))
.filter(move |r| r.matches(req))
} }
// For many catchers, using aho-corasick or similar should be much faster. // For many catchers, using aho-corasick or similar should be much faster.
#[track_caller]
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> {
// Note that catchers are presorted by descending base length. // Note that catchers are presorted by descending base length.
let explicit = self.catchers.get(&Some(status.code)) let explicit = self.catcher_map.get(&Some(status.code))
.and_then(|c| c.iter().find(|c| c.matches(status, req))); .map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
let default = self.catchers.get(&None) let default = self.catcher_map.get(&None)
.and_then(|c| c.iter().find(|c| c.matches(status, req))); .map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
match (explicit, default) { match (explicit, default) {
(None, None) => None, (None, None) => None,
@ -67,28 +123,19 @@ impl Router {
(Some(_), Some(b)) => Some(b), (Some(_), Some(b)) => Some(b),
} }
} }
}
fn collisions<'a, I, T>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a impl<T> Deref for Router<T> {
where I: Iterator<Item = &'a T> + Clone + 'a, T: Collide + Clone + 'a, type Target = T;
{
items.clone().enumerate() fn deref(&self) -> &Self::Target {
.flat_map(move |(i, a)| { &self.0
items.clone()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
} }
}
pub fn finalize(&self) -> Result<(), (Collisions<Route>, Collisions<Catcher>)> { impl DerefMut for Router<Pending> {
let routes: Vec<_> = self.collisions(self.routes()).collect(); fn deref_mut(&mut self) -> &mut Self::Target {
let catchers: Vec<_> = self.collisions(self.catchers()).collect(); &mut self.0
if !routes.is_empty() || !catchers.is_empty() {
return Err((routes, catchers))
}
Ok(())
} }
} }
@ -100,50 +147,32 @@ mod test {
use crate::local::blocking::Client; use crate::local::blocking::Client;
use crate::http::{Method::*, uri::Origin}; use crate::http::{Method::*, uri::Origin};
impl Router { fn make_router<I>(routes: I) -> Result<Router<Finalized>, Collisions>
fn has_collisions(&self) -> bool { where I: Iterator<Item = (Option<isize>, &'static str)>
self.finalize().is_err() {
}
}
fn router_with_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new(); let mut router = Router::new();
for route in routes { for (rank, route) in routes {
let route = Route::new(Get, route, dummy_handler);
router.add_route(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy_handler); let route = Route::ranked(rank, Get, route, dummy_handler);
router.add_route(route); router.routes.push(route);
} }
router router.finalize()
} }
fn router_with_rankless_routes(routes: &[&'static str]) -> Router { fn router_with_routes(routes: &[&'static str]) -> Router<Finalized> {
let mut router = Router::new(); make_router(routes.iter().map(|r| (None, *r))).unwrap()
for route in routes { }
let route = Route::ranked(0, Get, route, dummy_handler);
router.add_route(route);
}
router fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router<Finalized> {
make_router(routes.iter().map(|r| (Some(r.0), r.1))).unwrap()
} }
fn rankless_route_collisions(routes: &[&'static str]) -> bool { fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes); make_router(routes.iter().map(|r| (Some(0), *r))).is_err()
router.has_collisions()
} }
fn default_rank_route_collisions(routes: &[&'static str]) -> bool { fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes); make_router(routes.iter().map(|r| (None, *r))).is_err()
router.has_collisions()
} }
#[test] #[test]
@ -280,13 +309,15 @@ mod test {
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"])); assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
} }
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { #[track_caller]
fn matches<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Vec<&'a Route> {
let client = Client::debug_with(vec![]).expect("client"); let client = Client::debug_with(vec![]).expect("client");
let request = client.req(method, Origin::parse(uri).unwrap()); let request = client.req(method, Origin::parse(uri).unwrap());
router.route(&request).collect() router.route(&request).collect()
} }
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { #[track_caller]
fn route<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Option<&'a Route> {
matches(router, method, uri).into_iter().next() matches(router, method, uri).into_iter().next()
} }
@ -309,9 +340,10 @@ mod test {
assert!(route(&router, Get, "/a/").is_some()); assert!(route(&router, Get, "/a/").is_some());
let mut router = Router::new(); let mut router = Router::new();
router.add_route(Route::new(Put, "/hello", dummy_handler)); router.routes.push(Route::new(Put, "/hello", dummy_handler));
router.add_route(Route::new(Post, "/hello", dummy_handler)); router.routes.push(Route::new(Post, "/hello", dummy_handler));
router.add_route(Route::new(Delete, "/hello", dummy_handler)); router.routes.push(Route::new(Delete, "/hello", dummy_handler));
let router = router.finalize().unwrap();
assert!(route(&router, Put, "/hello").is_some()); assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some()); assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some()); assert!(route(&router, Delete, "/hello").is_some());
@ -368,7 +400,6 @@ mod test {
macro_rules! assert_ranked_match { macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({ ($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes); let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string(); let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string(), assert_eq!(route_path, $want.to_string(),
"\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router); "\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router);
@ -401,8 +432,7 @@ mod test {
} }
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool { fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes); make_router(routes.iter().map(|r| (Some(r.0), r.1))).is_err()
router.has_collisions()
} }
#[test] #[test]
@ -429,7 +459,7 @@ mod test {
let router = router_with_ranked_routes(&$routes); let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to); let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+]; let expected = &[$($want),+];
assert!(routed_to.len() == expected.len()); assert_eq!(routed_to.len(), expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) { for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0); assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string()); assert_eq!(got.uri.to_string(), expected.1.to_string());
@ -545,20 +575,21 @@ mod test {
); );
} }
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Router { fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Result<Router<Finalized>> {
let mut router = Router::new(); let mut router = Router::new();
for (code, base) in catchers { for (code, base) in catchers {
let catcher = Catcher::new(*code, crate::catcher::dummy_handler); let catcher = Catcher::new(*code, crate::catcher::dummy_handler);
router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap()); router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap());
} }
router router.finalize()
} }
fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { #[track_caller]
fn catcher<'a>(r: &'a Router<Finalized>, status: Status, uri: &str) -> Option<&'a Catcher> {
let client = Client::debug_with(vec![]).expect("client"); let client = Client::debug_with(vec![]).expect("client");
let request = client.get(Origin::parse(uri).unwrap()); let request = client.get(Origin::parse(uri).unwrap());
router.catch(status, &request) r.catch(status, &request)
} }
macro_rules! assert_catcher_routing { macro_rules! assert_catcher_routing {
@ -571,7 +602,7 @@ mod test {
let requests = vec![$($r),+]; let requests = vec![$($r),+];
let expected = vec![$(($ecode.into(), $euri)),+]; let expected = vec![$(($ecode.into(), $euri)),+];
let router = router_with_catchers(&catchers); let router = router_with_catchers(&catchers).expect("valid router");
for (req, expected) in requests.iter().zip(expected.iter()) { for (req, expected) in requests.iter().zip(expected.iter()) {
let req_status = Status::from_code(req.0).expect("valid status"); let req_status = Status::from_code(req.0).expect("valid status");
let catcher = catcher(&router, req_status, req.1).expect("some catcher"); let catcher = catcher(&router, req_status, req.1).expect("some catcher");

View File

@ -142,7 +142,10 @@ impl Trace for Route {
event! { level, "route", event! { level, "route",
name = self.name.as_ref().map(|n| &**n), name = self.name.as_ref().map(|n| &**n),
rank = self.rank, rank = self.rank,
method = %self.method, method = %Formatter(|f| match self.method {
Some(method) => write!(f, "{}", method),
None => write!(f, "[any]"),
}),
uri = %self.uri, uri = %self.uri,
uri.base = %self.uri.base(), uri.base = %self.uri.base(),
uri.unmounted = %self.uri.unmounted(), uri.unmounted = %self.uri.unmounted(),

View File

@ -13,13 +13,13 @@ fn patch(form_data: Form<FormData>) -> &'static str {
"PATCH OK" "PATCH OK"
} }
#[route(UPDATEREDIRECTREF, uri = "/", data = "<form_data>")] #[route("/", method = UPDATEREDIRECTREF, data = "<form_data>")]
fn urr(form_data: Form<FormData>) -> &'static str { fn urr(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data); assert_eq!("Form data", form_data.into_inner().form_data);
"UPDATEREDIRECTREF OK" "UPDATEREDIRECTREF OK"
} }
#[route("VERSION-CONTROL", uri = "/", data = "<form_data>")] #[route("/", method = "VERSION-CONTROL", data = "<form_data>")]
fn vc(form_data: Form<FormData>) -> &'static str { fn vc(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data); assert_eq!("Form data", form_data.into_inner().form_data);
"VERSION-CONTROL OK" "VERSION-CONTROL OK"

View File

@ -37,21 +37,55 @@ these properties and more.
## Methods ## Methods
A Rocket route attribute can be any one of `get`, `put`, `post`, `delete`, A Rocket route attribute can either be method-specific, any one of `get`, `put`,
`head`, `patch`, or `options`, each corresponding to the HTTP method to match `post`, `delete`, `head`, `patch`, or `options`, or the generic [`route`], which
against. For example, the following attribute will match against `POST` requests allows explicitly specifying any valid HTTP [`Method`] or no method at all, to
to the root path: match again _any_ method. Consider the following examples:
```rust * Match a `POST` request to `/`:
# #[macro_use] extern crate rocket;
# fn main() {}
#[post("/")] ```rust
# fn handler() {} # use rocket::post;
``` #[post("/")]
# fn handler() {}
```
* Match a `PATCH` request to `/fix`:
```rust
# use rocket::patch;
#[patch("/fix")]
# fn handler() {}
```
* Match a `PROPFIND` request to `/collection`:
```rust
# use rocket::route;
#[route("/collection", method = PROPFIND)]
# fn handler() {}
```
* Match a `VERSION-CONTROL` request to `/collection`:
```rust
# use rocket::route;
#[route("/resource", method = "VERSION-CONTROL")]
# fn handler() {}
```
* Match a request to `/page` with _any_ method:
```rust
# use rocket::route;
#[route("/page")]
# fn handler() {}
```
The grammar for these attributes is defined formally in the [`route`] API docs. The grammar for these attributes is defined formally in the [`route`] API docs.
[`Method`]: @api/master/rocket/http/enum.Method.html
### HEAD Requests ### HEAD Requests
Rocket handles `HEAD` requests automatically when there exists a `GET` route Rocket handles `HEAD` requests automatically when there exists a `GET` route

View File

@ -33,7 +33,7 @@ fn mir() -> &'static str {
// Try visiting: // Try visiting:
// http://127.0.0.1:8000/wave/Rocketeer/100 // http://127.0.0.1:8000/wave/Rocketeer/100
#[get("/<name>/<age>")] #[get("/<name>/<age>", rank = 2)]
fn wave(name: &str, age: u8) -> String { fn wave(name: &str, age: u8) -> String {
format!("👋 Hello, {} year old named {}!", age, name) format!("👋 Hello, {} year old named {}!", age, name)
} }

View File

@ -2,9 +2,9 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use rocket::http::Status; use rocket::http::uri::{Origin, Host};
use rocket::tracing::{self, Instrument}; use rocket::tracing::{self, Instrument};
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; use rocket::{Rocket, Ignite, Orbit, State, Error};
use rocket::fairing::{Fairing, Info, Kind}; use rocket::fairing::{Fairing, Info, Kind};
use rocket::response::Redirect; use rocket::response::Redirect;
use rocket::listener::tcp::TcpListener; use rocket::listener::tcp::TcpListener;
@ -19,43 +19,33 @@ pub struct Config {
tls_addr: SocketAddr, tls_addr: SocketAddr,
} }
#[route("/<_..>")]
fn redirect(config: &State<Config>, uri: &Origin<'_>, host: &Host<'_>) -> Redirect {
// FIXME: Check the host against a whitelist!
let domain = host.domain();
let https_uri = match config.tls_addr.port() {
443 => format!("https://{domain}{uri}"),
port => format!("https://{domain}:{port}{uri}"),
};
Redirect::permanent(https_uri)
}
impl Redirector { impl Redirector {
pub fn on(port: u16) -> Self { pub fn on(port: u16) -> Self {
Redirector(port) Redirector(port)
} }
// Route function that gets called on every single request.
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
// FIXME: Check the host against a whitelist!
let config = req.rocket().state::<Config>().expect("managed Self");
if let Some(host) = req.host() {
let domain = host.domain();
let https_uri = match config.tls_addr.port() {
443 => format!("https://{domain}{}", req.uri()),
port => format!("https://{domain}:{port}{}", req.uri()),
};
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
} else {
route::Outcome::from(req, Status::BadRequest).pin()
}
}
// Launch an instance of Rocket than handles redirection on `self.port`. // Launch an instance of Rocket than handles redirection on `self.port`.
pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> { pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
use rocket::http::Method::*; rocket::span_info!("HTTP -> HTTPS Redirector" => {
info!(from = self.0, to = config.tls_addr.port(), "redirecting");
});
// Build a vector of routes to `redirect` on `<path..>` for each method.
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
.into_iter()
.map(|m| Route::new(m, "/<path..>", Self::redirect))
.collect::<Vec<_>>();
info!(from = self.0, to = config.tls_addr.port(), "redirecting");
let addr = SocketAddr::new(config.tls_addr.ip(), self.0); let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
rocket::custom(&config.server) rocket::custom(&config.server)
.manage(config) .manage(config)
.mount("/", redirects) .mount("/", routes![redirect])
.try_launch_on(TcpListener::bind(addr)) .try_launch_on(TcpListener::bind(addr))
.await .await
} }

View File

@ -4,7 +4,7 @@ use crate::prelude::*;
use rocket::http::Method; use rocket::http::Method;
#[route(PROPFIND, uri = "/")] #[route("/", method = PROPFIND)]
fn route() -> &'static str { fn route() -> &'static str {
"Hello, World!" "Hello, World!"
} }