Impl 'DerefMut', 'inner_mut()' for 'LocalRequest'.

This commit is contained in:
Sergio Benitez 2021-04-13 17:40:22 -07:00
parent ad36b769bc
commit c16105dc58
4 changed files with 79 additions and 18 deletions

View File

@ -2,7 +2,7 @@ use std::fmt;
use std::convert::TryInto;
use crate::{Request, Data};
use crate::http::{Status, Method, ext::IntoOwned};
use crate::http::{Status, Method};
use crate::http::uri::Origin;
use super::{Client, LocalResponse};
@ -35,19 +35,22 @@ pub struct LocalRequest<'c> {
pub(in super) client: &'c Client,
pub(in super) request: Request<'c>,
data: Vec<u8>,
uri: Result<Origin<'c>, String>,
// The `Origin` on the right is INVALID! It should _not_ be used!
uri: Result<Origin<'c>, Origin<'static>>,
}
impl<'c> LocalRequest<'c> {
pub(crate) fn new<'u: 'c, U>(client: &'c Client, method: Method, uri: U) -> Self
where U: TryInto<Origin<'u>> + fmt::Display
{
// We try to validate the URI now so that the inner `Request` contains a
// valid URI. If it doesn't, we set a dummy one.
let uri_string = uri.to_string();
let uri = uri.try_into().map_err(move |_| uri_string);
let origin = uri.clone().unwrap_or_else(|_| Origin::dummy());
let mut request = Request::new(client.rocket(), method, origin.into_owned());
// Try to parse `uri` into an `Origin`, storing whether it's good.
let uri_str = uri.to_string();
let try_origin = uri.try_into()
.map_err(|_| Origin::new::<_, &'static str>(uri_str, None));
// Create a request. We'll handle bad URIs later, in `_dispatch`.
let origin = try_origin.clone().unwrap_or_else(|bad| bad);
let mut request = Request::new(client.rocket(), method, origin);
// Add any cookies we know about.
if client.tracked {
@ -58,7 +61,7 @@ impl<'c> LocalRequest<'c> {
})
}
LocalRequest { client, request, uri, data: vec![] }
LocalRequest { client, request, uri: try_origin, data: vec![] }
}
pub(crate) fn _request(&self) -> &Request<'c> {
@ -77,13 +80,17 @@ impl<'c> LocalRequest<'c> {
async fn _dispatch(mut self) -> LocalResponse<'c> {
// First, revalidate the URI, returning an error response (generated
// from an error catcher) immediately if it's invalid. If it's valid,
// then `request` already contains the correct URI.
// then `request` already contains a correct URI.
let rocket = self.client.rocket();
if let Err(malformed) = self.uri {
error!("Malformed request URI: {}", malformed);
return LocalResponse::new(self.request, move |req| {
rocket.handle_error(Status::BadRequest, req)
}).await
if let Err(ref invalid) = self.uri {
// The user may have changed the URI in the request in which case we
// _shouldn't_ error. Check that now and error only if not.
if self.inner().uri() == invalid {
error!("invalid request URI: {:?}", invalid.path());
return LocalResponse::new(self.request, move |req| {
rocket.handle_error(Status::BadRequest, req)
}).await
}
}
// Actually dispatch the request.
@ -142,3 +149,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
self.inner()
}
}
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_mut()
}
}

View File

@ -79,3 +79,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
self.inner()
}
}
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_mut()
}
}

View File

@ -1,10 +1,11 @@
macro_rules! pub_request_impl {
($import:literal $($prefix:tt $suffix:tt)?) =>
{
/// Retrieves the inner `Request` as seen by Rocket.
/// Borrows the inner `Request` as seen by Rocket.
///
/// Note that no routing has occurred and that there is no remote
/// connection.
/// address unless one has been explicitly set with
/// [`set_remote()`](Request::set_remote()).
///
/// # Example
///
@ -21,6 +22,27 @@ macro_rules! pub_request_impl {
self._request()
}
/// Mutably borrows the inner `Request` as seen by Rocket.
///
/// Note that no routing has occurred and that there is no remote
/// address unless one has been explicitly set with
/// [`set_remote()`](Request::set_remote()).
///
/// # Example
///
/// ```rust
#[doc = $import]
///
/// # Client::_test(|_, request, _| {
/// let mut request: LocalRequest = request;
/// let inner: &mut rocket::Request = request.inner_mut();
/// # });
/// ```
#[inline(always)]
pub fn inner_mut(&mut self) -> &mut Request<'c> {
self._request_mut()
}
/// Add a header to this request.
///
/// Any type that implements `Into<Header>` can be used here. Among
@ -92,7 +114,7 @@ macro_rules! pub_request_impl {
/// ```
#[inline]
pub fn remote(mut self, address: std::net::SocketAddr) -> Self {
self._request_mut().set_remote(address);
self.set_remote(address);
self
}
@ -246,5 +268,8 @@ macro_rules! pub_request_impl {
fn is_deref_req<'a, T: std::ops::Deref<Target = Request<'a>>>() {}
is_deref_req::<Self>();
fn is_deref_mut_req<'a, T: std::ops::DerefMut<Target = Request<'a>>>() {}
is_deref_mut_req::<Self>();
}
}}

View File

@ -0,0 +1,17 @@
use rocket::http::uri::Origin;
use rocket::local::blocking::Client;
#[test]
fn can_correct_bad_local_uri() {
#[rocket::get("/")] fn f() {}
let client = Client::debug_with(rocket::routes![f]).unwrap();
let mut req = client.get("this is a bad URI");
req.set_uri(Origin::parse("/").unwrap());
assert_eq!(req.uri(), "/");
assert!(req.dispatch().status().class().is_success());
let req = client.get("this is a bad URI");
assert!(req.dispatch().status().class().is_client_error());
}