From d24b5d4d6de1460003feda5f39b339942d60e67d Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 4 May 2023 17:30:37 -0700 Subject: [PATCH] Handle more cases in 'AdHoc::normalizer()'. The compatibility normalizer previously missed or was overly egregious in several cases. This commit resolves those issue. In particular: * Only request URIs that would not match any route are normalized. * Synthetic routes are added to the igniting `Rocket` so that requests with URIs of the form `/foo` match routes with URIs of the form `/foo/`, as they did prior to the trailing slash overhaul. Tests are added for all of these cases. --- core/lib/src/fairing/ad_hoc.rs | 94 +++++++++++++++++++++++--- core/lib/src/route/route.rs | 4 ++ core/lib/tests/adhoc-uri-normalizer.rs | 79 ++++++++++++++++++++++ 3 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 core/lib/tests/adhoc-uri-normalizer.rs diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index 4a4d75ee..4da164b3 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -1,6 +1,7 @@ use futures::future::{Future, BoxFuture, FutureExt}; use parking_lot::Mutex; +use crate::route::RouteUri; use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::fairing::{Fairing, Kind, Info, Result}; @@ -309,16 +310,91 @@ impl AdHoc { /// let response = client.get("/bar").dispatch(); /// assert_eq!(response.into_string().unwrap(), "bar"); /// ``` - #[deprecated(since = "0.6", note = "routing from Rocket v0.5 is now standard")] - pub fn uri_normalizer() -> AdHoc { - AdHoc::on_request("URI Normalizer", |req, _| Box::pin(async move { - if !req.uri().is_normalized_nontrailing() { - let normal = req.uri().clone().into_normalized_nontrailing(); - warn!("Incoming request URI was normalized for compatibility."); - info_!("{} -> {}", req.uri(), normal); - req.set_uri(normal); + // #[deprecated(since = "0.6", note = "routing from Rocket v0.5 is now standard")] + pub fn uri_normalizer() -> impl Fairing { + #[derive(Default)] + struct Normalizer { + routes: state::Storage>, + } + + impl Normalizer { + fn routes(&self, rocket: &Rocket) -> &[crate::Route] { + self.routes.get_or_set(|| { + rocket.routes() + .filter(|r| r.uri.has_trailing_slash() || r.uri.metadata.dynamic_trail) + .cloned() + .collect() + }) } - })) + } + + #[crate::async_trait] + impl Fairing for Normalizer { + fn info(&self) -> Info { + Info { name: "URI Normalizer", kind: Kind::Ignite | Kind::Liftoff | Kind::Request } + } + + async fn on_ignite(&self, rocket: Rocket) -> Result { + // We want a route like `/foo/` to match a request for + // `/foo` as it would have before. While we could check if a + // route is mounted that would cause this match and then rewrite + // the request URI as `/foo/`, doing so is expensive and + // potentially incorrect due to request guards and ranking. + // + // Instead, we generate a new route with URI `/foo` with the + // same rank and handler as the `/foo/` route and mount + // it to this instance of `rocket`. This preserves the previous + // matching while still checking request guards. + let normalized_trailing = rocket.routes() + .filter(|r| r.uri.metadata.dynamic_trail) + .filter(|r| r.uri.path().segments().num() > 1) + .filter_map(|route| { + let path = route.uri.unmounted().path(); + let new_path = path.as_str() + .rsplit_once('/') + .map(|(prefix, _)| prefix) + .unwrap_or(path.as_str()); + + let base = route.uri.base().as_str(); + let uri = match route.uri.unmounted().query() { + Some(q) => format!("{}?{}", new_path, q), + None => new_path.to_string() + }; + + let mut route = route.clone(); + route.uri = RouteUri::try_new(base, &uri).expect("valid => valid"); + route.name = route.name.map(|r| format!("{} [normalized]", r).into()); + Some(route) + }) + .collect::>(); + + Ok(rocket.mount("/", normalized_trailing)) + } + + async fn on_liftoff(&self, rocket: &Rocket) { + let _ = self.routes(rocket); + } + + async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) { + // If the URI has no trailing slash, it routes as before. + if req.uri().is_normalized_nontrailing() { + return + } + + // Otherwise, check if there's a route that matches the request + // with a trailing slash. If there is, leave the request alone. + // This allows incremental compatibility updates. Otherwise, + // rewrite the request URI to remove the `/`. + if !self.routes(req.rocket()).iter().any(|r| r.matches(req)) { + let normal = req.uri().clone().into_normalized_nontrailing(); + warn!("Incoming request URI was normalized for compatibility."); + info_!("{} -> {}", req.uri(), normal); + req.set_uri(normal); + } + } + } + + Normalizer::default() } } diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index 131d2a30..44ef2790 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -278,6 +278,10 @@ impl Route { /// let rebased = rebased.rebase(uri!("/base")); /// assert_eq!(rebased.uri.base(), "/base/boo"); /// + /// // Rebasing to `/` does nothing. + /// let rebased = rebased.rebase(uri!("/")); + /// assert_eq!(rebased.uri.base(), "/base/boo"); + /// /// // Note that trailing slashes are preserved: /// let index = Route::new(Method::Get, "/foo", handler); /// let rebased = index.rebase(uri!("/boo/")); diff --git a/core/lib/tests/adhoc-uri-normalizer.rs b/core/lib/tests/adhoc-uri-normalizer.rs new file mode 100644 index 00000000..a3774f04 --- /dev/null +++ b/core/lib/tests/adhoc-uri-normalizer.rs @@ -0,0 +1,79 @@ +#[macro_use] extern crate rocket; + +use std::path::PathBuf; + +use rocket::local::blocking::Client; +use rocket::fairing::AdHoc; + +#[get("/foo")] +fn foo() -> &'static str { "foo" } + +#[get("/bar")] +fn not_bar() -> &'static str { "not_bar" } + +#[get("/bar/")] +fn bar() -> &'static str { "bar" } + +#[get("/foo/<_>/<_baz..>")] +fn baz(_baz: PathBuf) -> &'static str { "baz" } + +#[get("/doggy/<_>/<_baz..>?doggy")] +fn doggy(_baz: PathBuf) -> &'static str { "doggy" } + +#[test] +fn test_adhoc_normalizer_works_as_expected () { + let rocket = rocket::build() + .mount("/", routes![foo, bar, not_bar, baz, doggy]) + .mount("/base", routes![foo, bar, not_bar, baz, doggy]) + .attach(AdHoc::uri_normalizer()); + + let client = Client::debug(rocket).unwrap(); + + let response = client.get("/foo/").dispatch(); + assert_eq!(response.into_string().unwrap(), "foo"); + + let response = client.get("/foo").dispatch(); + assert_eq!(response.into_string().unwrap(), "foo"); + + let response = client.get("/bar/").dispatch(); + assert_eq!(response.into_string().unwrap(), "bar"); + + let response = client.get("/bar").dispatch(); + assert_eq!(response.into_string().unwrap(), "not_bar"); + + let response = client.get("/foo/bar").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); + + let response = client.get("/doggy/bar?doggy").dispatch(); + assert_eq!(response.into_string().unwrap(), "doggy"); + + let response = client.get("/foo/bar/").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); + + let response = client.get("/foo/bar/baz").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); + + let response = client.get("/base/foo/").dispatch(); + assert_eq!(response.into_string().unwrap(), "foo"); + + let response = client.get("/base/foo").dispatch(); + assert_eq!(response.into_string().unwrap(), "foo"); + + let response = client.get("/base/bar/").dispatch(); + assert_eq!(response.into_string().unwrap(), "bar"); + + let response = client.get("/base/bar").dispatch(); + assert_eq!(response.into_string().unwrap(), "not_bar"); + + let response = client.get("/base/foo/bar").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); + + let response = client.get("/doggy/foo/bar?doggy").dispatch(); + assert_eq!(response.into_string().unwrap(), "doggy"); + + let response = client.get("/base/foo/bar/").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); + + let response = client.get("/base/foo/bar/baz").dispatch(); + assert_eq!(response.into_string().unwrap(), "baz"); +}