From 8749d7293a685b2f35534bfc7ba84c48c3a220e9 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 19 Mar 2021 03:41:52 -0700 Subject: [PATCH] Simplify and optimize router. This surfaced a dormant concurrency related issue. Prior to this commit, the router used `routed_segments()` to retrieve the path segments of the request. This was okay as there was no route in the request, and matched segments were retrieved eagerly. This commit makes segment matching lazy, so no matching occurs if unnecessary. Between two matches, a `route` is atomically set of `Request`. This is now visible in `routed_segments()`, which should not have considered the current route in the first place. This was fixed. --- core/lib/src/rocket.rs | 6 ++- core/lib/src/router/collider.rs | 2 +- core/lib/src/router/mod.rs | 76 +++++++++------------------------ core/lib/src/server.rs | 3 +- 4 files changed, 26 insertions(+), 61 deletions(-) diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index b8fb7904..6c8e0148 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -525,8 +525,10 @@ impl Rocket { /// * there were no fairing failures /// * a secret key, if needed, is securely configured pub(crate) async fn prelaunch_check(&mut self) -> Result<(), Error> { - if let Err(e) = self.router.collisions() { - return Err(Error::new(ErrorKind::Collision(e))); + let collisions: Vec<_> = self.router.collisions().collect(); + if !collisions.is_empty() { + let owned = collisions.into_iter().map(|(a, b)| (a.clone(), b.clone())); + return Err(Error::new(ErrorKind::Collision(owned.collect()))); } if let Some(failures) = self.fairings.failures() { diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index 088e64f5..50fe7dc7 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -71,7 +71,7 @@ fn paths_collide(route: &Route, other: &Route) -> bool { fn paths_match(route: &Route, req: &Request<'_>) -> bool { let route_segments = &route.uri.metadata.path_segs; - let req_segments = req.routed_segments(0..); + let req_segments = req.uri().path_segments(); if route_segments.len() > req_segments.len() + 1 { return false; } diff --git a/core/lib/src/router/mod.rs b/core/lib/src/router/mod.rs index d1ef2fb5..c203ac86 100644 --- a/core/lib/src/router/mod.rs +++ b/core/lib/src/router/mod.rs @@ -7,7 +7,6 @@ use std::collections::HashMap; use crate::request::Request; use crate::http::Method; -use crate::handler::dummy; pub use self::route::Route; pub use self::uri::RouteUri; @@ -34,64 +33,33 @@ impl Router { entries.insert(i, route); } - pub fn route<'b>(&'b self, req: &Request<'_>) -> Vec<&'b Route> { + pub fn route<'r, 'a: 'r>(&'a self, req: &'r Request<'r>) -> impl Iterator + 'r { // Note that routes are presorted by rank on each `add`. - let matches = self.routes.get(&req.method()).map_or(vec![], |routes| { - routes.iter() - .filter(|r| r.matches(req)) - .collect() - }); - - trace_!("Routing the request: {}", req); - trace_!("All matches: {:?}", matches); - matches + self.routes.get(&req.method()) + .into_iter() + .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) } - pub(crate) fn collisions(&mut self) -> Result<(), Vec<(Route, Route)>> { - let mut collisions = vec![]; - for routes in self.routes.values_mut() { - for i in 0..routes.len() { - let (left, right) = routes.split_at_mut(i); - for a_route in left.iter_mut() { - for b_route in right.iter_mut() { - if a_route.collides_with(b_route) { - let dummy_a = Route::new(Method::Get, "/", dummy); - let a = std::mem::replace(a_route, dummy_a); - let dummy_b = Route::new(Method::Get, "/", dummy); - let b = std::mem::replace(b_route, dummy_b); - collisions.push((a, b)); - } - } - } - } - } - - if collisions.is_empty() { - Ok(()) - } else { - Err(collisions) - } + pub(crate) fn collisions(&self) -> impl Iterator { + let all_routes = self.routes.values().flat_map(|v| v.iter()); + all_routes.clone().enumerate() + .flat_map(move |(i, a)| { + all_routes.clone() + .skip(i + 1) + .filter(move |b| b.collides_with(a)) + .map(move |b| (a, b)) + }) } #[inline] - pub fn routes<'a>(&'a self) -> impl Iterator + 'a { + pub fn routes(&self) -> impl Iterator { self.routes.values().flat_map(|v| v.iter()) } // This is slow. Don't expose this publicly; only for tests. #[cfg(test)] fn has_collisions(&self) -> bool { - for routes in self.routes.values() { - for (i, a_route) in routes.iter().enumerate() { - for b_route in routes.iter().skip(i + 1) { - if a_route.collides_with(b_route) { - return true; - } - } - } - } - - false + self.collisions().next().is_some() } } @@ -266,21 +234,17 @@ mod test { assert!(!default_rank_route_collisions(&["/hi?", "/hi?c"])); } - fn route<'a>(router: &'a Router, method: Method, uri: &str) -> Option<&'a Route> { + fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { let rocket = Rocket::custom(Config::default()); let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - let matches = router.route(&request); - if matches.len() > 0 { - Some(matches[0]) - } else { - None - } + let route = router.route(&request).next(); + route } - fn matches<'a>(router: &'a Router, method: Method, uri: &str) -> Vec<&'a Route> { + fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { let rocket = Rocket::custom(Config::default()); let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - router.route(&request) + router.route(&request).collect() } #[test] diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index e9e20f15..35060755 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -286,8 +286,7 @@ impl Rocket { mut data: Data, ) -> handler::Outcome<'r> { // Go through the list of matching routes until we fail or succeed. - let matches = self.router.route(request); - for route in matches { + for route in self.router.route(request) { // Retrieve and set the requests parameters. info_!("Matched: {}", route); request.set_route(route);