Handle more cases in 'AdHoc::normalizer()'.

The compatibility normalizer previously missed or was overly egregious
in several cases. This commit resolves those issue. In particular:

  * Only request URIs that would not match any route are normalized.

  * Synthetic routes are added to the igniting `Rocket` so that requests
    with URIs of the form `/foo` match routes with URIs of the form
    `/foo/<b..>`, as they did prior to the trailing slash overhaul.

Tests are added for all of these cases.
This commit is contained in:
Sergio Benitez 2023-05-04 17:30:37 -07:00
parent 541952bc58
commit d24b5d4d6d
3 changed files with 168 additions and 9 deletions

View File

@ -1,6 +1,7 @@
use futures::future::{Future, BoxFuture, FutureExt};
use parking_lot::Mutex;
use crate::route::RouteUri;
use crate::{Rocket, Request, Response, Data, Build, Orbit};
use crate::fairing::{Fairing, Kind, Info, Result};
@ -309,16 +310,91 @@ impl AdHoc {
/// let response = client.get("/bar").dispatch();
/// assert_eq!(response.into_string().unwrap(), "bar");
/// ```
#[deprecated(since = "0.6", note = "routing from Rocket v0.5 is now standard")]
pub fn uri_normalizer() -> AdHoc {
AdHoc::on_request("URI Normalizer", |req, _| Box::pin(async move {
if !req.uri().is_normalized_nontrailing() {
let normal = req.uri().clone().into_normalized_nontrailing();
warn!("Incoming request URI was normalized for compatibility.");
info_!("{} -> {}", req.uri(), normal);
req.set_uri(normal);
// #[deprecated(since = "0.6", note = "routing from Rocket v0.5 is now standard")]
pub fn uri_normalizer() -> impl Fairing {
#[derive(Default)]
struct Normalizer {
routes: state::Storage<Vec<crate::Route>>,
}
impl Normalizer {
fn routes(&self, rocket: &Rocket<Orbit>) -> &[crate::Route] {
self.routes.get_or_set(|| {
rocket.routes()
.filter(|r| r.uri.has_trailing_slash() || r.uri.metadata.dynamic_trail)
.cloned()
.collect()
})
}
}))
}
#[crate::async_trait]
impl Fairing for Normalizer {
fn info(&self) -> Info {
Info { name: "URI Normalizer", kind: Kind::Ignite | Kind::Liftoff | Kind::Request }
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> Result {
// We want a route like `/foo/<bar..>` to match a request for
// `/foo` as it would have before. While we could check if a
// route is mounted that would cause this match and then rewrite
// the request URI as `/foo/`, doing so is expensive and
// potentially incorrect due to request guards and ranking.
//
// Instead, we generate a new route with URI `/foo` with the
// same rank and handler as the `/foo/<bar..>` route and mount
// it to this instance of `rocket`. This preserves the previous
// matching while still checking request guards.
let normalized_trailing = rocket.routes()
.filter(|r| r.uri.metadata.dynamic_trail)
.filter(|r| r.uri.path().segments().num() > 1)
.filter_map(|route| {
let path = route.uri.unmounted().path();
let new_path = path.as_str()
.rsplit_once('/')
.map(|(prefix, _)| prefix)
.unwrap_or(path.as_str());
let base = route.uri.base().as_str();
let uri = match route.uri.unmounted().query() {
Some(q) => format!("{}?{}", new_path, q),
None => new_path.to_string()
};
let mut route = route.clone();
route.uri = RouteUri::try_new(base, &uri).expect("valid => valid");
route.name = route.name.map(|r| format!("{} [normalized]", r).into());
Some(route)
})
.collect::<Vec<_>>();
Ok(rocket.mount("/", normalized_trailing))
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let _ = self.routes(rocket);
}
async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
// If the URI has no trailing slash, it routes as before.
if req.uri().is_normalized_nontrailing() {
return
}
// Otherwise, check if there's a route that matches the request
// with a trailing slash. If there is, leave the request alone.
// This allows incremental compatibility updates. Otherwise,
// rewrite the request URI to remove the `/`.
if !self.routes(req.rocket()).iter().any(|r| r.matches(req)) {
let normal = req.uri().clone().into_normalized_nontrailing();
warn!("Incoming request URI was normalized for compatibility.");
info_!("{} -> {}", req.uri(), normal);
req.set_uri(normal);
}
}
}
Normalizer::default()
}
}

View File

@ -278,6 +278,10 @@ impl Route {
/// let rebased = rebased.rebase(uri!("/base"));
/// assert_eq!(rebased.uri.base(), "/base/boo");
///
/// // Rebasing to `/` does nothing.
/// let rebased = rebased.rebase(uri!("/"));
/// assert_eq!(rebased.uri.base(), "/base/boo");
///
/// // Note that trailing slashes are preserved:
/// let index = Route::new(Method::Get, "/foo", handler);
/// let rebased = index.rebase(uri!("/boo/"));

View File

@ -0,0 +1,79 @@
#[macro_use] extern crate rocket;
use std::path::PathBuf;
use rocket::local::blocking::Client;
use rocket::fairing::AdHoc;
#[get("/foo")]
fn foo() -> &'static str { "foo" }
#[get("/bar")]
fn not_bar() -> &'static str { "not_bar" }
#[get("/bar/")]
fn bar() -> &'static str { "bar" }
#[get("/foo/<_>/<_baz..>")]
fn baz(_baz: PathBuf) -> &'static str { "baz" }
#[get("/doggy/<_>/<_baz..>?doggy")]
fn doggy(_baz: PathBuf) -> &'static str { "doggy" }
#[test]
fn test_adhoc_normalizer_works_as_expected () {
let rocket = rocket::build()
.mount("/", routes![foo, bar, not_bar, baz, doggy])
.mount("/base", routes![foo, bar, not_bar, baz, doggy])
.attach(AdHoc::uri_normalizer());
let client = Client::debug(rocket).unwrap();
let response = client.get("/foo/").dispatch();
assert_eq!(response.into_string().unwrap(), "foo");
let response = client.get("/foo").dispatch();
assert_eq!(response.into_string().unwrap(), "foo");
let response = client.get("/bar/").dispatch();
assert_eq!(response.into_string().unwrap(), "bar");
let response = client.get("/bar").dispatch();
assert_eq!(response.into_string().unwrap(), "not_bar");
let response = client.get("/foo/bar").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
let response = client.get("/doggy/bar?doggy").dispatch();
assert_eq!(response.into_string().unwrap(), "doggy");
let response = client.get("/foo/bar/").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
let response = client.get("/foo/bar/baz").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
let response = client.get("/base/foo/").dispatch();
assert_eq!(response.into_string().unwrap(), "foo");
let response = client.get("/base/foo").dispatch();
assert_eq!(response.into_string().unwrap(), "foo");
let response = client.get("/base/bar/").dispatch();
assert_eq!(response.into_string().unwrap(), "bar");
let response = client.get("/base/bar").dispatch();
assert_eq!(response.into_string().unwrap(), "not_bar");
let response = client.get("/base/foo/bar").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
let response = client.get("/doggy/foo/bar?doggy").dispatch();
assert_eq!(response.into_string().unwrap(), "doggy");
let response = client.get("/base/foo/bar/").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
let response = client.get("/base/foo/bar/baz").dispatch();
assert_eq!(response.into_string().unwrap(), "baz");
}