Use 'async_trait' for 'FromRequest'.

Removes 'FromRequestAsync'.
This commit is contained in:
Sergio Benitez 2020-01-31 01:34:15 -08:00
parent 48c333721c
commit 431b963774
22 changed files with 159 additions and 200 deletions

View File

@ -141,21 +141,22 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea
} }
} }
impl<'a, 'r> #request::FromRequestAsync<'a, 'r> for #guard_type { #[::rocket::async_trait]
impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type {
type Error = (); type Error = ();
fn from_request(request: &'a #request::Request<'r>) -> #request::FromRequestFuture<'a, Self, Self::Error> { async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome<Self, ()> {
use ::rocket::{Outcome, http::Status}; use ::rocket::{Outcome, http::Status};
Box::pin(async move {
let pool = ::rocket::try_outcome!(request.guard::<::rocket::State<'_, #pool_type>>()).0.clone();
#spawn_blocking(move || { let guard = request.guard::<::rocket::State<'_, #pool_type>>();
match pool.get() { let pool = ::rocket::try_outcome!(guard.await).0.clone();
Ok(conn) => Outcome::Success(#guard_type(conn)),
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), #spawn_blocking(move || {
} match pool.get() {
}).await.expect("failed to spawn a blocking task to get a pooled connection") Ok(conn) => Outcome::Success(#guard_type(conn)),
}) Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
}
}).await.expect("failed to spawn a blocking task to get a pooled connection")
} }
} }

View File

@ -173,7 +173,7 @@ impl Fairing for TemplateFairing {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) { async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) {
let cm = req.guard::<rocket::State<'_, ContextManager>>() let cm = req.guard::<rocket::State<'_, ContextManager>>().await
.expect("Template ContextManager registered in on_attach"); .expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback); cm.reload_if_needed(&*self.custom_callback);

View File

