diff --git a/benchmarks/src/routing.rs b/benchmarks/src/routing.rs index c4fb8238..1a6dafcd 100644 --- a/benchmarks/src/routing.rs +++ b/benchmarks/src/routing.rs @@ -45,13 +45,13 @@ fn generate_matching_requests<'c>(client: &'c Client, routes: &[Route]) -> Vec(client: &'c Client, route: &Route) -> LocalRequest<'c> { - let path = route.uri.origin.path() + let path = route.uri.uri.path() .raw_segments() .map(staticify_segment) .collect::>() .join("/"); - let query = route.uri.origin.query() + let query = route.uri.uri.query() .map(|q| q.raw_segments()) .into_iter() .flatten() diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 40797e5d..1b66e9ce 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -158,7 +158,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { #_Err(__error) => return #parse_error, }, #_None => { - #_log::error_!("Internal invariant broken: dyn param not found."); + #_log::error_!("Internal invariant broken: dyn param {} not found.", #i); #_log::error_!("Please report this to the Rocket issue tracker."); #_log::error_!("https://github.com/SergioBenitez/Rocket/issues"); return #Outcome::Forward(#__data); diff --git a/core/codegen/src/attribute/route/parse.rs b/core/codegen/src/attribute/route/parse.rs index 59cbd4e3..88895161 100644 --- a/core/codegen/src/attribute/route/parse.rs +++ b/core/codegen/src/attribute/route/parse.rs @@ -77,9 +77,11 @@ impl FromMeta for RouteUri { .help("expected URI in origin form: \"/path/\"") })?; - if !origin.is_normalized() { - let normalized = origin.clone().into_normalized(); + if !origin.is_normalized_nontrailing() { + let normalized = origin.clone().into_normalized_nontrailing(); let span = origin.path().find("//") + .or_else(|| origin.has_trailing_slash() + .then_some(origin.path().len() - 1)) .or_else(|| origin.query() .and_then(|q| q.find("&&")) .map(|i| origin.path().len() + 1 + i)) diff --git a/core/codegen/src/bang/uri_parsing.rs b/core/codegen/src/bang/uri_parsing.rs index d7c772dc..e6ae1e23 100644 --- a/core/codegen/src/bang/uri_parsing.rs +++ b/core/codegen/src/bang/uri_parsing.rs @@ -309,7 +309,7 @@ impl Parse for InternalUriParams { // Validation should always succeed since this macro can only be called // if the route attribute succeeded, implying a valid route URI. let route_uri = Origin::parse_route(&route_uri_str) - .map(|o| o.into_normalized().into_owned()) + .map(|o| o.into_normalized_nontrailing().into_owned()) .map_err(|_| input.error("internal error: invalid route URI"))?; let content; diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index db74a45b..6fe269aa 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -334,8 +334,9 @@ fn test_inclusive_segments() { assert_eq!(get("/"), "empty+"); assert_eq!(get("//"), "empty+"); - assert_eq!(get("//a/"), "empty+a"); - assert_eq!(get("//a//"), "empty+a"); + assert_eq!(get("//a"), "empty+a"); + assert_eq!(get("//a/"), "empty+a/"); + assert_eq!(get("//a//"), "empty+a/"); assert_eq!(get("//a//c/d"), "empty+a/c/d"); assert_eq!(get("//a/b"), "nonempty+"); @@ -343,4 +344,5 @@ fn test_inclusive_segments() { assert_eq!(get("//a/b//c"), "nonempty+c"); assert_eq!(get("//a//b////c"), "nonempty+c"); assert_eq!(get("//a//b////c/d/e"), "nonempty+c/d/e"); + assert_eq!(get("//a//b////c/d/e/"), "nonempty+c/d/e/"); } diff --git a/core/codegen/tests/typed-uris.rs b/core/codegen/tests/typed-uris.rs index da9eb670..bef5eedf 100644 --- a/core/codegen/tests/typed-uris.rs +++ b/core/codegen/tests/typed-uris.rs @@ -14,7 +14,7 @@ macro_rules! assert_uri_eq { let actual = $uri; let expected = rocket::http::uri::Uri::parse_any($expected).expect("valid URI"); if actual != expected { - panic!("URI mismatch: got {}, expected {}\nGot) {:?}\nExpected) {:?}", + panic!("\nURI mismatch: got {}, expected {}\nGot) {:?}\nExpected) {:?}\n", actual, expected, actual, expected); } )+ @@ -186,6 +186,7 @@ fn check_simple_named() { fn check_route_prefix_suffix() { assert_uri_eq! { uri!(index) => "/", + uri!("/") => "/", uri!("/", index) => "/", uri!("/hi", index) => "/hi", uri!("/", simple3(10)) => "/?id=10", @@ -194,21 +195,33 @@ fn check_route_prefix_suffix() { uri!("/mount", simple(id = 23)) => "/mount/23", uri!("/another", simple(100)) => "/another/100", uri!("/another", simple(id = 23)) => "/another/23", + uri!("/foo") => "/foo", + uri!("/foo/") => "/foo/", + uri!("/foo///") => "/foo/", + uri!("/foo/bar/") => "/foo/bar/", + uri!("/foo/", index) => "/foo/", + uri!("/foo", index) => "/foo", } assert_uri_eq! { uri!("http://rocket.rs", index) => "http://rocket.rs", - uri!("http://rocket.rs/", index) => "http://rocket.rs", - uri!("http://rocket.rs", index) => "http://rocket.rs", + uri!("http://rocket.rs/", index) => "http://rocket.rs/", + uri!("http://rocket.rs/foo", index) => "http://rocket.rs/foo", + uri!("http://rocket.rs/foo/", index) => "http://rocket.rs/foo/", uri!("http://", index) => "http://", uri!("ftp:", index) => "ftp:/", } assert_uri_eq! { uri!("http://rocket.rs", index, "?foo") => "http://rocket.rs?foo", - uri!("http://rocket.rs/", index, "#bar") => "http://rocket.rs#bar", + uri!("http://rocket.rs", index, "?") => "http://rocket.rs?", + uri!("http://rocket.rs", index, "#") => "http://rocket.rs#", + uri!("http://rocket.rs/", index, "?") => "http://rocket.rs/?", + uri!("http://rocket.rs/", index, "#") => "http://rocket.rs/#", + uri!("http://rocket.rs", index, "#bar") => "http://rocket.rs#bar", + uri!("http://rocket.rs/", index, "#bar") => "http://rocket.rs/#bar", uri!("http://rocket.rs", index, "?bar#baz") => "http://rocket.rs?bar#baz", - uri!("http://rocket.rs/", index, "?bar#baz") => "http://rocket.rs?bar#baz", + uri!("http://rocket.rs/", index, "?bar#baz") => "http://rocket.rs/?bar#baz", uri!("http://", index, "?foo") => "http://?foo", uri!("http://rocket.rs", simple3(id = 100), "?foo") => "http://rocket.rs?id=100", uri!("http://rocket.rs", simple3(id = 100), "?foo#bar") => "http://rocket.rs?id=100#bar", @@ -239,8 +252,8 @@ fn check_route_prefix_suffix() { let dyn_abs = uri!("http://rocket.rs?foo"); assert_uri_eq! { uri!(_, index, dyn_abs.clone()) => "/?foo", - uri!("http://rocket.rs/", index, dyn_abs.clone()) => "http://rocket.rs?foo", uri!("http://rocket.rs", index, dyn_abs.clone()) => "http://rocket.rs?foo", + uri!("http://rocket.rs/", index, dyn_abs.clone()) => "http://rocket.rs/?foo", uri!("http://", index, dyn_abs.clone()) => "http://?foo", uri!(_, simple3(id = 123), dyn_abs) => "/?id=123", } @@ -248,8 +261,8 @@ fn check_route_prefix_suffix() { let dyn_ref = uri!("?foo#bar"); assert_uri_eq! { uri!(_, index, dyn_ref.clone()) => "/?foo#bar", - uri!("http://rocket.rs/", index, dyn_ref.clone()) => "http://rocket.rs?foo#bar", uri!("http://rocket.rs", index, dyn_ref.clone()) => "http://rocket.rs?foo#bar", + uri!("http://rocket.rs/", index, dyn_ref.clone()) => "http://rocket.rs/?foo#bar", uri!("http://", index, dyn_ref.clone()) => "http://?foo#bar", uri!(_, simple3(id = 123), dyn_ref) => "/?id=123#bar", } diff --git a/core/codegen/tests/ui-fail-nightly/async-entry.stderr b/core/codegen/tests/ui-fail-nightly/async-entry.stderr index dc7cf612..3ae1393b 100644 --- a/core/codegen/tests/ui-fail-nightly/async-entry.stderr +++ b/core/codegen/tests/ui-fail-nightly/async-entry.stderr @@ -141,8 +141,9 @@ error[E0308]: mismatched types --> tests/ui-fail-nightly/async-entry.rs:24:21 | 24 | async fn main() { - | ^ expected `()` because of default return type - | _____________________| + | ^ + | | + | _____________________expected `()` because of default return type | | 25 | | rocket::build() 26 | | } diff --git a/core/codegen/tests/ui-fail-nightly/route-path-bad-syntax.stderr b/core/codegen/tests/ui-fail-nightly/route-path-bad-syntax.stderr index c59c4ebc..f322538b 100644 --- a/core/codegen/tests/ui-fail-nightly/route-path-bad-syntax.stderr +++ b/core/codegen/tests/ui-fail-nightly/route-path-bad-syntax.stderr @@ -240,3 +240,27 @@ warning: `segment` starts with `<` but does not end with `>` | ^^^^^^^^ | = help: perhaps you meant the dynamic parameter ``? + +error: route URIs cannot contain empty segments + --> tests/ui-fail-nightly/route-path-bad-syntax.rs:107:10 + | +107 | #[get("/a/")] + | ^^ + | + = note: expected "/a", found "/a/" + +error: route URIs cannot contain empty segments + --> tests/ui-fail-nightly/route-path-bad-syntax.rs:110:12 + | +110 | #[get("/a/b/")] + | ^^ + | + = note: expected "/a/b", found "/a/b/" + +error: route URIs cannot contain empty segments + --> tests/ui-fail-nightly/route-path-bad-syntax.rs:113:14 + | +113 | #[get("/a/b/c/")] + | ^^ + | + = note: expected "/a/b/c", found "/a/b/c/" diff --git a/core/codegen/tests/ui-fail-stable/route-path-bad-syntax.stderr b/core/codegen/tests/ui-fail-stable/route-path-bad-syntax.stderr index c48ccadc..bc4b8ccc 100644 --- a/core/codegen/tests/ui-fail-stable/route-path-bad-syntax.stderr +++ b/core/codegen/tests/ui-fail-stable/route-path-bad-syntax.stderr @@ -180,3 +180,24 @@ error: parameters cannot be empty | 93 | #[get("/<>")] | ^^^^^ + +error: route URIs cannot contain empty segments + --- note: expected "/a", found "/a/" + --> tests/ui-fail-stable/route-path-bad-syntax.rs:107:7 + | +107 | #[get("/a/")] + | ^^^^^ + +error: route URIs cannot contain empty segments + --- note: expected "/a/b", found "/a/b/" + --> tests/ui-fail-stable/route-path-bad-syntax.rs:110:7 + | +110 | #[get("/a/b/")] + | ^^^^^^^ + +error: route URIs cannot contain empty segments + --- note: expected "/a/b/c", found "/a/b/c/" + --> tests/ui-fail-stable/route-path-bad-syntax.rs:113:7 + | +113 | #[get("/a/b/c/")] + | ^^^^^^^^^ diff --git a/core/codegen/tests/ui-fail/route-path-bad-syntax.rs b/core/codegen/tests/ui-fail/route-path-bad-syntax.rs index eb7b02e5..9eea869b 100644 --- a/core/codegen/tests/ui-fail/route-path-bad-syntax.rs +++ b/core/codegen/tests/ui-fail/route-path-bad-syntax.rs @@ -54,7 +54,7 @@ fn h3() {} #[get("/<_r>/")] fn h4() {} - +// // Check dynamic parameters are valid idents #[get("/")] @@ -102,4 +102,15 @@ fn m2() {} #[get("/<>name><")] fn m3() {} +// New additions for trailing paths, which we artificially disallow. + +#[get("/a/")] +fn n1() {} + +#[get("/a/b/")] +fn n2() {} + +#[get("/a/b/c/")] +fn n3() {} + fn main() { } diff --git a/core/http/src/ext.rs b/core/http/src/ext.rs index 1e263d4e..b1b40eaa 100644 --- a/core/http/src/ext.rs +++ b/core/http/src/ext.rs @@ -126,7 +126,6 @@ impl IntoOwned for (A, B) { } } - impl IntoOwned for Cow<'_, B> { type Owned = Cow<'static, B>; @@ -149,6 +148,7 @@ macro_rules! impl_into_owned_self { )*) } +impl_into_owned_self!(bool); impl_into_owned_self!(u8, u16, u32, u64, usize); impl_into_owned_self!(i8, i16, i32, i64, isize); diff --git a/core/http/src/raw_str.rs b/core/http/src/raw_str.rs index b22d6c88..1aa37b2b 100644 --- a/core/http/src/raw_str.rs +++ b/core/http/src/raw_str.rs @@ -180,6 +180,11 @@ impl RawStr { /// ``` #[inline(always)] pub fn percent_decode(&self) -> Result, Utf8Error> { + // don't let `percent-encoding` return a random empty string + if self.is_empty() { + return Ok(self.as_str().into()); + } + self._percent_decode().decode_utf8() } @@ -213,6 +218,11 @@ impl RawStr { /// ``` #[inline(always)] pub fn percent_decode_lossy(&self) -> Cow<'_, str> { + // don't let `percent-encoding` return a random empty string + if self.is_empty() { + return self.as_str().into(); + } + self._percent_decode().decode_utf8_lossy() } @@ -658,7 +668,6 @@ impl RawStr { pat.is_suffix_of(self.as_str()) } - /// Returns the byte index of the first character of this string slice that /// matches the pattern. /// @@ -710,8 +719,9 @@ impl RawStr { /// assert_eq!(v, ["Mary", "had", "a", "little", "lamb"]); /// ``` #[inline] - pub fn split<'a, P>(&'a self, pat: P) -> impl Iterator - where P: Pattern<'a> + pub fn split<'a, P>(&'a self, pat: P) -> impl DoubleEndedIterator + where P: Pattern<'a>, +

>::Searcher: stable_pattern::DoubleEndedSearcher<'a> { let split: Split<'_, P> = Split(SplitInternal { start: 0, @@ -837,6 +847,28 @@ impl RawStr { suffix.strip_suffix_of(self.as_str()).map(RawStr::new) } + /// Returns a string slice with leading and trailing whitespace removed. + /// + /// 'Whitespace' is defined according to the terms of the Unicode Derived + /// Core Property `White_Space`, which includes newlines. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// # extern crate rocket; + /// use rocket::http::RawStr; + /// + /// let s = RawStr::new("\n Hello\tworld\t\n"); + /// + /// assert_eq!("Hello\tworld", s.trim()); + /// ``` + #[inline] + pub fn trim(&self) -> &RawStr { + RawStr::new(self.as_str().trim_matches(|c: char| c.is_whitespace())) + } + /// Parses this string slice into another type. /// /// Because `parse` is so general, it can cause problems with type diff --git a/core/http/src/uri/absolute.rs b/core/http/src/uri/absolute.rs index 8d5dbaf5..e084e63c 100644 --- a/core/http/src/uri/absolute.rs +++ b/core/http/src/uri/absolute.rs @@ -24,9 +24,9 @@ use crate::uri::{Authority, Path, Query, Data, Error, as_utf8_unchecked, fmt}; /// Rocket prefers _normalized_ absolute URIs, an absolute URI with the /// following properties: /// -/// * The path and query, if any, are normalized with no empty segments. -/// * If there is an authority, the path is empty or absolute with more than -/// one character. +/// * If there is an authority, the path is empty or absolute. +/// * The path and query, if any, are normalized with no empty segments except +/// optionally for one trailing slash. /// /// The [`Absolute::is_normalized()`] method checks for normalization while /// [`Absolute::into_normalized()`] normalizes any absolute URI. @@ -38,8 +38,13 @@ use crate::uri::{Authority, Path, Query, Data, Error, as_utf8_unchecked, fmt}; /// # use rocket::http::uri::Absolute; /// # let valid_uris = [ /// "http://rocket.rs", +/// "http://rocket.rs/", +/// "ftp:/a/b/", +/// "ftp:/a/b/?", /// "scheme:/foo/bar", -/// "scheme:/foo/bar?abc", +/// "scheme:/foo/bar/", +/// "scheme:/foo/bar/?", +/// "scheme:/foo/bar/?abc", /// # ]; /// # for uri in &valid_uris { /// # let uri = Absolute::parse(uri).unwrap(); @@ -53,11 +58,9 @@ use crate::uri::{Authority, Path, Query, Data, Error, as_utf8_unchecked, fmt}; /// # extern crate rocket; /// # use rocket::http::uri::Absolute; /// # let invalid = [ -/// "http://rocket.rs/", // trailing '/' -/// "ftp:/a/b/", // trailing empty segment /// "ftp:/a//c//d", // two empty segments -/// "ftp:/a/b/?", // empty path segment /// "ftp:/?foo&", // trailing empty query segment +/// "ftp:/?fooa&&b", // empty query segment /// # ]; /// # for uri in &invalid { /// # assert!(!Absolute::parse(uri).unwrap().is_normalized()); @@ -263,17 +266,15 @@ impl<'a> Absolute<'a> { /// assert!(Absolute::parse("http://").unwrap().is_normalized()); /// assert!(Absolute::parse("http://foo.rs/foo/bar").unwrap().is_normalized()); /// assert!(Absolute::parse("foo:bar").unwrap().is_normalized()); + /// assert!(Absolute::parse("git://rocket.rs/").unwrap().is_normalized()); /// - /// assert!(!Absolute::parse("git://rocket.rs/").unwrap().is_normalized()); /// assert!(!Absolute::parse("http:/foo//bar").unwrap().is_normalized()); /// assert!(!Absolute::parse("foo:bar?baz&&bop").unwrap().is_normalized()); /// ``` pub fn is_normalized(&self) -> bool { let normalized_query = self.query().map_or(true, |q| q.is_normalized()); if self.authority().is_some() && !self.path().is_empty() { - self.path().is_normalized(true) - && self.path() != "/" - && normalized_query + self.path().is_normalized(true) && normalized_query } else { self.path().is_normalized(false) && normalized_query } @@ -287,9 +288,10 @@ impl<'a> Absolute<'a> { /// ```rust /// use rocket::http::uri::Absolute; /// + /// let mut uri = Absolute::parse("git://rocket.rs").unwrap(); + /// assert!(uri.is_normalized()); + /// /// let mut uri = Absolute::parse("git://rocket.rs/").unwrap(); - /// assert!(!uri.is_normalized()); - /// uri.normalize(); /// assert!(uri.is_normalized()); /// /// let mut uri = Absolute::parse("http:/foo//bar").unwrap(); @@ -304,18 +306,18 @@ impl<'a> Absolute<'a> { /// ``` pub fn normalize(&mut self) { if self.authority().is_some() && !self.path().is_empty() { - if self.path() == "/" { - self.set_path(""); - } else if !self.path().is_normalized(true) { - self.path = self.path().to_normalized(true); + if !self.path().is_normalized(true) { + self.path = self.path().to_normalized(true, true); } } else { - self.path = self.path().to_normalized(false); + if !self.path().is_normalized(false) { + self.path = self.path().to_normalized(false, true); + } } if let Some(query) = self.query() { if !query.is_normalized() { - self.query = query.to_normalized(); + self.query = Some(query.to_normalized()); } } } @@ -328,8 +330,7 @@ impl<'a> Absolute<'a> { /// use rocket::http::uri::Absolute; /// /// let mut uri = Absolute::parse("git://rocket.rs/").unwrap(); - /// assert!(!uri.is_normalized()); - /// assert!(uri.into_normalized().is_normalized()); + /// assert!(uri.is_normalized()); /// /// let mut uri = Absolute::parse("http:/foo//bar").unwrap(); /// assert!(!uri.is_normalized()); diff --git a/core/http/src/uri/origin.rs b/core/http/src/uri/origin.rs index fa2c214e..716dd163 100644 --- a/core/http/src/uri/origin.rs +++ b/core/http/src/uri/origin.rs @@ -27,8 +27,8 @@ use crate::{RawStr, RawStrBuf}; /// # Normalization /// /// Rocket prefers, and will sometimes require, origin URIs to be _normalized_. -/// A normalized origin URI is a valid origin URI that contains zero empty -/// segments except when there are no segments. +/// A normalized origin URI is a valid origin URI that contains no empty +/// segments except optionally a trailing slash. /// /// As an example, the following URIs are all valid, normalized URIs: /// @@ -37,9 +37,14 @@ use crate::{RawStr, RawStrBuf}; /// # use rocket::http::uri::Origin; /// # let valid_uris = [ /// "/", +/// "/?", +/// "/a/b/", /// "/a/b/c", +/// "/a/b/c/", +/// "/a/b/c?", /// "/a/b/c?q", /// "/hello?lang=en", +/// "/hello/?lang=en", /// "/some%20thing?q=foo&lang=fr", /// # ]; /// # for uri in &valid_uris { @@ -53,8 +58,7 @@ use crate::{RawStr, RawStrBuf}; /// # extern crate rocket; /// # use rocket::http::uri::Origin; /// # let invalid = [ -/// "//", // one empty segment -/// "/a/b/", // trailing empty segment +/// "//", // an empty segment /// "/a/ab//c//d", // two empty segments /// "/?a&&b", // empty query segment /// "/?foo&", // trailing empty query segment @@ -72,10 +76,10 @@ use crate::{RawStr, RawStrBuf}; /// # use rocket::http::uri::Origin; /// # let invalid = [ /// // non-normal versions -/// "//", "/a/b/", "/a/ab//c//d", "/a?a&&b&", +/// "//", "/a/b//c", "/a/ab//c//d/", "/a?a&&b&", /// /// // normalized versions -/// "/", "/a/b", "/a/ab/c/d", "/a?a&b", +/// "/", "/a/b/c", "/a/ab/c/d/", "/a?a&b", /// # ]; /// # for i in 0..(invalid.len() / 2) { /// # let abnormal = Origin::parse(invalid[i]).unwrap(); @@ -219,9 +223,11 @@ impl<'a> Origin<'a> { }); } - let (path, query) = RawStr::new(string).split_at_byte(b'?'); - let query = (!query.is_empty()).then(|| query.as_str()); - Ok(Origin::new(path.as_str(), query)) + let (path, query) = string.split_once('?') + .map(|(path, query)| (path, Some(query))) + .unwrap_or((string, None)); + + Ok(Origin::new(path, query)) } /// Parses the string `string` into an `Origin`. Never allocates on success. @@ -376,6 +382,18 @@ impl<'a> Origin<'a> { self.path().is_normalized(true) && self.query().map_or(true, |q| q.is_normalized()) } + fn _normalize(&mut self, allow_trail: bool) { + if !self.path().is_normalized(true) { + self.path = self.path().to_normalized(true, allow_trail); + } + + if let Some(query) = self.query() { + if !query.is_normalized() { + self.query = Some(query.to_normalized()); + } + } + } + /// Normalizes `self`. This is a no-op if `self` is already normalized. /// /// See [Normalization](#normalization) for more information on what it @@ -393,15 +411,7 @@ impl<'a> Origin<'a> { /// assert!(abnormal.is_normalized()); /// ``` pub fn normalize(&mut self) { - if !self.path().is_normalized(true) { - self.path = self.path().to_normalized(true); - } - - if let Some(query) = self.query() { - if !query.is_normalized() { - self.query = query.to_normalized(); - } - } + self._normalize(true); } /// Consumes `self` and returns a normalized version. @@ -424,6 +434,116 @@ impl<'a> Origin<'a> { self.normalize(); self } + + /// Returns `true` if `self` has a _trailing_ slash. + /// + /// This is defined as `path.len() > 1` && `path.ends_with('/')`. This + /// implies that the URI `/` is _not_ considered to have a trailing slash. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// + /// assert!(!uri!("/").has_trailing_slash()); + /// assert!(!uri!("/a").has_trailing_slash()); + /// assert!(!uri!("/foo/bar/baz").has_trailing_slash()); + /// + /// assert!(uri!("/a/").has_trailing_slash()); + /// assert!(uri!("/foo/").has_trailing_slash()); + /// assert!(uri!("/foo/bar/baz/").has_trailing_slash()); + /// ``` + pub fn has_trailing_slash(&self) -> bool { + self.path().len() > 1 && self.path().ends_with('/') + } + + /// Returns `true` if `self` is normalized ([`Origin::is_normalized()`]) and + /// **does not** have a trailing slash ([Origin::has_trailing_slash()]). + /// Otherwise returns `false`. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::http::uri::Origin; + /// + /// let origin = Origin::parse("/").unwrap(); + /// assert!(origin.is_normalized_nontrailing()); + /// + /// let origin = Origin::parse("/foo/bar").unwrap(); + /// assert!(origin.is_normalized_nontrailing()); + /// + /// let origin = Origin::parse("//").unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// + /// let origin = Origin::parse("/foo/bar//baz/").unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// + /// let origin = Origin::parse("/foo/bar/").unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// ``` + pub fn is_normalized_nontrailing(&self) -> bool { + self.is_normalized() && !self.has_trailing_slash() + } + + /// Converts `self` into a normalized origin path without a trailing slash. + /// Does nothing is `self` is already [`normalized_nontrailing`]. + /// + /// [`normalized_nontrailing`]: Origin::is_normalized_nontrailing() + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::http::uri::Origin; + /// + /// let origin = Origin::parse("/").unwrap(); + /// assert!(origin.is_normalized_nontrailing()); + /// + /// let normalized = origin.into_normalized_nontrailing(); + /// assert_eq!(normalized, uri!("/")); + /// + /// let origin = Origin::parse("//").unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// + /// let normalized = origin.into_normalized_nontrailing(); + /// assert_eq!(normalized, uri!("/")); + /// + /// let origin = Origin::parse_owned("/foo/bar//baz/".into()).unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// + /// let normalized = origin.into_normalized_nontrailing(); + /// assert_eq!(normalized, uri!("/foo/bar/baz")); + /// + /// let origin = Origin::parse("/foo/bar/").unwrap(); + /// assert!(!origin.is_normalized_nontrailing()); + /// + /// let normalized = origin.into_normalized_nontrailing(); + /// assert_eq!(normalized, uri!("/foo/bar")); + /// ``` + pub fn into_normalized_nontrailing(mut self) -> Self { + if !self.is_normalized_nontrailing() { + if self.is_normalized() && self.has_trailing_slash() { + let indexed = match self.path.value { + IndexedStr::Indexed(i, j) => IndexedStr::Indexed(i, j - 1), + IndexedStr::Concrete(cow) => IndexedStr::Concrete(match cow { + Cow::Borrowed(s) => Cow::Borrowed(&s[..s.len() - 1]), + Cow::Owned(mut s) => Cow::Owned({ s.pop(); s }), + }) + }; + + self.path = Data { + value: indexed, + decoded_segments: state::Storage::new(), + }; + } else { + self._normalize(false); + } + } + + self + + } } impl_serde!(Origin<'a>, "an origin-form URI"); @@ -448,7 +568,7 @@ mod tests { fn seg_count(path: &str, expected: usize) -> bool { let origin = Origin::parse(path).unwrap(); let segments = origin.path().segments(); - let actual = segments.len(); + let actual = segments.num(); if actual != expected { eprintln!("Count mismatch: expected {}, got {}.", expected, actual); eprintln!("{}", if actual != expected { "lifetime" } else { "buf" }); @@ -479,26 +599,24 @@ mod tests { #[test] fn simple_segment_count() { - assert!(seg_count("/", 0)); + assert!(seg_count("/", 1)); assert!(seg_count("/a", 1)); - assert!(seg_count("/a/", 1)); - assert!(seg_count("/a/", 1)); + assert!(seg_count("/a/", 2)); assert!(seg_count("/a/b", 2)); - assert!(seg_count("/a/b/", 2)); - assert!(seg_count("/a/b/", 2)); - assert!(seg_count("/ab/", 1)); + assert!(seg_count("/a/b/", 3)); + assert!(seg_count("/ab/", 2)); } #[test] fn segment_count() { - assert!(seg_count("////", 0)); - assert!(seg_count("//a//", 1)); - assert!(seg_count("//abc//", 1)); - assert!(seg_count("//abc/def/", 2)); - assert!(seg_count("//////abc///def//////////", 2)); + assert!(seg_count("////", 1)); + assert!(seg_count("//a//", 2)); + assert!(seg_count("//abc//", 2)); + assert!(seg_count("//abc/def/", 3)); + assert!(seg_count("//////abc///def//////////", 3)); assert!(seg_count("/a/b/c/d/e/f/g", 7)); assert!(seg_count("/a/b/c/d/e/f/g", 7)); - assert!(seg_count("/a/b/c/d/e/f/g/", 7)); + assert!(seg_count("/a/b/c/d/e/f/g/", 8)); assert!(seg_count("/a/b/cdjflk/d/e/f/g", 7)); assert!(seg_count("//aaflja/b/cdjflk/d/e/f/g", 7)); assert!(seg_count("/a/b", 2)); @@ -506,18 +624,18 @@ mod tests { #[test] fn single_segments_match() { - assert!(eq_segments("/", &[])); + assert!(eq_segments("/", &[""])); assert!(eq_segments("/a", &["a"])); - assert!(eq_segments("/a/", &["a"])); - assert!(eq_segments("///a/", &["a"])); - assert!(eq_segments("///a///////", &["a"])); - assert!(eq_segments("/a///////", &["a"])); + assert!(eq_segments("/a/", &["a", ""])); + assert!(eq_segments("///a/", &["a", ""])); + assert!(eq_segments("///a///////", &["a", ""])); + assert!(eq_segments("/a///////", &["a", ""])); assert!(eq_segments("//a", &["a"])); assert!(eq_segments("/abc", &["abc"])); - assert!(eq_segments("/abc/", &["abc"])); - assert!(eq_segments("///abc/", &["abc"])); - assert!(eq_segments("///abc///////", &["abc"])); - assert!(eq_segments("/abc///////", &["abc"])); + assert!(eq_segments("/abc/", &["abc", ""])); + assert!(eq_segments("///abc/", &["abc", ""])); + assert!(eq_segments("///abc///////", &["abc", ""])); + assert!(eq_segments("/abc///////", &["abc", ""])); assert!(eq_segments("//abc", &["abc"])); } @@ -529,10 +647,11 @@ mod tests { assert!(eq_segments("/a/b/c/d", &["a", "b", "c", "d"])); assert!(eq_segments("///a///////d////c", &["a", "d", "c"])); assert!(eq_segments("/abc/abc", &["abc", "abc"])); - assert!(eq_segments("/abc/abc/", &["abc", "abc"])); + assert!(eq_segments("/abc/abc/", &["abc", "abc", ""])); assert!(eq_segments("///abc///////a", &["abc", "a"])); assert!(eq_segments("/////abc/b", &["abc", "b"])); assert!(eq_segments("//abc//c////////d", &["abc", "c", "d"])); + assert!(eq_segments("//abc//c////////d/", &["abc", "c", "d", ""])); } #[test] @@ -548,6 +667,8 @@ mod tests { assert!(!eq_segments("/a/b", &["b", "a"])); assert!(!eq_segments("/a/a/b", &["a", "b"])); assert!(!eq_segments("///a/", &[])); + assert!(!eq_segments("///a/", &["a"])); + assert!(!eq_segments("///a/", &["a", "a"])); } fn test_query(uri: &str, query: Option<&str>) { @@ -575,20 +696,4 @@ mod tests { test_query("/?", Some("")); test_query("/?hi", Some("hi")); } - - #[test] - fn normalized() { - let uri_to_string = |s| Origin::parse(s) - .unwrap() - .into_normalized() - .to_string(); - - assert_eq!(uri_to_string("/"), "/".to_string()); - assert_eq!(uri_to_string("//"), "/".to_string()); - assert_eq!(uri_to_string("//////a/"), "/a".to_string()); - assert_eq!(uri_to_string("//ab"), "/ab".to_string()); - assert_eq!(uri_to_string("//a"), "/a".to_string()); - assert_eq!(uri_to_string("/a/b///c"), "/a/b/c".to_string()); - assert_eq!(uri_to_string("/a///b/c/d///"), "/a/b/c/d".to_string()); - } } diff --git a/core/http/src/uri/path_query.rs b/core/http/src/uri/path_query.rs index 54725709..2b9ea23d 100644 --- a/core/http/src/uri/path_query.rs +++ b/core/http/src/uri/path_query.rs @@ -1,6 +1,5 @@ use std::hash::Hash; use std::borrow::Cow; -use std::fmt::Write; use state::Storage; @@ -57,9 +56,9 @@ fn decode_to_indexed_str( match decoded { Cow::Borrowed(b) if indexed.is_indexed() => { - let indexed = IndexedStr::checked_from(b, source.as_str()); - debug_assert!(indexed.is_some()); - indexed.unwrap_or_else(|| IndexedStr::from(Cow::Borrowed(""))) + let checked = IndexedStr::checked_from(b, source.as_str()); + debug_assert!(checked.is_some(), "\nunindexed {:?} in {:?} {:?}", b, indexed, source); + checked.unwrap_or_else(|| IndexedStr::from(Cow::Borrowed(""))) } cow => IndexedStr::from(Cow::Owned(cow.into_owned())), } @@ -94,24 +93,37 @@ impl<'a> Path<'a> { self.raw().as_str() } - /// Whether `self` is normalized, i.e, it has no empty segments. + /// Whether `self` is normalized, i.e, it has no empty segments except the + /// last one. /// /// If `absolute`, then a starting `/` is required. pub(crate) fn is_normalized(&self, absolute: bool) -> bool { - (!absolute || self.raw().starts_with('/')) - && self.raw_segments().all(|s| !s.is_empty()) - } - - /// Normalizes `self`. If `absolute`, a starting `/` is required. - pub(crate) fn to_normalized(self, absolute: bool) -> Data<'static, fmt::Path> { - let mut path = String::with_capacity(self.raw().len()); - let absolute = absolute || self.raw().starts_with('/'); - for (i, seg) in self.raw_segments().filter(|s| !s.is_empty()).enumerate() { - if absolute || i != 0 { path.push('/'); } - let _ = write!(path, "{}", seg); + if absolute && !self.raw().starts_with('/') { + return false; } - if path.is_empty() && absolute { + self.raw_segments() + .rev() + .skip(1) + .all(|s| !s.is_empty()) + } + + /// Normalizes `self`. If `absolute`, a starting `/` is required. If + /// `trail`, a trailing slash is allowed. Otherwise it is not. + pub(crate) fn to_normalized(self, absolute: bool, trail: bool) -> Data<'static, fmt::Path> { + let raw = self.raw().trim(); + let mut path = String::with_capacity(raw.len()); + + if absolute || raw.starts_with('/') { + path.push('/'); + } + + for (i, segment) in self.raw_segments().filter(|s| !s.is_empty()).enumerate() { + if i != 0 { path.push('/'); } + path.push_str(segment.as_str()); + } + + if trail && raw.len() > 1 && raw.ends_with('/') && !path.ends_with('/') { path.push('/'); } @@ -121,8 +133,8 @@ impl<'a> Path<'a> { } } - /// Returns an iterator over the raw, undecoded segments. Segments may be - /// empty. + /// Returns an iterator over the raw, undecoded segments, potentially empty + /// segments. /// /// ### Example /// @@ -131,38 +143,41 @@ impl<'a> Path<'a> { /// use rocket::http::uri::Origin; /// /// let uri = Origin::parse("/").unwrap(); - /// assert_eq!(uri.path().raw_segments().count(), 0); + /// let segments: Vec<_> = uri.path().raw_segments().collect(); + /// assert_eq!(segments, &[""]); /// /// let uri = Origin::parse("//").unwrap(); /// let segments: Vec<_> = uri.path().raw_segments().collect(); /// assert_eq!(segments, &["", ""]); /// + /// let uri = Origin::parse("/foo").unwrap(); + /// let segments: Vec<_> = uri.path().raw_segments().collect(); + /// assert_eq!(segments, &["foo"]); + /// + /// let uri = Origin::parse("/a/").unwrap(); + /// let segments: Vec<_> = uri.path().raw_segments().collect(); + /// assert_eq!(segments, &["a", ""]); + /// /// // Recall that `uri!()` normalizes static inputs. /// let uri = uri!("//"); - /// assert_eq!(uri.path().raw_segments().count(), 0); - /// - /// let uri = Origin::parse("/a").unwrap(); /// let segments: Vec<_> = uri.path().raw_segments().collect(); - /// assert_eq!(segments, &["a"]); + /// assert_eq!(segments, &[""]); /// /// let uri = Origin::parse("/a//b///c/d?query¶m").unwrap(); /// let segments: Vec<_> = uri.path().raw_segments().collect(); /// assert_eq!(segments, &["a", "", "b", "", "", "c", "d"]); /// ``` - #[inline(always)] - pub fn raw_segments(&self) -> impl Iterator { - let path = match self.raw() { - p if p.is_empty() || p == "/" => None, - p if p.starts_with(fmt::Path::DELIMITER) => Some(&p[1..]), - p => Some(p) - }; - - path.map(|p| p.split(fmt::Path::DELIMITER)) - .into_iter() - .flatten() + #[inline] + pub fn raw_segments(&self) -> impl DoubleEndedIterator { + let raw = self.raw().trim(); + raw.strip_prefix(fmt::Path::DELIMITER) + .unwrap_or(raw) + .split(fmt::Path::DELIMITER) } - /// Returns a (smart) iterator over the non-empty, percent-decoded segments. + /// Returns a (smart) iterator over the percent-decoded segments. Empty + /// segments between non-empty segments are skipped. A trailing slash will + /// result in an empty segment emitted as the final item. /// /// # Example /// @@ -170,20 +185,52 @@ impl<'a> Path<'a> { /// # #[macro_use] extern crate rocket; /// use rocket::http::uri::Origin; /// + /// let uri = Origin::parse("/").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &[""]); + /// + /// let uri = Origin::parse("/a").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &["a"]); + /// + /// let uri = Origin::parse("/a/").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &["a", ""]); + /// + /// let uri = Origin::parse("/foo/bar").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &["foo", "bar"]); + /// + /// let uri = Origin::parse("/foo///bar").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &["foo", "bar"]); + /// + /// let uri = Origin::parse("/foo///bar//").unwrap(); + /// let path_segs: Vec<&str> = uri.path().segments().collect(); + /// assert_eq!(path_segs, &["foo", "bar", ""]); + /// /// let uri = Origin::parse("/a%20b/b%2Fc/d//e?query=some").unwrap(); /// let path_segs: Vec<&str> = uri.path().segments().collect(); /// assert_eq!(path_segs, &["a b", "b/c", "d", "e"]); /// ``` pub fn segments(&self) -> Segments<'a, fmt::Path> { + let raw = self.raw(); let cached = self.data.decoded_segments.get_or_set(|| { - let (indexed, path) = (&self.data.value, self.raw()); - self.raw_segments() - .filter(|r| !r.is_empty()) - .map(|s| decode_to_indexed_str::(s, (indexed, path))) - .collect() + let mut segments = vec![]; + let mut raw_segments = self.raw_segments().peekable(); + while let Some(s) = raw_segments.next() { + // Only allow an empty segment if it's the last one. + if s.is_empty() && raw_segments.peek().is_some() { + continue; + } + + segments.push(decode_to_indexed_str::(s, (&self.data.value, raw))); + } + + segments }); - Segments::new(self.raw(), cached) + Segments::new(raw, cached) } } @@ -218,30 +265,26 @@ impl<'a> Query<'a> { /// Whether `self` is normalized, i.e, it has no empty segments. pub(crate) fn is_normalized(&self) -> bool { - !self.is_empty() && self.raw_segments().all(|s| !s.is_empty()) + self.raw_segments().all(|s| !s.is_empty()) } /// Normalizes `self`. - pub(crate) fn to_normalized(self) -> Option> { - let mut query = String::with_capacity(self.raw().len()); + pub(crate) fn to_normalized(self) -> Data<'static, fmt::Query> { + let mut query = String::with_capacity(self.raw().trim().len()); for (i, seg) in self.raw_segments().filter(|s| !s.is_empty()).enumerate() { if i != 0 { query.push('&'); } - let _ = write!(query, "{}", seg); + query.push_str(seg.as_str()); } - if query.is_empty() { - return None; - } - - Some(Data { + Data { value: IndexedStr::from(Cow::Owned(query)), decoded_segments: Storage::new(), - }) + } } - /// Returns an iterator over the non-empty, undecoded `(name, value)` pairs - /// of this query. If there is no query, the iterator is empty. Segments may - /// be empty. + /// Returns an iterator over the undecoded, potentially empty `(name, + /// value)` pairs of this query. If there is no query, the iterator is + /// empty. /// /// # Example /// @@ -252,18 +295,26 @@ impl<'a> Query<'a> { /// let uri = Origin::parse("/").unwrap(); /// assert!(uri.query().is_none()); /// + /// let uri = Origin::parse("/?").unwrap(); + /// let query_segs: Vec<_> = uri.query().unwrap().raw_segments().collect(); + /// assert!(query_segs.is_empty()); + /// + /// let uri = Origin::parse("/?foo").unwrap(); + /// let query_segs: Vec<_> = uri.query().unwrap().raw_segments().collect(); + /// assert_eq!(query_segs, &["foo"]); + /// /// let uri = Origin::parse("/?a=b&dog").unwrap(); /// let query_segs: Vec<_> = uri.query().unwrap().raw_segments().collect(); /// assert_eq!(query_segs, &["a=b", "dog"]); /// - /// // This is not normalized, so the query is `""`, the empty string. /// let uri = Origin::parse("/?&").unwrap(); /// let query_segs: Vec<_> = uri.query().unwrap().raw_segments().collect(); /// assert_eq!(query_segs, &["", ""]); /// - /// // Recall that `uri!()` normalizes. + /// // Recall that `uri!()` normalizes, so this is equivalent to `/?`. /// let uri = uri!("/?&"); - /// assert!(uri.query().is_none()); + /// let query_segs: Vec<_> = uri.query().unwrap().raw_segments().collect(); + /// assert!(query_segs.is_empty()); /// /// // These are raw and undecoded. Use `segments()` for decoded variant. /// let uri = Origin::parse("/foo/bar?a+b%2F=some+one%40gmail.com&&%26%3D2").unwrap(); @@ -272,7 +323,7 @@ impl<'a> Query<'a> { /// ``` #[inline] pub fn raw_segments(&self) -> impl Iterator { - let query = match self.raw() { + let query = match self.raw().trim() { q if q.is_empty() => None, q => Some(q) }; diff --git a/core/http/src/uri/reference.rs b/core/http/src/uri/reference.rs index 58eb06fd..c3b2fa2c 100644 --- a/core/http/src/uri/reference.rs +++ b/core/http/src/uri/reference.rs @@ -264,15 +264,17 @@ impl<'a> Reference<'a> { /// /// ```rust /// # #[macro_use] extern crate rocket; + /// let uri = uri!("http://rocket.rs/guide"); + /// assert!(uri.query().is_none()); + /// + /// let uri = uri!("http://rocket.rs/guide?"); + /// assert_eq!(uri.query().unwrap(), ""); + /// /// let uri = uri!("http://rocket.rs/guide?foo#bar"); /// assert_eq!(uri.query().unwrap(), "foo"); /// /// let uri = uri!("http://rocket.rs/guide?q=bar"); /// assert_eq!(uri.query().unwrap(), "q=bar"); - /// - /// // Empty query parts are normalized away by `uri!()`. - /// let uri = uri!("http://rocket.rs/guide?#bar"); - /// assert!(uri.query().is_none()); /// ``` #[inline(always)] pub fn query(&self) -> Option> { @@ -316,23 +318,23 @@ impl<'a> Reference<'a> { /// assert!(Reference::parse("http://foo.rs/foo/bar").unwrap().is_normalized()); /// assert!(Reference::parse("foo:bar#baz").unwrap().is_normalized()); /// assert!(Reference::parse("http://rocket.rs#foo").unwrap().is_normalized()); + /// assert!(Reference::parse("http://?").unwrap().is_normalized()); + /// assert!(Reference::parse("git://rocket.rs/").unwrap().is_normalized()); + /// assert!(Reference::parse("http://rocket.rs?#foo").unwrap().is_normalized()); + /// assert!(Reference::parse("http://rocket.rs#foo").unwrap().is_normalized()); /// - /// assert!(!Reference::parse("http://?").unwrap().is_normalized()); - /// assert!(!Reference::parse("git://rocket.rs/").unwrap().is_normalized()); /// assert!(!Reference::parse("http:/foo//bar").unwrap().is_normalized()); /// assert!(!Reference::parse("foo:bar?baz&&bop#c").unwrap().is_normalized()); - /// assert!(!Reference::parse("http://rocket.rs?#foo").unwrap().is_normalized()); /// /// // Recall that `uri!()` normalizes static input. - /// assert!(uri!("http://rocket.rs#foo").is_normalized()); + /// assert!(uri!("http:/foo//bar").is_normalized()); + /// assert!(uri!("foo:bar?baz&&bop#c").is_normalized()); /// assert!(uri!("http://rocket.rs///foo////bar#cat").is_normalized()); /// ``` pub fn is_normalized(&self) -> bool { let normalized_query = self.query().map_or(true, |q| q.is_normalized()); if self.authority().is_some() && !self.path().is_empty() { - self.path().is_normalized(true) - && self.path() != "/" - && normalized_query + self.path().is_normalized(true) && normalized_query } else { self.path().is_normalized(false) && normalized_query } @@ -347,8 +349,6 @@ impl<'a> Reference<'a> { /// use rocket::http::uri::Reference; /// /// let mut uri = Reference::parse("git://rocket.rs/").unwrap(); - /// assert!(!uri.is_normalized()); - /// uri.normalize(); /// assert!(uri.is_normalized()); /// /// let mut uri = Reference::parse("http:/foo//bar?baz&&#cat").unwrap(); @@ -363,18 +363,18 @@ impl<'a> Reference<'a> { /// ``` pub fn normalize(&mut self) { if self.authority().is_some() && !self.path().is_empty() { - if self.path() == "/" { - self.set_path(""); - } else if !self.path().is_normalized(true) { - self.path = self.path().to_normalized(true); + if !self.path().is_normalized(true) { + self.path = self.path().to_normalized(true, true); } } else { - self.path = self.path().to_normalized(false); + if !self.path().is_normalized(false) { + self.path = self.path().to_normalized(false, true); + } } if let Some(query) = self.query() { if !query.is_normalized() { - self.query = query.to_normalized(); + self.query = Some(query.to_normalized()); } } } @@ -387,7 +387,7 @@ impl<'a> Reference<'a> { /// use rocket::http::uri::Reference; /// /// let mut uri = Reference::parse("git://rocket.rs/").unwrap(); - /// assert!(!uri.is_normalized()); + /// assert!(uri.is_normalized()); /// assert!(uri.into_normalized().is_normalized()); /// /// let mut uri = Reference::parse("http:/foo//bar?baz&&#cat").unwrap(); @@ -403,6 +403,7 @@ impl<'a> Reference<'a> { self } + #[allow(unused)] pub(crate) fn set_path

(&mut self, path: P) where P: Into> { diff --git a/core/http/src/uri/segments.rs b/core/http/src/uri/segments.rs index f42596eb..9d62edc5 100644 --- a/core/http/src/uri/segments.rs +++ b/core/http/src/uri/segments.rs @@ -28,7 +28,7 @@ use crate::uri::error::PathError; /// _ => panic!("only four segments") /// } /// } -/// # assert_eq!(uri.path().segments().len(), 4); +/// # assert_eq!(uri.path().segments().num(), 4); /// # assert_eq!(uri.path().segments().count(), 4); /// # assert_eq!(uri.path().segments().next(), Some("a z")); /// ``` @@ -55,19 +55,19 @@ impl Segments<'_, P> { /// let uri = uri!("/foo/bar?baz&cat&car"); /// /// let mut segments = uri.path().segments(); - /// assert_eq!(segments.len(), 2); + /// assert_eq!(segments.num(), 2); /// /// segments.next(); - /// assert_eq!(segments.len(), 1); + /// assert_eq!(segments.num(), 1); /// /// segments.next(); - /// assert_eq!(segments.len(), 0); + /// assert_eq!(segments.num(), 0); /// /// segments.next(); - /// assert_eq!(segments.len(), 0); + /// assert_eq!(segments.num(), 0); /// ``` #[inline] - pub fn len(&self) -> usize { + pub fn num(&self) -> usize { let max_pos = std::cmp::min(self.pos, self.segments.len()); self.segments.len() - max_pos } @@ -89,7 +89,7 @@ impl Segments<'_, P> { /// ``` #[inline] pub fn is_empty(&self) -> bool { - self.len() == 0 + self.num() == 0 } /// Returns a new `Segments` with `n` segments skipped. @@ -101,11 +101,11 @@ impl Segments<'_, P> { /// let uri = uri!("/foo/bar/baz/cat"); /// /// let mut segments = uri.path().segments(); - /// assert_eq!(segments.len(), 4); + /// assert_eq!(segments.num(), 4); /// assert_eq!(segments.next(), Some("foo")); /// /// let mut segments = segments.skip(2); - /// assert_eq!(segments.len(), 1); + /// assert_eq!(segments.num(), 1); /// assert_eq!(segments.next(), Some("cat")); /// ``` #[inline] @@ -143,6 +143,21 @@ impl<'a> Segments<'a, Path> { /// /// ```rust /// # #[macro_use] extern crate rocket; + /// let a = uri!("/"); + /// let b = uri!("/"); + /// assert!(a.path().segments().prefix_of(b.path().segments())); + /// assert!(b.path().segments().prefix_of(a.path().segments())); + /// + /// let a = uri!("/"); + /// let b = uri!("/foo"); + /// assert!(a.path().segments().prefix_of(b.path().segments())); + /// assert!(!b.path().segments().prefix_of(a.path().segments())); + /// + /// let a = uri!("/foo"); + /// let b = uri!("/foo/"); + /// assert!(a.path().segments().prefix_of(b.path().segments())); + /// assert!(!b.path().segments().prefix_of(a.path().segments())); + /// /// let a = uri!("/foo/bar/baaaz/cat"); /// let b = uri!("/foo/bar"); /// @@ -155,11 +170,11 @@ impl<'a> Segments<'a, Path> { /// ``` #[inline] pub fn prefix_of(self, other: Segments<'_, Path>) -> bool { - if self.len() > other.len() { + if self.num() > other.num() { return false; } - self.zip(other).all(|(a, b)| a == b) + self.zip(other).all(|(a, b)| a.is_empty() || a == b) } /// Creates a `PathBuf` from `self`. The returned `PathBuf` is @@ -271,11 +286,11 @@ macro_rules! impl_iterator { } fn size_hint(&self) -> (usize, Option) { - (self.len(), Some(self.len())) + (self.num(), Some(self.num())) } fn count(self) -> usize { - self.len() + self.num() } } ) diff --git a/core/http/src/uri/uri.rs b/core/http/src/uri/uri.rs index 913bf08a..57e992ea 100644 --- a/core/http/src/uri/uri.rs +++ b/core/http/src/uri/uri.rs @@ -467,3 +467,35 @@ macro_rules! impl_base_traits { } } } + +mod tests { + #[test] + fn normalization() { + fn normalize(uri: &str) -> String { + use crate::uri::Uri; + + match Uri::parse_any(uri).unwrap() { + Uri::Origin(uri) => uri.into_normalized().to_string(), + Uri::Absolute(uri) => uri.into_normalized().to_string(), + Uri::Reference(uri) => uri.into_normalized().to_string(), + uri => uri.to_string(), + } + } + + assert_eq!(normalize("/#"), "/#"); + + assert_eq!(normalize("/"), "/"); + assert_eq!(normalize("//"), "/"); + assert_eq!(normalize("//////a/"), "/a/"); + assert_eq!(normalize("//ab"), "/ab"); + assert_eq!(normalize("//a"), "/a"); + assert_eq!(normalize("/a/b///c"), "/a/b/c"); + assert_eq!(normalize("/a/b///c/"), "/a/b/c/"); + assert_eq!(normalize("/a///b/c/d///"), "/a/b/c/d/"); + + assert_eq!(normalize("/?"), "/?"); + assert_eq!(normalize("/?foo"), "/?foo"); + assert_eq!(normalize("/a/?"), "/a/?"); + assert_eq!(normalize("/a/?foo"), "/a/?foo"); + } +} diff --git a/core/lib/fuzz/Cargo.toml b/core/lib/fuzz/Cargo.toml index 5e12bea4..14a00c61 100644 --- a/core/lib/fuzz/Cargo.toml +++ b/core/lib/fuzz/Cargo.toml @@ -1,4 +1,3 @@ - [package] name = "rocket-fuzz" version = "0.0.0" @@ -30,3 +29,9 @@ name = "uri-roundtrip" path = "targets/uri-roundtrip.rs" test = false doc = false + +[[bin]] +name = "uri-normalization" +path = "targets/uri-normalization.rs" +test = false +doc = false diff --git a/core/lib/fuzz/targets/uri-normalization.rs b/core/lib/fuzz/targets/uri-normalization.rs new file mode 100644 index 00000000..f228a2a0 --- /dev/null +++ b/core/lib/fuzz/targets/uri-normalization.rs @@ -0,0 +1,23 @@ +#![no_main] + +use rocket::http::uri::*; +use libfuzzer_sys::fuzz_target; + +fn fuzz(data: &str) { + if let Ok(uri) = Uri::parse_any(data) { + match uri { + Uri::Origin(uri) if uri.is_normalized() => { + assert_eq!(uri.clone(), uri.into_normalized()); + } + Uri::Absolute(uri) if uri.is_normalized() => { + assert_eq!(uri.clone(), uri.into_normalized()); + } + Uri::Reference(uri) if uri.is_normalized() => { + assert_eq!(uri.clone(), uri.into_normalized()); + } + _ => { /* not normalizable */ }, + } + } +} + +fuzz_target!(|data: &str| { fuzz(data) }); diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 95c86b63..a477fac4 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -107,14 +107,25 @@ pub struct Catcher { /// The name of this catcher, if one was given. pub name: Option>, - /// The mount point. - pub base: uri::Origin<'static>, - /// The HTTP status to match against if this route is not `default`. pub code: Option, /// The catcher's associated error handler. pub handler: Box, + + /// The mount point. + pub(crate) base: uri::Origin<'static>, + + /// The catcher's calculated rank. + /// + /// This is [base.segments().len() | base.chars().len()]. + pub(crate) rank: u64, +} + +fn compute_rank(base: &uri::Origin<'_>) -> u64 { + let major = u32::MAX - base.path().segments().num() as u32; + let minor = u32::MAX - base.path().as_str().chars().count() as u32; + ((major as u64) << 32) | (minor as u64) } impl Catcher { @@ -166,10 +177,36 @@ impl Catcher { name: None, base: uri::Origin::ROOT, handler: Box::new(handler), - code, + rank: compute_rank(&uri::Origin::ROOT), + code } } + /// Returns the mount point (base) of the catcher, which defaults to `/`. + /// + /// # Example + /// + /// ```rust + /// use rocket::request::Request; + /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::response::Responder; + /// use rocket::http::Status; + /// + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// let res = (status, format!("404: {}", req.uri())); + /// Box::pin(async move { res.respond_to(req) }) + /// } + /// + /// let catcher = Catcher::new(404, handle_404); + /// assert_eq!(catcher.base().path(), "/"); + /// + /// let catcher = catcher.map_base(|base| format!("/foo/bar/{}", base)).unwrap(); + /// assert_eq!(catcher.base().path(), "/foo/bar"); + /// ``` + pub fn base(&self) -> &uri::Origin<'_> { + &self.base + } + /// Maps the `base` of this catcher using `mapper`, returning a new /// `Catcher` with the returned base. /// @@ -192,13 +229,13 @@ impl Catcher { /// } /// /// let catcher = Catcher::new(404, handle_404); - /// assert_eq!(catcher.base.path(), "/"); + /// assert_eq!(catcher.base().path(), "/"); /// /// let catcher = catcher.map_base(|_| format!("/bar")).unwrap(); - /// assert_eq!(catcher.base.path(), "/bar"); + /// assert_eq!(catcher.base().path(), "/bar"); /// /// let catcher = catcher.map_base(|base| format!("/foo{}", base)).unwrap(); - /// assert_eq!(catcher.base.path(), "/foo/bar"); + /// assert_eq!(catcher.base().path(), "/foo/bar"); /// /// let catcher = catcher.map_base(|base| format!("/foo ? {}", base)); /// assert!(catcher.is_err()); @@ -209,8 +246,10 @@ impl Catcher { ) -> std::result::Result> where F: FnOnce(uri::Origin<'a>) -> String { - self.base = uri::Origin::parse_owned(mapper(self.base))?.into_normalized(); + let new_base = uri::Origin::parse_owned(mapper(self.base))?; + self.base = new_base.into_normalized_nontrailing(); self.base.clear_query(); + self.rank = compute_rank(&self.base); Ok(self) } } @@ -254,9 +293,7 @@ impl fmt::Display for Catcher { write!(f, "{}{}{} ", Paint::cyan("("), Paint::white(n), Paint::cyan(")"))?; } - if self.base.path() != "/" { - write!(f, "{} ", Paint::green(self.base.path()))?; - } + write!(f, "{} ", Paint::green(self.base.path()))?; match self.code { Some(code) => write!(f, "{}", Paint::blue(code)), @@ -271,6 +308,7 @@ impl fmt::Debug for Catcher { .field("name", &self.name) .field("base", &self.base) .field("code", &self.code) + .field("rank", &self.rank) .finish() } } diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index bf9121fe..7365ec9e 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -935,11 +935,12 @@ impl<'r> Request<'r> { } } - /// Get the `n`th path segment, 0-indexed, after the mount point for the - /// currently matched route, as a string, if it exists. Used by codegen. + /// Get the `n`th non-empty path segment, 0-indexed, after the mount point + /// for the currently matched route, as a string, if it exists. Used by + /// codegen. #[inline] pub fn routed_segment(&self, n: usize) -> Option<&str> { - self.routed_segments(0..).get(n) + self.routed_segments(0..).get(n).filter(|p| !p.is_empty()) } /// Get the segments beginning at the `n`th, 0-indexed, after the mount @@ -947,9 +948,10 @@ impl<'r> Request<'r> { #[inline] pub fn routed_segments(&self, n: RangeFrom) -> Segments<'_, Path> { let mount_segments = self.route() - .map(|r| r.uri.metadata.base_segs.len()) + .map(|r| r.uri.metadata.base_len) .unwrap_or(0); + trace!("requesting {}.. ({}..) from {}", n.start, mount_segments, self); self.uri().path().segments().skip(mount_segments + n.start) } diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index edc1e49f..99bc6642 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -579,9 +579,9 @@ fn log_items(e: &str, t: &str, items: I, base: B, origin: O) } items.sort_by_key(|i| origin(i).path().as_str().chars().count()); - items.sort_by_key(|i| origin(i).path().segments().len()); + items.sort_by_key(|i| origin(i).path().segments().count()); items.sort_by_key(|i| base(i).path().as_str().chars().count()); - items.sort_by_key(|i| base(i).path().segments().len()); + items.sort_by_key(|i| base(i).path().segments().count()); items.iter().for_each(|i| launch_meta_!("{}", i)); } @@ -794,9 +794,9 @@ impl Rocket

{ /// .register("/", catchers![just_500, some_default]); /// /// assert_eq!(rocket.catchers().count(), 3); - /// assert!(rocket.catchers().any(|c| c.code == Some(404) && c.base == "/foo")); - /// assert!(rocket.catchers().any(|c| c.code == Some(500) && c.base == "/")); - /// assert!(rocket.catchers().any(|c| c.code == None && c.base == "/")); + /// assert!(rocket.catchers().any(|c| c.code == Some(404) && c.base() == "/foo")); + /// assert!(rocket.catchers().any(|c| c.code == Some(500) && c.base() == "/")); + /// assert!(rocket.catchers().any(|c| c.code == None && c.base() == "/")); /// ``` pub fn catchers(&self) -> impl Iterator { match self.0.as_state_ref() { diff --git a/core/lib/src/route/segment.rs b/core/lib/src/route/segment.rs index 3e394410..81573e9f 100644 --- a/core/lib/src/route/segment.rs +++ b/core/lib/src/route/segment.rs @@ -1,28 +1,29 @@ -use crate::http::RawStr; - #[derive(Debug, Clone)] pub struct Segment { + /// The name of the parameter or just the static string. pub value: String, + /// This is a ``. pub dynamic: bool, - pub trailing: bool, + /// This is a ``. + pub dynamic_trail: bool, } impl Segment { - pub fn from(segment: &RawStr) -> Self { + pub fn from(segment: &crate::http::RawStr) -> Self { let mut value = segment; let mut dynamic = false; - let mut trailing = false; + let mut dynamic_trail = false; if segment.starts_with('<') && segment.ends_with('>') { dynamic = true; value = &segment[1..(segment.len() - 1)]; if value.ends_with("..") { - trailing = true; + dynamic_trail = true; value = &value[..(value.len() - 2)]; } } - Segment { value: value.to_string(), dynamic, trailing } + Segment { value: value.to_string(), dynamic, dynamic_trail } } } diff --git a/core/lib/src/route/uri.rs b/core/lib/src/route/uri.rs index d770c50a..a0c885b7 100644 --- a/core/lib/src/route/uri.rs +++ b/core/lib/src/route/uri.rs @@ -62,7 +62,7 @@ pub struct RouteUri<'a> { /// The URI _without_ the `base` mount point. pub unmounted_origin: Origin<'a>, /// The URI _with_ the base mount point. This is the canonical route URI. - pub origin: Origin<'a>, + pub uri: Origin<'a>, /// Cached metadata about this URI. pub(crate) metadata: Metadata, } @@ -79,10 +79,10 @@ pub(crate) enum Color { #[derive(Debug, Clone)] pub(crate) struct Metadata { - /// Segments in the base. - pub base_segs: Vec, - /// Segments in the path, including base. - pub path_segs: Vec, + /// Segments in the route URI, including base. + pub uri_segments: Vec, + /// Numbers of segments in `uri_segments` that belong to the base. + pub base_len: usize, /// `(name, value)` of the query segments that are static. pub static_query_fields: Vec<(String, String)>, /// The "color" of the route path. @@ -90,7 +90,7 @@ pub(crate) struct Metadata { /// The "color" of the route query, if there is query. pub query_color: Option, /// Whether the path has a `` parameter. - pub trailing_path: bool, + pub dynamic_trail: bool, } type Result> = std::result::Result; @@ -103,25 +103,36 @@ impl<'a> RouteUri<'a> { pub(crate) fn try_new(base: &str, uri: &str) -> Result> { let mut base = Origin::parse(base) .map_err(|e| e.into_owned())? - .into_normalized() + .into_normalized_nontrailing() .into_owned(); base.clear_query(); - let unmounted_origin = Origin::parse_route(uri) + let origin = Origin::parse_route(uri) .map_err(|e| e.into_owned())? .into_normalized() .into_owned(); - let origin = Origin::parse_route(&format!("{}/{}", base, unmounted_origin)) + let compiled_uri = match base.path().as_str() { + "/" => origin.to_string(), + base => match (origin.path().as_str(), origin.query()) { + ("/", None) => base.to_string(), + ("/", Some(q)) => format!("{}?{}", base, q), + _ => format!("{}{}", base, origin), + } + }; + + let uri = Origin::parse_route(&compiled_uri) .map_err(|e| e.into_owned())? .into_normalized() .into_owned(); - let source = origin.to_string().into(); - let metadata = Metadata::from(&base, &origin); + dbg!(&base, &origin, &compiled_uri, &uri); - Ok(RouteUri { source, base, unmounted_origin, origin, metadata }) + let source = uri.to_string().into(); + let metadata = Metadata::from(&base, &uri); + + Ok(RouteUri { source, base, unmounted_origin: origin, uri, metadata }) } /// Create a new `RouteUri`. @@ -167,7 +178,7 @@ impl<'a> RouteUri<'a> { /// ``` #[inline(always)] pub fn path(&self) -> &str { - self.origin.path().as_str() + self.uri.path().as_str() } /// The query part of this route URI, if there is one. @@ -184,7 +195,7 @@ impl<'a> RouteUri<'a> { /// /// // Normalization clears the empty '?'. /// let index = Route::new(Method::Get, "/foo/bar?", handler); - /// assert!(index.uri.query().is_none()); + /// assert_eq!(index.uri.query().unwrap(), ""); /// /// let index = Route::new(Method::Get, "/foo/bar?a=1", handler); /// assert_eq!(index.uri.query().unwrap(), "a=1"); @@ -194,7 +205,7 @@ impl<'a> RouteUri<'a> { /// ``` #[inline(always)] pub fn query(&self) -> Option<&str> { - self.origin.query().map(|q| q.as_str()) + self.uri.query().map(|q| q.as_str()) } /// The full URI as an `&str`. @@ -247,16 +258,13 @@ impl<'a> RouteUri<'a> { } impl Metadata { - fn from(base: &Origin<'_>, origin: &Origin<'_>) -> Self { - let base_segs = base.path().raw_segments() + fn from(base: &Origin<'_>, uri: &Origin<'_>) -> Self { + let uri_segments = uri.path() + .raw_segments() .map(Segment::from) .collect::>(); - let path_segs = origin.path().raw_segments() - .map(Segment::from) - .collect::>(); - - let query_segs = origin.query() + let query_segs = uri.query() .map(|q| q.raw_segments().map(Segment::from).collect::>()) .unwrap_or_default(); @@ -265,8 +273,8 @@ impl Metadata { .map(|f| (f.name.source().to_string(), f.value.to_string())) .collect(); - let static_path = path_segs.iter().all(|s| !s.dynamic); - let wild_path = !path_segs.is_empty() && path_segs.iter().all(|s| s.dynamic); + let static_path = uri_segments.iter().all(|s| !s.dynamic); + let wild_path = !uri_segments.is_empty() && uri_segments.iter().all(|s| s.dynamic); let path_color = match (static_path, wild_path) { (true, _) => Color::Static, (_, true) => Color::Wild, @@ -283,11 +291,13 @@ impl Metadata { } }); - let trailing_path = path_segs.last().map_or(false, |p| p.trailing); + let dynamic_trail = uri_segments.last().map_or(false, |p| p.dynamic_trail); + let segments = base.path().segments(); + let num_empty = segments.clone().filter(|s| s.is_empty()).count(); + let base_len = segments.num() - num_empty; Metadata { - base_segs, path_segs, static_query_fields, path_color, query_color, - trailing_path, + uri_segments, base_len, static_query_fields, path_color, query_color, dynamic_trail } } } @@ -296,13 +306,13 @@ impl<'a> std::ops::Deref for RouteUri<'a> { type Target = Origin<'a>; fn deref(&self) -> &Self::Target { - &self.origin + &self.uri } } impl fmt::Display for RouteUri<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.origin.fmt(f) + self.uri.fmt(f) } } @@ -311,14 +321,14 @@ impl fmt::Debug for RouteUri<'_> { f.debug_struct("RouteUri") .field("base", &self.base) .field("unmounted_origin", &self.unmounted_origin) - .field("origin", &self.origin) + .field("origin", &self.uri) .field("metadata", &self.metadata) .finish() } } impl<'a, 'b> PartialEq> for RouteUri<'a> { - fn eq(&self, other: &Origin<'b>) -> bool { &self.origin == other } + fn eq(&self, other: &Origin<'b>) -> bool { &self.uri == other } } impl PartialEq for RouteUri<'_> { diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index ade9e1e6..757a624d 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -1,16 +1,72 @@ use crate::catcher::Catcher; -use crate::route::{Route, Color}; +use crate::route::{Route, Segment, RouteUri}; -use crate::http::{MediaType, Status}; -use crate::request::Request; +use crate::http::MediaType; pub trait Collide { fn collides_with(&self, other: &T) -> bool; } -impl<'a, 'b, T: Collide> Collide<&T> for &T { - fn collides_with(&self, other: &&T) -> bool { - T::collides_with(*self, *other) +impl Collide for Route { + /// Determines if two routes can match against some request. That is, if two + /// routes `collide`, there exists a request that can match against both + /// routes. + /// + /// This implementation is used at initialization to check if two user + /// routes collide before launching. Format collisions works like this: + /// + /// * If route specifies a format, it only gets requests for that format. + /// * If route doesn't specify a format, it gets requests for any format. + /// + /// Because query parsing is lenient, and dynamic query parameters can be + /// missing, queries do not impact whether two routes collide. + fn collides_with(&self, other: &Route) -> bool { + self.method == other.method + && self.rank == other.rank + && self.uri.collides_with(&other.uri) + && formats_collide(self, other) + } +} + +impl Collide for Catcher { + /// Determines if two catchers are in conflict: there exists a request for + /// which there exist no rule to determine _which_ of the two catchers to + /// use. This means that the catchers: + /// + /// * Have the same base. + /// * Have the same status code or are both defaults. + fn collides_with(&self, other: &Self) -> bool { + self.code == other.code + && self.base.path().segments().eq(other.base.path().segments()) + } +} + +impl Collide for RouteUri<'_> { + fn collides_with(&self, other: &Self) -> bool { + let a_segments = &self.metadata.uri_segments; + let b_segments = &other.metadata.uri_segments; + for (seg_a, seg_b) in a_segments.iter().zip(b_segments.iter()) { + if seg_a.dynamic_trail || seg_b.dynamic_trail { + return true; + } + + if !seg_a.collides_with(seg_b) { + return false; + } + } + + // Check for `/a/` vs. `/a`, which should collide. + a_segments.get(b_segments.len()).map_or(false, |s| s.dynamic_trail) + || b_segments.get(a_segments.len()).map_or(false, |s| s.dynamic_trail) + || a_segments.len() == b_segments.len() + } +} + +impl Collide for Segment { + fn collides_with(&self, other: &Self) -> bool { + self.dynamic && !other.value.is_empty() + || other.dynamic && !self.value.is_empty() + || self.value == other.value } } @@ -21,28 +77,6 @@ impl Collide for MediaType { } } -fn paths_collide(route: &Route, other: &Route) -> bool { - let a_segments = &route.uri.metadata.path_segs; - let b_segments = &other.uri.metadata.path_segs; - for (seg_a, seg_b) in a_segments.iter().zip(b_segments.iter()) { - if seg_a.trailing || seg_b.trailing { - return true; - } - - if seg_a.dynamic || seg_b.dynamic { - continue; - } - - if seg_a.value != seg_b.value { - return false; - } - } - - a_segments.get(b_segments.len()).map_or(false, |s| s.trailing) - || b_segments.get(a_segments.len()).map_or(false, |s| s.trailing) - || a_segments.len() == b_segments.len() -} - fn formats_collide(route: &Route, other: &Route) -> bool { // When matching against the `Accept` header, the client can always provide // a media type that will cause a collision through non-specificity, i.e, @@ -61,319 +95,172 @@ fn formats_collide(route: &Route, other: &Route) -> bool { } } -impl Collide for Route { - /// Determines if two routes can match against some request. That is, if two - /// routes `collide`, there exists a request that can match against both - /// routes. - /// - /// This implementation is used at initialization to check if two user - /// routes collide before launching. Format collisions works like this: - /// - /// * If route specifies a format, it only gets requests for that format. - /// * If route doesn't specify a format, it gets requests for any format. - /// - /// Because query parsing is lenient, and dynamic query parameters can be - /// missing, queries do not impact whether two routes collide. - fn collides_with(&self, other: &Route) -> bool { - self.method == other.method - && self.rank == other.rank - && paths_collide(self, other) - && formats_collide(self, other) - } -} - -impl Route { - /// Determines if this route matches against the given request. - /// - /// This means that: - /// - /// * The route's method matches that of the incoming request. - /// * The route's format (if any) matches that of the incoming request. - /// - If route specifies format, it only gets requests for that format. - /// - If route doesn't specify format, it gets requests for any format. - /// * All static components in the route's path match the corresponding - /// components in the same position in the incoming request. - /// * All static components in the route's query string are also in the - /// request query string, though in any position. If there is no query - /// in the route, requests with/without queries match. - pub(crate) fn matches(&self, req: &Request<'_>) -> bool { - self.method == req.method() - && paths_match(self, req) - && queries_match(self, req) - && formats_match(self, req) - } -} - -fn paths_match(route: &Route, req: &Request<'_>) -> bool { - let route_segments = &route.uri.metadata.path_segs; - let req_segments = req.uri().path().segments(); - - if route.uri.metadata.trailing_path { - // The last route segment can be trailing, which is allowed to be empty. - // So we can have one more segment in `route` than in `req` and match. - // ok if: req_segments.len() >= route_segments.len() - 1 - if req_segments.len() + 1 < route_segments.len() { - return false; - } - } else if route_segments.len() != req_segments.len() { - return false; - } - - if route.uri.metadata.path_color == Color::Wild { - return true; - } - - for (route_seg, req_seg) in route_segments.iter().zip(req_segments) { - if route_seg.trailing { - return true; - } - - if !(route_seg.dynamic || route_seg.value == req_seg) { - return false; - } - } - - true -} - -fn queries_match(route: &Route, req: &Request<'_>) -> bool { - if matches!(route.uri.metadata.query_color, None | Some(Color::Wild)) { - return true; - } - - let route_query_fields = route.uri.metadata.static_query_fields.iter() - .map(|(k, v)| (k.as_str(), v.as_str())); - - for route_seg in route_query_fields { - if let Some(query) = req.uri().query() { - if !query.segments().any(|req_seg| req_seg == route_seg) { - trace_!("request {} missing static query {:?}", req, route_seg); - return false; - } - } else { - trace_!("query-less request {} missing static query {:?}", req, route_seg); - return false; - } - } - - true -} - -fn formats_match(route: &Route, request: &Request<'_>) -> bool { - if !route.method.supports_payload() { - route.format.as_ref() - .and_then(|a| request.format().map(|b| (a, b))) - .map(|(a, b)| a.collides_with(b)) - .unwrap_or(true) - } else { - match route.format.as_ref() { - Some(a) => match request.format() { - Some(b) if b.specificity() == 2 => a.collides_with(b), - _ => false - } - None => true - } - } -} - - -impl Collide for Catcher { - /// Determines if two catchers are in conflict: there exists a request for - /// which there exist no rule to determine _which_ of the two catchers to - /// use. This means that the catchers: - /// - /// * Have the same base. - /// * Have the same status code or are both defaults. - fn collides_with(&self, other: &Self) -> bool { - self.code == other.code - && self.base.path().segments().eq(other.base.path().segments()) - } -} - -impl Catcher { - /// Determines if this catcher is responsible for handling the error with - /// `status` that occurred during request `req`. A catcher matches if: - /// - /// * It is a default catcher _or_ has a code of `status`. - /// * Its base is a prefix of the normalized/decoded `req.path()`. - pub(crate) fn matches(&self, status: Status, req: &Request<'_>) -> bool { - self.code.map_or(true, |code| code == status.code) - && self.base.path().segments().prefix_of(req.uri().path().segments()) - } -} - #[cfg(test)] mod tests { use std::str::FromStr; use super::*; use crate::route::{Route, dummy_handler}; - use crate::local::blocking::Client; - use crate::http::{Method, Method::*, MediaType, ContentType, Accept}; - use crate::http::uri::Origin; + use crate::http::{Method, Method::*, MediaType}; - type SimpleRoute = (Method, &'static str); - - fn m_collide(a: SimpleRoute, b: SimpleRoute) -> bool { - let route_a = Route::new(a.0, a.1, dummy_handler); - route_a.collides_with(&Route::new(b.0, b.1, dummy_handler)) + fn dummy_route(ranked: bool, method: impl Into>, uri: &'static str) -> Route { + let method = method.into().unwrap_or(Get); + Route::ranked((!ranked).then(|| 0), method, uri, dummy_handler) } - fn unranked_collide(a: &'static str, b: &'static str) -> bool { - let route_a = Route::ranked(0, Get, a, dummy_handler); - let route_b = Route::ranked(0, Get, b, dummy_handler); - eprintln!("Checking {} against {}.", route_a, route_b); - route_a.collides_with(&route_b) + macro_rules! assert_collision { + ($ranked:expr, $p1:expr, $p2:expr) => (assert_collision!($ranked, None $p1, None $p2)); + ($ranked:expr, $m1:ident $p1:expr, $m2:ident $p2:expr) => { + let (a, b) = (dummy_route($ranked, $m1, $p1), dummy_route($ranked, $m2, $p2)); + assert! { + a.collides_with(&b), + "\nroutes failed to collide:\n{} does not collide with {}\n", a, b + } + }; + (ranked $($t:tt)+) => (assert_collision!(true, $($t)+)); + ($($t:tt)+) => (assert_collision!(false, $($t)+)); } - fn s_s_collide(a: &'static str, b: &'static str) -> bool { - let a = Route::new(Get, a, dummy_handler); - let b = Route::new(Get, b, dummy_handler); - paths_collide(&a, &b) - } - - #[test] - fn simple_collisions() { - assert!(unranked_collide("/a", "/a")); - assert!(unranked_collide("/hello", "/hello")); - assert!(unranked_collide("/hello", "/hello/")); - assert!(unranked_collide("/hello/there/how/ar", "/hello/there/how/ar")); - assert!(unranked_collide("/hello/there", "/hello/there/")); - } - - #[test] - fn simple_param_collisions() { - assert!(unranked_collide("/", "/")); - assert!(unranked_collide("/", "/b")); - assert!(unranked_collide("/hello/", "/hello/")); - assert!(unranked_collide("/hello//hi", "/hello//hi")); - assert!(unranked_collide("/hello//hi/there", "/hello//hi/there")); - assert!(unranked_collide("//hi/there", "//hi/there")); - assert!(unranked_collide("//hi/there", "/dude//there")); - assert!(unranked_collide("///", "///")); - assert!(unranked_collide("////", "////")); - assert!(unranked_collide("/", "/hi")); - assert!(unranked_collide("/", "/hi/hey")); - assert!(unranked_collide("/", "/hi/hey/hayo")); - assert!(unranked_collide("/a/", "/a/hi/hey/hayo")); - assert!(unranked_collide("/a//", "/a/hi/hey/hayo")); - assert!(unranked_collide("/a///", "/a/hi/hey/hayo")); - assert!(unranked_collide("///", "/a/hi/hey/hayo")); - assert!(unranked_collide("///hey/hayo", "/a/hi/hey/hayo")); - assert!(unranked_collide("/", "/foo")); - } - - #[test] - fn medium_param_collisions() { - assert!(unranked_collide("/", "/b")); - assert!(unranked_collide("/hello/", "/hello/bob")); - assert!(unranked_collide("/", "//bob")); - } - - #[test] - fn hard_param_collisions() { - assert!(unranked_collide("/", "///a///")); - assert!(unranked_collide("/", "//a/bcjdklfj//")); - assert!(unranked_collide("/a/", "//a/bcjdklfj//")); - assert!(unranked_collide("/a//", "//a/bcjdklfj//")); - assert!(unranked_collide("/", "/")); - assert!(unranked_collide("/", "/<_..>")); - assert!(unranked_collide("/a/b/", "/a/")); - assert!(unranked_collide("/a/b/", "/a//")); - assert!(unranked_collide("/hi/", "/hi")); - assert!(unranked_collide("/hi/", "/hi/")); - assert!(unranked_collide("/", "//////")); - } - - #[test] - fn query_collisions() { - assert!(unranked_collide("/?", "/?")); - assert!(unranked_collide("/a/?", "/a/?")); - assert!(unranked_collide("/a?", "/a?")); - assert!(unranked_collide("/?", "/?")); - assert!(unranked_collide("/a/b/c?", "/a/b/c?")); - assert!(unranked_collide("//b/c?", "/a/b/?")); - assert!(unranked_collide("/?", "/")); - assert!(unranked_collide("/a?", "/a")); - assert!(unranked_collide("/a?", "/a")); - assert!(unranked_collide("/a/b?", "/a/b")); - assert!(unranked_collide("/a/b", "/a/b?")); + macro_rules! assert_no_collision { + ($ranked:expr, $p1:expr, $p2:expr) => (assert_no_collision!($ranked, None $p1, None $p2)); + ($ranked:expr, $m1:ident $p1:expr, $m2:ident $p2:expr) => { + let (a, b) = (dummy_route($ranked, $m1, $p1), dummy_route($ranked, $m2, $p2)); + assert! { + !a.collides_with(&b), + "\nunexpected collision:\n{} collides with {}\n", a, b + } + }; + (ranked $($t:tt)+) => (assert_no_collision!(true, $($t)+)); + ($($t:tt)+) => (assert_no_collision!(false, $($t)+)); } #[test] fn non_collisions() { - assert!(!unranked_collide("/", "/")); - assert!(!unranked_collide("/a", "/b")); - assert!(!unranked_collide("/a/b", "/a")); - assert!(!unranked_collide("/a/b", "/a/c")); - assert!(!unranked_collide("/a/hello", "/a/c")); - assert!(!unranked_collide("/hello", "/a/c")); - assert!(!unranked_collide("/hello/there", "/hello/there/guy")); - assert!(!unranked_collide("/a/", "/b/")); - assert!(!unranked_collide("/t", "/test")); - assert!(!unranked_collide("/a", "/aa")); - assert!(!unranked_collide("/a", "/aaa")); - assert!(!unranked_collide("/", "/a")); + assert_no_collision!("/", "/"); + assert_no_collision!("/a", "/b"); + assert_no_collision!("/a/b", "/a"); + assert_no_collision!("/a/b", "/a/c"); + assert_no_collision!("/a/hello", "/a/c"); + assert_no_collision!("/hello", "/a/c"); + assert_no_collision!("/hello/there", "/hello/there/guy"); + assert_no_collision!("/a/", "/b/"); + assert_no_collision!("//b", "//a"); + assert_no_collision!("/t", "/test"); + assert_no_collision!("/a", "/aa"); + assert_no_collision!("/a", "/aaa"); + assert_no_collision!("/", "/a"); + + assert_no_collision!("/hello", "/hello/"); + assert_no_collision!("/hello/there", "/hello/there/"); + assert_no_collision!("/hello/", "/hello/"); + + assert_no_collision!("/a?", "/b"); + assert_no_collision!("/a/b", "/a?"); + assert_no_collision!("/a/b/c?", "/a/b/c/d"); + assert_no_collision!("/a/hello", "/a/?"); + assert_no_collision!("/?", "/hi"); + + assert_no_collision!(Get "/", Post "/"); + assert_no_collision!(Post "/", Put "/"); + assert_no_collision!(Put "/a", Put "/"); + assert_no_collision!(Post "/a", Put "/"); + assert_no_collision!(Get "/a", Put "/"); + assert_no_collision!(Get "/hello", Put "/hello"); + assert_no_collision!(Get "/", Post "/"); + + assert_no_collision!("/a", "/b"); + assert_no_collision!("/a/b", "/a"); + assert_no_collision!("/a/b", "/a/c"); + assert_no_collision!("/a/hello", "/a/c"); + assert_no_collision!("/hello", "/a/c"); + assert_no_collision!("/hello/there", "/hello/there/guy"); + assert_no_collision!("/a/", "/b/"); + assert_no_collision!("/a", "/b"); + assert_no_collision!("/a/b", "/a"); + assert_no_collision!("/a/b", "/a/c"); + assert_no_collision!("/a/hello", "/a/c"); + assert_no_collision!("/hello", "/a/c"); + assert_no_collision!("/hello/there", "/hello/there/guy"); + assert_no_collision!("/a/", "/b/"); + assert_no_collision!("/a", "/b"); + assert_no_collision!("/a/b", "/a"); + assert_no_collision!("/a/b", "/a/c"); + assert_no_collision!("/a/hello", "/a/c"); + assert_no_collision!("/hello", "/a/c"); + assert_no_collision!("/hello/there", "/hello/there/guy"); + assert_no_collision!("/a/", "/b/"); + assert_no_collision!("/t", "/test"); + assert_no_collision!("/a", "/aa"); + assert_no_collision!("/a", "/aaa"); + assert_no_collision!("/", "/a"); + + assert_no_collision!(ranked "/", "/?a"); + assert_no_collision!(ranked "/", "/?"); + assert_no_collision!(ranked "/a/", "/a/?d"); } #[test] - fn query_non_collisions() { - assert!(!unranked_collide("/a?", "/b")); - assert!(!unranked_collide("/a/b", "/a?")); - assert!(!unranked_collide("/a/b/c?", "/a/b/c/d")); - assert!(!unranked_collide("/a/hello", "/a/?")); - assert!(!unranked_collide("/?", "/hi")); - } + fn collisions() { + assert_collision!("/a", "/a"); + assert_collision!("/hello", "/hello"); + assert_collision!("/hello/there/how/ar", "/hello/there/how/ar"); - #[test] - fn method_dependent_non_collisions() { - assert!(!m_collide((Get, "/"), (Post, "/"))); - assert!(!m_collide((Post, "/"), (Put, "/"))); - assert!(!m_collide((Put, "/a"), (Put, "/"))); - assert!(!m_collide((Post, "/a"), (Put, "/"))); - assert!(!m_collide((Get, "/a"), (Put, "/"))); - assert!(!m_collide((Get, "/hello"), (Put, "/hello"))); - assert!(!m_collide((Get, "/"), (Post, "/"))); - } + assert_collision!("/", "/"); + assert_collision!("/", "/b"); + assert_collision!("/hello/", "/hello/"); + assert_collision!("/hello//hi", "/hello//hi"); + assert_collision!("/hello//hi/there", "/hello//hi/there"); + assert_collision!("//hi/there", "//hi/there"); + assert_collision!("//hi/there", "/dude//there"); + assert_collision!("///", "///"); + assert_collision!("////", "////"); + assert_collision!("/", "/hi"); + assert_collision!("/", "/hi/hey"); + assert_collision!("/", "/hi/hey/hayo"); + assert_collision!("/a/", "/a/hi/hey/hayo"); + assert_collision!("/a//", "/a/hi/hey/hayo"); + assert_collision!("/a///", "/a/hi/hey/hayo"); + assert_collision!("///", "/a/hi/hey/hayo"); + assert_collision!("///hey/hayo", "/a/hi/hey/hayo"); + assert_collision!("/", "/foo"); - #[test] - fn query_dependent_non_collisions() { - assert!(!m_collide((Get, "/"), (Get, "/?a"))); - assert!(!m_collide((Get, "/"), (Get, "/?"))); - assert!(!m_collide((Get, "/a/"), (Get, "/a/?d"))); - } + assert_collision!("/", "/"); + assert_collision!("/a", "/a/"); + assert_collision!("/a/", "/a/"); + assert_collision!("//", "/a/"); + assert_collision!("/", "/a/"); - #[test] - fn test_str_non_collisions() { - assert!(!s_s_collide("/a", "/b")); - assert!(!s_s_collide("/a/b", "/a")); - assert!(!s_s_collide("/a/b", "/a/c")); - assert!(!s_s_collide("/a/hello", "/a/c")); - assert!(!s_s_collide("/hello", "/a/c")); - assert!(!s_s_collide("/hello/there", "/hello/there/guy")); - assert!(!s_s_collide("/a/", "/b/")); - assert!(!s_s_collide("/a", "/b")); - assert!(!s_s_collide("/a/b", "/a")); - assert!(!s_s_collide("/a/b", "/a/c")); - assert!(!s_s_collide("/a/hello", "/a/c")); - assert!(!s_s_collide("/hello", "/a/c")); - assert!(!s_s_collide("/hello/there", "/hello/there/guy")); - assert!(!s_s_collide("/a/", "/b/")); - assert!(!s_s_collide("/a", "/b")); - assert!(!s_s_collide("/a/b", "/a")); - assert!(!s_s_collide("/a/b", "/a/c")); - assert!(!s_s_collide("/a/hello", "/a/c")); - assert!(!s_s_collide("/hello", "/a/c")); - assert!(!s_s_collide("/hello/there", "/hello/there/guy")); - assert!(!s_s_collide("/a/", "/b/")); - assert!(!s_s_collide("/t", "/test")); - assert!(!s_s_collide("/a", "/aa")); - assert!(!s_s_collide("/a", "/aaa")); - assert!(!s_s_collide("/", "/a")); + assert_collision!("/", "/b"); + assert_collision!("/hello/", "/hello/bob"); + assert_collision!("/", "//bob"); - assert!(s_s_collide("/a/hi/", "/a/hi/")); - assert!(s_s_collide("/hi/", "/hi/")); - assert!(s_s_collide("/", "/")); + assert_collision!("/", "///a///"); + assert_collision!("/", "//a/bcjdklfj//"); + assert_collision!("/a/", "//a/bcjdklfj//"); + assert_collision!("/a//", "//a/bcjdklfj//"); + assert_collision!("/", "/"); + assert_collision!("/", "/<_..>"); + assert_collision!("/a/b/", "/a/"); + assert_collision!("/a/b/", "/a//"); + assert_collision!("/hi/", "/hi"); + assert_collision!("/hi/", "/hi/"); + assert_collision!("/", "//////"); + + assert_collision!("/?", "/?"); + assert_collision!("/a/?", "/a/?"); + assert_collision!("/a?", "/a?"); + assert_collision!("/?", "/?"); + assert_collision!("/a/b/c?", "/a/b/c?"); + assert_collision!("//b/c?", "/a/b/?"); + assert_collision!("/?", "/"); + assert_collision!("/a?", "/a"); + assert_collision!("/a?", "/a"); + assert_collision!("/a/b?", "/a/b"); + assert_collision!("/a/b", "/a/b?"); + + assert_collision!("/a/hi/", "/a/hi/"); + assert_collision!("/hi/", "/hi/"); + assert_collision!("/", "/"); } fn mt_mt_collide(mt1: &str, mt2: &str) -> bool { @@ -458,119 +345,6 @@ mod tests { assert!(!r_mt_mt_collide(Post, "other/html", "text/html")); } - fn req_route_mt_collide(m: Method, mt1: S1, mt2: S2) -> bool - where S1: Into>, S2: Into> - { - let client = Client::debug_with(vec![]).expect("client"); - let mut req = client.req(m, "/"); - if let Some(mt_str) = mt1.into() { - if m.supports_payload() { - req.replace_header(mt_str.parse::().unwrap()); - } else { - req.replace_header(mt_str.parse::().unwrap()); - } - } - - let mut route = Route::new(m, "/", dummy_handler); - if let Some(mt_str) = mt2.into() { - route.format = Some(mt_str.parse::().unwrap()); - } - - route.matches(&req) - } - - #[test] - fn test_req_route_mt_collisions() { - assert!(req_route_mt_collide(Post, "application/json", "application/json")); - assert!(req_route_mt_collide(Post, "application/json", "application/*")); - assert!(req_route_mt_collide(Post, "application/json", "*/json")); - assert!(req_route_mt_collide(Post, "text/html", "*/*")); - - assert!(req_route_mt_collide(Get, "application/json", "application/json")); - assert!(req_route_mt_collide(Get, "text/html", "text/html")); - assert!(req_route_mt_collide(Get, "text/html", "*/*")); - assert!(req_route_mt_collide(Get, None, "*/*")); - assert!(req_route_mt_collide(Get, None, "text/*")); - assert!(req_route_mt_collide(Get, None, "text/html")); - assert!(req_route_mt_collide(Get, None, "application/json")); - - assert!(req_route_mt_collide(Post, "text/html", None)); - assert!(req_route_mt_collide(Post, "application/json", None)); - assert!(req_route_mt_collide(Post, "x-custom/anything", None)); - assert!(req_route_mt_collide(Post, None, None)); - - assert!(req_route_mt_collide(Get, "text/html", None)); - assert!(req_route_mt_collide(Get, "application/json", None)); - assert!(req_route_mt_collide(Get, "x-custom/anything", None)); - assert!(req_route_mt_collide(Get, None, None)); - assert!(req_route_mt_collide(Get, None, "text/html")); - assert!(req_route_mt_collide(Get, None, "application/json")); - - assert!(req_route_mt_collide(Get, "text/html, text/plain", "text/html")); - assert!(req_route_mt_collide(Get, "text/html; q=0.5, text/xml", "text/xml")); - - assert!(!req_route_mt_collide(Post, None, "text/html")); - assert!(!req_route_mt_collide(Post, None, "text/*")); - assert!(!req_route_mt_collide(Post, None, "*/text")); - assert!(!req_route_mt_collide(Post, None, "*/*")); - assert!(!req_route_mt_collide(Post, None, "text/html")); - assert!(!req_route_mt_collide(Post, None, "application/json")); - - assert!(!req_route_mt_collide(Post, "application/json", "text/html")); - assert!(!req_route_mt_collide(Post, "application/json", "text/*")); - assert!(!req_route_mt_collide(Post, "application/json", "*/xml")); - assert!(!req_route_mt_collide(Get, "application/json", "text/html")); - assert!(!req_route_mt_collide(Get, "application/json", "text/*")); - assert!(!req_route_mt_collide(Get, "application/json", "*/xml")); - - assert!(!req_route_mt_collide(Post, None, "text/html")); - assert!(!req_route_mt_collide(Post, None, "application/json")); - } - - fn req_route_path_match(a: &'static str, b: &'static str) -> bool { - let client = Client::debug_with(vec![]).expect("client"); - let req = client.get(Origin::parse(a).expect("valid URI")); - let route = Route::ranked(0, Get, b, dummy_handler); - route.matches(&req) - } - - #[test] - fn test_req_route_query_collisions() { - assert!(req_route_path_match("/a/b?a=b", "/a/b?")); - assert!(req_route_path_match("/a/b?a=b", "//b?")); - assert!(req_route_path_match("/a/b?a=b", "//?")); - assert!(req_route_path_match("/a/b?a=b", "/a/?")); - assert!(req_route_path_match("/?b=c", "/?")); - - assert!(req_route_path_match("/a/b?a=b", "/a/b")); - assert!(req_route_path_match("/a/b", "/a/b")); - assert!(req_route_path_match("/a/b/c/d?", "/a/b/c/d")); - assert!(req_route_path_match("/a/b/c/d?v=1&v=2", "/a/b/c/d")); - - assert!(req_route_path_match("/a/b", "/a/b?")); - assert!(req_route_path_match("/a/b", "/a/b?")); - assert!(req_route_path_match("/a/b?c", "/a/b?c")); - assert!(req_route_path_match("/a/b?c", "/a/b?")); - assert!(req_route_path_match("/a/b?c=foo&d=z", "/a/b?")); - assert!(req_route_path_match("/a/b?c=foo&d=z", "/a/b?")); - - assert!(req_route_path_match("/a/b?c=foo&d=z", "/a/b?c=foo&")); - assert!(req_route_path_match("/a/b?c=foo&d=z", "/a/b?d=z&")); - - assert!(!req_route_path_match("/a/b/c", "/a/b?")); - assert!(!req_route_path_match("/a?b=c", "/a/b?")); - assert!(!req_route_path_match("/?b=c", "/a/b?")); - assert!(!req_route_path_match("/?b=c", "/a?")); - - assert!(!req_route_path_match("/a/b?c=foo&d=z", "/a/b?a=b&")); - assert!(!req_route_path_match("/a/b?c=foo&d=z", "/a/b?d=b&")); - assert!(!req_route_path_match("/a/b", "/a/b?c")); - assert!(!req_route_path_match("/a/b", "/a/b?foo")); - assert!(!req_route_path_match("/a/b", "/a/b?foo&")); - assert!(!req_route_path_match("/a/b", "/a/b?&b&")); - } - - fn catchers_collide(a: A, ap: &str, b: B, bp: &str) -> bool where A: Into>, B: Into> { diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs new file mode 100644 index 00000000..496efac0 --- /dev/null +++ b/core/lib/src/router/matcher.rs @@ -0,0 +1,257 @@ +use crate::{Route, Request, Catcher}; +use crate::router::Collide; +use crate::http::Status; +use crate::route::Color; + +impl Route { + /// Determines if this route matches against the given request. + /// + /// This means that: + /// + /// * The route's method matches that of the incoming request. + /// * The route's format (if any) matches that of the incoming request. + /// - If route specifies format, it only gets requests for that format. + /// - If route doesn't specify format, it gets requests for any format. + /// * All static components in the route's path match the corresponding + /// components in the same position in the incoming request. + /// * All static components in the route's query string are also in the + /// request query string, though in any position. If there is no query + /// in the route, requests with/without queries match. + pub(crate) fn matches(&self, req: &Request<'_>) -> bool { + self.method == req.method() + && paths_match(self, req) + && queries_match(self, req) + && formats_match(self, req) + } +} + +impl Catcher { + /// Determines if this catcher is responsible for handling the error with + /// `status` that occurred during request `req`. A catcher matches if: + /// + /// * It is a default catcher _or_ has a code of `status`. + /// * Its base is a prefix of the normalized/decoded `req.path()`. + pub(crate) fn matches(&self, status: Status, req: &Request<'_>) -> bool { + dbg!(self.base.path().segments()); + dbg!(req.uri().path().segments()); + + self.code.map_or(true, |code| code == status.code) + && self.base.path().segments().prefix_of(req.uri().path().segments()) + } +} + +fn paths_match(route: &Route, req: &Request<'_>) -> bool { + trace!("checking path match: route {} vs. request {}", route, req); + let route_segments = &route.uri.metadata.uri_segments; + let req_segments = req.uri().path().segments(); + + // requests with longer paths only match if we have dynamic trail (). + if req_segments.num() > route_segments.len() { + if !route.uri.metadata.dynamic_trail { + return false; + } + } + + // The last route segment can be trailing (`/<..>`), which is allowed to be + // empty in the request. That is, we want to match `GET /a` to `/a/`. + if route_segments.len() > req_segments.num() { + if route_segments.len() != req_segments.num() + 1 { + return false; + } + + if !route.uri.metadata.dynamic_trail { + return false; + } + } + + // We've checked everything beyond the zip of their lengths already. + for (route_seg, req_seg) in route_segments.iter().zip(req_segments.clone()) { + if route_seg.dynamic_trail { + return true; + } + + if route_seg.dynamic && req_seg.is_empty() { + return false; + } + + if !route_seg.dynamic && route_seg.value != req_seg { + return false; + } + } + + true +} + +fn queries_match(route: &Route, req: &Request<'_>) -> bool { + trace!("checking query match: route {} vs. request {}", route, req); + if matches!(route.uri.metadata.query_color, None | Some(Color::Wild)) { + return true; + } + + let route_query_fields = route.uri.metadata.static_query_fields.iter() + .map(|(k, v)| (k.as_str(), v.as_str())); + + for route_seg in route_query_fields { + if let Some(query) = req.uri().query() { + if !query.segments().any(|req_seg| req_seg == route_seg) { + trace_!("request {} missing static query {:?}", req, route_seg); + return false; + } + } else { + trace_!("query-less request {} missing static query {:?}", req, route_seg); + return false; + } + } + + true +} + +fn formats_match(route: &Route, req: &Request<'_>) -> bool { + trace!("checking format match: route {} vs. request {}", route, req); + let route_format = match route.format { + Some(ref format) => format, + None => return true, + }; + + if route.method.supports_payload() { + match req.format() { + Some(f) if f.specificity() == 2 => route_format.collides_with(f), + _ => false + } + } else { + match req.format() { + Some(f) => route_format.collides_with(f), + None => true + } + } +} + +#[cfg(test)] +mod tests { + use crate::local::blocking::Client; + use crate::route::{Route, dummy_handler}; + use crate::http::{Method, Method::*, MediaType, ContentType, Accept}; + + fn req_matches_route(a: &'static str, b: &'static str) -> bool { + let client = Client::debug_with(vec![]).expect("client"); + let route = Route::ranked(0, Get, b, dummy_handler); + route.matches(&client.get(a)) + } + + #[test] + fn request_route_matching() { + assert!(req_matches_route("/a/b?a=b", "/a/b?")); + assert!(req_matches_route("/a/b?a=b", "//b?")); + assert!(req_matches_route("/a/b?a=b", "//?")); + assert!(req_matches_route("/a/b?a=b", "/a/?")); + assert!(req_matches_route("/?b=c", "/?")); + + assert!(req_matches_route("/a/b?a=b", "/a/b")); + assert!(req_matches_route("/a/b", "/a/b")); + assert!(req_matches_route("/a/b/c/d?", "/a/b/c/d")); + assert!(req_matches_route("/a/b/c/d?v=1&v=2", "/a/b/c/d")); + + assert!(req_matches_route("/a/b", "/a/b?")); + assert!(req_matches_route("/a/b", "/a/b?")); + assert!(req_matches_route("/a/b?c", "/a/b?c")); + assert!(req_matches_route("/a/b?c", "/a/b?")); + assert!(req_matches_route("/a/b?c=foo&d=z", "/a/b?")); + assert!(req_matches_route("/a/b?c=foo&d=z", "/a/b?")); + + assert!(req_matches_route("/a/b?c=foo&d=z", "/a/b?c=foo&")); + assert!(req_matches_route("/a/b?c=foo&d=z", "/a/b?d=z&")); + + assert!(req_matches_route("/a", "/a")); + assert!(req_matches_route("/a/", "/a/")); + + assert!(req_matches_route("//", "/")); + assert!(req_matches_route("/a///", "/a/")); + assert!(req_matches_route("/a/b", "/a/b")); + + assert!(!req_matches_route("/a///", "/a")); + assert!(!req_matches_route("/a", "/a/")); + assert!(!req_matches_route("/a/", "/a")); + assert!(!req_matches_route("/a/b", "/a/b/")); + + assert!(!req_matches_route("/a/b/c", "/a/b?")); + assert!(!req_matches_route("/a?b=c", "/a/b?")); + assert!(!req_matches_route("/?b=c", "/a/b?")); + assert!(!req_matches_route("/?b=c", "/a?")); + + assert!(!req_matches_route("/a/b?c=foo&d=z", "/a/b?a=b&")); + assert!(!req_matches_route("/a/b?c=foo&d=z", "/a/b?d=b&")); + assert!(!req_matches_route("/a/b", "/a/b?c")); + assert!(!req_matches_route("/a/b", "/a/b?foo")); + assert!(!req_matches_route("/a/b", "/a/b?foo&")); + assert!(!req_matches_route("/a/b", "/a/b?&b&")); + } + + fn req_matches_format(m: Method, mt1: S1, mt2: S2) -> bool + where S1: Into>, S2: Into> + { + let client = Client::debug_with(vec![]).expect("client"); + let mut req = client.req(m, "/"); + if let Some(mt_str) = mt1.into() { + if m.supports_payload() { + req.replace_header(mt_str.parse::().unwrap()); + } else { + req.replace_header(mt_str.parse::().unwrap()); + } + } + + let mut route = Route::new(m, "/", dummy_handler); + if let Some(mt_str) = mt2.into() { + route.format = Some(mt_str.parse::().unwrap()); + } + + route.matches(&req) + } + + #[test] + fn test_req_route_mt_collisions() { + assert!(req_matches_format(Post, "application/json", "application/json")); + assert!(req_matches_format(Post, "application/json", "application/*")); + assert!(req_matches_format(Post, "application/json", "*/json")); + assert!(req_matches_format(Post, "text/html", "*/*")); + + assert!(req_matches_format(Get, "application/json", "application/json")); + assert!(req_matches_format(Get, "text/html", "text/html")); + assert!(req_matches_format(Get, "text/html", "*/*")); + assert!(req_matches_format(Get, None, "*/*")); + assert!(req_matches_format(Get, None, "text/*")); + assert!(req_matches_format(Get, None, "text/html")); + assert!(req_matches_format(Get, None, "application/json")); + + assert!(req_matches_format(Post, "text/html", None)); + assert!(req_matches_format(Post, "application/json", None)); + assert!(req_matches_format(Post, "x-custom/anything", None)); + assert!(req_matches_format(Post, None, None)); + + assert!(req_matches_format(Get, "text/html", None)); + assert!(req_matches_format(Get, "application/json", None)); + assert!(req_matches_format(Get, "x-custom/anything", None)); + assert!(req_matches_format(Get, None, None)); + assert!(req_matches_format(Get, None, "text/html")); + assert!(req_matches_format(Get, None, "application/json")); + + assert!(req_matches_format(Get, "text/html, text/plain", "text/html")); + assert!(req_matches_format(Get, "text/html; q=0.5, text/xml", "text/xml")); + + assert!(!req_matches_format(Post, None, "text/html")); + assert!(!req_matches_format(Post, None, "text/*")); + assert!(!req_matches_format(Post, None, "*/text")); + assert!(!req_matches_format(Post, None, "*/*")); + assert!(!req_matches_format(Post, None, "text/html")); + assert!(!req_matches_format(Post, None, "application/json")); + + assert!(!req_matches_format(Post, "application/json", "text/html")); + assert!(!req_matches_format(Post, "application/json", "text/*")); + assert!(!req_matches_format(Post, "application/json", "*/xml")); + assert!(!req_matches_format(Get, "application/json", "text/html")); + assert!(!req_matches_format(Get, "application/json", "text/*")); + assert!(!req_matches_format(Get, "application/json", "*/xml")); + + assert!(!req_matches_format(Post, None, "text/html")); + assert!(!req_matches_format(Post, None, "application/json")); + } +} diff --git a/core/lib/src/router/mod.rs b/core/lib/src/router/mod.rs index dc1a6621..c0bbccfb 100644 --- a/core/lib/src/router/mod.rs +++ b/core/lib/src/router/mod.rs @@ -2,6 +2,7 @@ mod router; mod collider; +mod matcher; pub(crate) use router::*; pub(crate) use collider::*; diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 08408a4f..d74e6d3b 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -32,7 +32,7 @@ impl Router { pub fn add_catcher(&mut self, catcher: Catcher) { let catchers = self.catchers.entry(catcher.code).or_default(); catchers.push(catcher); - catchers.sort_by(|a, b| b.base.path().segments().len().cmp(&a.base.path().segments().len())) + catchers.sort_by_key(|c| c.rank); } #[inline] @@ -67,13 +67,8 @@ impl Router { match (explicit, default) { (None, None) => None, (None, c@Some(_)) | (c@Some(_), None) => c, - (Some(a), Some(b)) => { - if b.base.path().segments().len() > a.base.path().segments().len() { - Some(b) - } else { - Some(a) - } - } + (Some(a), Some(b)) if a.rank <= b.rank => Some(a), + (Some(_), Some(b)) => Some(b), } } @@ -194,15 +189,11 @@ mod test { #[test] fn test_collisions_normalize() { - assert!(rankless_route_collisions(&["/hello/", "/hello"])); - assert!(rankless_route_collisions(&["//hello/", "/hello"])); assert!(rankless_route_collisions(&["//hello/", "/hello//"])); - assert!(rankless_route_collisions(&["/", "/hello//"])); - assert!(rankless_route_collisions(&["/", "/hello///"])); assert!(rankless_route_collisions(&["/hello///bob", "/hello/"])); - assert!(rankless_route_collisions(&["///", "/a//"])); assert!(rankless_route_collisions(&["/a///", "/a/"])); assert!(rankless_route_collisions(&["/a///", "/a/b//c//d/"])); + assert!(rankless_route_collisions(&["///", "/a//"])); assert!(rankless_route_collisions(&["/a//", "/a/bd/e/"])); assert!(rankless_route_collisions(&["//", "/a/bd/e/"])); assert!(rankless_route_collisions(&["//", "/"])); @@ -233,6 +224,11 @@ mod test { #[test] fn test_no_collisions() { + assert!(!rankless_route_collisions(&["/a", "/a/"])); + assert!(!rankless_route_collisions(&["/", "/hello//"])); + assert!(!rankless_route_collisions(&["/", "/hello///"])); + assert!(!rankless_route_collisions(&["/hello/", "/hello"])); + assert!(!rankless_route_collisions(&["//hello/", "/hello"])); assert!(!rankless_route_collisions(&["/a/b", "/a/b/c"])); assert!(!rankless_route_collisions(&["/a/b/c/d", "/a/b/c//e"])); assert!(!rankless_route_collisions(&["/a/d/", "/a/b/c"])); @@ -309,7 +305,6 @@ mod test { let router = router_with_routes(&["//"]); assert!(route(&router, Get, "/hello/hi").is_some()); - assert!(route(&router, Get, "/a/b/").is_some()); assert!(route(&router, Get, "/i/a").is_some()); assert!(route(&router, Get, "/jdlk/asdij").is_some()); @@ -352,29 +347,33 @@ mod test { assert!(route(&router, Put, "/hello").is_none()); assert!(route(&router, Post, "/hello").is_none()); assert!(route(&router, Options, "/hello").is_none()); - assert!(route(&router, Get, "/hello/there").is_none()); - assert!(route(&router, Get, "/hello/i").is_none()); + assert!(route(&router, Get, "/").is_none()); + assert!(route(&router, Get, "/hello/").is_none()); + assert!(route(&router, Get, "/hello/there/").is_none()); + assert!(route(&router, Get, "/hello/there/").is_none()); let router = router_with_routes(&["//"]); assert!(route(&router, Get, "/a/b/c").is_none()); assert!(route(&router, Get, "/a").is_none()); assert!(route(&router, Get, "/a/").is_none()); assert!(route(&router, Get, "/a/b/c/d").is_none()); + assert!(route(&router, Get, "/a/b/").is_none()); assert!(route(&router, Put, "/hello/hi").is_none()); assert!(route(&router, Put, "/a/b").is_none()); - assert!(route(&router, Put, "/a/b").is_none()); let router = router_with_routes(&["/prefix/"]); assert!(route(&router, Get, "/").is_none()); assert!(route(&router, Get, "/prefi/").is_none()); } + /// Asserts that `$to` routes to `$want` given `$routes` are present. macro_rules! assert_ranked_match { ($routes:expr, $to:expr => $want:expr) => ({ let router = router_with_routes($routes); assert!(!router.has_collisions()); let route_path = route(&router, Get, $to).unwrap().uri.to_string(); - assert_eq!(route_path, $want.to_string()); + assert_eq!(route_path, $want.to_string(), + "\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router); }) } @@ -578,8 +577,11 @@ mod test { for (req, expected) in requests.iter().zip(expected.iter()) { let req_status = Status::from_code(req.0).expect("valid status"); let catcher = catcher(&router, req_status, req.1).expect("some catcher"); - assert_eq!(catcher.code, expected.0, "<- got, expected ->"); - assert_eq!(catcher.base.path(), expected.1, "<- got, expected ->"); + assert_eq!(catcher.code, expected.0, + "\nmatched {}, expected {:?} for req {:?}", catcher, expected, req); + + assert_eq!(catcher.base.path(), expected.1, + "\nmatched {}, expected {:?} for req {:?}", catcher, expected, req); } }) } diff --git a/core/lib/src/sentinel.rs b/core/lib/src/sentinel.rs index 39df5813..df503c31 100644 --- a/core/lib/src/sentinel.rs +++ b/core/lib/src/sentinel.rs @@ -263,7 +263,7 @@ use crate::{Rocket, Ignite}; /// return true; /// } /// -/// if !rocket.catchers().any(|c| c.code == Some(400) && c.base == "/") { +/// if !rocket.catchers().any(|c| c.code == Some(400) && c.base() == "/") { /// return true; /// } /// diff --git a/examples/hello/src/main.rs b/examples/hello/src/main.rs index 0f8c55cb..e9c1081a 100644 --- a/examples/hello/src/main.rs +++ b/examples/hello/src/main.rs @@ -74,6 +74,10 @@ fn hello(lang: Option, opt: Options<'_>) -> String { #[launch] fn rocket() -> _ { + // FIXME: Check docs corresponding to normalization/matching/colliding. + // FUZZ: If rand_req1.matches(foo) && rand_req2.matches(bar) => + // rand_req1.collides_with(rand_req2) + rocket::build() .mount("/", routes![hello]) .mount("/hello", routes![world, mir])