mirror of https://github.com/rwf2/Rocket.git
Support routes that match any method.
This commit introduces support for method-less routes and route attributes, which match _any_ valid method: `#[route("/")]`. The `Route` structure's `method` field is now accordingly of type `Option<Route>`. The syntax for the `route` attribute has changed in a breaking manner. To set a method, a key/value of `method = NAME` must be introduced: ```rust #[route("/", method = GET)] ``` If the method's name is a valid identifier, it can be used without quotes. Otherwise it must be quoted: ```rust // `GET` is a valid identifier, but `VERSION-CONTROL` is not #[route("/", method = "VERSION-CONTROL")] ``` Closes #2731.
This commit is contained in:
parent
9496b70e8c
commit
72c91958b7
|
@ -58,9 +58,9 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec<L
|
|||
.join("&");
|
||||
|
||||
let uri = format!("/{}?{}", path, query);
|
||||
let mut req = client.req(route.method, uri);
|
||||
let mut req = client.req(route.method.unwrap(), uri);
|
||||
if let Some(ref format) = route.format {
|
||||
if let Some(true) = route.method.allows_request_body() {
|
||||
if let Some(true) = route.method.and_then(|m| m.allows_request_body()) {
|
||||
req.add_header(ContentType::from(format.clone()));
|
||||
} else {
|
||||
req.add_header(Accept::from(format.clone()));
|
||||
|
|
|
@ -263,7 +263,7 @@ fn internal_uri_macro_decl(route: &Route) -> TokenStream {
|
|||
// Generate a unique macro name based on the route's metadata.
|
||||
let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX);
|
||||
let inner_macro_name = macro_name.uniqueify_with(|mut hasher| {
|
||||
route.attr.method.0.hash(&mut hasher);
|
||||
route.attr.method.as_ref().map(|m| m.0.hash(&mut hasher));
|
||||
route.attr.uri.path().hash(&mut hasher);
|
||||
route.attr.uri.query().hash(&mut hasher);
|
||||
route.attr.data.as_ref().map(|d| d.value.hash(&mut hasher));
|
||||
|
@ -395,7 +395,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
|
|||
let internal_uri_macro = internal_uri_macro_decl(&route);
|
||||
let responder_outcome = responder_outcome_expr(&route);
|
||||
|
||||
let method = &route.attr.method;
|
||||
let method = Optional(route.attr.method.clone());
|
||||
let uri = route.attr.uri.to_string();
|
||||
let rank = Optional(route.attr.rank);
|
||||
let format = Optional(route.attr.format.as_ref());
|
||||
|
@ -480,9 +480,12 @@ fn incomplete_route(
|
|||
let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?;
|
||||
|
||||
let attribute = Attribute {
|
||||
method: SpanWrapped {
|
||||
full_span: method_span, key_span: None, span: method_span, value: Method(method)
|
||||
},
|
||||
method: Some(SpanWrapped {
|
||||
full_span: method_span,
|
||||
key_span: None,
|
||||
span: method_span,
|
||||
value: Method(method),
|
||||
}),
|
||||
uri: method_attribute.uri,
|
||||
data: method_attribute.data,
|
||||
format: method_attribute.format,
|
||||
|
|
|
@ -43,8 +43,8 @@ pub struct Arguments {
|
|||
#[derive(Debug, FromMeta)]
|
||||
pub struct Attribute {
|
||||
#[meta(naked)]
|
||||
pub method: SpanWrapped<Method>,
|
||||
pub uri: RouteUri,
|
||||
pub method: Option<SpanWrapped<Method>>,
|
||||
pub data: Option<SpanWrapped<Dynamic>>,
|
||||
pub format: Option<MediaType>,
|
||||
pub rank: Option<isize>,
|
||||
|
@ -129,17 +129,23 @@ impl Route {
|
|||
// Emit a warning if a `data` param was supplied for non-payload methods.
|
||||
if let Some(ref data) = attr.data {
|
||||
let lint = Lint::DubiousPayload;
|
||||
match attr.method.0.allows_request_body() {
|
||||
None if lint.enabled(handler.span()) => {
|
||||
data.full_span.warning("`data` used with non-payload-supporting method")
|
||||
.note(format!("'{}' does not typically support payloads", attr.method.0))
|
||||
.note(lint.how_to_suppress())
|
||||
.emit_as_item_tokens();
|
||||
}
|
||||
Some(false) => {
|
||||
match attr.method.as_ref() {
|
||||
Some(m) if m.0.allows_request_body() == Some(false) => {
|
||||
diags.push(data.full_span
|
||||
.error("`data` cannot be used on this route")
|
||||
.span_note(attr.method.span, "method does not support request payloads"))
|
||||
.span_note(m.span, "method does not support request payloads"))
|
||||
},
|
||||
Some(m) if m.0.allows_request_body().is_none() && lint.enabled(handler.span()) => {
|
||||
data.full_span.warning("`data` used with non-payload-supporting method")
|
||||
.span_note(m.span, format!("'{}' does not typically support payloads", m.0))
|
||||
.note(lint.how_to_suppress())
|
||||
.emit_as_item_tokens();
|
||||
},
|
||||
None if lint.enabled(handler.span()) => {
|
||||
data.full_span.warning("`data` used on route with wildcard method")
|
||||
.note("some methods may not support request payloads")
|
||||
.note(lint.how_to_suppress())
|
||||
.emit_as_item_tokens();
|
||||
}
|
||||
_ => { /* okay */ },
|
||||
}
|
||||
|
|
|
@ -119,15 +119,20 @@ macro_rules! route_attribute {
|
|||
/// * [`patch`] - `PATCH` specific route
|
||||
///
|
||||
/// Additionally, [`route`] allows the method and uri to be explicitly
|
||||
/// specified:
|
||||
/// specified, and for the method to be omitted entirely, to match any
|
||||
/// method:
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// #
|
||||
/// #[route(GET, uri = "/")]
|
||||
/// fn index() -> &'static str {
|
||||
/// "Hello, world!"
|
||||
/// }
|
||||
///
|
||||
/// #[route("/", method = GET)]
|
||||
/// fn get_index() { /* ... */ }
|
||||
///
|
||||
/// #[route("/", method = "VERSION-CONTROL")]
|
||||
/// fn versioned_index() { /* ... */ }
|
||||
///
|
||||
/// #[route("/")]
|
||||
/// fn index() { /* ... */ }
|
||||
/// ```
|
||||
///
|
||||
/// [`get`]: attr.get.html
|
||||
|
@ -171,7 +176,9 @@ macro_rules! route_attribute {
|
|||
/// The generic route attribute is defined as:
|
||||
///
|
||||
/// ```text
|
||||
/// generic-route := METHOD ',' 'uri' '=' route
|
||||
/// generic-route := route (',' method)?
|
||||
///
|
||||
/// method := 'method' '=' METHOD
|
||||
/// ```
|
||||
///
|
||||
/// # Typing Requirements
|
||||
|
@ -1161,12 +1168,12 @@ pub fn derive_uri_display_path(input: TokenStream) -> TokenStream {
|
|||
/// assert_eq!(my_routes.len(), 2);
|
||||
///
|
||||
/// let index_route = &my_routes[0];
|
||||
/// assert_eq!(index_route.method, Method::Get);
|
||||
/// assert_eq!(index_route.method, Some(Method::Get));
|
||||
/// assert_eq!(index_route.name.as_ref().unwrap(), "index");
|
||||
/// assert_eq!(index_route.uri.path(), "/");
|
||||
///
|
||||
/// let hello_route = &my_routes[1];
|
||||
/// assert_eq!(hello_route.method, Method::Post);
|
||||
/// assert_eq!(hello_route.method, Some(Method::Post));
|
||||
/// assert_eq!(hello_route.name.as_ref().unwrap(), "hello");
|
||||
/// assert_eq!(hello_route.uri.path(), "/hi/<person>");
|
||||
/// ```
|
||||
|
|
|
@ -54,8 +54,8 @@ fn post1(
|
|||
}
|
||||
|
||||
#[route(
|
||||
POST,
|
||||
uri = "/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
|
||||
"/<a>/<name>/name/<path..>?sky=blue&<sky>&<query..>",
|
||||
method = POST,
|
||||
format = "json",
|
||||
data = "<simple>",
|
||||
rank = 138
|
||||
|
|
|
@ -124,6 +124,10 @@ macro_rules! define_methods {
|
|||
#[doc(hidden)]
|
||||
pub const ALL: &'static [&'static str] = &[$($name),*];
|
||||
|
||||
/// A slice containing every defined method variant.
|
||||
#[doc(hidden)]
|
||||
pub const ALL_VARIANTS: &'static [Method] = &[$(Self::$V),*];
|
||||
|
||||
/// Whether the method is considered "safe".
|
||||
///
|
||||
/// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1):
|
||||
|
|
|
@ -16,7 +16,7 @@ struct ArbitraryRequestData<'a> {
|
|||
|
||||
#[derive(Arbitrary)]
|
||||
struct ArbitraryRouteData<'a> {
|
||||
method: ArbitraryMethod,
|
||||
method: Option<ArbitraryMethod>,
|
||||
uri: ArbitraryRouteUri<'a>,
|
||||
format: Option<ArbitraryMediaType>,
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ struct ArbitraryRouteData<'a> {
|
|||
impl std::fmt::Debug for ArbitraryRouteData<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ArbitraryRouteData")
|
||||
.field("method", &self.method.0)
|
||||
.field("method", &self.method.map(|v| v.0))
|
||||
.field("base", &self.uri.0.base())
|
||||
.field("unmounted", &self.uri.0.unmounted().to_string())
|
||||
.field("uri", &self.uri.0.to_string())
|
||||
|
@ -59,12 +59,14 @@ impl<'c, 'a: 'c> ArbitraryRequestData<'a> {
|
|||
|
||||
impl<'a> ArbitraryRouteData<'a> {
|
||||
fn into_route(self) -> Route {
|
||||
let mut r = Route::ranked(0, self.method.0, &self.uri.0.to_string(), dummy_handler);
|
||||
let method = self.method.map(|m| m.0);
|
||||
let mut r = Route::ranked(0, method, &self.uri.0.to_string(), dummy_handler);
|
||||
r.format = self.format.map(|f| f.0);
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ArbitraryMethod(Method);
|
||||
|
||||
struct ArbitraryOrigin<'a>(Origin<'a>);
|
||||
|
@ -79,12 +81,7 @@ struct ArbitraryRouteUri<'a>(RouteUri<'a>);
|
|||
|
||||
impl<'a> Arbitrary<'a> for ArbitraryMethod {
|
||||
fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
|
||||
let all_methods = &[
|
||||
Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options,
|
||||
Method::Head, Method::Trace, Method::Connect, Method::Patch
|
||||
];
|
||||
|
||||
Ok(ArbitraryMethod(*u.choose(all_methods)?))
|
||||
Ok(ArbitraryMethod(*u.choose(Method::ALL_VARIANTS)?))
|
||||
}
|
||||
|
||||
fn size_hint(_: usize) -> (usize, Option<usize>) {
|
||||
|
|
|
@ -4,7 +4,7 @@ use figment::Figment;
|
|||
use crate::listener::Endpoint;
|
||||
use crate::shutdown::Stages;
|
||||
use crate::{Catcher, Config, Rocket, Route};
|
||||
use crate::router::Router;
|
||||
use crate::router::{Router, Finalized};
|
||||
use crate::fairing::Fairings;
|
||||
|
||||
mod private {
|
||||
|
@ -100,7 +100,7 @@ phases! {
|
|||
/// represents a fully built and finalized application server ready for
|
||||
/// launch into orbit. See [`Rocket#ignite`] for full details.
|
||||
Ignite (#[derive(Debug)] Igniting) {
|
||||
pub(crate) router: Router,
|
||||
pub(crate) router: Router<Finalized>,
|
||||
pub(crate) fairings: Fairings,
|
||||
pub(crate) figment: Figment,
|
||||
pub(crate) config: Config,
|
||||
|
@ -114,7 +114,7 @@ phases! {
|
|||
/// An instance of `Rocket` in this phase is typed as [`Rocket<Orbit>`] and
|
||||
/// represents a running application.
|
||||
Orbit (#[derive(Debug)] Orbiting) {
|
||||
pub(crate) router: Router,
|
||||
pub(crate) router: Router<Finalized>,
|
||||
pub(crate) fairings: Fairings,
|
||||
pub(crate) figment: Figment,
|
||||
pub(crate) config: Config,
|
||||
|
|
|
@ -557,9 +557,10 @@ impl Rocket<Build> {
|
|||
|
||||
// Initialize the router; check for collisions.
|
||||
let mut router = Router::new();
|
||||
self.routes.clone().into_iter().for_each(|r| router.add_route(r));
|
||||
self.catchers.clone().into_iter().for_each(|c| router.add_catcher(c));
|
||||
router.finalize().map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?;
|
||||
self.routes.clone().into_iter().for_each(|r| router.routes.push(r));
|
||||
self.catchers.clone().into_iter().for_each(|c| router.catchers.push(c));
|
||||
let router = router.finalize()
|
||||
.map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?;
|
||||
|
||||
// Finally, freeze managed state for faster access later.
|
||||
self.state.freeze();
|
||||
|
@ -840,8 +841,8 @@ impl<P: Phase> Rocket<P> {
|
|||
pub fn routes(&self) -> impl Iterator<Item = &Route> {
|
||||
match self.0.as_ref() {
|
||||
StateRef::Build(p) => Either::Left(p.routes.iter()),
|
||||
StateRef::Ignite(p) => Either::Right(p.router.routes()),
|
||||
StateRef::Orbit(p) => Either::Right(p.router.routes()),
|
||||
StateRef::Ignite(p) => Either::Right(p.router.routes.iter()),
|
||||
StateRef::Orbit(p) => Either::Right(p.router.routes.iter()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -871,8 +872,8 @@ impl<P: Phase> Rocket<P> {
|
|||
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> {
|
||||
match self.0.as_ref() {
|
||||
StateRef::Build(p) => Either::Left(p.catchers.iter()),
|
||||
StateRef::Ignite(p) => Either::Right(p.router.catchers()),
|
||||
StateRef::Orbit(p) => Either::Right(p.router.catchers()),
|
||||
StateRef::Ignite(p) => Either::Right(p.router.catchers.iter()),
|
||||
StateRef::Orbit(p) => Either::Right(p.router.catchers.iter()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ use crate::sentinel::Sentry;
|
|||
///
|
||||
/// let route = routes![route_name].remove(0);
|
||||
/// assert_eq!(route.name.unwrap(), "route_name");
|
||||
/// assert_eq!(route.method, Method::Get);
|
||||
/// assert_eq!(route.method, Some(Method::Get));
|
||||
/// assert_eq!(route.uri, "/route/<path..>?query");
|
||||
/// assert_eq!(route.rank, 2);
|
||||
/// assert_eq!(route.format.unwrap(), MediaType::JSON);
|
||||
|
@ -164,8 +164,8 @@ use crate::sentinel::Sentry;
|
|||
pub struct Route {
|
||||
/// The name of this route, if one was given.
|
||||
pub name: Option<Cow<'static, str>>,
|
||||
/// The method this route matches against.
|
||||
pub method: Method,
|
||||
/// The method this route matches, or `None` to match any method.
|
||||
pub method: Option<Method>,
|
||||
/// The function that should be called when the route matches.
|
||||
pub handler: Box<dyn Handler>,
|
||||
/// The route URI.
|
||||
|
@ -203,12 +203,12 @@ impl Route {
|
|||
/// // this is a route matching requests to `GET /`
|
||||
/// let index = Route::new(Method::Get, "/", handler);
|
||||
/// assert_eq!(index.rank, -9);
|
||||
/// assert_eq!(index.method, Method::Get);
|
||||
/// assert_eq!(index.method, Some(Method::Get));
|
||||
/// assert_eq!(index.uri, "/");
|
||||
/// ```
|
||||
#[track_caller]
|
||||
pub fn new<H: Handler>(method: Method, uri: &str, handler: H) -> Route {
|
||||
Route::ranked(None, method, uri, handler)
|
||||
pub fn new<M: Into<Option<Method>>, H: Handler>(method: M, uri: &str, handler: H) -> Route {
|
||||
Route::ranked(None, method.into(), uri, handler)
|
||||
}
|
||||
|
||||
/// Creates a new route with the given rank, method, path, and handler with
|
||||
|
@ -233,17 +233,19 @@ impl Route {
|
|||
///
|
||||
/// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler);
|
||||
/// assert_eq!(foo.rank, 1);
|
||||
/// assert_eq!(foo.method, Method::Post);
|
||||
/// assert_eq!(foo.method, Some(Method::Post));
|
||||
/// assert_eq!(foo.uri, "/foo?bar");
|
||||
///
|
||||
/// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler);
|
||||
/// assert_eq!(foo.rank, -12);
|
||||
/// assert_eq!(foo.method, Method::Post);
|
||||
/// assert_eq!(foo.method, Some(Method::Post));
|
||||
/// assert_eq!(foo.uri, "/foo?bar");
|
||||
/// ```
|
||||
#[track_caller]
|
||||
pub fn ranked<H, R>(rank: R, method: Method, uri: &str, handler: H) -> Route
|
||||
where H: Handler + 'static, R: Into<Option<isize>>,
|
||||
pub fn ranked<M, H, R>(rank: R, method: M, uri: &str, handler: H) -> Route
|
||||
where M: Into<Option<Method>>,
|
||||
H: Handler + 'static,
|
||||
R: Into<Option<isize>>,
|
||||
{
|
||||
let uri = RouteUri::new("/", uri);
|
||||
let rank = rank.into().unwrap_or_else(|| uri.default_rank());
|
||||
|
@ -253,7 +255,9 @@ impl Route {
|
|||
sentinels: Vec::new(),
|
||||
handler: Box::new(handler),
|
||||
location: None,
|
||||
rank, uri, method,
|
||||
method: method.into(),
|
||||
rank,
|
||||
uri,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,7 +366,7 @@ pub struct StaticInfo {
|
|||
/// The route's name, i.e, the name of the function.
|
||||
pub name: &'static str,
|
||||
/// The route's method.
|
||||
pub method: Method,
|
||||
pub method: Option<Method>,
|
||||
/// The route's URi, without the base mount point.
|
||||
pub uri: &'static str,
|
||||
/// The route's format, if any.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::catcher::Catcher;
|
||||
use crate::route::{Route, Segment, RouteUri};
|
||||
|
||||
use crate::http::MediaType;
|
||||
use crate::http::{MediaType, Method};
|
||||
|
||||
pub trait Collide<T = Self> {
|
||||
fn collides_with(&self, other: &T) -> bool;
|
||||
|
@ -87,7 +87,7 @@ impl Route {
|
|||
/// assert!(a.collides_with(&b));
|
||||
/// ```
|
||||
pub fn collides_with(&self, other: &Route) -> bool {
|
||||
self.method == other.method
|
||||
methods_collide(self, other)
|
||||
&& self.rank == other.rank
|
||||
&& self.uri.collides_with(&other.uri)
|
||||
&& formats_collide(self, other)
|
||||
|
@ -190,8 +190,16 @@ impl Collide for MediaType {
|
|||
}
|
||||
}
|
||||
|
||||
fn methods_collide(route: &Route, other: &Route) -> bool {
|
||||
match (route.method, other.method) {
|
||||
(Some(a), Some(b)) => a == b,
|
||||
(None, _) | (_, None) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn formats_collide(route: &Route, other: &Route) -> bool {
|
||||
match (route.method.allows_request_body(), other.method.allows_request_body()) {
|
||||
let payload_support = |m: &Option<Method>| m.and_then(|m| m.allows_request_body());
|
||||
match (payload_support(&route.method), payload_support(&other.method)) {
|
||||
// Payload supporting methods match against `Content-Type` which must be
|
||||
// fully specified, so the request cannot contain a format that matches
|
||||
// more than one route format as long as those formats don't collide.
|
||||
|
|
|
@ -67,8 +67,7 @@ impl Route {
|
|||
/// ```
|
||||
#[tracing::instrument(level = "trace", name = "matching", skip_all, ret)]
|
||||
pub fn matches(&self, request: &Request<'_>) -> bool {
|
||||
trace!(route.method = %self.method, request.method = %request.method());
|
||||
self.method == request.method()
|
||||
methods_match(self, request)
|
||||
&& paths_match(self, request)
|
||||
&& queries_match(self, request)
|
||||
&& formats_match(self, request)
|
||||
|
@ -140,6 +139,11 @@ impl Catcher {
|
|||
}
|
||||
}
|
||||
|
||||
fn methods_match(route: &Route, req: &Request<'_>) -> bool {
|
||||
trace!(?route.method, request.method = %req.method());
|
||||
route.method.map_or(true, |method| method == req.method())
|
||||
}
|
||||
|
||||
fn paths_match(route: &Route, req: &Request<'_>) -> bool {
|
||||
trace!(route.uri = %route.uri, request.uri = %req.uri());
|
||||
let route_segments = &route.uri.metadata.uri_segments;
|
||||
|
@ -208,7 +212,7 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool {
|
|||
None => return true,
|
||||
};
|
||||
|
||||
match route.method.allows_request_body() {
|
||||
match route.method.and_then(|m| m.allows_request_body()) {
|
||||
Some(true) => match req.format() {
|
||||
Some(f) if f.specificity() == 2 => route_format.collides_with(f),
|
||||
_ => false
|
||||
|
|
|
@ -1,64 +1,120 @@
|
|||
use std::ops::{Deref, DerefMut};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::request::Request;
|
||||
use crate::http::{Method, Status};
|
||||
|
||||
use crate::{Route, Catcher};
|
||||
use crate::router::Collide;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Router<T>(T);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct Router {
|
||||
routes: HashMap<Method, Vec<Route>>,
|
||||
catchers: HashMap<Option<u16>, Vec<Catcher>>,
|
||||
pub struct Pending {
|
||||
pub routes: Vec<Route>,
|
||||
pub catchers: Vec<Catcher>,
|
||||
}
|
||||
|
||||
pub type Collisions<T> = Vec<(T, T)>;
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Finalized {
|
||||
pub routes: Vec<Route>,
|
||||
pub catchers: Vec<Catcher>,
|
||||
route_map: HashMap<Method, Vec<usize>>,
|
||||
catcher_map: HashMap<Option<u16>, Vec<usize>>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
pub type Pair<T> = (T, T);
|
||||
|
||||
pub type Collisions = (Vec<Pair<Route>>, Vec<Pair<Catcher>>);
|
||||
|
||||
pub type Result<T, E = Collisions> = std::result::Result<T, E>;
|
||||
|
||||
impl Router<Pending> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
Router(Pending::default())
|
||||
}
|
||||
|
||||
pub fn add_route(&mut self, route: Route) {
|
||||
let routes = self.routes.entry(route.method).or_default();
|
||||
routes.push(route);
|
||||
routes.sort_by_key(|r| r.rank);
|
||||
pub fn finalize(self) -> Result<Router<Finalized>, Collisions> {
|
||||
fn collisions<'a, T>(items: &'a [T]) -> impl Iterator<Item = (T, T)> + 'a
|
||||
where T: Collide + Clone + 'a,
|
||||
{
|
||||
items.iter()
|
||||
.enumerate()
|
||||
.flat_map(move |(i, a)| {
|
||||
items.iter()
|
||||
.skip(i + 1)
|
||||
.filter(move |b| a.collides_with(b))
|
||||
.map(move |b| (a.clone(), b.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_catcher(&mut self, catcher: Catcher) {
|
||||
let catchers = self.catchers.entry(catcher.code).or_default();
|
||||
catchers.push(catcher);
|
||||
catchers.sort_by_key(|c| c.rank);
|
||||
let route_collisions: Vec<_> = collisions(&self.routes).collect();
|
||||
let catcher_collisions: Vec<_> = collisions(&self.catchers).collect();
|
||||
|
||||
if !route_collisions.is_empty() || !catcher_collisions.is_empty() {
|
||||
return Err((route_collisions, catcher_collisions))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn routes(&self) -> impl Iterator<Item = &Route> + Clone {
|
||||
self.routes.values().flat_map(|v| v.iter())
|
||||
// create the route map
|
||||
let mut route_map: HashMap<Method, Vec<usize>> = HashMap::new();
|
||||
for (i, route) in self.routes.iter().enumerate() {
|
||||
match route.method {
|
||||
Some(method) => route_map.entry(method).or_default().push(i),
|
||||
None => for method in Method::ALL_VARIANTS {
|
||||
route_map.entry(*method).or_default().push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn catchers(&self) -> impl Iterator<Item = &Catcher> + Clone {
|
||||
self.catchers.values().flat_map(|v| v.iter())
|
||||
// create the catcher map
|
||||
let mut catcher_map: HashMap<Option<u16>, Vec<usize>> = HashMap::new();
|
||||
for (i, catcher) in self.catchers.iter().enumerate() {
|
||||
catcher_map.entry(catcher.code).or_default().push(i);
|
||||
}
|
||||
|
||||
// sort routes by rank
|
||||
for routes in route_map.values_mut() {
|
||||
routes.sort_by_key(|&i| &self.routes[i].rank);
|
||||
}
|
||||
|
||||
// sort catchers by rank
|
||||
for catchers in catcher_map.values_mut() {
|
||||
catchers.sort_by_key(|&i| &self.catchers[i].rank);
|
||||
}
|
||||
|
||||
Ok(Router(Finalized {
|
||||
routes: self.0.routes,
|
||||
catchers: self.0.catchers,
|
||||
route_map, catcher_map
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl Router<Finalized> {
|
||||
#[track_caller]
|
||||
pub fn route<'r, 'a: 'r>(
|
||||
&'a self,
|
||||
req: &'r Request<'r>
|
||||
) -> impl Iterator<Item = &'a Route> + 'r {
|
||||
// Note that routes are presorted by ascending rank on each `add`.
|
||||
self.routes.get(&req.method())
|
||||
// Note that routes are presorted by ascending rank on each `add` and
|
||||
// that all routes with `None` methods have been cloned into all methods.
|
||||
self.route_map.get(&req.method())
|
||||
.into_iter()
|
||||
.flat_map(move |routes| routes.iter().filter(move |r| r.matches(req)))
|
||||
.flat_map(move |routes| routes.iter().map(move |&i| &self.routes[i]))
|
||||
.filter(move |r| r.matches(req))
|
||||
}
|
||||
|
||||
// For many catchers, using aho-corasick or similar should be much faster.
|
||||
#[track_caller]
|
||||
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> {
|
||||
// Note that catchers are presorted by descending base length.
|
||||
let explicit = self.catchers.get(&Some(status.code))
|
||||
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
|
||||
let explicit = self.catcher_map.get(&Some(status.code))
|
||||
.map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
|
||||
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
|
||||
|
||||
let default = self.catchers.get(&None)
|
||||
.and_then(|c| c.iter().find(|c| c.matches(status, req)));
|
||||
let default = self.catcher_map.get(&None)
|
||||
.map(|catchers| catchers.iter().map(|&i| &self.catchers[i]))
|
||||
.and_then(|mut catchers| catchers.find(|c| c.matches(status, req)));
|
||||
|
||||
match (explicit, default) {
|
||||
(None, None) => None,
|
||||
|
@ -67,28 +123,19 @@ impl Router {
|
|||
(Some(_), Some(b)) => Some(b),
|
||||
}
|
||||
}
|
||||
|
||||
fn collisions<'a, I, T>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a
|
||||
where I: Iterator<Item = &'a T> + Clone + 'a, T: Collide + Clone + 'a,
|
||||
{
|
||||
items.clone().enumerate()
|
||||
.flat_map(move |(i, a)| {
|
||||
items.clone()
|
||||
.skip(i + 1)
|
||||
.filter(move |b| a.collides_with(b))
|
||||
.map(move |b| (a.clone(), b.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn finalize(&self) -> Result<(), (Collisions<Route>, Collisions<Catcher>)> {
|
||||
let routes: Vec<_> = self.collisions(self.routes()).collect();
|
||||
let catchers: Vec<_> = self.collisions(self.catchers()).collect();
|
||||
impl<T> Deref for Router<T> {
|
||||
type Target = T;
|
||||
|
||||
if !routes.is_empty() || !catchers.is_empty() {
|
||||
return Err((routes, catchers))
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
impl DerefMut for Router<Pending> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -100,50 +147,32 @@ mod test {
|
|||
use crate::local::blocking::Client;
|
||||
use crate::http::{Method::*, uri::Origin};
|
||||
|
||||
impl Router {
|
||||
fn has_collisions(&self) -> bool {
|
||||
self.finalize().is_err()
|
||||
}
|
||||
}
|
||||
|
||||
fn router_with_routes(routes: &[&'static str]) -> Router {
|
||||
fn make_router<I>(routes: I) -> Result<Router<Finalized>, Collisions>
|
||||
where I: Iterator<Item = (Option<isize>, &'static str)>
|
||||
{
|
||||
let mut router = Router::new();
|
||||
for route in routes {
|
||||
let route = Route::new(Get, route, dummy_handler);
|
||||
router.add_route(route);
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
|
||||
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router {
|
||||
let mut router = Router::new();
|
||||
for &(rank, route) in routes {
|
||||
for (rank, route) in routes {
|
||||
let route = Route::ranked(rank, Get, route, dummy_handler);
|
||||
router.add_route(route);
|
||||
router.routes.push(route);
|
||||
}
|
||||
|
||||
router
|
||||
router.finalize()
|
||||
}
|
||||
|
||||
fn router_with_rankless_routes(routes: &[&'static str]) -> Router {
|
||||
let mut router = Router::new();
|
||||
for route in routes {
|
||||
let route = Route::ranked(0, Get, route, dummy_handler);
|
||||
router.add_route(route);
|
||||
fn router_with_routes(routes: &[&'static str]) -> Router<Finalized> {
|
||||
make_router(routes.iter().map(|r| (None, *r))).unwrap()
|
||||
}
|
||||
|
||||
router
|
||||
fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router<Finalized> {
|
||||
make_router(routes.iter().map(|r| (Some(r.0), r.1))).unwrap()
|
||||
}
|
||||
|
||||
fn rankless_route_collisions(routes: &[&'static str]) -> bool {
|
||||
let router = router_with_rankless_routes(routes);
|
||||
router.has_collisions()
|
||||
make_router(routes.iter().map(|r| (Some(0), *r))).is_err()
|
||||
}
|
||||
|
||||
fn default_rank_route_collisions(routes: &[&'static str]) -> bool {
|
||||
let router = router_with_routes(routes);
|
||||
router.has_collisions()
|
||||
make_router(routes.iter().map(|r| (None, *r))).is_err()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -280,13 +309,15 @@ mod test {
|
|||
assert!(!default_rank_route_collisions(&["/<foo>?a=b", "/<foo>?c=d&<d>"]));
|
||||
}
|
||||
|
||||
fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> {
|
||||
#[track_caller]
|
||||
fn matches<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Vec<&'a Route> {
|
||||
let client = Client::debug_with(vec![]).expect("client");
|
||||
let request = client.req(method, Origin::parse(uri).unwrap());
|
||||
router.route(&request).collect()
|
||||
}
|
||||
|
||||
fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> {
|
||||
#[track_caller]
|
||||
fn route<'a>(router: &'a Router<Finalized>, method: Method, uri: &'a str) -> Option<&'a Route> {
|
||||
matches(router, method, uri).into_iter().next()
|
||||
}
|
||||
|
||||
|
@ -309,9 +340,10 @@ mod test {
|
|||
assert!(route(&router, Get, "/a/").is_some());
|
||||
|
||||
let mut router = Router::new();
|
||||
router.add_route(Route::new(Put, "/hello", dummy_handler));
|
||||
router.add_route(Route::new(Post, "/hello", dummy_handler));
|
||||
router.add_route(Route::new(Delete, "/hello", dummy_handler));
|
||||
router.routes.push(Route::new(Put, "/hello", dummy_handler));
|
||||
router.routes.push(Route::new(Post, "/hello", dummy_handler));
|
||||
router.routes.push(Route::new(Delete, "/hello", dummy_handler));
|
||||
let router = router.finalize().unwrap();
|
||||
assert!(route(&router, Put, "/hello").is_some());
|
||||
assert!(route(&router, Post, "/hello").is_some());
|
||||
assert!(route(&router, Delete, "/hello").is_some());
|
||||
|
@ -368,7 +400,6 @@ mod test {
|
|||
macro_rules! assert_ranked_match {
|
||||
($routes:expr, $to:expr => $want:expr) => ({
|
||||
let router = router_with_routes($routes);
|
||||
assert!(!router.has_collisions());
|
||||
let route_path = route(&router, Get, $to).unwrap().uri.to_string();
|
||||
assert_eq!(route_path, $want.to_string(),
|
||||
"\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router);
|
||||
|
@ -401,8 +432,7 @@ mod test {
|
|||
}
|
||||
|
||||
fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool {
|
||||
let router = router_with_ranked_routes(routes);
|
||||
router.has_collisions()
|
||||
make_router(routes.iter().map(|r| (Some(r.0), r.1))).is_err()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -429,7 +459,7 @@ mod test {
|
|||
let router = router_with_ranked_routes(&$routes);
|
||||
let routed_to = matches(&router, Get, $to);
|
||||
let expected = &[$($want),+];
|
||||
assert!(routed_to.len() == expected.len());
|
||||
assert_eq!(routed_to.len(), expected.len());
|
||||
for (got, expected) in routed_to.iter().zip(expected.iter()) {
|
||||
assert_eq!(got.rank, expected.0);
|
||||
assert_eq!(got.uri.to_string(), expected.1.to_string());
|
||||
|
@ -545,20 +575,21 @@ mod test {
|
|||
);
|
||||
}
|
||||
|
||||
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Router {
|
||||
fn router_with_catchers(catchers: &[(Option<u16>, &str)]) -> Result<Router<Finalized>> {
|
||||
let mut router = Router::new();
|
||||
for (code, base) in catchers {
|
||||
let catcher = Catcher::new(*code, crate::catcher::dummy_handler);
|
||||
router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap());
|
||||
router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap());
|
||||
}
|
||||
|
||||
router
|
||||
router.finalize()
|
||||
}
|
||||
|
||||
fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> {
|
||||
#[track_caller]
|
||||
fn catcher<'a>(r: &'a Router<Finalized>, status: Status, uri: &str) -> Option<&'a Catcher> {
|
||||
let client = Client::debug_with(vec![]).expect("client");
|
||||
let request = client.get(Origin::parse(uri).unwrap());
|
||||
router.catch(status, &request)
|
||||
r.catch(status, &request)
|
||||
}
|
||||
|
||||
macro_rules! assert_catcher_routing {
|
||||
|
@ -571,7 +602,7 @@ mod test {
|
|||
let requests = vec![$($r),+];
|
||||
let expected = vec![$(($ecode.into(), $euri)),+];
|
||||
|
||||
let router = router_with_catchers(&catchers);
|
||||
let router = router_with_catchers(&catchers).expect("valid router");
|
||||
for (req, expected) in requests.iter().zip(expected.iter()) {
|
||||
let req_status = Status::from_code(req.0).expect("valid status");
|
||||
let catcher = catcher(&router, req_status, req.1).expect("some catcher");
|
||||
|
|
|
@ -142,7 +142,10 @@ impl Trace for Route {
|
|||
event! { level, "route",
|
||||
name = self.name.as_ref().map(|n| &**n),
|
||||
rank = self.rank,
|
||||
method = %self.method,
|
||||
method = %Formatter(|f| match self.method {
|
||||
Some(method) => write!(f, "{}", method),
|
||||
None => write!(f, "[any]"),
|
||||
}),
|
||||
uri = %self.uri,
|
||||
uri.base = %self.uri.base(),
|
||||
uri.unmounted = %self.uri.unmounted(),
|
||||
|
|
|
@ -13,13 +13,13 @@ fn patch(form_data: Form<FormData>) -> &'static str {
|
|||
"PATCH OK"
|
||||
}
|
||||
|
||||
#[route(UPDATEREDIRECTREF, uri = "/", data = "<form_data>")]
|
||||
#[route("/", method = UPDATEREDIRECTREF, data = "<form_data>")]
|
||||
fn urr(form_data: Form<FormData>) -> &'static str {
|
||||
assert_eq!("Form data", form_data.into_inner().form_data);
|
||||
"UPDATEREDIRECTREF OK"
|
||||
}
|
||||
|
||||
#[route("VERSION-CONTROL", uri = "/", data = "<form_data>")]
|
||||
#[route("/", method = "VERSION-CONTROL", data = "<form_data>")]
|
||||
fn vc(form_data: Form<FormData>) -> &'static str {
|
||||
assert_eq!("Form data", form_data.into_inner().form_data);
|
||||
"VERSION-CONTROL OK"
|
||||
|
|
|
@ -37,21 +37,55 @@ these properties and more.
|
|||
|
||||
## Methods
|
||||
|
||||
A Rocket route attribute can be any one of `get`, `put`, `post`, `delete`,
|
||||
`head`, `patch`, or `options`, each corresponding to the HTTP method to match
|
||||
against. For example, the following attribute will match against `POST` requests
|
||||
to the root path:
|
||||
A Rocket route attribute can either be method-specific, any one of `get`, `put`,
|
||||
`post`, `delete`, `head`, `patch`, or `options`, or the generic [`route`], which
|
||||
allows explicitly specifying any valid HTTP [`Method`] or no method at all, to
|
||||
match again _any_ method. Consider the following examples:
|
||||
|
||||
* Match a `POST` request to `/`:
|
||||
|
||||
```rust
|
||||
# #[macro_use] extern crate rocket;
|
||||
# fn main() {}
|
||||
|
||||
# use rocket::post;
|
||||
#[post("/")]
|
||||
# fn handler() {}
|
||||
```
|
||||
|
||||
* Match a `PATCH` request to `/fix`:
|
||||
|
||||
```rust
|
||||
# use rocket::patch;
|
||||
#[patch("/fix")]
|
||||
# fn handler() {}
|
||||
```
|
||||
|
||||
* Match a `PROPFIND` request to `/collection`:
|
||||
|
||||
```rust
|
||||
# use rocket::route;
|
||||
#[route("/collection", method = PROPFIND)]
|
||||
# fn handler() {}
|
||||
```
|
||||
|
||||
* Match a `VERSION-CONTROL` request to `/collection`:
|
||||
|
||||
```rust
|
||||
# use rocket::route;
|
||||
#[route("/resource", method = "VERSION-CONTROL")]
|
||||
# fn handler() {}
|
||||
```
|
||||
|
||||
* Match a request to `/page` with _any_ method:
|
||||
|
||||
```rust
|
||||
# use rocket::route;
|
||||
#[route("/page")]
|
||||
# fn handler() {}
|
||||
```
|
||||
|
||||
The grammar for these attributes is defined formally in the [`route`] API docs.
|
||||
|
||||
[`Method`]: @api/master/rocket/http/enum.Method.html
|
||||
|
||||
### HEAD Requests
|
||||
|
||||
Rocket handles `HEAD` requests automatically when there exists a `GET` route
|
||||
|
|
|
@ -33,7 +33,7 @@ fn mir() -> &'static str {
|
|||
|
||||
// Try visiting:
|
||||
// http://127.0.0.1:8000/wave/Rocketeer/100
|
||||
#[get("/<name>/<age>")]
|
||||
#[get("/<name>/<age>", rank = 2)]
|
||||
fn wave(name: &str, age: u8) -> String {
|
||||
format!("👋 Hello, {} year old named {}!", age, name)
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use rocket::http::Status;
|
||||
use rocket::http::uri::{Origin, Host};
|
||||
use rocket::tracing::{self, Instrument};
|
||||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
|
||||
use rocket::{Rocket, Ignite, Orbit, State, Error};
|
||||
use rocket::fairing::{Fairing, Info, Kind};
|
||||
use rocket::response::Redirect;
|
||||
use rocket::listener::tcp::TcpListener;
|
||||
|
@ -19,43 +19,33 @@ pub struct Config {
|
|||
tls_addr: SocketAddr,
|
||||
}
|
||||
|
||||
#[route("/<_..>")]
|
||||
fn redirect(config: &State<Config>, uri: &Origin<'_>, host: &Host<'_>) -> Redirect {
|
||||
// FIXME: Check the host against a whitelist!
|
||||
let domain = host.domain();
|
||||
let https_uri = match config.tls_addr.port() {
|
||||
443 => format!("https://{domain}{uri}"),
|
||||
port => format!("https://{domain}:{port}{uri}"),
|
||||
};
|
||||
|
||||
Redirect::permanent(https_uri)
|
||||
}
|
||||
|
||||
impl Redirector {
|
||||
pub fn on(port: u16) -> Self {
|
||||
Redirector(port)
|
||||
}
|
||||
|
||||
// Route function that gets called on every single request.
|
||||
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
|
||||
// FIXME: Check the host against a whitelist!
|
||||
let config = req.rocket().state::<Config>().expect("managed Self");
|
||||
if let Some(host) = req.host() {
|
||||
let domain = host.domain();
|
||||
let https_uri = match config.tls_addr.port() {
|
||||
443 => format!("https://{domain}{}", req.uri()),
|
||||
port => format!("https://{domain}:{port}{}", req.uri()),
|
||||
};
|
||||
|
||||
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
|
||||
} else {
|
||||
route::Outcome::from(req, Status::BadRequest).pin()
|
||||
}
|
||||
}
|
||||
|
||||
// Launch an instance of Rocket than handles redirection on `self.port`.
|
||||
pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
|
||||
use rocket::http::Method::*;
|
||||
|
||||
// Build a vector of routes to `redirect` on `<path..>` for each method.
|
||||
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
|
||||
.into_iter()
|
||||
.map(|m| Route::new(m, "/<path..>", Self::redirect))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
rocket::span_info!("HTTP -> HTTPS Redirector" => {
|
||||
info!(from = self.0, to = config.tls_addr.port(), "redirecting");
|
||||
});
|
||||
|
||||
let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
|
||||
rocket::custom(&config.server)
|
||||
.manage(config)
|
||||
.mount("/", redirects)
|
||||
.mount("/", routes![redirect])
|
||||
.try_launch_on(TcpListener::bind(addr))
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::prelude::*;
|
|||
|
||||
use rocket::http::Method;
|
||||
|
||||
#[route(PROPFIND, uri = "/")]
|
||||
#[route("/", method = PROPFIND)]
|
||||
fn route() -> &'static str {
|
||||
"Hello, World!"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue