Use 'async_trait' for 'FromRequest'.

Removes 'FromRequestAsync'.
This commit is contained in:
Sergio Benitez 2020-01-31 01:34:15 -08:00
parent 48c333721c
commit 431b963774
22 changed files with 159 additions and 200 deletions

View File

@ -141,21 +141,22 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
}
}
impl<'a, 'r> #request::FromRequestAsync<'a, 'r> for #guard_type {
#[::rocket::async_trait]
impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type {
type Error = ();
fn from_request(request: &'a #request::Request<'r>) -> #request::FromRequestFuture<'a, Self, Self::Error> {
async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
use ::rocket::{Outcome, http::Status};
Box::pin(async move {
let pool = ::rocket::try_outcome!(request.guard::<::rocket::State<'_, #pool_type>>()).0.clone();
#spawn_blocking(move || {
match pool.get() {
Ok(conn) => Outcome::Success(#guard_type(conn)),
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
}
}).await.expect("failed to spawn a blocking task to get a pooled connection")
})
let guard = request.guard::<::rocket::State<'_, #pool_type>>();
let pool = ::rocket::try_outcome!(guard.await).0.clone();
#spawn_blocking(move || {
match pool.get() {
Ok(conn) => Outcome::Success(#guard_type(conn)),
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
}
}).await.expect("failed to spawn a blocking task to get a pooled connection")
}
}

View File

@ -173,7 +173,7 @@ impl Fairing for TemplateFairing {
#[cfg(debug_assertions)]
async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
let cm = req.guard::<rocket::State<'_, ContextManager>>().await
.expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback);

View File

