Impl 'DerefMut', 'inner_mut()' for 'LocalRequest'.

This commit is contained in:
Sergio Benitez 2021-04-13 17:40:22 -07:00
parent ad36b769bc
commit c16105dc58
4 changed files with 79 additions and 18 deletions

View File

@ -2,7 +2,7 @@ use std::fmt;
use std::convert::TryInto; use std::convert::TryInto;
use crate::{Request, Data}; use crate::{Request, Data};
use crate::http::{Status, Method, ext::IntoOwned}; use crate::http::{Status, Method};
use crate::http::uri::Origin; use crate::http::uri::Origin;
use super::{Client, LocalResponse}; use super::{Client, LocalResponse};
@ -35,19 +35,22 @@ pub struct LocalRequest<'c> {
pub(in super) client: &'c Client, pub(in super) client: &'c Client,
pub(in super) request: Request<'c>, pub(in super) request: Request<'c>,
data: Vec<u8>, data: Vec<u8>,
uri: Result<Origin<'c>, String>, // The `Origin` on the right is INVALID! It should _not_ be used!
uri: Result<Origin<'c>, Origin<'static>>,
} }
impl<'c> LocalRequest<'c> { impl<'c> LocalRequest<'c> {
pub(crate) fn new<'u: 'c, U>(client: &'c Client, method: Method, uri: U) -> Self pub(crate) fn new<'u: 'c, U>(client: &'c Client, method: Method, uri: U) -> Self
where U: TryInto<Origin<'u>> + fmt::Display where U: TryInto<Origin<'u>> + fmt::Display
{ {
// We try to validate the URI now so that the inner `Request` contains a // Try to parse `uri` into an `Origin`, storing whether it's good.
// valid URI. If it doesn't, we set a dummy one. let uri_str = uri.to_string();
let uri_string = uri.to_string(); let try_origin = uri.try_into()
let uri = uri.try_into().map_err(move |_| uri_string); .map_err(|_| Origin::new::<_, &'static str>(uri_str, None));
let origin = uri.clone().unwrap_or_else(|_| Origin::dummy());
let mut request = Request::new(client.rocket(), method, origin.into_owned()); // Create a request. We'll handle bad URIs later, in `_dispatch`.
let origin = try_origin.clone().unwrap_or_else(|bad| bad);
let mut request = Request::new(client.rocket(), method, origin);
// Add any cookies we know about. // Add any cookies we know about.
if client.tracked { if client.tracked {
@ -58,7 +61,7 @@ impl<'c> LocalRequest<'c> {
}) })
} }
LocalRequest { client, request, uri, data: vec![] } LocalRequest { client, request, uri: try_origin, data: vec![] }
} }
pub(crate) fn _request(&self) -> &Request<'c> { pub(crate) fn _request(&self) -> &Request<'c> {
@ -77,13 +80,17 @@ impl<'c> LocalRequest<'c> {
async fn _dispatch(mut self) -> LocalResponse<'c> { async fn _dispatch(mut self) -> LocalResponse<'c> {
// First, revalidate the URI, returning an error response (generated // First, revalidate the URI, returning an error response (generated
// from an error catcher) immediately if it's invalid. If it's valid, // from an error catcher) immediately if it's invalid. If it's valid,
// then `request` already contains the correct URI. // then `request` already contains a correct URI.
let rocket = self.client.rocket(); let rocket = self.client.rocket();
if let Err(malformed) = self.uri { if let Err(ref invalid) = self.uri {
error!("Malformed request URI: {}", malformed); // The user may have changed the URI in the request in which case we
return LocalResponse::new(self.request, move |req| { // _shouldn't_ error. Check that now and error only if not.
rocket.handle_error(Status::BadRequest, req) if self.inner().uri() == invalid {
}).await error!("invalid request URI: {:?}", invalid.path());
return LocalResponse::new(self.request, move |req| {
rocket.handle_error(Status::BadRequest, req)
}).await
}
} }
// Actually dispatch the request. // Actually dispatch the request.
@ -142,3 +149,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
self.inner() self.inner()
} }
} }
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_mut()
}
}

View File

@ -79,3 +79,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
self.inner() self.inner()
} }
} }
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_mut()
}
}

View File

@ -1,10 +1,11 @@
macro_rules! pub_request_impl { macro_rules! pub_request_impl {
($import:literal $($prefix:tt $suffix:tt)?) => ($import:literal $($prefix:tt $suffix:tt)?) =>
{ {
/// Retrieves the inner `Request` as seen by Rocket. /// Borrows the inner `Request` as seen by Rocket.
/// ///
/// Note that no routing has occurred and that there is no remote /// Note that no routing has occurred and that there is no remote
/// connection. /// address unless one has been explicitly set with
/// [`set_remote()`](Request::set_remote()).
/// ///
/// # Example /// # Example
/// ///
@ -21,6 +22,27 @@ macro_rules! pub_request_impl {
self._request() self._request()
} }
/// Mutably borrows the inner `Request` as seen by Rocket.
///
/// Note that no routing has occurred and that there is no remote
/// address unless one has been explicitly set with
/// [`set_remote()`](Request::set_remote()).
///
/// # Example
///
/// ```rust
#[doc = $import]
///
/// # Client::_test(|_, request, _| {
/// let mut request: LocalRequest = request;
/// let inner: &mut rocket::Request = request.inner_mut();
/// # });
/// ```
#[inline(always)]
pub fn inner_mut(&mut self) -> &mut Request<'c> {
self._request_mut()
}
/// Add a header to this request. /// Add a header to this request.
/// ///
/// Any type that implements `Into<Header>` can be used here. Among /// Any type that implements `Into<Header>` can be used here. Among
@ -92,7 +114,7 @@ macro_rules! pub_request_impl {
/// ``` /// ```
#[inline] #[inline]
pub fn remote(mut self, address: std::net::SocketAddr) -> Self { pub fn remote(mut self, address: std::net::SocketAddr) -> Self {
self._request_mut().set_remote(address); self.set_remote(address);
self self
} }
@ -246,5 +268,8 @@ macro_rules! pub_request_impl {
fn is_deref_req<'a, T: std::ops::Deref<Target = Request<'a>>>() {} fn is_deref_req<'a, T: std::ops::Deref<Target = Request<'a>>>() {}
is_deref_req::<Self>(); is_deref_req::<Self>();
fn is_deref_mut_req<'a, T: std::ops::DerefMut<Target = Request<'a>>>() {}
is_deref_mut_req::<Self>();
} }
}} }}

View File

@ -0,0 +1,17 @@
use rocket::http::uri::Origin;
use rocket::local::blocking::Client;
#[test]
fn can_correct_bad_local_uri() {
#[rocket::get("/")] fn f() {}
let client = Client::debug_with(rocket::routes![f]).unwrap();
let mut req = client.get("this is a bad URI");
req.set_uri(Origin::parse("/").unwrap());
assert_eq!(req.uri(), "/");
assert!(req.dispatch().status().class().is_success());
let req = client.get("this is a bad URI");
assert!(req.dispatch().status().class().is_client_error());
}