From c16105dc58fa5feedba359e5531fe721cd2968c6 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 13 Apr 2021 17:40:22 -0700 Subject: [PATCH] Impl 'DerefMut', 'inner_mut()' for 'LocalRequest'. --- core/lib/src/local/asynchronous/request.rs | 43 ++++++++++++++------- core/lib/src/local/blocking/request.rs | 6 +++ core/lib/src/local/request.rs | 31 +++++++++++++-- core/lib/tests/can-correct-bad-local-uri.rs | 17 ++++++++ 4 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 core/lib/tests/can-correct-bad-local-uri.rs diff --git a/core/lib/src/local/asynchronous/request.rs b/core/lib/src/local/asynchronous/request.rs index dc5c8882..e30b5c58 100644 --- a/core/lib/src/local/asynchronous/request.rs +++ b/core/lib/src/local/asynchronous/request.rs @@ -2,7 +2,7 @@ use std::fmt; use std::convert::TryInto; use crate::{Request, Data}; -use crate::http::{Status, Method, ext::IntoOwned}; +use crate::http::{Status, Method}; use crate::http::uri::Origin; use super::{Client, LocalResponse}; @@ -35,19 +35,22 @@ pub struct LocalRequest<'c> { pub(in super) client: &'c Client, pub(in super) request: Request<'c>, data: Vec, - uri: Result, String>, + // The `Origin` on the right is INVALID! It should _not_ be used! + uri: Result, Origin<'static>>, } impl<'c> LocalRequest<'c> { pub(crate) fn new<'u: 'c, U>(client: &'c Client, method: Method, uri: U) -> Self where U: TryInto> + fmt::Display { - // We try to validate the URI now so that the inner `Request` contains a - // valid URI. If it doesn't, we set a dummy one. - let uri_string = uri.to_string(); - let uri = uri.try_into().map_err(move |_| uri_string); - let origin = uri.clone().unwrap_or_else(|_| Origin::dummy()); - let mut request = Request::new(client.rocket(), method, origin.into_owned()); + // Try to parse `uri` into an `Origin`, storing whether it's good. + let uri_str = uri.to_string(); + let try_origin = uri.try_into() + .map_err(|_| Origin::new::<_, &'static str>(uri_str, None)); + + // Create a request. We'll handle bad URIs later, in `_dispatch`. + let origin = try_origin.clone().unwrap_or_else(|bad| bad); + let mut request = Request::new(client.rocket(), method, origin); // Add any cookies we know about. if client.tracked { @@ -58,7 +61,7 @@ impl<'c> LocalRequest<'c> { }) } - LocalRequest { client, request, uri, data: vec![] } + LocalRequest { client, request, uri: try_origin, data: vec![] } } pub(crate) fn _request(&self) -> &Request<'c> { @@ -77,13 +80,17 @@ impl<'c> LocalRequest<'c> { async fn _dispatch(mut self) -> LocalResponse<'c> { // First, revalidate the URI, returning an error response (generated // from an error catcher) immediately if it's invalid. If it's valid, - // then `request` already contains the correct URI. + // then `request` already contains a correct URI. let rocket = self.client.rocket(); - if let Err(malformed) = self.uri { - error!("Malformed request URI: {}", malformed); - return LocalResponse::new(self.request, move |req| { - rocket.handle_error(Status::BadRequest, req) - }).await + if let Err(ref invalid) = self.uri { + // The user may have changed the URI in the request in which case we + // _shouldn't_ error. Check that now and error only if not. + if self.inner().uri() == invalid { + error!("invalid request URI: {:?}", invalid.path()); + return LocalResponse::new(self.request, move |req| { + rocket.handle_error(Status::BadRequest, req) + }).await + } } // Actually dispatch the request. @@ -142,3 +149,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> { self.inner() } } + +impl<'c> std::ops::DerefMut for LocalRequest<'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner_mut() + } +} diff --git a/core/lib/src/local/blocking/request.rs b/core/lib/src/local/blocking/request.rs index 18abef16..004103db 100644 --- a/core/lib/src/local/blocking/request.rs +++ b/core/lib/src/local/blocking/request.rs @@ -79,3 +79,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> { self.inner() } } + +impl<'c> std::ops::DerefMut for LocalRequest<'c> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner_mut() + } +} diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 020ff42f..a6a6303b 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -1,10 +1,11 @@ macro_rules! pub_request_impl { ($import:literal $($prefix:tt $suffix:tt)?) => { - /// Retrieves the inner `Request` as seen by Rocket. + /// Borrows the inner `Request` as seen by Rocket. /// /// Note that no routing has occurred and that there is no remote - /// connection. + /// address unless one has been explicitly set with + /// [`set_remote()`](Request::set_remote()). /// /// # Example /// @@ -21,6 +22,27 @@ macro_rules! pub_request_impl { self._request() } + /// Mutably borrows the inner `Request` as seen by Rocket. + /// + /// Note that no routing has occurred and that there is no remote + /// address unless one has been explicitly set with + /// [`set_remote()`](Request::set_remote()). + /// + /// # Example + /// + /// ```rust + #[doc = $import] + /// + /// # Client::_test(|_, request, _| { + /// let mut request: LocalRequest = request; + /// let inner: &mut rocket::Request = request.inner_mut(); + /// # }); + /// ``` + #[inline(always)] + pub fn inner_mut(&mut self) -> &mut Request<'c> { + self._request_mut() + } + /// Add a header to this request. /// /// Any type that implements `Into
` can be used here. Among @@ -92,7 +114,7 @@ macro_rules! pub_request_impl { /// ``` #[inline] pub fn remote(mut self, address: std::net::SocketAddr) -> Self { - self._request_mut().set_remote(address); + self.set_remote(address); self } @@ -246,5 +268,8 @@ macro_rules! pub_request_impl { fn is_deref_req<'a, T: std::ops::Deref>>() {} is_deref_req::(); + + fn is_deref_mut_req<'a, T: std::ops::DerefMut>>() {} + is_deref_mut_req::(); } }} diff --git a/core/lib/tests/can-correct-bad-local-uri.rs b/core/lib/tests/can-correct-bad-local-uri.rs new file mode 100644 index 00000000..fcd5e2a7 --- /dev/null +++ b/core/lib/tests/can-correct-bad-local-uri.rs @@ -0,0 +1,17 @@ +use rocket::http::uri::Origin; +use rocket::local::blocking::Client; + +#[test] +fn can_correct_bad_local_uri() { + #[rocket::get("/")] fn f() {} + + let client = Client::debug_with(rocket::routes![f]).unwrap(); + let mut req = client.get("this is a bad URI"); + req.set_uri(Origin::parse("/").unwrap()); + + assert_eq!(req.uri(), "/"); + assert!(req.dispatch().status().class().is_success()); + + let req = client.get("this is a bad URI"); + assert!(req.dispatch().status().class().is_client_error()); +}