From be522783ecb3d470e3f3ba01d405adc9357ba622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesper=20Steen=20M=C3=B8ller?= Date: Thu, 5 Sep 2024 22:18:14 +0200 Subject: [PATCH] Fix #1224 by searching routes after failed match --- core/codegen/tests/route-format.rs | 9 ++-- core/codegen/tests/route.rs | 4 +- core/lib/src/lifecycle.rs | 19 +++++++ core/lib/src/router/matcher.rs | 6 +-- core/lib/src/router/router.rs | 53 +++++++++++++++++++ core/lib/tests/form_method-issue-45.rs | 2 +- .../tests/precise-content-type-matching.rs | 2 +- docs/guide/05-requests.md | 7 +++ examples/error-handling/src/tests.rs | 9 ++++ examples/responders/src/tests.rs | 3 +- examples/templating/src/tests.rs | 3 +- 11 files changed, 103 insertions(+), 14 deletions(-) diff --git a/core/codegen/tests/route-format.rs b/core/codegen/tests/route-format.rs index 221ce975..b7b94c72 100644 --- a/core/codegen/tests/route-format.rs +++ b/core/codegen/tests/route-format.rs @@ -61,7 +61,7 @@ fn test_formats() { assert_eq!(response.into_string().unwrap(), "plain"); let response = client.put("/").header(ContentType::HTML).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); } // Test custom formats. @@ -109,9 +109,12 @@ fn test_custom_formats() { let response = client.get("/").dispatch(); assert_eq!(response.into_string().unwrap(), "get_foo"); + let response = client.get("/").header(Accept::JPEG).dispatch(); + assert_eq!(response.status(), Status::NotAcceptable); // Route can't produce JPEG + let response = client.put("/").header(ContentType::HTML).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::UnsupportedMediaType); // Route expects "bar/baz" let response = client.post("/").header(ContentType::HTML).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::UnsupportedMediaType); // Route expects "foo" } diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index f9a1bf66..96855e1e 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -105,7 +105,7 @@ fn test_full_route() { assert_eq!(response.status(), Status::NotFound); let response = client.post(format!("/1{}", uri)).body(simple).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::UnsupportedMediaType); let response = client .post(format!("/1{}", uri)) @@ -117,7 +117,7 @@ fn test_full_route() { sky, name.percent_decode().unwrap(), "A A", "inside", path, simple, expected_uri)); let response = client.post(format!("/2{}", uri)).body(simple).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::UnsupportedMediaType); let response = client .post(format!("/2{}", uri)) diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index cf76e3be..434c05a6 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -112,6 +112,25 @@ impl Rocket { Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, } } + Outcome::Forward((_, status)) if status == Status::NotFound => { + // We failed to find a route which matches on path, query AND formats. + // Before we return the error to the user, we'll check for "near misses". + + // This code path primarily exists to assist clients in debugging their requests. + let mut status = status; + if self.router.matches_except_formats(request) { + // Tailor the error code to the interpretation of the request in question. + if request.method().allows_request_body() == Some(true) { + status = Status::UnsupportedMediaType; + } else if request.headers().contains("Accept") { + status = Status::NotAcceptable; + } + } else if self.router.matches_except_method(request) { + // Found a more suitable error code for paths implemented on different methods. + status = Status::MethodNotAllowed; + } + self.dispatch_error(status, request).await + } Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, Outcome::Error(status) => self.dispatch_error(status, request).await, }; diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index 5cb5b918..c12e3ba6 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -144,7 +144,7 @@ fn methods_match(route: &Route, req: &Request<'_>) -> bool { route.method.map_or(true, |method| method == req.method()) } -fn paths_match(route: &Route, req: &Request<'_>) -> bool { +pub(crate) fn paths_match(route: &Route, req: &Request<'_>) -> bool { trace!(route.uri = %route.uri, request.uri = %req.uri()); let route_segments = &route.uri.metadata.uri_segments; let req_segments = req.uri().path().segments(); @@ -174,7 +174,7 @@ fn paths_match(route: &Route, req: &Request<'_>) -> bool { true } -fn queries_match(route: &Route, req: &Request<'_>) -> bool { +pub(crate) fn queries_match(route: &Route, req: &Request<'_>) -> bool { trace!( route.query = route.uri.query().map(display), route.query.color = route.uri.metadata.query_color.map(debug), @@ -201,7 +201,7 @@ fn queries_match(route: &Route, req: &Request<'_>) -> bool { true } -fn formats_match(route: &Route, req: &Request<'_>) -> bool { +pub(crate) fn formats_match(route: &Route, req: &Request<'_>) -> bool { trace!( route.format = route.format.as_ref().map(display), request.format = req.format().map(display), diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 017486a6..d0d0ab54 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -6,6 +6,8 @@ use crate::http::{Method, Status}; use crate::{Route, Catcher}; use crate::router::Collide; +use super::matcher::{paths_match, queries_match, formats_match}; + #[derive(Debug)] pub(crate) struct Router(T); @@ -104,6 +106,38 @@ impl Router { .filter(move |r| r.matches(req)) } + pub(crate) fn matches_except_formats<'r, 'a: 'r>( + &'a self, + req: &'r Request<'r> + ) -> bool { + self.route_map.get(&req.method()) + .into_iter() + .flatten() + .any(|&route| + paths_match(&self.routes[route], req) + && + queries_match(&self.routes[route], req) + && + !formats_match(&self.routes[route], req)) + } + + const ALL_METHODS: &'static [Method] = &[ + Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options, + Method::Head, Method::Trace, Method::Connect, Method::Patch, + ]; + + pub(crate) fn matches_except_method<'r, 'a: 'r>( + &'a self, + req: &'r Request<'r> + ) -> bool { + Self::ALL_METHODS + .iter() + .filter(|method| *method != &req.method()) + .filter_map(|method| self.route_map.get(method)) + .flatten() + .any(|route| paths_match(&self.routes[*route], req)) + } + // For many catchers, using aho-corasick or similar should be much faster. #[track_caller] pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { @@ -396,6 +430,25 @@ mod test { assert!(route(&router, Get, "/prefi/").is_none()); } + fn has_mismatched_method<'a>( + router: &'a Router, + method: Method, uri: &'a str + ) -> bool { + let client = Client::debug_with(vec![]).expect("client"); + let request = client.req(method, Origin::parse(uri).unwrap()); + router.matches_except_method(&request) + } + + #[test] + fn test_bad_method_routing() { + let router = router_with_routes(&["/hello"]); + assert!(route(&router, Put, "/hello").is_none()); + assert!(has_mismatched_method(&router, Put, "/hello")); + assert!(has_mismatched_method(&router, Post, "/hello")); + + assert!(! has_mismatched_method(&router, Get, "/hello")); + } + /// Asserts that `$to` routes to `$want` given `$routes` are present. macro_rules! assert_ranked_match { ($routes:expr, $to:expr => $want:expr) => ({ diff --git a/core/lib/tests/form_method-issue-45.rs b/core/lib/tests/form_method-issue-45.rs index 4e57a48b..e4779c7c 100644 --- a/core/lib/tests/form_method-issue-45.rs +++ b/core/lib/tests/form_method-issue-45.rs @@ -70,6 +70,6 @@ mod tests { .body("_method=patch&form_data=Form+data") .dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); } } diff --git a/core/lib/tests/precise-content-type-matching.rs b/core/lib/tests/precise-content-type-matching.rs index 6848c738..e57a431b 100644 --- a/core/lib/tests/precise-content-type-matching.rs +++ b/core/lib/tests/precise-content-type-matching.rs @@ -48,7 +48,7 @@ mod tests { let body: Option<&'static str> = $body; match body { Some(string) => assert_eq!(body_str, Some(string.to_string())), - None => assert_eq!(status, Status::NotFound) + None => assert_eq!(status, Status::UnsupportedMediaType) } ) } diff --git a/docs/guide/05-requests.md b/docs/guide/05-requests.md index d3ad9792..7ef3a15e 100644 --- a/docs/guide/05-requests.md +++ b/docs/guide/05-requests.md @@ -2074,6 +2074,13 @@ fn main() { } ``` +Besides `404 Not Found` for unknown URIs, Rocket may also produce +`405 Method Not Allowed` if a request matches a URI but not a declared method +for that URI. For routes declaring formats, Rocket will produce +`406 Not Acceptable` status for a client request _accepting_ a format which +isn't declared by the matching routes, or `415 Unsupported Media Type` in case +the _payload_ of a `PUT` or `POST` is not allowed by the route. + ### Scoping The first argument to `register()` is a path to scope the catcher under called diff --git a/examples/error-handling/src/tests.rs b/examples/error-handling/src/tests.rs index fcd78424..d1dfa2e6 100644 --- a/examples/error-handling/src/tests.rs +++ b/examples/error-handling/src/tests.rs @@ -64,6 +64,15 @@ fn test_hello_invalid_age() { } } +#[test] +fn test_method_not_allowed() { + let client = Client::tracked(super::rocket()).unwrap(); + let (name, age) = ("Pat", 86); + let request = client.post(format!("/hello/{}/{}", name, age)).body("body"); + let response = request.dispatch(); + assert_eq!(response.status(), Status::MethodNotAllowed); +} + #[test] fn test_hello_sergio() { let client = Client::tracked(super::rocket()).unwrap(); diff --git a/examples/responders/src/tests.rs b/examples/responders/src/tests.rs index a084b1be..c1e76ca9 100644 --- a/examples/responders/src/tests.rs +++ b/examples/responders/src/tests.rs @@ -126,8 +126,7 @@ fn test_xml() { assert_eq!(r.into_string().unwrap(), r#"{ "payload": "I'm here" }"#); let r = client.get(uri!(super::xml)).header(Accept::CSV).dispatch(); - assert_eq!(r.status(), Status::NotFound); - assert!(r.into_string().unwrap().contains("not supported")); + assert_eq!(r.status(), Status::NotAcceptable); let r = client.get("/content/i/dont/exist").header(Accept::HTML).dispatch(); assert_eq!(r.content_type().unwrap(), ContentType::HTML); diff --git a/examples/templating/src/tests.rs b/examples/templating/src/tests.rs index 44984038..38e929d7 100644 --- a/examples/templating/src/tests.rs +++ b/examples/templating/src/tests.rs @@ -22,8 +22,7 @@ fn test_root(kind: &str) { let expected = Template::show(client.rocket(), format!("{}/error/404", kind), &context); let response = client.req(*method, format!("/{}", kind)).dispatch(); - assert_eq!(response.status(), Status::NotFound); - assert_eq!(response.into_string(), expected); + assert_eq!(response.status(), Status::MethodNotAllowed); } }