@ -87,11 +87,12 @@ impl Metadata<'_> {
/// Retrieves the template metadata. If a template fairing hasn't been attached, /// Retrieves the template metadata. If a template fairing hasn't been attached,
/// an error is printed and an empty `Err` with status `InternalServerError` /// an error is printed and an empty `Err` with status `InternalServerError`
/// (`500`) is returned. /// (`500`) is returned.
impl<'a> FromRequest<'a, '_> for Metadata<'a> { #[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Metadata<'a> {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, ()> { async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
request.guard::<State<'_, ContextManager>>() request.guard::<State<'_, ContextManager>>().await
.succeeded() .succeeded()
.and_then(|cm| Some(Outcome::Success(Metadata(cm.inner())))) .and_then(|cm| Some(Outcome::Success(Metadata(cm.inner()))))
.unwrap_or_else(|| { .unwrap_or_else(|| {

View File

@ -387,7 +387,7 @@ impl<'r> Responder<'r> for Template {
fn respond_to(self, req: &'r Request<'_>) -> response::ResultFuture<'r> { fn respond_to(self, req: &'r Request<'_>) -> response::ResultFuture<'r> {
Box::pin(async move { Box::pin(async move {
let (render, content_type) = { let (render, content_type) = {
let ctxt = req.guard::<State<'_, ContextManager>>().succeeded().ok_or_else(|| { let ctxt = req.guard::<State<'_, ContextManager>>().await.succeeded().ok_or_else(|| {
error_!("Uninitialized template context: missing fairing."); error_!("Uninitialized template context: missing fairing.");
info_!("To use templates, you must attach `Template::fairing()`."); info_!("To use templates, you must attach `Template::fairing()`.");
info_!("See the `Template` documentation for more information."); info_!("See the `Template` documentation for more information.");

View File

@ -313,7 +313,7 @@ fn request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
let span = ident.span().unstable().join(ty.span()).unwrap().into(); let span = ident.span().unstable().join(ty.span()).unwrap().into();
quote_spanned! { span => quote_spanned! { span =>
#[allow(non_snake_case, unreachable_patterns, unreachable_code)] #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
let #ident: #ty = match <#ty as #request::FromRequestAsync>::from_request(#req).await { let #ident: #ty = match <#ty as #request::FromRequest>::from_request(#req).await {
#Outcome::Success(__v) => __v, #Outcome::Success(__v) => __v,
#Outcome::Forward(_) => return #Outcome::Forward(#data), #Outcome::Forward(_) => return #Outcome::Forward(#data),
#Outcome::Failure((__c, _)) => return #Outcome::Failure(__c), #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),

View File

@ -84,10 +84,11 @@ mod key {
/// // In practice, we'd probably fetch the user from the database. /// // In practice, we'd probably fetch the user from the database.
/// struct User(usize); /// struct User(usize);
/// ///
/// impl FromRequest<'_, '_> for User { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for User {
/// type Error = std::convert::Infallible; /// type Error = std::convert::Infallible;
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Self, Self::Error> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// request.cookies() /// request.cookies()
/// .get_private("user_id") /// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok()) /// .and_then(|cookie| cookie.value().parse().ok())

View File

@ -353,10 +353,11 @@ pub use self::info_kind::{Info, Kind};
/// pub struct StartTime(pub SystemTime); /// pub struct StartTime(pub SystemTime);
/// ///
/// // Allows a route to access the time a request was initiated. /// // Allows a route to access the time a request was initiated.
/// impl FromRequest<'_, '_> for StartTime { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for StartTime {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<StartTime, ()> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// match *request.local_cache(|| TimerStart(None)) { /// match *request.local_cache(|| TimerStart(None)) {
/// TimerStart(Some(time)) => Outcome::Success(StartTime(time)), /// TimerStart(Some(time)) => Outcome::Success(StartTime(time)),
/// TimerStart(None) => Outcome::Failure((Status::InternalServerError, ())), /// TimerStart(None) => Outcome::Failure((Status::InternalServerError, ())),
@ -366,7 +367,6 @@ pub use self::info_kind::{Info, Kind};
/// ``` /// ```
/// ///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait] #[crate::async_trait]
pub trait Fairing: Send + Sync + 'static { pub trait Fairing: Send + Sync + 'static {
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this /// Returns an [`Info`] structure containing the `name` and [`Kind`] of this

View File

@ -628,12 +628,13 @@ impl<S, E, F> Outcome<S, E, F> {
/// struct Guard1; /// struct Guard1;
/// struct Guard2; /// struct Guard2;
/// ///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { /// impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> { /// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward. /// // Attempt to fetch the guard, passing through any error or forward.
/// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>()); /// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
/// atomics.uncached.fetch_add(1, Ordering::Relaxed); /// atomics.uncached.fetch_add(1, Ordering::Relaxed);
/// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); /// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
/// ///
@ -641,12 +642,13 @@ impl<S, E, F> Outcome<S, E, F> {
/// } /// }
/// } /// }
/// ///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { /// impl<'a, 'r> FromRequest<'a, 'r> for Guard2 {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> { /// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward. /// // Attempt to fetch the guard, passing through any error or forward.
/// let guard1: Guard1 = try_outcome!(req.guard::<Guard1>()); /// let guard1: Guard1 = try_outcome!(req.guard::<Guard1>().await);
/// Success(Guard2) /// Success(Guard2)
/// } /// }
/// } /// }

View File

@ -34,84 +34,6 @@ impl<S, E> IntoOutcome<S, (Status, E), ()> for Result<S, E> {
} }
} }
/// Type alias for the future returned by [`FromRequestAsync::from_request`].
pub type FromRequestFuture<'fut, T, E> = BoxFuture<'fut, Outcome<T, E>>;
/// Trait implemented by asynchronous request guards to derive a value from
/// incoming requests.
///
/// For more information on request guards in general, see the [`FromRequest`]
/// trait.
///
/// ## Example
///
/// Imagine you're running an authenticated service backed by a database. You
/// want to ensure that certain handlers will only run if a valid API key is
/// present in the request, and you need to make a database request in order to
/// determine if the key is valid or not.
///
/// ```rust
/// # #![feature(proc_macro_hygiene)]
/// # #[macro_use] extern crate rocket;
/// #
/// # struct Database;
/// # impl Database {
/// # async fn check_key(&self, key: &str) -> bool {
/// # true
/// # }
/// # }
/// #
/// use rocket::Outcome;
/// use rocket::http::Status;
/// use rocket::request::{self, Request, State, FromRequestAsync};
///
/// struct ApiKey(String);
///
/// #[derive(Debug)]
/// enum ApiKeyError {
/// BadCount,
/// Missing,
/// Invalid,
/// }
///
/// impl<'a, 'r> FromRequestAsync<'a, 'r> for ApiKey {
/// type Error = ApiKeyError;
///
/// fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> {
/// Box::pin(async move {
/// let keys: Vec<_> = request.headers().get("x-api-key").collect();
/// let database: State<'_, Database> = request.guard().expect("get database connection");
/// match keys.len() {
/// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)),
/// 1 if database.check_key(keys[0]).await => Outcome::Success(ApiKey(keys[0].to_string())),
/// 1 => Outcome::Failure((Status::BadRequest, ApiKeyError::Invalid)),
/// _ => Outcome::Failure((Status::BadRequest, ApiKeyError::BadCount)),
/// }
/// })
/// }
/// }
///
/// #[get("/sensitive")]
/// fn sensitive(key: ApiKey) -> &'static str {
/// # let _key = key;
/// "Sensitive data."
/// }
///
/// # fn main() { }
/// ```
pub trait FromRequestAsync<'a, 'r>: Sized {
/// The associated error to be returned if derivation fails.
type Error: Debug;
/// Derives an instance of `Self` from the incoming request metadata.
///
/// If the derivation is successful, an outcome of `Success` is returned. If
/// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any.
fn from_request(request: &'a Request<'r>) -> FromRequestFuture<'a, Self, Self::Error>;
}
/// Trait implemented by request guards to derive a value from incoming /// Trait implemented by request guards to derive a value from incoming
/// requests. /// requests.
/// ///
@ -127,9 +49,26 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// the handler. Rocket only dispatches requests to a handler when all of its /// the handler. Rocket only dispatches requests to a handler when all of its
/// guards pass. /// guards pass.
/// ///
/// Request guards can be made *asynchronous* by implementing /// ## Async Trait
/// [`FromRequestAsync`] instead of `FromRequest`. This is useful when the ///
/// validation requires working with a database or performing other I/O. /// [`FromRequest`] is an _async_ trait. Implementations of `FromRequest` must
/// be decorated with an attribute of `#[rocket::async_trait]`:
///
/// ```rust
/// use rocket::request::{self, Request, FromRequest};
/// # struct MyType;
/// # type MyError = String;
///
/// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for MyType {
/// type Error = MyError;
///
/// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// /* .. */
/// # unimplemented!()
/// }
/// }
/// ```
/// ///
/// ## Example /// ## Example
/// ///
@ -270,11 +209,12 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// Invalid, /// Invalid,
/// } /// }
/// ///
/// impl FromRequest<'_, '_> for ApiKey { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for ApiKey {
/// type Error = ApiKeyError; /// type Error = ApiKeyError;
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Self, Self::Error> { /// async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// let keys: Vec<_> = request.headers().get("x-api-key").collect(); /// let keys: Vec<_> = req.headers().get("x-api-key").collect();
/// match keys.len() { /// match keys.len() {
/// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)), /// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)),
/// 1 if is_valid(keys[0]) => Outcome::Success(ApiKey(keys[0].to_string())), /// 1 if is_valid(keys[0]) => Outcome::Success(ApiKey(keys[0].to_string())),
@ -316,20 +256,22 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// # Ok(User { id, is_admin: false }) /// # Ok(User { id, is_admin: false })
/// # } /// # }
/// # } /// # }
/// # impl FromRequest<'_, '_> for Database { /// # #[rocket::async_trait]
/// # impl<'a, 'r> FromRequest<'a, 'r> for Database {
/// # type Error = (); /// # type Error = ();
/// # fn from_request(request: &Request<'_>) -> request::Outcome<Database, ()> { /// # async fn from_request(request: &'a Request<'r>) -> request::Outcome<Database, ()> {
/// # Outcome::Success(Database) /// # Outcome::Success(Database)
/// # } /// # }
/// # } /// # }
/// # /// #
/// # struct Admin { user: User } /// # struct Admin { user: User }
/// # /// #
/// impl FromRequest<'_, '_> for User { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for User {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<User, ()> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<User, ()> {
/// let db = try_outcome!(request.guard::<Database>()); /// let db = try_outcome!(request.guard::<Database>().await);
/// request.cookies() /// request.cookies()
/// .get_private("user_id") /// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok()) /// .and_then(|cookie| cookie.value().parse().ok())
@ -338,13 +280,13 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// } /// }
/// } /// }
/// ///
/// impl FromRequest<'_, '_> for Admin { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Admin {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Admin, ()> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Admin, ()> {
/// // This will unconditionally query the database! /// // This will unconditionally query the database!
/// let user = try_outcome!(request.guard::<User>()); /// let user = try_outcome!(request.guard::<User>().await);
///
/// if user.is_admin { /// if user.is_admin {
/// Outcome::Success(Admin { user }) /// Outcome::Success(Admin { user })
/// } else { /// } else {
@ -379,39 +321,41 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// # Ok(User { id, is_admin: false }) /// # Ok(User { id, is_admin: false })
/// # } /// # }
/// # } /// # }
/// # impl FromRequest<'_, '_> for Database { /// # #[rocket::async_trait]
/// # impl<'a, 'r> FromRequest<'a, 'r> for Database {
/// # type Error = (); /// # type Error = ();
/// # fn from_request(request: &Request<'_>) -> request::Outcome<Database, ()> { /// # async fn from_request(request: &'a Request<'r>) -> request::Outcome<Database, ()> {
/// # Outcome::Success(Database) /// # Outcome::Success(Database)
/// # } /// # }
/// # } /// # }
/// # /// #
/// # struct Admin<'a> { user: &'a User } /// # struct Admin<'a> { user: &'a User }
/// # /// #
/// impl<'a> FromRequest<'a, '_> for &'a User { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for &'a User {
/// type Error = std::convert::Infallible; /// type Error = std::convert::Infallible;
/// ///
/// fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, Self::Error> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// // This closure will execute at most once per request, regardless of /// // This closure will execute at most once per request, regardless of
/// // the number of times the `User` guard is executed. /// // the number of times the `User` guard is executed.
/// let user_result = request.local_cache(|| { /// let user_result = request.local_cache_async(async {
/// let db = request.guard::<Database>().succeeded()?; /// let db = request.guard::<Database>().await.succeeded()?;
/// request.cookies() /// request.cookies()
/// .get_private("user_id") /// .get_private("user_id")
/// .and_then(|cookie| cookie.value().parse().ok()) /// .and_then(|cookie| cookie.value().parse().ok())
/// .and_then(|id| db.get_user(id).ok()) /// .and_then(|id| db.get_user(id).ok())
/// }); /// }).await;
/// ///
/// user_result.as_ref().or_forward(()) /// user_result.as_ref().or_forward(())
/// } /// }
/// } /// }
/// ///
/// impl<'a> FromRequest<'a, '_> for Admin<'a> { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Admin<'a> {
/// type Error = std::convert::Infallible; /// type Error = std::convert::Infallible;
/// ///
/// fn from_request(request: &'a Request<'_>) -> request::Outcome<Self, Self::Error> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
/// let user = try_outcome!(request.guard::<&User>()); /// let user = try_outcome!(request.guard::<&User>().await);
///
/// if user.is_admin { /// if user.is_admin {
/// Outcome::Success(Admin { user }) /// Outcome::Success(Admin { user })
/// } else { /// } else {
@ -427,6 +371,7 @@ pub trait FromRequestAsync<'a, 'r>: Sized {
/// ///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait]
pub trait FromRequest<'a, 'r>: Sized { pub trait FromRequest<'a, 'r>: Sized {
/// The associated error to be returned if derivation fails. /// The associated error to be returned if derivation fails.
type Error: Debug; type Error: Debug;
@ -437,37 +382,32 @@ pub trait FromRequest<'a, 'r>: Sized {
/// the derivation fails in an unrecoverable fashion, `Failure` is returned. /// the derivation fails in an unrecoverable fashion, `Failure` is returned.
/// `Forward` is returned to indicate that the request should be forwarded /// `Forward` is returned to indicate that the request should be forwarded
/// to other matching routes, if any. /// to other matching routes, if any.
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error>; async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error>;
} }
impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequestAsync<'a, 'r> for T { #[crate::async_trait]
type Error = T::Error; impl<'a, 'r> FromRequest<'a, 'r> for Method {
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> {
Box::pin(async move { T::from_request(request) })
}
}
impl FromRequest<'_, '_> for Method {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.method()) Success(request.method())
} }
} }
impl<'a> FromRequest<'a, '_> for &'a Origin<'a> { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a Origin<'a> {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.uri()) Success(request.uri())
} }
} }
impl<'r> FromRequest<'_, 'r> for &'r Route { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'r Route {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.route() { match request.route() {
Some(route) => Success(route), Some(route) => Success(route),
None => Forward(()) None => Forward(())
@ -475,18 +415,20 @@ impl<'r> FromRequest<'_, 'r> for &'r Route {
} }
} }
impl<'a> FromRequest<'a, '_> for Cookies<'a> { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Cookies<'a> {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.cookies()) Success(request.cookies())
} }
} }
impl<'a> FromRequest<'a, '_> for &'a Accept { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a Accept {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.accept() { match request.accept() {
Some(accept) => Success(accept), Some(accept) => Success(accept),
None => Forward(()) None => Forward(())
@ -494,10 +436,11 @@ impl<'a> FromRequest<'a, '_> for &'a Accept {
} }
} }
impl<'a> FromRequest<'a, '_> for &'a ContentType { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for &'a ContentType {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.content_type() { match request.content_type() {
Some(content_type) => Success(content_type), Some(content_type) => Success(content_type),
None => Forward(()) None => Forward(())
@ -505,10 +448,11 @@ impl<'a> FromRequest<'a, '_> for &'a ContentType {
} }
} }
impl FromRequest<'_, '_> for SocketAddr { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.remote() { match request.remote() {
Some(addr) => Success(addr), Some(addr) => Success(addr),
None => Forward(()) None => Forward(())
@ -516,10 +460,12 @@ impl FromRequest<'_, '_> for SocketAddr {
} }
} }
impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Result<T, T::Error> { impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Result<T, T::Error> {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> { fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome<Self, Self::Error>>
where 'a: 'y, 'r: 'y
{
// TODO: FutureExt::map is a workaround (see rust-lang/rust#60658) // TODO: FutureExt::map is a workaround (see rust-lang/rust#60658)
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
T::from_request(request).map(|x| match x { T::from_request(request).map(|x| match x {
@ -530,10 +476,12 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Resu
} }
} }
impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Option<T> { impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Option<T> {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome<Self, Self::Error>> { fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome<Self, Self::Error>>
where 'a: 'y, 'r: 'y
{
// TODO: FutureExt::map is a workaround (see rust-lang/rust#60658) // TODO: FutureExt::map is a workaround (see rust-lang/rust#60658)
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
T::from_request(request).map(|x| match x { T::from_request(request).map(|x| match x {
@ -542,4 +490,3 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Opti
}).boxed() }).boxed()
} }
} }

View File

@ -13,7 +13,7 @@ mod tests;
#[doc(hidden)] pub use rocket_codegen::{FromForm, FromFormValue}; #[doc(hidden)] pub use rocket_codegen::{FromForm, FromFormValue};
pub use self::request::Request; pub use self::request::Request;
pub use self::from_request::{FromRequest, FromRequestAsync, FromRequestFuture, Outcome}; pub use self::from_request::{FromRequest, Outcome};
pub use self::param::{FromParam, FromSegments}; pub use self::param::{FromParam, FromSegments};
pub use self::form::{FromForm, FromFormValue}; pub use self::form::{FromForm, FromFormValue};
pub use self::form::{Form, LenientForm, FormItems, FormItem}; pub use self::form::{Form, LenientForm, FormItems, FormItem};

View File

@ -6,6 +6,7 @@ use std::str;
use yansi::Paint; use yansi::Paint;
use state::{Container, Storage}; use state::{Container, Storage};
use futures_util::future::BoxFuture;
use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
use crate::request::{FromFormValue, FormItems, FormItem}; use crate::request::{FromFormValue, FormItems, FormItem};
@ -530,8 +531,9 @@ impl<'r> Request<'r> {
/// let pool = request.guard::<State<Pool>>(); /// let pool = request.guard::<State<Pool>>();
/// # }); /// # });
/// ``` /// ```
#[inline(always)] pub fn guard<'z, 'a, T>(&'a self) -> BoxFuture<'z, Outcome<T, T::Error>>
pub fn guard<'a, T: FromRequest<'a, 'r>>(&'a self) -> Outcome<T, T::Error> { where T: FromRequest<'a, 'r> + 'z, 'a: 'z, 'r: 'z
{
T::from_request(self) T::from_request(self)
} }

View File

@ -70,11 +70,12 @@ use crate::http::Status;
/// # struct MyConfig{ user_val: String }; /// # struct MyConfig{ user_val: String };
/// struct Item(String); /// struct Item(String);
/// ///
/// impl FromRequest<'_, '_> for Item { /// #[rocket::async_trait]
/// impl<'a, 'r> FromRequest<'a, 'r> for Item {
/// type Error = (); /// type Error = ();
/// ///
/// fn from_request(request: &Request<'_>) -> request::Outcome<Item, ()> { /// async fn from_request(request: &'a Request<'r>) -> request::Outcome<Item, ()> {
/// request.guard::<State<MyConfig>>() /// request.guard::<State<MyConfig>>().await
/// .map(|my_config| Item(my_config.user_val.clone())) /// .map(|my_config| Item(my_config.user_val.clone()))
/// } /// }
/// } /// }
@ -165,11 +166,12 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> {
} }
} }
impl<'r, T: Send + Sync + 'static> FromRequest<'_, 'r> for State<'r, T> { #[crate::async_trait]
impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> {
type Error = (); type Error = ();
#[inline(always)] #[inline(always)]
fn from_request(req: &Request<'r>) -> request::Outcome<State<'r, T>, ()> { async fn from_request(req: &'a Request<'r>) -> request::Outcome<State<'r, T>, ()> {
match req.state.managed.try_get::<T>() { match req.state.managed.try_get::<T>() {
Some(state) => Outcome::Success(State(state)), Some(state) => Outcome::Success(State(state)),
None => { None => {

View File

@ -245,10 +245,11 @@ impl<'a, 'r> Flash<&'a Request<'r>> {
/// ///
/// The suggested use is through an `Option` and the `FlashMessage` type alias /// The suggested use is through an `Option` and the `FlashMessage` type alias
/// in `request`: `Option<FlashMessage>`. /// in `request`: `Option<FlashMessage>`.
#[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Flash<&'a Request<'r>> { impl<'a, 'r> FromRequest<'a, 'r> for Flash<&'a Request<'r>> {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> { async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
trace_!("Flash: attempting to retrieve message."); trace_!("Flash: attempting to retrieve message.");
req.cookies().get(FLASH_COOKIE_NAME).ok_or(()).and_then(|cookie| { req.cookies().get(FLASH_COOKIE_NAME).ok_or(()).and_then(|cookie| {
trace_!("Flash: retrieving message: {:?}", cookie); trace_!("Flash: retrieving message: {:?}", cookie);

View File

@ -46,11 +46,12 @@ impl ShutdownHandle {
} }
} }
impl FromRequest<'_, '_> for ShutdownHandle { #[crate::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for ShutdownHandle {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
#[inline] #[inline]
fn from_request(request: &Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Outcome::Success(request.state.managed.get::<ShutdownHandleManaged>().0.clone()) Outcome::Success(request.state.managed.get::<ShutdownHandleManaged>().0.clone())
} }
} }

View File

@ -62,7 +62,7 @@ mod fairing_before_head_strip {
assert_eq!(req.method(), Method::Head); assert_eq!(req.method(), Method::Head);
// This should be called exactly once. // This should be called exactly once.
let c = req.guard::<State<Counter>>().unwrap(); let c = req.guard::<State<Counter>>().await.unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
}) })
})) }))

View File

@ -8,10 +8,11 @@ use rocket::request::{self, FromRequest};
struct HasContentType; struct HasContentType;
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for HasContentType { impl<'a, 'r> FromRequest<'a, 'r> for HasContentType {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> { async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, ()> {
if request.content_type().is_some() { if request.content_type().is_some() {
Success(HasContentType) Success(HasContentType)
} else { } else {

View File

@ -31,7 +31,8 @@ fn rocket() -> rocket::Rocket {
.attach(AdHoc::on_request("Inner", |req, _| { .attach(AdHoc::on_request("Inner", |req, _| {
Box::pin(async move { Box::pin(async move {
if req.method() == Method::Get { if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>().unwrap(); let counter = req.guard::<State<'_, Counter>>()
.await.unwrap();
counter.get.fetch_add(1, Ordering::Release); counter.get.fetch_add(1, Ordering::Release);
} }
}) })

View File

@ -8,10 +8,11 @@ use rocket::outcome::Outcome::*;
#[derive(Debug)] #[derive(Debug)]
struct HeaderCount(usize); struct HeaderCount(usize);
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for HeaderCount { impl<'a, 'r> FromRequest<'a, 'r> for HeaderCount {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
Success(HeaderCount(request.headers().len())) Success(HeaderCount(request.headers().len()))
} }
} }

View File

@ -5,7 +5,7 @@
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use rocket::outcome::Outcome::*; use rocket::outcome::Outcome::*;
use rocket::request::{self, FromRequest, FromRequestAsync, FromRequestFuture, Request, State}; use rocket::request::{self, FromRequest, Request, State};
#[cfg(test)] mod tests; #[cfg(test)] mod tests;
@ -20,11 +20,12 @@ struct Guard2;
struct Guard3; struct Guard3;
struct Guard4; struct Guard4;
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> { async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>()); let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed); atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
@ -32,41 +33,38 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
} }
} }
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { impl<'a, 'r> FromRequest<'a, 'r> for Guard2 {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> { async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
try_outcome!(req.guard::<Guard1>()); try_outcome!(req.guard::<Guard1>().await);
Success(Guard2) Success(Guard2)
} }
} }
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard3 { #[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard3 {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()> async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
{ let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
Box::pin(async move { atomics.uncached.fetch_add(1, Ordering::Relaxed);
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>()); req.local_cache_async(async {
atomics.uncached.fetch_add(1, Ordering::Relaxed); atomics.cached.fetch_add(1, Ordering::Relaxed)
req.local_cache_async(async { }).await;
atomics.cached.fetch_add(1, Ordering::Relaxed)
}).await;
Success(Guard3) Success(Guard3)
})
} }
} }
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard4 { #[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for Guard4 {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()> async fn from_request(req: &'a Request<'r>) -> request::Outcome<Self, ()> {
{ try_outcome!(Guard3::from_request(req).await);
Box::pin(async move { Success(Guard4)
try_outcome!(Guard3::from_request(req).await);
Success(Guard4)
})
} }
} }

View File

@ -21,10 +21,11 @@ struct Login {
#[derive(Debug)] #[derive(Debug)]
struct User(usize); struct User(usize);
#[rocket::async_trait]
impl<'a, 'r> FromRequest<'a, 'r> for User { impl<'a, 'r> FromRequest<'a, 'r> for User {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> { async fn from_request(request: &'a Request<'r>) -> request::Outcome<User, Self::Error> {
request.cookies() request.cookies()
.get_private("user_id") .get_private("user_id")
.and_then(|cookie| cookie.value().parse().ok()) .and_then(|cookie| cookie.value().parse().ok())

View File

@ -227,7 +227,7 @@ synchronous equivalents inside Rocket applications.
`async fn`s. Inside an `async fn`, you can `.await` `Future`s from Rocket or `async fn`s. Inside an `async fn`, you can `.await` `Future`s from Rocket or
other libraries other libraries
* Several of Rocket's traits, such as [`FromData`](../requests#body-data) and * Several of Rocket's traits, such as [`FromData`](../requests#body-data) and
[`FromRequestAsync`](../requests#request-guards), have methods that return [`FromRequest`](../requests#request-guards), have methods that return
`Future`s. `Future`s.
* `Data` and `DataStream` (incoming request data) and `Response` and `Body` * `Data` and `DataStream` (incoming request data) and `Response` and `Body`
(outgoing response data) are based on `tokio::io::AsyncRead` instead of (outgoing response data) are based on `tokio::io::AsyncRead` instead of

View File

@ -412,7 +412,7 @@ imply, a request guard protects a handler from being called erroneously based on
information contained in an incoming request. More specifically, a request guard information contained in an incoming request. More specifically, a request guard
is a type that represents an arbitrary validation policy. The validation policy is a type that represents an arbitrary validation policy. The validation policy
is implemented through the [`FromRequest`] trait. Every type that implements is implemented through the [`FromRequest`] trait. Every type that implements
`FromRequest` (or the related [`FromRequestAsync`]) is a request guard. `FromRequest` is a request guard.
Request guards appear as inputs to handlers. An arbitrary number of request Request guards appear as inputs to handlers. An arbitrary number of request
guards can appear as arguments in a route handler. Rocket will automatically guards can appear as arguments in a route handler. Rocket will automatically
@ -444,7 +444,6 @@ more about request guards and implementing them, see the [`FromRequest`]
documentation. documentation.
[`FromRequest`]: @api/rocket/request/trait.FromRequest.html [`FromRequest`]: @api/rocket/request/trait.FromRequest.html
[`FromRequestAsync`]: @api/rocket/request/trait.FromRequestAsync.html
[`Cookies`]: @api/rocket/http/enum.Cookies.html [`Cookies`]: @api/rocket/http/enum.Cookies.html
### Custom Guards ### Custom Guards