From b34085392d8d2fd44e863221a3aaab62b4cd55b3 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Wed, 24 Apr 2024 18:31:39 -0700 Subject: [PATCH] Add 'Method' variants for all registered methods. This commit allow routes to be declared for methods outside of the standard HTTP method set. Specifically, it enables declaring routes for any method in the IANA Method Registry: ```rust #[route(LINK, uri = "/")] fn link() { ... } #[route("VERSION-CONTROL", uri = "/")] fn version_control() { ... } ``` The `Method` type has gained variants for each registered method. Breaking changes: - `Method::from_str()` no longer parses mixed-case method names. - `Method` is marked as non-exhaustive. - `Method::supports_payload()` removed in favor of `Method::allows_request_body()`. Resolves #232. --- benchmarks/src/routing.rs | 2 +- core/codegen/src/attribute/route/parse.rs | 19 +- core/codegen/src/http_codegen.rs | 49 +-- core/http/src/method.rs | 441 ++++++++++++++++------ core/lib/src/lifecycle.rs | 7 +- core/lib/src/request/atomic_method.rs | 23 +- core/lib/src/request/request.rs | 20 +- core/lib/src/router/collider.rs | 27 +- core/lib/src/router/matcher.rs | 11 +- core/lib/tests/form_method-issue-45.rs | 45 ++- core/lib/tests/http_serde.rs | 2 +- 11 files changed, 445 insertions(+), 201 deletions(-) diff --git a/benchmarks/src/routing.rs b/benchmarks/src/routing.rs index 417615cd..48e7b165 100644 --- a/benchmarks/src/routing.rs +++ b/benchmarks/src/routing.rs @@ -62,7 +62,7 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec { + data.full_span.warning("`data` used with non-payload-supporting method") + .note(format!("'{}' does not typically support payloads", attr.method.0)) + .note(lint.how_to_suppress()) + .emit_as_item_tokens(); + } + Some(false) => { + diags.push(data.full_span + .error("`data` cannot be used on this route") + .span_note(attr.method.span, "method does not support request payloads")) + } + _ => { /* okay */ }, } } diff --git a/core/codegen/src/http_codegen.rs b/core/codegen/src/http_codegen.rs index 896f1d6d..e234705a 100644 --- a/core/codegen/src/http_codegen.rs +++ b/core/codegen/src/http_codegen.rs @@ -1,5 +1,5 @@ use quote::ToTokens; -use devise::{FromMeta, MetaItem, Result, ext::{Split2, PathExt, SpanDiagnosticExt}}; +use devise::{FromMeta, MetaItem, Result, ext::{Split2, SpanDiagnosticExt}}; use proc_macro2::{TokenStream, Span}; use crate::{http, attribute::suppress::Lint}; @@ -97,47 +97,34 @@ impl ToTokens for MediaType { } } -const VALID_METHODS_STR: &str = "`GET`, `PUT`, `POST`, `DELETE`, `HEAD`, \ - `PATCH`, `OPTIONS`"; - -const VALID_METHODS: &[http::Method] = &[ - http::Method::Get, http::Method::Put, http::Method::Post, - http::Method::Delete, http::Method::Head, http::Method::Patch, - http::Method::Options, -]; - impl FromMeta for Method { fn from_meta(meta: &MetaItem) -> Result { let span = meta.value_span(); - let help_text = format!("method must be one of: {VALID_METHODS_STR}"); + let help = format!("known methods: {}", http::Method::ALL.join(", ")); - if let MetaItem::Path(path) = meta { - if let Some(ident) = path.last_ident() { - let method = ident.to_string().parse() - .map_err(|_| span.error("invalid HTTP method").help(&*help_text))?; + let string = meta.path().ok() + .and_then(|p| p.get_ident().cloned()) + .map(|ident| (ident.span(), ident.to_string())) + .or_else(|| match meta.lit() { + Ok(syn::Lit::Str(s)) => Some((s.span(), s.value())), + _ => None + }); - if !VALID_METHODS.contains(&method) { - return Err(span.error("invalid HTTP method for route handlers") - .help(&*help_text)); - } - - return Ok(Method(method)); - } + if let Some((span, string)) = string { + string.to_ascii_uppercase() + .parse() + .map(Method) + .map_err(|_| span.error("invalid or unknown HTTP method").help(help)) + } else { + let err = format!("expected method ident or string, found {}", meta.description()); + Err(span.error(err).help(help)) } - - Err(span.error(format!("expected identifier, found {}", meta.description())) - .help(&*help_text)) } } impl ToTokens for Method { fn to_tokens(&self, tokens: &mut TokenStream) { - let mut chars = self.0.as_str().chars(); - let variant_str = chars.next() - .map(|c| c.to_ascii_uppercase().to_string() + &chars.as_str().to_lowercase()) - .unwrap_or_default(); - - let variant = syn::Ident::new(&variant_str, Span::call_site()); + let variant = syn::Ident::new(self.0.variant_str(), Span::call_site()); tokens.extend(quote!(::rocket::http::Method::#variant)); } } diff --git a/core/http/src/method.rs b/core/http/src/method.rs index e47461ac..0c455020 100644 --- a/core/http/src/method.rs +++ b/core/http/src/method.rs @@ -1,138 +1,361 @@ use std::fmt; use std::str::FromStr; -use self::Method::*; +self::define_methods! { + // enum variant method name body safe idempotent [RFC,section] + Get "GET" maybe yes yes [9110,9.3.1] + Head "HEAD" maybe yes yes [9110,9.3.2] + Post "POST" yes no no [9110,9.3.3] + Put "PUT" yes no yes [9110,9.3.4] + Delete "DELETE" maybe no yes [9110,9.3.5] + Connect "CONNECT" maybe no no [9110,9.3.6] + Options "OPTIONS" maybe yes yes [9110,9.3.7] + Trace "TRACE" no yes yes [9110,9.3.8] + Patch "PATCH" yes no no [5789,2] -// TODO: Support non-standard methods, here and in codegen? + Acl "ACL" yes no yes [3744,8.1] + BaselineControl "BASELINE-CONTROL" yes no yes [3253,12.6] + Bind "BIND" yes no yes [5842,4] + CheckIn "CHECKIN" yes no yes [3253,4.4] + CheckOut "CHECKOUT" maybe no yes [3253,4.3] + Copy "COPY" maybe no yes [4918,9.8] + Label "LABEL" yes no yes [3253,8.2] + Link "LINK" maybe no yes [2068,19.6.1.2] + Lock "LOCK" yes no no [4918,9.10] + Merge "MERGE" yes no yes [3253,11.2] + MkActivity "MKACTIVITY" yes no yes [3253,13.5] + MkCalendar "MKCALENDAR" yes no yes [4791,5.3.1][8144,2.3] + MkCol "MKCOL" yes no yes [4918,9.3][5689,3][8144,2.3] + MkRedirectRef "MKREDIRECTREF" yes no yes [4437,6] + MkWorkspace "MKWORKSPACE" yes no yes [3253,6.3] + Move "MOVE" maybe no yes [4918,9.9] + OrderPatch "ORDERPATCH" yes no yes [3648,7] + PropFind "PROPFIND" yes yes yes [4918,9.1][8144,2.1] + PropPatch "PROPPATCH" yes no yes [4918,9.2][8144,2.2] + Rebind "REBIND" yes no yes [5842,6] + Report "REPORT" yes yes yes [3253,3.6][8144,2.1] + Search "SEARCH" yes yes yes [5323,2] + Unbind "UNBIND" yes no yes [5842,5] + Uncheckout "UNCHECKOUT" maybe no yes [3253,4.5] + Unlink "UNLINK" maybe no yes [2068,19.6.1.3] + Unlock "UNLOCK" maybe no yes [4918,9.11] + Update "UPDATE" yes no yes [3253,7.1] + UpdateRedirectRef "UPDATEREDIRECTREF" yes no yes [4437,7] + VersionControl "VERSION-CONTROL" yes no yes [3253,3.5] +} -/// Representation of HTTP methods. -/// -/// # (De)serialization -/// -/// `Method` is both `Serialize` and `Deserialize`, represented as an -/// [uncased](crate::uncased) string. For example, [`Method::Get`] serializes to -/// `"GET"` and deserializes from any casing of `"GET"` including `"get"`, -/// `"GeT"`, and `"GET"`. -/// -/// ```rust -/// # #[cfg(feature = "serde")] mod serde { -/// # use serde_ as serde; -/// use serde::{Serialize, Deserialize}; -/// use rocket::http::Method; -/// -/// #[derive(Deserialize, Serialize)] -/// # #[serde(crate = "serde_")] -/// struct Foo { -/// method: Method, -/// } -/// # } -/// ``` -#[repr(u8)] -#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub enum Method { - /// The `GET` variant. - Get, - /// The `PUT` variant. - Put, - /// The `POST` variant. - Post, - /// The `DELETE` variant. - Delete, - /// The `OPTIONS` variant. - Options, - /// The `HEAD` variant. - Head, - /// The `TRACE` variant. - Trace, - /// The `CONNECT` variant. - Connect, - /// The `PATCH` variant. - Patch +#[doc(hidden)] +#[macro_export] +macro_rules! define_methods { + ($($V:ident $name:tt $body:ident $safe:ident $idem:ident $([$n:expr,$s:expr])+)*) => { + /// An HTTP method. + /// + /// Each variant corresponds to a method in the [HTTP Method Registry]. + /// The string form of the method can be obtained via + /// [`Method::as_str()`] and parsed via the `FromStr` or + /// `TryFrom<&[u8]>` implementations. The parse implementations parse + /// both the case-sensitive string form as well as a lowercase version + /// of the string, but _not_ mixed-case versions. + /// + /// [HTTP Method Registry]: https://www.iana.org/assignments/http-methods/http-methods.xhtml + /// + /// # (De)Serialization + /// + /// `Method` is both `Serialize` and `Deserialize`. + /// + /// - `Method` _serializes_ as the specification-defined string form + /// of the method, equivalent to the value returned from + /// [`Method::as_str()`]. + /// - `Method` _deserializes_ from method's string form _or_ from a + /// lowercased string, equivalent to the `FromStr` implementation. + /// + /// For example, [`Method::Get`] serializes to `"GET"` and deserializes + /// from either `"GET"` or `"get"` but not `"GeT"`. + /// + /// ```rust + /// # #[cfg(feature = "serde")] mod serde { + /// # use serde_ as serde; + /// use serde::{Serialize, Deserialize}; + /// use rocket::http::Method; + /// + /// #[derive(Deserialize, Serialize)] + /// # #[serde(crate = "serde_")] + /// struct Foo { + /// method: Method, + /// } + /// # } + /// ``` + #[non_exhaustive] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub enum Method {$( + #[doc = concat!("The `", $name, "` method.")] + #[doc = concat!("Defined in" $(, + " [RFC", stringify!($n), " §", stringify!($s), "]", + "(https://www.rfc-editor.org/rfc/rfc", stringify!($n), ".html", + "#section-", stringify!($s), ")", + )","+ ".")] + /// + #[doc = concat!("* safe: `", stringify!($safe), "`")] + #[doc = concat!("* idempotent: `", stringify!($idem), "`")] + #[doc = concat!("* request body: `", stringify!($body), "`")] + $V + ),*} + + macro_rules! lowercase { + ($str:literal) => {{ + const BYTES: [u8; $str.len()] = { + let mut i = 0; + let _: &str = $str; + let mut result = [0; $str.len()]; + while i < $str.len() { + result[i] = $str.as_bytes()[i].to_ascii_lowercase(); + i += 1; + } + + result + }; + + unsafe { std::str::from_utf8_unchecked(&BYTES) } + }}; + } + + #[allow(non_upper_case_globals)] + impl Method { + /// A slice containing every defined method string. + #[doc(hidden)] + pub const ALL: &'static [&'static str] = &[$($name),*]; + + /// Whether the method is considered "safe". + /// + /// From [RFC9110 §9.2.1](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.1): + /// + /// > Request methods are considered "safe" if their defined + /// semantics are essentially read-only; i.e., the client does not + /// request, and does not expect, any state change on the origin server + /// as a result of applying a safe method to a target resource. + /// Likewise, reasonable use of a safe method is not expected to cause + /// any harm, loss of property, or unusual burden on the origin server. + /// Of the request methods defined by this specification, the GET, + /// HEAD, OPTIONS, and TRACE methods are defined to be safe. + /// + /// # Example + /// + /// ```rust + /// use rocket::http::Method; + /// + /// assert!(Method::Get.is_safe()); + /// assert!(Method::Head.is_safe()); + /// + /// assert!(!Method::Put.is_safe()); + /// assert!(!Method::Post.is_safe()); + /// ``` + pub const fn is_safe(&self) -> bool { + const yes: bool = true; + const no: bool = false; + + match self { + $(Self::$V => $safe),* + } + } + + /// Whether the method is considered "idempotent". + /// + /// From [RFC9110 §9.2.2](https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2): + /// + /// > A request method is considered "idempotent" if the intended + /// effect on the server of multiple identical requests with that method + /// is the same as the effect for a single such request. Of the request + /// methods defined by this specification, PUT, DELETE, and safe request + /// methods are idempotent. + /// + /// # Example + /// + /// ```rust + /// use rocket::http::Method; + /// + /// assert!(Method::Get.is_idempotent()); + /// assert!(Method::Head.is_idempotent()); + /// assert!(Method::Put.is_idempotent()); + /// + /// assert!(!Method::Post.is_idempotent()); + /// assert!(!Method::Patch.is_idempotent()); + /// ``` + pub const fn is_idempotent(&self) -> bool { + const yes: bool = true; + const no: bool = false; + + match self { + $(Self::$V => $idem),* + } + } + + /// Whether requests with this method are allowed to have a body. + /// + /// Returns: + /// * `Some(true)` if a request body is _always_ allowed. + /// * `Some(false)` if a request body is **never** allowed. + /// * `None` if a request body is discouraged or has no defined semantics. + /// + /// # Example + /// + /// ```rust + /// use rocket::http::Method; + /// + /// assert_eq!(Method::Post.allows_request_body(), Some(true)); + /// assert_eq!(Method::Put.allows_request_body(), Some(true)); + /// + /// assert_eq!(Method::Trace.allows_request_body(), Some(false)); + /// + /// assert_eq!(Method::Get.allows_request_body(), None); + /// assert_eq!(Method::Head.allows_request_body(), None); + /// ``` + pub const fn allows_request_body(self) -> Option { + const yes: Option = Some(true); + const no: Option = Some(false); + const maybe: Option = None; + + match self { + $(Self::$V => $body),* + } + } + + /// Returns the method's string representation. + /// + /// # Example + /// + /// ```rust + /// use rocket::http::Method; + /// + /// assert_eq!(Method::Get.as_str(), "GET"); + /// assert_eq!(Method::Put.as_str(), "PUT"); + /// assert_eq!(Method::BaselineControl.as_str(), "BASELINE-CONTROL"); + /// ``` + pub const fn as_str(self) -> &'static str { + match self { + $(Self::$V => $name),* + } + } + + /// Returns a static reference to the method. + /// + /// # Example + /// + /// ```rust + /// use rocket::http::Method; + /// + /// assert_eq!(Method::Get.as_ref(), &Method::Get); + /// ``` + pub const fn as_ref(self) -> &'static Method { + match self { + $(Self::$V => &Self::$V),* + } + } + + #[doc(hidden)] + pub const fn variant_str(self) -> &'static str { + match self { + $(Self::$V => stringify!($V)),* + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + #[allow(non_upper_case_globals)] + fn test_properties_and_parsing() { + const yes: bool = true; + const no: bool = false; + + $( + assert_eq!(Method::$V.is_idempotent(), $idem); + assert_eq!(Method::$V.is_safe(), $safe); + assert_eq!(Method::from_str($name).unwrap(), Method::$V); + assert_eq!(Method::from_str(lowercase!($name)).unwrap(), Method::$V); + assert_eq!(Method::$V.as_ref(), Method::$V); + )* + } + } + + impl TryFrom<&[u8]> for Method { + type Error = ParseMethodError; + + #[inline] + #[allow(non_upper_case_globals)] + fn try_from(value: &[u8]) -> Result { + mod upper { $(pub const $V: &[u8] = $name.as_bytes();)* } + mod lower { $(pub const $V: &[u8] = lowercase!($name).as_bytes();)* } + + match value { + $(upper::$V | lower::$V => Ok(Self::$V),)* + _ => Err(ParseMethodError) + } + } + } + }; } impl Method { - /// Returns `true` if an HTTP request with the method represented by `self` - /// always supports a payload. + /// Deprecated. Returns `self.allows_request_body() == Some(true)`. /// - /// The following methods always support payloads: - /// - /// * `PUT`, `POST`, `DELETE`, `PATCH` - /// - /// The following methods _do not_ always support payloads: - /// - /// * `GET`, `HEAD`, `CONNECT`, `TRACE`, `OPTIONS` - /// - /// # Example - /// - /// ```rust - /// # extern crate rocket; - /// use rocket::http::Method; - /// - /// assert_eq!(Method::Get.supports_payload(), false); - /// assert_eq!(Method::Post.supports_payload(), true); - /// ``` - #[inline] - pub fn supports_payload(self) -> bool { - match self { - Put | Post | Delete | Patch => true, - Get | Head | Connect | Trace | Options => false, + /// Use [`Method::allows_request_body()`] instead. + #[deprecated(since = "0.6", note = "use Self::allows_request_body()")] + pub const fn supports_payload(self) -> bool { + match self.allows_request_body() { + Some(v) => v, + None => false, } } +} - /// Returns the string representation of `self`. - /// - /// # Example - /// - /// ```rust - /// # extern crate rocket; - /// use rocket::http::Method; - /// - /// assert_eq!(Method::Get.as_str(), "GET"); - /// ``` - #[inline] - pub fn as_str(self) -> &'static str { - match self { - Get => "GET", - Put => "PUT", - Post => "POST", - Delete => "DELETE", - Options => "OPTIONS", - Head => "HEAD", - Trace => "TRACE", - Connect => "CONNECT", - Patch => "PATCH", - } +use define_methods as define_methods; + +#[derive(Debug, PartialEq, Eq)] +pub struct ParseMethodError; + +impl std::error::Error for ParseMethodError { } + +impl fmt::Display for ParseMethodError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("invalid HTTP method") } } impl FromStr for Method { - type Err = (); + type Err = ParseMethodError; - // According to the RFC, method names are case-sensitive. But some old - // clients don't follow this, so we just do a case-insensitive match here. - fn from_str(s: &str) -> Result { - match s { - x if uncased::eq(x, Get.as_str()) => Ok(Get), - x if uncased::eq(x, Put.as_str()) => Ok(Put), - x if uncased::eq(x, Post.as_str()) => Ok(Post), - x if uncased::eq(x, Delete.as_str()) => Ok(Delete), - x if uncased::eq(x, Options.as_str()) => Ok(Options), - x if uncased::eq(x, Head.as_str()) => Ok(Head), - x if uncased::eq(x, Trace.as_str()) => Ok(Trace), - x if uncased::eq(x, Connect.as_str()) => Ok(Connect), - x if uncased::eq(x, Patch.as_str()) => Ok(Patch), - _ => Err(()), - } + #[inline(always)] + fn from_str(s: &str) -> Result { + Self::try_from(s.as_bytes()) + } +} + +impl AsRef for Method { + fn as_ref(&self) -> &str { + self.as_str() } } impl fmt::Display for Method { - #[inline(always)] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.as_str().fmt(f) } } +impl PartialEq<&Method> for Method { + fn eq(&self, other: &&Method) -> bool { + self == *other + } +} + +impl PartialEq for &Method { + fn eq(&self, other: &Method) -> bool { + *self == other + } +} + #[cfg(feature = "serde")] mod serde { use super::*; diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 6300f15e..3508e1bf 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -59,11 +59,8 @@ impl Rocket { ) -> RequestToken { // Check if this is a form and if the form contains the special _method // field which we use to reinterpret the request's method. - let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); - let peek_buffer = data.peek(max_len).await; - let is_form = req.content_type().map_or(false, |ct| ct.is_form()); - - if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len { + if req.method() == Method::Post && req.content_type().map_or(false, |v| v.is_form()) { + let peek_buffer = data.peek(32).await; let method = std::str::from_utf8(peek_buffer).ok() .and_then(|raw_form| Form::values(raw_form).next()) .filter(|field| field.name == "_method") diff --git a/core/lib/src/request/atomic_method.rs b/core/lib/src/request/atomic_method.rs index 6d49f603..74bc8679 100644 --- a/core/lib/src/request/atomic_method.rs +++ b/core/lib/src/request/atomic_method.rs @@ -2,36 +2,25 @@ use crate::http::Method; pub struct AtomicMethod(ref_swap::RefSwap<'static, Method>); -#[inline(always)] -const fn makeref(method: Method) -> &'static Method { - match method { - Method::Get => &Method::Get, - Method::Put => &Method::Put, - Method::Post => &Method::Post, - Method::Delete => &Method::Delete, - Method::Options => &Method::Options, - Method::Head => &Method::Head, - Method::Trace => &Method::Trace, - Method::Connect => &Method::Connect, - Method::Patch => &Method::Patch, - } -} - impl AtomicMethod { + #[inline] pub fn new(value: Method) -> Self { - Self(ref_swap::RefSwap::new(makeref(value))) + Self(ref_swap::RefSwap::new(value.as_ref())) } + #[inline] pub fn load(&self) -> Method { *self.0.load(std::sync::atomic::Ordering::Acquire) } + #[inline] pub fn set(&mut self, new: Method) { *self = Self::new(new); } + #[inline] pub fn store(&self, new: Method) { - self.0.store(makeref(new), std::sync::atomic::Ordering::Release) + self.0.store(new.as_ref(), std::sync::atomic::Ordering::Release) } } diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index b0d52829..989d7106 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -656,10 +656,20 @@ impl<'r> Request<'r> { /// Returns the media type "format" of the request. /// - /// The "format" of a request is either the Content-Type, if the request - /// methods indicates support for a payload, or the preferred media type in - /// the Accept header otherwise. If the method indicates no payload and no - /// Accept header is specified, a media type of `Any` is returned. + /// The returned `MediaType` is derived from either the `Content-Type` or + /// the `Accept` header of the request, based on whether the request's + /// method allows a body (see [`Method::allows_request_body()`]). The table + /// below summarized this: + /// + /// | Method Allows Body | Returned Format | + /// |--------------------|---------------------------------| + /// | Always | `Option` | + /// | Maybe or Never | `Some(Preferred Accept or Any)` | + /// + /// In short, if the request's method indicates support for a payload, the + /// request's `Content-Type` header value, if any, is returned. Otherwise + /// the [preferred](Accept::preferred()) `Accept` header value is returned, + /// or if none is present, [`Accept::Any`]. /// /// The media type returned from this method is used to match against the /// `format` route attribute. @@ -691,7 +701,7 @@ impl<'r> Request<'r> { /// ``` pub fn format(&self) -> Option<&MediaType> { static ANY: MediaType = MediaType::Any; - if self.method().supports_payload() { + if self.method().allows_request_body().unwrap_or(false) { self.content_type().map(|ct| ct.media_type()) } else { // TODO: Should we be using `accept_first` or `preferred`? Or diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index 1c2aa3a1..f0978e4d 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -191,20 +191,19 @@ impl Collide for MediaType { } fn formats_collide(route: &Route, other: &Route) -> bool { - // If the routes' method doesn't support a payload, then format matching - // considers the `Accept` header. The client can always provide a media type - // that will cause a collision through non-specificity, i.e, `*/*`. - if !route.method.supports_payload() && !other.method.supports_payload() { - return true; - } - - // Payload supporting methods match against `Content-Type`. We only - // consider requests as having a `Content-Type` if they're fully - // specified. A route without a `format` accepts all `Content-Type`s. A - // request without a format only matches routes without a format. - match (route.format.as_ref(), other.format.as_ref()) { - (Some(a), Some(b)) => a.collides_with(b), - _ => true + match (route.method.allows_request_body(), other.method.allows_request_body()) { + // Payload supporting methods match against `Content-Type` which must be + // fully specified, so the request cannot contain a format that matches + // more than one route format as long as those formats don't collide. + (Some(true), Some(true)) => match (route.format.as_ref(), other.format.as_ref()) { + (Some(a), Some(b)) => a.collides_with(b), + // A route without a `format` accepts all `Content-Type`s. + _ => true + }, + // When a request method may not support a payload, the `Accept` header + // is considered during matching. The header can always be `*/*`, which + // would match any format. Thus two such routes would always collide. + _ => true, } } diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index b8468740..2d1ee14f 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -199,13 +199,12 @@ fn formats_match(route: &Route, req: &Request<'_>) -> bool { None => return true, }; - if route.method.supports_payload() { - match req.format() { + match route.method.allows_request_body() { + Some(true) => match req.format() { Some(f) if f.specificity() == 2 => route_format.collides_with(f), _ => false - } - } else { - match req.format() { + }, + _ => match req.format() { Some(f) => route_format.collides_with(f), None => true } @@ -287,7 +286,7 @@ mod tests { let client = Client::debug_with(vec![]).expect("client"); let mut req = client.req(m, "/"); if let Some(mt_str) = mt1.into() { - if m.supports_payload() { + if m.allows_request_body() == Some(true) { req.replace_header(mt_str.parse::().unwrap()); } else { req.replace_header(mt_str.parse::().unwrap()); diff --git a/core/lib/tests/form_method-issue-45.rs b/core/lib/tests/form_method-issue-45.rs index ac704619..009aa164 100644 --- a/core/lib/tests/form_method-issue-45.rs +++ b/core/lib/tests/form_method-issue-45.rs @@ -8,30 +8,63 @@ struct FormData { } #[patch("/", data = "")] -fn bug(form_data: Form) -> &'static str { +fn patch(form_data: Form) -> &'static str { assert_eq!("Form data", form_data.into_inner().form_data); - "OK" + "PATCH OK" +} + +#[route(UPDATEREDIRECTREF, uri = "/", data = "")] +fn urr(form_data: Form) -> &'static str { + assert_eq!("Form data", form_data.into_inner().form_data); + "UPDATEREDIRECTREF OK" +} + +#[route("VERSION-CONTROL", uri = "/", data = "")] +fn vc(form_data: Form) -> &'static str { + assert_eq!("Form data", form_data.into_inner().form_data); + "VERSION-CONTROL OK" } mod tests { use super::*; use rocket::local::blocking::Client; - use rocket::http::{Status, ContentType}; + use rocket::http::{Status, ContentType, Method}; #[test] fn method_eval() { - let client = Client::debug_with(routes![bug]).unwrap(); + let client = Client::debug_with(routes![patch, urr, vc]).unwrap(); let response = client.post("/") .header(ContentType::Form) .body("_method=patch&form_data=Form+data") .dispatch(); - assert_eq!(response.into_string(), Some("OK".into())); + assert_eq!(response.into_string(), Some("PATCH OK".into())); + + let response = client.post("/") + .header(ContentType::Form) + .body("_method=updateredirectref&form_data=Form+data") + .dispatch(); + + assert_eq!(response.into_string(), Some("UPDATEREDIRECTREF OK".into())); + + let response = client.req(Method::UpdateRedirectRef, "/") + .header(ContentType::Form) + .body("form_data=Form+data") + .dispatch(); + + assert_eq!(response.into_string(), Some("UPDATEREDIRECTREF OK".into())); + + let response = client.post("/") + .header(ContentType::Form) + .body("_method=version-control&form_data=Form+data") + .dispatch(); + + assert_eq!(response.into_string(), Some("VERSION-CONTROL OK".into())); } #[test] fn get_passes_through() { - let client = Client::debug_with(routes![bug]).unwrap(); + let client = Client::debug_with(routes![patch, urr, vc]).unwrap(); let response = client.get("/") .header(ContentType::Form) .body("_method=patch&form_data=Form+data") diff --git a/core/lib/tests/http_serde.rs b/core/lib/tests/http_serde.rs index 3c5388d5..c12173b1 100644 --- a/core/lib/tests/http_serde.rs +++ b/core/lib/tests/http_serde.rs @@ -133,7 +133,7 @@ fn method_serde() { jail.create_file("Rocket.toml", r#" [default] mget = "GET" - mput = "PuT" + mput = "PUT" mpost = "post" "#)?;