@ -87,11 +87,12 @@ impl Metadata<'_> {
/// Retrieves the template metadata. If a template fairing hasn't been attached,
/// an error is printed and an empty `Err` with status `InternalServerError`
/// (`500`) is returned.
impl<'a> FromRequest<'a, '_> for Metadata<'a> {
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Metadata<'a> {
type Error = ();
fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, ()> {
request.guard::<State<'_, ContextManager>>()
async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
request.guard::<State<'_, ContextManager>>().await
.succeeded()
.and_then(|cm| Some(Outcome::Success(Metadata(cm.inner()))))
.unwrap_or_else(|| {

View File

@ -387,7 +387,7 @@ impl<'r> Responder<'r> for Template {
fn respond_to(self, req: &'r Request<'_>) -> response::ResultFuture<'r> {
Box::pin(async move {
let (render, content_type) = {
let ctxt = req.guard::<State<'_, ContextManager>>().succeeded().ok_or_else(|| {
let ctxt = req.guard::<State<'_, ContextManager>>().await.succeeded().ok_or_else(|| {
error_!("Uninitialized template context: missing fairing.");
info_!("To use templates, you must attach `Template::fairing()`.");
info_!("See the `Template` documentation for more information.");

View File

@ -313,7 +313,7 @@ fn request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
let span = ident.span().unstable().join(ty.span()).unwrap().into();
quote_spanned! { span =>
#[allow(non_snake_case, unreachable_patterns, unreachable_code)]
let #ident: #ty = match <#ty as #request::FromRequestAsync>::from_request(#req).await {
let #ident: #ty = match <#ty as #request::FromRequest>::from_request(#req).await {
#Outcome::Success(__v) => __v,
#Outcome::Forward(_) => return #Outcome::Forward(#data),
#Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),

View File

@ -84,10 +84,11 @@ mod key {
/// // In practice, we'd probably fetch the user from the database.
/// struct User(usize);
///
/// impl FromRequest<'_, '_> for User {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for User {
/// type Error = std::convert::Infallible;
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Self, Self::Error> {
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// request.cookies()
/// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok())

View File

@ -353,10 +353,11 @@ pub use self::info_kind::{Info, Kind};
/// pub struct StartTime(pub SystemTime);
///
/// // Allows a route to access the time a request was initiated.
/// impl FromRequest<'_, '_> for StartTime {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for StartTime {
/// type Error = ();
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<StartTime, ()> {
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// match *request.local_cache(|| TimerStart(None)) {
/// TimerStart(Some(time)) => Outcome::Success(StartTime(time)),
/// TimerStart(None) => Outcome::Failure((Status::InternalServerError, ())),
@ -366,7 +367,6 @@ pub use self::info_kind::{Info, Kind};
/// ```
///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait]
pub trait Fairing: Send + Sync + 'static {
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this

View File

@ -628,12 +628,13 @@ impl<S, E, F> Outcome<S, E, F> {
/// struct Guard1;
/// struct Guard2;
///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
/// type Error = ();
///
/// fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward.
/// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>());
/// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
/// atomics.uncached.fetch_add(1, Ordering::Relaxed);
/// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
///
@ -641,12 +642,13 @@ impl<S, E, F> Outcome<S, E, F> {
/// }
/// }
///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Guard2 {
/// type Error = ();
///
/// fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward.
/// let guard1: Guard1 = try_outcome!(req.guard::<Guard1>());
/// let guard1: Guard1 = try_outcome!(req.guard::<Guard1>().await);
/// Success(Guard2)
/// }
/// }

View File

@ -34,84 +34,6 @@ impl<S, E> IntoOutcome<S, (Status, E), ()> for Result<S, E> {
}
}
/// Type alias for the future returned by [`FromRequestAsync::from_request`].
pub type FromRequestFuture<'fut, T, E> = BoxFuture<'fut, Outcome<T, E>>;
/// Trait implemented by asynchronous request guards to derive a value from
/// incoming requests.
///
/// For more information on request guards in general, see the [`FromRequest`]
/// trait.
///
/// ## Example
///
/// Imagine you're running an authenticated service backed by a database. You
/// want to ensure that certain handlers will only run if a valid API key is
/// present in the request, and you need to make a database request in order to
/// determine if the key is valid or not.
///
/// ```rust
/// # #![feature(proc_macro_hygiene)]
/// # #[macro_use] extern crate rocket;
/// #
/// # struct Database;
/// # impl Database {
/// # async fn check_key(&self, key: &str) -> bool {
/// # true
/// # }
/// # }
/// #
/// use rocket::Outcome;
/// use rocket::http::Status;
/// use rocket::request::{self, Request, State, FromRequestAsync};
///
/// struct ApiKey(String);
///
/// #[derive(Debug)]
/// enum ApiKeyError {
/// BadCount,
/// Missing,
/// Invalid,
/// }
///
/// impl<'a, 'r> FromRequestAsync<'a, 'r> for ApiKey {
/// type Error = ApiKeyError;
///
/// fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> {
/// Box::pin(async move {
/// let keys: Vec<_> = request.headers().get("x-api-key").collect();
/// let database: State<'_, Database> = request.guard().expect("get database connection");
/// match keys.len() {
/// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)),
/// 1 if database.check_key(keys[0]).await => Outcome::Success(ApiKey(keys[0].to_string())),
/// 1 => Outcome::Failure((Status::BadRequest, ApiKeyError::Invalid)),
/// _ => Outcome::Failure((Status::BadRequest, ApiKeyError::BadCount)),
/// }
/// })
/// }
/// }
///
/// #[get("/sensitive")]
/// fn sensitive(key: ApiKey) -> &'static str {
/// # let _key = key;
/// "Sensitive data."
/// }
///
/// # fn main() { }
/// ```
pub trait FromRequestAsync<'a, 'r>: Sized {
/// The associated error to be returned if derivation fails.
type Error: Debug;
/// Derives an instance of `Self` from the incoming request metadata.
///
/// If the derivation is successful, an outcome of `Success` is returned. If
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
fn from_request(request: &'a Request<'r>) -> FromRequestFuture<'a, Self, Self::Error>;
}
/// Trait implemented by request guards to derive a value from incoming
/// requests.
///
@ -127,9 +49,26 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// the handler. Rocket only dispatches requests to a handler when all of its
/// guards pass.
///
/// Request guards can be made *asynchronous* by implementing
/// [`FromRequestAsync`] instead of `FromRequest`. This is useful when the
/// validation requires working with a database or performing other I/O.
/// ## Async Trait
///
/// [`FromRequest`] is an _async_ trait. Implementations of `FromRequest` must
/// be decorated with an attribute of `#[rocket::async_trait]`:
///
/// ```rust
/// use rocket::request::{self, Request, FromRequest};
/// # struct MyType;
/// # type MyError = String;
///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for MyType {
/// type Error = MyError;
///
/// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// /* .. */
/// # unimplemented!()
/// }
/// }
/// ```
///
/// ## Example
///
@ -270,11 +209,12 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// Invalid,
/// }
///
/// impl FromRequest<'_, '_> for ApiKey {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for ApiKey {
/// type Error = ApiKeyError;
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Self, Self::Error> {
/// let keys: Vec<_> = request.headers().get("x-api-key").collect();
/// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// let keys: Vec<_> = req.headers().get("x-api-key").collect();
/// match keys.len() {
/// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)),
/// 1 if is_valid(keys[0]) => Outcome::Success(ApiKey(keys[0].to_string())),
@ -316,20 +256,22 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// # Ok(User { id, is_admin: false })
/// # }
/// # }
/// # impl FromRequest<'_, '_> for Database {
/// # #[rocket::async_trait]
/// # impl<'a, 'r> FromRequest<'a, 'r> for Database {
/// # type Error = ();
/// # fn from_request(request: &Request<'_>) -> request::Outcome<Database, ()> {
/// # async fn from_request(request: &'a Request<'r>) -> request::Outcome<Database, ()> {
/// # Outcome::Success(Database)
/// # }
/// # }
/// #
/// # struct Admin { user: User }
/// #
/// impl FromRequest<'_, '_> for User {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for User {
/// type Error = ();
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<User, ()> {
/// let db = try_outcome!(request.guard::<Database>());
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<User, ()> {
/// let db = try_outcome!(request.guard::<Database>().await);
/// request.cookies()
/// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok())
@ -338,13 +280,13 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// }
/// }
///
/// impl FromRequest<'_, '_> for Admin {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Admin {
/// type Error = ();
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Admin, ()> {
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Admin, ()> {
/// // This will unconditionally query the database!
/// let user = try_outcome!(request.guard::<User>());
///
/// let user = try_outcome!(request.guard::<User>().await);
/// if user.is_admin {
/// Outcome::Success(Admin { user })
/// } else {
@ -379,39 +321,41 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// # Ok(User { id, is_admin: false })
/// # }
/// # }
/// # impl FromRequest<'_, '_> for Database {
/// # #[rocket::async_trait]
/// # impl<'a, 'r> FromRequest<'a, 'r> for Database {
/// # type Error = ();
/// # fn from_request(request: &Request<'_>) -> request::Outcome<Database, ()> {
/// # async fn from_request(request: &'a Request<'r>) -> request::Outcome<Database, ()> {
/// # Outcome::Success(Database)
/// # }
/// # }
/// #
/// # struct Admin<'a> { user: &'a User }
/// #
/// impl<'a> FromRequest<'a, '_> for &'a User {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for &'a User {
/// type Error = std::convert::Infallible;
///
/// fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, Self::Error> {
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// // This closure will execute at most once per request, regardless of
/// // the number of times the `User` guard is executed.
/// let user_result = request.local_cache(|| {
/// let db = request.guard::<Database>().succeeded()?;
/// let user_result = request.local_cache_async(async {
/// let db = request.guard::<Database>().await.succeeded()?;
/// request.cookies()
/// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok())
/// .and_then(|id| db.get_user(id).ok())
/// });
/// }).await;
///
/// user_result.as_ref().or_forward(())
/// }
/// }
///
/// impl<'a> FromRequest<'a, '_> for Admin<'a> {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Admin<'a> {
/// type Error = std::convert::Infallible;
///
/// fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, Self::Error> {
/// let user = try_outcome!(request.guard::<&User>());
///
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// let user = try_outcome!(request.guard::<&User>().await);
/// if user.is_admin {
/// Outcome::Success(Admin { user })
/// } else {
@ -427,6 +371,7 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait]
pub trait FromRequest<'a, 'r>: Sized {
/// The associated error to be returned if derivation fails.
type Error: Debug;
@ -437,37 +382,32 @@ pub trait FromRequest<'a, 'r>: Sized {
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error>;
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error>;
}
impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequestAsync<'a, 'r> for T {
type Error = T::Error;
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> {
Box::pin(async move { T::from_request(request) })
}
}
impl FromRequest<'_, '_> for Method {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Method {
type Error = std::convert::Infallible;
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.method())
}
}
impl<'a> FromRequest<'a, '_> for &'a Origin<'a> {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a Origin<'a> {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.uri())
}
}
impl<'r> FromRequest<'_, 'r> for &'r Route {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'r Route {
type Error = std::convert::Infallible;
fn from_request(request: &Request<'r>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.route() {
Some(route) => Success(route),
None => Forward(())
@ -475,18 +415,20 @@ impl<'r> FromRequest<'_, 'r> for &'r Route {
}
}
impl<'a> FromRequest<'a, '_> for Cookies<'a> {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Cookies<'a> {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.cookies())
}
}
impl<'a> FromRequest<'a, '_> for &'a Accept {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a Accept {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.accept() {
Some(accept) => Success(accept),
None => Forward(())
@ -494,10 +436,11 @@ impl<'a> FromRequest<'a, '_> for &'a Accept {
}
}
impl<'a> FromRequest<'a, '_> for &'a ContentType {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a ContentType {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.content_type() {
Some(content_type) => Success(content_type),
None => Forward(())
@ -505,10 +448,11 @@ impl<'a> FromRequest<'a, '_> for &'a ContentType {
}
}
impl FromRequest<'_, '_> for SocketAddr {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr {
type Error = std::convert::Infallible;
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.remote() {
Some(addr) => Success(addr),
None => Forward(())
@ -516,10 +460,12 @@ impl FromRequest<'_, '_> for SocketAddr {
}
}
impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Result<T, T::Error> {
impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Result<T, T::Error> {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> {
fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome<Self, Self::Error>>
where 'a: 'y, 'r: 'y
{
// TODO: FutureExt::map is a workaround (see rust-lang/rust#60658)
use futures_util::future::FutureExt;
T::from_request(request).map(|x| match x {
@ -530,10 +476,12 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Resu
}
}
impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Option<T> {
impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Option<T> {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> {
fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome<Self, Self::Error>>
where 'a: 'y, 'r: 'y
{
// TODO: FutureExt::map is a workaround (see rust-lang/rust#60658)
use futures_util::future::FutureExt;
T::from_request(request).map(|x| match x {
@ -542,4 +490,3 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Opti
}).boxed()
}
}

View File

@ -13,7 +13,7 @@ mod tests;
#[doc(hidden)] pub use rocket_codegen::{FromForm, FromFormValue};
pub use self::request::Request;
pub use self::from_request::{FromRequest, FromRequestAsync, FromRequestFuture, Outcome};
pub use self::from_request::{FromRequest, Outcome};
pub use self::param::{FromParam, FromSegments};
pub use self::form::{FromForm, FromFormValue};
pub use self::form::{Form, LenientForm, FormItems, FormItem};

View File

@ -6,6 +6,7 @@ use std::str;
use yansi::Paint;
use state::{Container, Storage};
use futures_util::future::BoxFuture;
use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
use crate::request::{FromFormValue, FormItems, FormItem};
@ -530,8 +531,9 @@ impl<'r> Request<'r> {
/// let pool = request.guard::<State<Pool>>();
/// # });
/// ```
#[inline(always)]
pub fn guard<'a, T: FromRequest<'a, 'r>>(&'a self) -> Outcome<T, T::Error> {
pub fn guard<'z, 'a, T>(&'a self) -> BoxFuture<'z, Outcome<T, T::Error>>
where T: FromRequest<'a, 'r> + 'z, 'a: 'z, 'r: 'z
{
T::from_request(self)
}

View File

@ -70,11 +70,12 @@ use crate::http::Status;
/// # struct MyConfig{ user_val: String };
/// struct Item(String);
///
/// impl FromRequest<'_, '_> for Item {
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Item {
/// type Error = ();
///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Item, ()> {
/// request.guard::<State<MyConfig>>()
/// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Item, ()> {
/// request.guard::<State<MyConfig>>().await
/// .map(|my_config| Item(my_config.user_val.clone()))
/// }
/// }
@ -165,11 +166,12 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> {
}
}
impl<'r, T: Send + Sync + 'static> FromRequest<'_, 'r> for State<'r, T> {
#[crate::async_trait]
impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> {
type Error = ();
#[inline(always)]
fn from_request(req: &Request<'r>) -> request::Outcome<State<'r, T>, ()> {
async fn from_request(req: &'a Request<'r>) -> request::Outcome<State<'r, T>, ()> {
match req.state.managed.try_get::<T>() {
Some(state) => Outcome::Success(State(state)),
None => {

View File

@ -245,10 +245,11 @@ impl<'a, 'r> Flash<&'a Request<'r>> {
///
/// The suggested use is through an `Option` and the `FlashMessage` type alias
/// in `request`: `Option<FlashMessage>`.
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Flash<&'a Request<'r>> {
type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
trace_!("Flash: attempting to retrieve message.");
req.cookies().get(FLASH_COOKIE_NAME).ok_or(()).and_then(|cookie| {
trace_!("Flash: retrieving message: {:?}", cookie);

View File

@ -46,11 +46,12 @@ impl ShutdownHandle {
}
}
impl FromRequest<'_, '_> for ShutdownHandle {
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for ShutdownHandle {
type Error = std::convert::Infallible;
#[inline]
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Outcome::Success(request.state.managed.get::<ShutdownHandleManaged>().0.clone())
}
}

View File

@ -62,7 +62,7 @@ mod fairing_before_head_strip {
assert_eq!(req.method(), Method::Head);
// This should be called exactly once.
let c = req.guard::<State<Counter>>().unwrap();
let c = req.guard::<State<Counter>>().await.unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
})
}))

View File

@ -8,10 +8,11 @@ use rocket::request::{self, FromRequest};
struct HasContentType;
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for HasContentType {
type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
if request.content_type().is_some() {
Success(HasContentType)
} else {

View File

@ -31,7 +31,8 @@ fn rocket() -> rocket::Rocket {
.attach(AdHoc::on_request("Inner", |req, _| {
Box::pin(async move {
if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>().unwrap();
let counter = req.guard::<State<'_, Counter>>()
.await.unwrap();
counter.get.fetch_add(1, Ordering::Release);
}
})

View File

@ -8,10 +8,11 @@ use rocket::outcome::Outcome::*;
#[derive(Debug)]
struct HeaderCount(usize);
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for HeaderCount {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
Success(HeaderCount(request.headers().len()))
}
}

View File

@ -5,7 +5,7 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use rocket::outcome::Outcome::*;
use rocket::request::{self, FromRequest, FromRequestAsync, FromRequestFuture, Request, State};
use rocket::request::{self, FromRequest, Request, State};
#[cfg(test)] mod tests;
@ -20,11 +20,12 @@ struct Guard2;
struct Guard3;
struct Guard4;
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>());
async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
@ -32,41 +33,38 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
}
}
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard2 {
type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
try_outcome!(req.guard::<Guard1>());
async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
try_outcome!(req.guard::<Guard1>().await);
Success(Guard2)
}
}
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard3 {
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard3 {
type Error = ();
fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()>
{
Box::pin(async move {
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>());
atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache_async(async {
atomics.cached.fetch_add(1, Ordering::Relaxed)
}).await;
async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache_async(async {
atomics.cached.fetch_add(1, Ordering::Relaxed)
}).await;
Success(Guard3)
})
Success(Guard3)
}
}
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard4 {
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard4 {
type Error = ();
fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()>
{
Box::pin(async move {
try_outcome!(Guard3::from_request(req).await);
Success(Guard4)
})
async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
try_outcome!(Guard3::from_request(req).await);
Success(Guard4)
}
}

View File

@ -21,10 +21,11 @@ struct Login {
#[derive(Debug)]
struct User(usize);
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for User {
type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
async fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
request.cookies()
.get_private("user_id")
.and_then(|cookie| cookie.value().parse().ok())

View File

@ -227,7 +227,7 @@ synchronous equivalents inside Rocket applications.
`async fn`s. Inside an `async fn`, you can `.await` `Future`s from Rocket or
other libraries
* Several of Rocket's traits, such as [`FromData`](../requests#body-data) and
[`FromRequestAsync`](../requests#request-guards), have methods that return
[`FromRequest`](../requests#request-guards), have methods that return
`Future`s.
* `Data` and `DataStream` (incoming request data) and `Response` and `Body`
(outgoing response data) are based on `tokio::io::AsyncRead` instead of

View File

@ -412,7 +412,7 @@ imply, a request guard protects a handler from being called erroneously based on
information contained in an incoming request. More specifically, a request guard
is a type that represents an arbitrary validation policy. The validation policy
is implemented through the [`FromRequest`] trait. Every type that implements
`FromRequest` (or the related [`FromRequestAsync`]) is a request guard.
`FromRequest` is a request guard.
Request guards appear as inputs to handlers. An arbitrary number of request
guards can appear as arguments in a route handler. Rocket will automatically
@ -444,7 +444,6 @@ more about request guards and implementing them, see the [`FromRequest`]
documentation.
[`FromRequest`]: @api/rocket/request/trait.FromRequest.html
[`FromRequestAsync`]: @api/rocket/request/trait.FromRequestAsync.html
[`Cookies`]: @api/rocket/http/enum.Cookies.html
### Custom Guards