mirror of https://github.com/rwf2/Rocket.git
Impl 'DerefMut', 'inner_mut()' for 'LocalRequest'.
This commit is contained in:
parent
ad36b769bc
commit
c16105dc58
|
@ -2,7 +2,7 @@ use std::fmt;
|
|||
use std::convert::TryInto;
|
||||
|
||||
use crate::{Request, Data};
|
||||
use crate::http::{Status, Method, ext::IntoOwned};
|
||||
use crate::http::{Status, Method};
|
||||
use crate::http::uri::Origin;
|
||||
|
||||
use super::{Client, LocalResponse};
|
||||
|
@ -35,19 +35,22 @@ pub struct LocalRequest<'c> {
|
|||
pub(in super) client: &'c Client,
|
||||
pub(in super) request: Request<'c>,
|
||||
data: Vec<u8>,
|
||||
uri: Result<Origin<'c>, String>,
|
||||
// The `Origin` on the right is INVALID! It should _not_ be used!
|
||||
uri: Result<Origin<'c>, Origin<'static>>,
|
||||
}
|
||||
|
||||
impl<'c> LocalRequest<'c> {
|
||||
pub(crate) fn new<'u: 'c, U>(client: &'c Client, method: Method, uri: U) -> Self
|
||||
where U: TryInto<Origin<'u>> + fmt::Display
|
||||
{
|
||||
// We try to validate the URI now so that the inner `Request` contains a
|
||||
// valid URI. If it doesn't, we set a dummy one.
|
||||
let uri_string = uri.to_string();
|
||||
let uri = uri.try_into().map_err(move |_| uri_string);
|
||||
let origin = uri.clone().unwrap_or_else(|_| Origin::dummy());
|
||||
let mut request = Request::new(client.rocket(), method, origin.into_owned());
|
||||
// Try to parse `uri` into an `Origin`, storing whether it's good.
|
||||
let uri_str = uri.to_string();
|
||||
let try_origin = uri.try_into()
|
||||
.map_err(|_| Origin::new::<_, &'static str>(uri_str, None));
|
||||
|
||||
// Create a request. We'll handle bad URIs later, in `_dispatch`.
|
||||
let origin = try_origin.clone().unwrap_or_else(|bad| bad);
|
||||
let mut request = Request::new(client.rocket(), method, origin);
|
||||
|
||||
// Add any cookies we know about.
|
||||
if client.tracked {
|
||||
|
@ -58,7 +61,7 @@ impl<'c> LocalRequest<'c> {
|
|||
})
|
||||
}
|
||||
|
||||
LocalRequest { client, request, uri, data: vec![] }
|
||||
LocalRequest { client, request, uri: try_origin, data: vec![] }
|
||||
}
|
||||
|
||||
pub(crate) fn _request(&self) -> &Request<'c> {
|
||||
|
@ -77,13 +80,17 @@ impl<'c> LocalRequest<'c> {
|
|||
async fn _dispatch(mut self) -> LocalResponse<'c> {
|
||||
// First, revalidate the URI, returning an error response (generated
|
||||
// from an error catcher) immediately if it's invalid. If it's valid,
|
||||
// then `request` already contains the correct URI.
|
||||
// then `request` already contains a correct URI.
|
||||
let rocket = self.client.rocket();
|
||||
if let Err(malformed) = self.uri {
|
||||
error!("Malformed request URI: {}", malformed);
|
||||
return LocalResponse::new(self.request, move |req| {
|
||||
rocket.handle_error(Status::BadRequest, req)
|
||||
}).await
|
||||
if let Err(ref invalid) = self.uri {
|
||||
// The user may have changed the URI in the request in which case we
|
||||
// _shouldn't_ error. Check that now and error only if not.
|
||||
if self.inner().uri() == invalid {
|
||||
error!("invalid request URI: {:?}", invalid.path());
|
||||
return LocalResponse::new(self.request, move |req| {
|
||||
rocket.handle_error(Status::BadRequest, req)
|
||||
}).await
|
||||
}
|
||||
}
|
||||
|
||||
// Actually dispatch the request.
|
||||
|
@ -142,3 +149,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
|
|||
self.inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.inner_mut()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,3 +79,9 @@ impl<'c> std::ops::Deref for LocalRequest<'c> {
|
|||
self.inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> std::ops::DerefMut for LocalRequest<'c> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.inner_mut()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
macro_rules! pub_request_impl {
|
||||
($import:literal $($prefix:tt $suffix:tt)?) =>
|
||||
{
|
||||
/// Retrieves the inner `Request` as seen by Rocket.
|
||||
/// Borrows the inner `Request` as seen by Rocket.
|
||||
///
|
||||
/// Note that no routing has occurred and that there is no remote
|
||||
/// connection.
|
||||
/// address unless one has been explicitly set with
|
||||
/// [`set_remote()`](Request::set_remote()).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -21,6 +22,27 @@ macro_rules! pub_request_impl {
|
|||
self._request()
|
||||
}
|
||||
|
||||
/// Mutably borrows the inner `Request` as seen by Rocket.
|
||||
///
|
||||
/// Note that no routing has occurred and that there is no remote
|
||||
/// address unless one has been explicitly set with
|
||||
/// [`set_remote()`](Request::set_remote()).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
#[doc = $import]
|
||||
///
|
||||
/// # Client::_test(|_, request, _| {
|
||||
/// let mut request: LocalRequest = request;
|
||||
/// let inner: &mut rocket::Request = request.inner_mut();
|
||||
/// # });
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn inner_mut(&mut self) -> &mut Request<'c> {
|
||||
self._request_mut()
|
||||
}
|
||||
|
||||
/// Add a header to this request.
|
||||
///
|
||||
/// Any type that implements `Into<Header>` can be used here. Among
|
||||
|
@ -92,7 +114,7 @@ macro_rules! pub_request_impl {
|
|||
/// ```
|
||||
#[inline]
|
||||
pub fn remote(mut self, address: std::net::SocketAddr) -> Self {
|
||||
self._request_mut().set_remote(address);
|
||||
self.set_remote(address);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -246,5 +268,8 @@ macro_rules! pub_request_impl {
|
|||
|
||||
fn is_deref_req<'a, T: std::ops::Deref<Target = Request<'a>>>() {}
|
||||
is_deref_req::<Self>();
|
||||
|
||||
fn is_deref_mut_req<'a, T: std::ops::DerefMut<Target = Request<'a>>>() {}
|
||||
is_deref_mut_req::<Self>();
|
||||
}
|
||||
}}
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
use rocket::http::uri::Origin;
|
||||
use rocket::local::blocking::Client;
|
||||
|
||||
#[test]
|
||||
fn can_correct_bad_local_uri() {
|
||||
#[rocket::get("/")] fn f() {}
|
||||
|
||||
let client = Client::debug_with(rocket::routes![f]).unwrap();
|
||||
let mut req = client.get("this is a bad URI");
|
||||
req.set_uri(Origin::parse("/").unwrap());
|
||||
|
||||
assert_eq!(req.uri(), "/");
|
||||
assert!(req.dispatch().status().class().is_success());
|
||||
|
||||
let req = client.get("this is a bad URI");
|
||||
assert!(req.dispatch().status().class().is_client_error());
|
||||
}
|
Loading…
Reference in New Issue