Support routes that match any method.

This commit introduces support for method-less routes and route
attributes, which match _any_ valid method: `#[route("/")]`. The `Route`
structure's `method` field is now accordingly of type `Option<Route>`.

The syntax for the `route` attribute has changed in a breaking manner.
To set a method, a key/value of `method = NAME` must be introduced:

```rust
#[route("/", method = GET)]
```

If the method's name is a valid identifier, it can be used without
quotes. Otherwise it must be quoted:

```rust
// `GET` is a valid identifier, but `VERSION-CONTROL` is not
#[route("/", method = "VERSION-CONTROL")]
```

Closes #2731.
This commit is contained in:
Sergio Benitez 2024-06-03 21:54:07 -07:00
parent 9496b70e8c
commit 72c91958b7
19 changed files with 297 additions and 205 deletions

View File

@ -58,9 +58,9 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec<L
.join("&");
let uri = format!("/{}?{}", path, query);
let mut req = client.req(route.method, uri);
let mut req = client.req(route.method.unwrap(), uri);
if let Some(ref format) = route.format {
if let Some(true) = route.method.allows_request_body() {
if let Some(true) = route.method.and_then(|m| m.allows_request_body()) {
req.add_header(ContentType::from(format.clone()));
} else {
req.add_header(Accept::from(format.clone()));

View File

@ -263,7 +263,7 @@ fn internal_uri_macro_decl(route: &Route) -> TokenStream {
// Generate a unique macro name based on the route's metadata.
let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX);
let inner_macro_name = macro_name.uniqueify_with(|mut hasher| {
route.attr.method.0.hash(&mut hasher);
route.attr.method.as_ref().map(|m| m.0.hash(&mut hasher));
route.attr.uri.path().hash(&mut hasher);
route.attr.uri.query().hash(&mut hasher);
route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher));
@ -395,7 +395,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
let internal_uri_macro = internal_uri_macro_decl(&route);
let responder_outcome = responder_outcome_expr(&route);
let method = &route.attr.method;
let method = Optional(route.attr.method.clone());
let uri = route.attr.uri.to_string();
let rank = Optional(route.attr.rank);
let format = Optional(route.attr.format.as_ref());
@ -480,9 +480,12 @@ fn incomplete_route(
let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?;
let attribute = Attribute {
method: SpanWrapped {
full_span: method_span, key_span: None, span: method_span, value: Method(method)
},
method: Some(SpanWrapped {
full_span: method_span,
key_span: None,
span: method_span,
value: Method(method),
}),
uri: method_attribute.uri,
data: method_attribute.data,
format: method_attribute.format,

View File

@ -43,8 +43,8 @@ pub struct Arguments {
#[derive(Debug, FromMeta)]
pub struct Attribute {
#[meta(naked)]
pub method: SpanWrapped<Method>,
pub uri: RouteUri,
pub method: Option<SpanWrapped<Method>>,
pub data: Option<SpanWrapped<Dynamic>>,
pub format: Option<MediaType>,
pub rank: Option<isize>,
@ -129,17 +129,23 @@ impl Route {
// Emit a warning if a `data` param was supplied for non-payload methods.
if let Some(ref data) = attr.data {
let lint = Lint::DubiousPayload;
match attr.method.0.allows_request_body() {
None if lint.enabled(handler.span()) => {
data.full_span.warning("`data` used with non-payload-supporting method")
.note(format!("'{}' does not typically support payloads", attr.method.0))
.note(lint.how_to_suppress())
.emit_as_item_tokens();
}
Some(false) => {
match attr.method.as_ref() {
Some(m) if m.0.allows_request_body() == Some(false) => {
diags.push(data.full_span
.error("`data` cannot be used on this route")
.span_note(attr.method.span, "method does not support request payloads"))
.span_note(m.span, "method does not support request payloads"))
},
Some(m) if m.0.allows_request_body().is_none() && lint.enabled(handler.span()) => {
data.full_span.warning("`data` used with non-payload-supporting method")
.span_note(m.span, format!("'{}' does not typically support payloads", m.0))
.note(lint.how_to_suppress())
.emit_as_item_tokens();
},
None if lint.enabled(handler.span()) => {
data.full_span.warning("`data` used on route with wildcard method")
.note("some methods may not support request payloads")
.note(lint.how_to_suppress())
.emit_as_item_tokens();
}
_ => { /* okay */ },
}

View File

@ -119,15 +119,20 @@ macro_rules! route_attribute {
/// * [`patch`] - `PATCH` specific route
///
/// Additionally, [`route`] allows the method and uri to be explicitly
/// specified:
/// specified, and for the method to be omitted entirely, to match any
/// method:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// #
/// #[route(GET, uri = "/")]
/// fn index() -> &'static str {
/// "Hello, world!"
/// }
///
/// #[route("/", method = GET)]
/// fn get_index() { /* ... */ }
///
/// #[route("/", method = "VERSION-CONTROL")]
/// fn versioned_index() { /* ... */ }
///
/// #[route("/")]
/// fn index() { /* ... */ }
/// ```
///
/// [`get`]: attr.get.html
@ -171,7 +176,9 @@ macro_rules! route_attribute {
/// The generic route attribute is defined as:
///
/// ```text
/// generic-route := METHOD ',' 'uri' '=' route
/// generic-route := route (',' method)?
///
/// method := 'method' '=' METHOD
/// ```
///
/// # Typing Requirements
@ -1161,12 +1168,12 @@ pub fn derive_uri_display_path(input: TokenStream) -> TokenStream {
/// assert_eq!(my_routes.len(), 2);
///
/// let index_route = &my_routes[0];
/// assert_eq!(index_route.method, Method::Get);
/// assert_eq!(index_route.method, Some(Method::Get));
/// assert_eq!(index_route.name.as_ref().unwrap(), "index");
/// assert_eq!(index_route.uri.path(), "/");
///
/// let hello_route = &my_routes[1];
/// assert_eq!(hello_route.method, Method::Post);
/// assert_eq!(hello_route.method, Some(Method::Post));
/// assert_eq!(hello_route.name.as_ref().unwrap(), "hello");
/// assert_eq!(hello_route.uri.path(), "/hi/<person>");
/// ```

View File

@ -54,8 +54,8 @@ fn post1(
}
#[route(
POST,
uri = "/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
"/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
method = POST,
format = "json",
data = "<simple>",
rank = 138

View File

@ -124,6 +124,10 @@ macro_rules! define_methods {
#[doc(hidden)]
pub const ALL: &'static [&'static str] = &[$($name),*];
/// A slice containing every defined method variant.
#[doc(hidden)]
pub const ALL_VARIANTS: &'static [Method] = &[$(Self::$V),*];
/// Whether the method is considered "safe".
///
/// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1):

View File

@ -16,7 +16,7 @@ struct ArbitraryRequestData<'a> {
#[derive(Arbitrary)]
struct ArbitraryRouteData<'a> {
method: ArbitraryMethod,
method: Option<ArbitraryMethod>,
uri: ArbitraryRouteUri<'a>,
format: Option<ArbitraryMediaType>,
}
@ -24,7 +24,7 @@ struct ArbitraryRouteData<'a> {
impl std::fmt::Debug for ArbitraryRouteData<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArbitraryRouteData")
.field("method", &self.method.0)
.field("method", &self.method.map(|v| v.0))
.field("base", &self.uri.0.base())
.field("unmounted", &self.uri.0.unmounted().to_string())
.field("uri", &self.uri.0.to_string())
@ -59,12 +59,14 @@ impl<'c, 'a: 'c> ArbitraryRequestData<'a> {
impl<'a> ArbitraryRouteData<'a> {
fn into_route(self) -> Route {
let mut r = Route::ranked(0, self.method.0, &self.uri.0.to_string(), dummy_handler);
let method = self.method.map(|m| m.0);
let mut r = Route::ranked(0, method, &self.uri.0.to_string(), dummy_handler);
r.format = self.format.map(|f| f.0);
r
}
}
#[derive(Clone, Copy)]
struct ArbitraryMethod(Method);
struct ArbitraryOrigin<'a>(Origin<'a>);
@ -79,12 +81,7 @@ struct ArbitraryRouteUri<'a>(RouteUri<'a>);
impl<'a> Arbitrary<'a> for ArbitraryMethod {
fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
let all_methods = &[
Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options,
Method::Head, Method::Trace, Method::Connect, Method::Patch
];
Ok(ArbitraryMethod(*u.choose(all_methods)?))
Ok(ArbitraryMethod(*u.choose(Method::ALL_VARIANTS)?))
}
fn size_hint(_: usize) -> (usize, Option<usize>) {

View File

@ -4,7 +4,7 @@ use figment::Figment;
use crate::listener::Endpoint;
use crate::shutdown::Stages;
use crate::{Catcher, Config, Rocket, Route};
use crate::router::Router;
use crate::router::{Router, Finalized};
use crate::fairing::Fairings;
mod private {
@ -100,7 +100,7 @@ phases! {
/// represents a fully built and finalized application server ready for
/// launch into orbit. See [`Rocket#ignite`] for full details.
Ignite (#[derive(Debug)] Igniting) {
pub(crate) router: Router,
pub(crate) router: Router<Finalized>,
pub(crate) fairings: Fairings,
pub(crate) figment: Figment,
pub(crate) config: Config,
@ -114,7 +114,7 @@ phases! {
/// An instance of `Rocket` in this phase is typed as [`Rocket<Orbit>`] and
/// represents a running application.
Orbit (#[derive(Debug)] Orbiting) {
pub(crate) router: Router,
pub(crate) router: Router<Finalized>,
pub(crate) fairings: Fairings,
pub(crate) figment: Figment,
pub(crate) config: Config,

View File

@ -557,9 +557,10 @@ impl Rocket<Build> {
// Initialize the router; check for collisions.
let mut router = Router::new();
self.routes.clone().into_iter().for_each(|r| router.add_route(r));
self.catchers.clone().into_iter().for_each(|c| router.add_catcher(c));
router.finalize().map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?;
self.routes.clone().into_iter().for_each(|r| router.routes.push(r));
self.catchers.clone().into_iter().for_each(|c| router.catchers.push(c));
let router = router.finalize()
.map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?;
// Finally, freeze managed state for faster access later.
self.state.freeze();
@ -840,8 +841,8 @@ impl<P: Phase> Rocket<P> {
pub fn routes(&self) -> impl Iterator<Item = &Route> {
match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.routes.iter()),
StateRef::Ignite(p) => Either::Right(p.router.routes()),
StateRef::Orbit(p) => Either::Right(p.router.routes()),
StateRef::Ignite(p) => Either::Right(p.router.routes.iter()),
StateRef::Orbit(p) => Either::Right(p.router.routes.iter()),
}
}
@ -871,8 +872,8 @@ impl<P: Phase> Rocket<P> {
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
match self.0.as_ref() {
StateRef::Build(p) => Either::Left(p.catchers.iter()),
StateRef::Ignite(p) => Either::Right(p.router.catchers()),
StateRef::Orbit(p) => Either::Right(p.router.catchers()),
StateRef::Ignite(p) => Either::Right(p.router.catchers.iter()),
StateRef::Orbit(p) => Either::Right(p.router.catchers.iter()),
}
}

View File

@ -22,7 +22,7 @@ use crate::sentinel::Sentry;
///
/// let route = routes![route_name].remove(0);
/// assert_eq!(route.name.unwrap(), "route_name");
/// assert_eq!(route.method, Method::Get);
/// assert_eq!(route.method, Some(Method::Get));
/// assert_eq!(route.uri, "/route/<path..>?query");
/// assert_eq!(route.rank, 2);
/// assert_eq!(route.format.unwrap(), MediaType::JSON);
@ -164,8 +164,8 @@ use crate::sentinel::Sentry;
pub struct Route {
/// The name of this route, if one was given.
pub name: Option<Cow<'static, str>>,
/// The method this route matches against.
pub method: Method,
/// The method this route matches, or `None` to match any method.
pub method: Option<Method>,
/// The function that should be called when the route matches.
pub handler: Box<dyn Handler>,
/// The route URI.
@ -203,12 +203,12 @@ impl Route {
/// // this is a route matching requests to `GET /`
/// let index = Route::new(Method::Get, "/", handler);
/// assert_eq!(index.rank, -9);
/// assert_eq!(index.method, Method::Get);
/// assert_eq!(index.method, Some(Method::Get));
/// assert_eq!(index.uri, "/");
/// ```
#[track_caller]
pub fn new<H: Handler>(method: Method, uri: &str, handler: H) -> Route {
Route::ranked(None, method, uri, handler)
pub fn new<M: Into<Option<Method>>, H: Handler>(method: M, uri: &str, handler: H) -> Route {
Route::ranked(None, method.into(), uri, handler)
}
/// Creates a new route with the given rank, method, path, and handler with
@ -233,17 +233,19 @@ impl Route {
///
/// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler);
/// assert_eq!(foo.rank, 1);
/// assert_eq!(foo.method, Method::Post);
/// assert_eq!(foo.method, Some(Method::Post));
/// assert_eq!(foo.uri, "/foo?bar");
///
/// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler);
/// assert_eq!(foo.rank, -12);
/// assert_eq!(foo.method, Method::Post);
/// assert_eq!(foo.method, Some(Method::Post));
/// assert_eq!(foo.uri, "/foo?bar");
/// ```
#[track_caller]
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route
where H: Handler + 'static, R: Into<Option<isize>>,
pub fn ranked<M, H, R>(rank: R, method: M, uri: &str, handler: H) -> Route
where M: Into<Option<Method>>,
H: Handler + 'static,
R: Into<Option<isize>>,
{
let uri = RouteUri::new("/", uri);
let rank = rank.into().unwrap_or_else(|| uri.default_rank());
@ -253,7 +255,9 @@ impl Route {
sentinels: Vec::new(),
handler: Box::new(handler),
location: None,
rank, uri, method,
method: method.into(),
rank,
uri,
}
}
@ -362,7 +366,7 @@ pub struct StaticInfo {
/// The route's name, i.e, the name of the function.
pub name: &'static str,
/// The route's method.
pub method: Method,
pub method: Option<Method>,
/// The route's URi, without the base mount point.
pub uri: &'static str,
/// The route's format, if any.

View File

@ -1,7 +1,7 @@
use crate::catcher::Catcher;
use crate::route::{Route, Segment, RouteUri};
use crate::http::MediaType;
use crate::http::{MediaType, Method};
pub trait Collide<T = Self> {
fn collides_with(&self, other: &T) -> bool;
@ -87,7 +87,7 @@ impl Route {
/// assert!(a.collides_with(&b));
/// ```
pub fn collides_with(&self, other: &Route) -> bool {
self.method == other.method
methods_collide(self, other)
&& self.rank == other.rank
&& self.uri.collides_with(&other.uri)
&& formats_collide(self, other)
@ -190,8 +190,16 @@ impl Collide for MediaType {
}
}
fn methods_collide(route: &Route, other: &Route) -> bool {
match (route.method, other.method) {
(Some(a), Some(b)) => a == b,
(None, _) | (_, None) => true,
}
}
fn formats_collide(route: &Route, other: &Route) -> bool {
match (route.method.allows_request_body(), other.method.allows_request_body()) {
let payload_support = |m: &Option<Method>| m.and_then(|m| m.allows_request_body());
match (payload_support(&route.method), payload_support(&other.method)) {
// Payload supporting methods match against `Content-Type` which must be
// fully specified, so the request cannot contain a format that matches
// more than one route format as long as those formats don't collide.

View File

@ -67,8 +67,7 @@ impl Route {
/// ```
#[tracing::instrument(level = "trace", name = "matching", skip_all, ret)]
pub fn matches(&self, request: &Request<'_>) -> bool {
trace!(route.method = %self.method, request.method = %request.method());
self.method == request.method()
methods_match(self, request)
&& paths_match(self, request)
&& queries_match(self, request)
&& formats_match(self, request)
@ -140,6 +139,11 @@ impl Catcher {
}
}
fn methods_match(route: &Route, req: &Request<'_>) -> bool {
trace!(?route.method, request.method = %req.method());
route.method.map_or(true, |method| method == req.method())
}
fn paths_match(route: &Route, req: &Request<'_>) -> bool {
trace!(route.uri = %route.uri, request.uri = %req.uri());
let route_segments = &route.uri.metadata.uri_segments;
@ -208,7 +212,7 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool {
None => return true,
};
match route.method.allows_request_body() {
match route.method.and_then(|m| m.allows_request_body()) {
Some(true) => match req.format() {
Some(f) if f.specificity() == 2 => route_format.collides_with(f),
_ => false

View File

@ -1,64 +1,120 @@
use std::ops::{Deref, DerefMut};
use std::collections::HashMap;
use crate::request::Request;
use crate::http::{Method, Status};
use crate::{Route, Catcher};
use crate::router::Collide;
#[derive(Debug)]
pub(crate) struct Router<T>(T);
#[derive(Debug, Default)]
pub(crate) struct Router {
routes: HashMap<Method, Vec<Route>>,
catchers: HashMap<Option<u16>, Vec<Catcher>>,
pub struct Pending {
pub routes: Vec<Route>,
pub catchers: Vec<Catcher>,
}
pub type Collisions<T> = Vec<(T, T)>;
#[derive(Debug, Default)]
pub struct Finalized {
pub routes: Vec<Route>,
pub catchers: Vec<Catcher>,
route_map: HashMap<Method, Vec<usize>>,
catcher_map: HashMap<Option<u16>, Vec<usize>>,
}
impl Router {
pub type Pair<T> = (T, T);
pub type Collisions = (Vec<Pair<Route>>, Vec<Pair<Catcher>>);
pub type Result<T, E = Collisions> = std::result::Result<T, E>;
impl Router<Pending> {
pub fn new() -> Self {
Self::default()
Router(Pending::default())
}
pub fn add_route(&mut self, route: Route) {
let routes = self.routes.entry(route.method).or_default();
routes.push(route);
routes.sort_by_key(|r| r.rank);
}
pub fn finalize(self) -> Result<Router<Finalized>, Collisions> {
fn collisions<'a, T>(items: &'a [T]) -> impl Iterator<Item = (T, T)> + 'a
where T: Collide + Clone + 'a,
{
items.iter()
.enumerate()
.flat_map(move |(i, a)| {
items.iter()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
}
pub fn add_catcher(&mut self, catcher: Catcher) {
let catchers = self.catchers.entry(catcher.code).or_default();
catchers.push(catcher);
catchers.sort_by_key(|c| c.rank);
}
let route_collisions: Vec<_> = collisions(&self.routes).collect();
let catcher_collisions: Vec<_> = collisions(&self.catchers).collect();
#[inline]
pub fn routes(&self) -> impl Iterator<Item = &Route> + Clone {
self.routes.values().flat_map(|v| v.iter())
}
if !route_collisions.is_empty() || !catcher_collisions.is_empty() {
return Err((route_collisions, catcher_collisions))
}
#[inline]
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + Clone {
self.catchers.values().flat_map(|v| v.iter())
}
// create the route map
let mut route_map: HashMap<Method, Vec<usize>> = HashMap::new();
for (i, route) in self.routes.iter().enumerate() {
match route.method {
Some(method) => route_map.entry(method).or_default().push(i),
None => for method in Method::ALL_VARIANTS {
route_map.entry(*method).or_default().push(i);
}
}
}
// create the catcher map
let mut catcher_map: HashMap<Option<u16>, Vec<usize>> = HashMap::new();
for (i, catcher) in self.catchers.iter().enumerate() {
catcher_map.entry(catcher.code).or_default().push(i);
}
// sort routes by rank
for routes in route_map.values_mut() {
routes.sort_by_key(|&i| &self.routes[i].rank);
}
// sort catchers by rank
for catchers in catcher_map.values_mut() {
catchers.sort_by_key(|&i| &self.catchers[i].rank);
}
Ok(Router(Finalized {
routes: self.0.routes,
catchers: self.0.catchers,
route_map, catcher_map
}))
}
}
impl Router<Finalized> {
#[track_caller]
pub fn route<'r, 'a: 'r>(
&'a self,
req: &'r Request<'r>
) -> impl Iterator<Item = &'a Route> + 'r {
// Note that routes are presorted by ascending rank on each `add`.
self.routes.get(&req.method())
// Note that routes are presorted by ascending rank on each `add` and
// that all routes with `None` methods have been cloned into all methods.
self.route_map.get(&req.method())
.into_iter()
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
.flat_map(move |routes| routes.iter().map(move |&i| &self.routes[i]))
.filter(move |r| r.matches(req))
}
// For many catchers, using aho-corasick or similar should be much faster.
#[track_caller]
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> {
// Note that catchers are presorted by descending base length.
let explicit = self.catchers.get(&Some(status.code))
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
let explicit = self.catcher_map.get(&Some(status.code))
.map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
let default = self.catchers.get(&None)
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
let default = self.catcher_map.get(&None)
.map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
match (explicit, default) {
(None, None) => None,
@ -67,28 +123,19 @@ impl Router {
(Some(_), Some(b)) => Some(b),
}
}
}
fn collisions<'a, I, T>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a
where I: Iterator<Item = &'a T> + Clone + 'a, T: Collide + Clone + 'a,
{
items.clone().enumerate()
.flat_map(move |(i, a)| {
items.clone()
.skip(i + 1)
.filter(move |b| a.collides_with(b))
.map(move |b| (a.clone(), b.clone()))
})
impl<T> Deref for Router<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub fn finalize(&self) -> Result<(), (Collisions<Route>, Collisions<Catcher>)> {
let routes: Vec<_> = self.collisions(self.routes()).collect();
let catchers: Vec<_> = self.collisions(self.catchers()).collect();
if !routes.is_empty() || !catchers.is_empty() {
return Err((routes, catchers))
}
Ok(())
impl DerefMut for Router<Pending> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
@ -100,50 +147,32 @@ mod test {
use crate::local::blocking::Client;
use crate::http::{Method::*, uri::Origin};
impl Router {
fn has_collisions(&self) -> bool {
self.finalize().is_err()
}
}
fn router_with_routes(routes: &[&'static str]) -> Router {
fn make_router<I>(routes: I) -> Result<Router<Finalized>, Collisions>
where I: Iterator<Item = (Option<isize>, &'static str)>
{
let mut router = Router::new();
for route in routes {
let route = Route::new(Get, route, dummy_handler);
router.add_route(route);
}
router
}
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
let mut router = Router::new();
for &(rank, route) in routes {
for (rank, route) in routes {
let route = Route::ranked(rank, Get, route, dummy_handler);
router.add_route(route);
router.routes.push(route);
}
router
router.finalize()
}
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
let mut router = Router::new();
for route in routes {
let route = Route::ranked(0, Get, route, dummy_handler);
router.add_route(route);
}
fn router_with_routes(routes: &[&'static str]) -> Router<Finalized> {
make_router(routes.iter().map(|r| (None, *r))).unwrap()
}
router
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router<Finalized> {
make_router(routes.iter().map(|r| (Some(r.0), r.1))).unwrap()
}
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_rankless_routes(routes);
router.has_collisions()
make_router(routes.iter().map(|r| (Some(0), *r))).is_err()
}
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
let router = router_with_routes(routes);
router.has_collisions()
make_router(routes.iter().map(|r| (None, *r))).is_err()
}
#[test]
@ -280,13 +309,15 @@ mod test {
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
}
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
#[track_caller]
fn matches<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Vec<&'a Route> {
let client = Client::debug_with(vec![]).expect("client");
let request = client.req(method, Origin::parse(uri).unwrap());
router.route(&request).collect()
}
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
#[track_caller]
fn route<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Option<&'a Route> {
matches(router, method, uri).into_iter().next()
}
@ -309,9 +340,10 @@ mod test {
assert!(route(&router, Get, "/a/").is_some());
let mut router = Router::new();
router.add_route(Route::new(Put, "/hello", dummy_handler));
router.add_route(Route::new(Post, "/hello", dummy_handler));
router.add_route(Route::new(Delete, "/hello", dummy_handler));
router.routes.push(Route::new(Put, "/hello", dummy_handler));
router.routes.push(Route::new(Post, "/hello", dummy_handler));
router.routes.push(Route::new(Delete, "/hello", dummy_handler));
let router = router.finalize().unwrap();
assert!(route(&router, Put, "/hello").is_some());
assert!(route(&router, Post, "/hello").is_some());
assert!(route(&router, Delete, "/hello").is_some());
@ -368,7 +400,6 @@ mod test {
macro_rules! assert_ranked_match {
($routes:expr, $to:expr => $want:expr) => ({
let router = router_with_routes($routes);
assert!(!router.has_collisions());
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
assert_eq!(route_path, $want.to_string(),
"\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router);
@ -401,8 +432,7 @@ mod test {
}
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
let router = router_with_ranked_routes(routes);
router.has_collisions()
make_router(routes.iter().map(|r| (Some(r.0), r.1))).is_err()
}
#[test]
@ -429,7 +459,7 @@ mod test {
let router = router_with_ranked_routes(&$routes);
let routed_to = matches(&router, Get, $to);
let expected = &[$($want),+];
assert!(routed_to.len() == expected.len());
assert_eq!(routed_to.len(), expected.len());
for (got, expected) in routed_to.iter().zip(expected.iter()) {
assert_eq!(got.rank, expected.0);
assert_eq!(got.uri.to_string(), expected.1.to_string());
@ -545,20 +575,21 @@ mod test {
);
}
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Router {
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Result<Router<Finalized>> {
let mut router = Router::new();
for (code, base) in catchers {
let catcher = Catcher::new(*code, crate::catcher::dummy_handler);
router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap());
router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap());
}
router
router.finalize()
}
fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> {
#[track_caller]
fn catcher<'a>(r: &'a Router<Finalized>, status: Status, uri: &str) -> Option<&'a Catcher> {
let client = Client::debug_with(vec![]).expect("client");
let request = client.get(Origin::parse(uri).unwrap());
router.catch(status, &request)
r.catch(status, &request)
}
macro_rules! assert_catcher_routing {
@ -571,7 +602,7 @@ mod test {
let requests = vec![$($r),+];
let expected = vec![$(($ecode.into(), $euri)),+];
let router = router_with_catchers(&catchers);
let router = router_with_catchers(&catchers).expect("valid router");
for (req, expected) in requests.iter().zip(expected.iter()) {
let req_status = Status::from_code(req.0).expect("valid status");
let catcher = catcher(&router, req_status, req.1).expect("some catcher");

View File

@ -142,7 +142,10 @@ impl Trace for Route {
event! { level, "route",
name = self.name.as_ref().map(|n| &**n),
rank = self.rank,
method = %self.method,
method = %Formatter(|f| match self.method {
Some(method) => write!(f, "{}", method),
None => write!(f, "[any]"),
}),
uri = %self.uri,
uri.base = %self.uri.base(),
uri.unmounted = %self.uri.unmounted(),

View File

@ -13,13 +13,13 @@ fn patch(form_data: Form<FormData>) -> &'static str {
"PATCH OK"
}
#[route(UPDATEREDIRECTREF, uri = "/", data = "<form_data>")]
#[route("/", method = UPDATEREDIRECTREF, data = "<form_data>")]
fn urr(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data);
"UPDATEREDIRECTREF OK"
}
#[route("VERSION-CONTROL", uri = "/", data = "<form_data>")]
#[route("/", method = "VERSION-CONTROL", data = "<form_data>")]
fn vc(form_data: Form<FormData>) -> &'static str {
assert_eq!("Form data", form_data.into_inner().form_data);
"VERSION-CONTROL OK"

View File

@ -37,21 +37,55 @@ these properties and more.
## Methods
A Rocket route attribute can be any one of `get`, `put`, `post`, `delete`,
`head`, `patch`, or `options`, each corresponding to the HTTP method to match
against. For example, the following attribute will match against `POST` requests
to the root path:
A Rocket route attribute can either be method-specific, any one of `get`, `put`,
`post`, `delete`, `head`, `patch`, or `options`, or the generic [`route`], which
allows explicitly specifying any valid HTTP [`Method`] or no method at all, to
match again _any_ method. Consider the following examples:
```rust
# #[macro_use] extern crate rocket;
# fn main() {}
* Match a `POST` request to `/`:
#[post("/")]
# fn handler() {}
```
```rust
# use rocket::post;
#[post("/")]
# fn handler() {}
```
* Match a `PATCH` request to `/fix`:
```rust
# use rocket::patch;
#[patch("/fix")]
# fn handler() {}
```
* Match a `PROPFIND` request to `/collection`:
```rust
# use rocket::route;
#[route("/collection", method = PROPFIND)]
# fn handler() {}
```
* Match a `VERSION-CONTROL` request to `/collection`:
```rust
# use rocket::route;
#[route("/resource", method = "VERSION-CONTROL")]
# fn handler() {}
```
* Match a request to `/page` with _any_ method:
```rust
# use rocket::route;
#[route("/page")]
# fn handler() {}
```
The grammar for these attributes is defined formally in the [`route`] API docs.
[`Method`]: @api/master/rocket/http/enum.Method.html
### HEAD Requests
Rocket handles `HEAD` requests automatically when there exists a `GET` route

View File

@ -33,7 +33,7 @@ fn mir() -> &'static str {
// Try visiting:
// http://127.0.0.1:8000/wave/Rocketeer/100
#[get("/<name>/<age>")]
#[get("/<name>/<age>", rank = 2)]
fn wave(name: &str, age: u8) -> String {
format!("👋 Hello, {} year old named {}!", age, name)
}

View File

@ -2,9 +2,9 @@
use std::net::SocketAddr;
use rocket::http::Status;
use rocket::http::uri::{Origin, Host};
use rocket::tracing::{self, Instrument};
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
use rocket::{Rocket, Ignite, Orbit, State, Error};
use rocket::fairing::{Fairing, Info, Kind};
use rocket::response::Redirect;
use rocket::listener::tcp::TcpListener;
@ -19,43 +19,33 @@ pub struct Config {
tls_addr: SocketAddr,
}
#[route("/<_..>")]
fn redirect(config: &State<Config>, uri: &Origin<'_>, host: &Host<'_>) -> Redirect {
// FIXME: Check the host against a whitelist!
let domain = host.domain();
let https_uri = match config.tls_addr.port() {
443 => format!("https://{domain}{uri}"),
port => format!("https://{domain}:{port}{uri}"),
};
Redirect::permanent(https_uri)
}
impl Redirector {
pub fn on(port: u16) -> Self {
Redirector(port)
}
// Route function that gets called on every single request.
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
// FIXME: Check the host against a whitelist!
let config = req.rocket().state::<Config>().expect("managed Self");
if let Some(host) = req.host() {
let domain = host.domain();
let https_uri = match config.tls_addr.port() {
443 => format!("https://{domain}{}", req.uri()),
port => format!("https://{domain}:{port}{}", req.uri()),
};
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
} else {
route::Outcome::from(req, Status::BadRequest).pin()
}
}
// Launch an instance of Rocket than handles redirection on `self.port`.
pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
use rocket::http::Method::*;
rocket::span_info!("HTTP -> HTTPS Redirector" => {
info!(from = self.0, to = config.tls_addr.port(), "redirecting");
});
// Build a vector of routes to `redirect` on `<path..>` for each method.
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
.into_iter()
.map(|m| Route::new(m, "/<path..>", Self::redirect))
.collect::<Vec<_>>();
info!(from = self.0, to = config.tls_addr.port(), "redirecting");
let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
rocket::custom(&config.server)
.manage(config)
.mount("/", redirects)
.mount("/", routes![redirect])
.try_launch_on(TcpListener::bind(addr))
.await
}

View File

@ -4,7 +4,7 @@ use crate::prelude::*;
use rocket::http::Method;
#[route(PROPFIND, uri = "/")]
#[route("/", method = PROPFIND)]
fn route() -> &'static str {
"Hello, World!"
}