From 431b963774bab757ad84b15452d5b7fedc007aef Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 31 Jan 2020 01:34:15 -0800 Subject: [PATCH] Use 'async_trait' for 'FromRequest'. Removes 'FromRequestAsync'. --- contrib/codegen/src/database.rs | 23 +- contrib/lib/src/templates/fairing.rs | 2 +- contrib/lib/src/templates/metadata.rs | 7 +- contrib/lib/src/templates/mod.rs | 2 +- core/codegen/src/attribute/route.rs | 2 +- core/http/src/cookies.rs | 5 +- core/lib/src/fairing/mod.rs | 6 +- core/lib/src/outcome.rs | 10 +- core/lib/src/request/from_request.rs | 211 +++++++----------- core/lib/src/request/mod.rs | 2 +- core/lib/src/request/request.rs | 6 +- core/lib/src/request/state.rs | 12 +- core/lib/src/response/flash.rs | 3 +- core/lib/src/shutdown.rs | 5 +- .../fairing_before_head_strip-issue-546.rs | 2 +- .../local-request-content-type-issue-505.rs | 3 +- core/lib/tests/nested-fairing-attaches.rs | 3 +- examples/request_guard/src/main.rs | 3 +- examples/request_local_state/src/main.rs | 44 ++-- examples/session/src/main.rs | 3 +- site/guide/3-overview.md | 2 +- site/guide/4-requests.md | 3 +- 22 files changed, 159 insertions(+), 200 deletions(-) diff --git a/contrib/codegen/src/database.rs b/contrib/codegen/src/database.rs index 49a8796d..8261721e 100644 --- a/contrib/codegen/src/database.rs +++ b/contrib/codegen/src/database.rs @@ -141,21 +141,22 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result #request::FromRequestAsync<'a, 'r> for #guard_type { + #[::rocket::async_trait] + impl<'a, 'r> #request::FromRequest<'a, 'r> for #guard_type { type Error = (); - fn from_request(request: &'a #request::Request<'r>) -> #request::FromRequestFuture<'a, Self, Self::Error> { + async fn from_request(request: &'a #request::Request<'r>) -> #request::Outcome { use ::rocket::{Outcome, http::Status}; - Box::pin(async move { - let pool = ::rocket::try_outcome!(request.guard::<::rocket::State<'_, #pool_type>>()).0.clone(); - #spawn_blocking(move || { - match pool.get() { - Ok(conn) => Outcome::Success(#guard_type(conn)), - Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), - } - }).await.expect("failed to spawn a blocking task to get a pooled connection") - }) + let guard = request.guard::<::rocket::State<'_, #pool_type>>(); + let pool = ::rocket::try_outcome!(guard.await).0.clone(); + + #spawn_blocking(move || { + match pool.get() { + Ok(conn) => Outcome::Success(#guard_type(conn)), + Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())), + } + }).await.expect("failed to spawn a blocking task to get a pooled connection") } } diff --git a/contrib/lib/src/templates/fairing.rs b/contrib/lib/src/templates/fairing.rs index 4bc4bb36..8e209b7b 100644 --- a/contrib/lib/src/templates/fairing.rs +++ b/contrib/lib/src/templates/fairing.rs @@ -173,7 +173,7 @@ impl Fairing for TemplateFairing { #[cfg(debug_assertions)] async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) { - let cm = req.guard::>() + let cm = req.guard::>().await .expect("Template ContextManager registered in on_attach"); cm.reload_if_needed(&*self.custom_callback); diff --git a/contrib/lib/src/templates/metadata.rs b/contrib/lib/src/templates/metadata.rs index b60bb673..256eb663 100644 --- a/contrib/lib/src/templates/metadata.rs +++ b/contrib/lib/src/templates/metadata.rs @@ -87,11 +87,12 @@ impl Metadata<'_> { /// Retrieves the template metadata. If a template fairing hasn't been attached, /// an error is printed and an empty `Err` with status `InternalServerError` /// (`500`) is returned. -impl<'a> FromRequest<'a, '_> for Metadata<'a> { +#[rocket::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for Metadata<'a> { type Error = (); - fn from_request(request: &'a Request<'_>) -> request::Outcome { - request.guard::>() + async fn from_request(request: &'a Request<'r>) -> request::Outcome { + request.guard::>().await .succeeded() .and_then(|cm| Some(Outcome::Success(Metadata(cm.inner())))) .unwrap_or_else(|| { diff --git a/contrib/lib/src/templates/mod.rs b/contrib/lib/src/templates/mod.rs index 1876c98e..0ad40e69 100644 --- a/contrib/lib/src/templates/mod.rs +++ b/contrib/lib/src/templates/mod.rs @@ -387,7 +387,7 @@ impl<'r> Responder<'r> for Template { fn respond_to(self, req: &'r Request<'_>) -> response::ResultFuture<'r> { Box::pin(async move { let (render, content_type) = { - let ctxt = req.guard::>().succeeded().ok_or_else(|| { + let ctxt = req.guard::>().await.succeeded().ok_or_else(|| { error_!("Uninitialized template context: missing fairing."); info_!("To use templates, you must attach `Template::fairing()`."); info_!("See the `Template` documentation for more information."); diff --git a/core/codegen/src/attribute/route.rs b/core/codegen/src/attribute/route.rs index 2370ad62..c98d401e 100644 --- a/core/codegen/src/attribute/route.rs +++ b/core/codegen/src/attribute/route.rs @@ -313,7 +313,7 @@ fn request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 { let span = ident.span().unstable().join(ty.span()).unwrap().into(); quote_spanned! { span => #[allow(non_snake_case, unreachable_patterns, unreachable_code)] - let #ident: #ty = match <#ty as #request::FromRequestAsync>::from_request(#req).await { + let #ident: #ty = match <#ty as #request::FromRequest>::from_request(#req).await { #Outcome::Success(__v) => __v, #Outcome::Forward(_) => return #Outcome::Forward(#data), #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c), diff --git a/core/http/src/cookies.rs b/core/http/src/cookies.rs index 73eb65a1..becd9ec3 100644 --- a/core/http/src/cookies.rs +++ b/core/http/src/cookies.rs @@ -84,10 +84,11 @@ mod key { /// // In practice, we'd probably fetch the user from the database. /// struct User(usize); /// -/// impl FromRequest<'_, '_> for User { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for User { /// type Error = std::convert::Infallible; /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// request.cookies() /// .get_private("user_id") /// .and_then(|cookie| cookie.value().parse().ok()) diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index 0e1e57ef..4ed2d3aa 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -353,10 +353,11 @@ pub use self::info_kind::{Info, Kind}; /// pub struct StartTime(pub SystemTime); /// /// // Allows a route to access the time a request was initiated. -/// impl FromRequest<'_, '_> for StartTime { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for StartTime { /// type Error = (); /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// match *request.local_cache(|| TimerStart(None)) { /// TimerStart(Some(time)) => Outcome::Success(StartTime(time)), /// TimerStart(None) => Outcome::Failure((Status::InternalServerError, ())), @@ -366,7 +367,6 @@ pub use self::info_kind::{Info, Kind}; /// ``` /// /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state - #[crate::async_trait] pub trait Fairing: Send + Sync + 'static { /// Returns an [`Info`] structure containing the `name` and [`Kind`] of this diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index 72a6392f..c4de4599 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -628,12 +628,13 @@ impl Outcome { /// struct Guard1; /// struct Guard2; /// +/// #[rocket::async_trait] /// impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { /// type Error = (); /// -/// fn from_request(req: &'a Request<'r>) -> request::Outcome { +/// async fn from_request(req: &'a Request<'r>) -> request::Outcome { /// // Attempt to fetch the guard, passing through any error or forward. -/// let atomics = try_outcome!(req.guard::>()); +/// let atomics = try_outcome!(req.guard::>().await); /// atomics.uncached.fetch_add(1, Ordering::Relaxed); /// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); /// @@ -641,12 +642,13 @@ impl Outcome { /// } /// } /// +/// #[rocket::async_trait] /// impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { /// type Error = (); /// -/// fn from_request(req: &'a Request<'r>) -> request::Outcome { +/// async fn from_request(req: &'a Request<'r>) -> request::Outcome { /// // Attempt to fetch the guard, passing through any error or forward. -/// let guard1: Guard1 = try_outcome!(req.guard::()); +/// let guard1: Guard1 = try_outcome!(req.guard::().await); /// Success(Guard2) /// } /// } diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index 78cee3ca..ce01e3d5 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -34,84 +34,6 @@ impl IntoOutcome for Result { } } -/// Type alias for the future returned by [`FromRequestAsync::from_request`]. -pub type FromRequestFuture<'fut, T, E> = BoxFuture<'fut, Outcome>; - -/// Trait implemented by asynchronous request guards to derive a value from -/// incoming requests. -/// -/// For more information on request guards in general, see the [`FromRequest`] -/// trait. -/// -/// ## Example -/// -/// Imagine you're running an authenticated service backed by a database. You -/// want to ensure that certain handlers will only run if a valid API key is -/// present in the request, and you need to make a database request in order to -/// determine if the key is valid or not. -/// -/// ```rust -/// # #![feature(proc_macro_hygiene)] -/// # #[macro_use] extern crate rocket; -/// # -/// # struct Database; -/// # impl Database { -/// # async fn check_key(&self, key: &str) -> bool { -/// # true -/// # } -/// # } -/// # -/// use rocket::Outcome; -/// use rocket::http::Status; -/// use rocket::request::{self, Request, State, FromRequestAsync}; -/// -/// struct ApiKey(String); -/// -/// #[derive(Debug)] -/// enum ApiKeyError { -/// BadCount, -/// Missing, -/// Invalid, -/// } -/// -/// impl<'a, 'r> FromRequestAsync<'a, 'r> for ApiKey { -/// type Error = ApiKeyError; -/// -/// fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> { -/// Box::pin(async move { -/// let keys: Vec<_> = request.headers().get("x-api-key").collect(); -/// let database: State<'_, Database> = request.guard().expect("get database connection"); -/// match keys.len() { -/// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)), -/// 1 if database.check_key(keys[0]).await => Outcome::Success(ApiKey(keys[0].to_string())), -/// 1 => Outcome::Failure((Status::BadRequest, ApiKeyError::Invalid)), -/// _ => Outcome::Failure((Status::BadRequest, ApiKeyError::BadCount)), -/// } -/// }) -/// } -/// } -/// -/// #[get("/sensitive")] -/// fn sensitive(key: ApiKey) -> &'static str { -/// # let _key = key; -/// "Sensitive data." -/// } -/// -/// # fn main() { } -/// ``` -pub trait FromRequestAsync<'a, 'r>: Sized { - /// The associated error to be returned if derivation fails. - type Error: Debug; - - /// Derives an instance of `Self` from the incoming request metadata. - /// - /// If the derivation is successful, an outcome of `Success` is returned. If - /// the derivation fails in an unrecoverable fashion, `Failure` is returned. - /// `Forward` is returned to indicate that the request should be forwarded - /// to other matching routes, if any. - fn from_request(request: &'a Request<'r>) -> FromRequestFuture<'a, Self, Self::Error>; -} - /// Trait implemented by request guards to derive a value from incoming /// requests. /// @@ -127,9 +49,26 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// the handler. Rocket only dispatches requests to a handler when all of its /// guards pass. /// -/// Request guards can be made *asynchronous* by implementing -/// [`FromRequestAsync`] instead of `FromRequest`. This is useful when the -/// validation requires working with a database or performing other I/O. +/// ## Async Trait +/// +/// [`FromRequest`] is an _async_ trait. Implementations of `FromRequest` must +/// be decorated with an attribute of `#[rocket::async_trait]`: +/// +/// ```rust +/// use rocket::request::{self, Request, FromRequest}; +/// # struct MyType; +/// # type MyError = String; +/// +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for MyType { +/// type Error = MyError; +/// +/// async fn from_request(req: &'a Request<'r>) -> request::Outcome { +/// /* .. */ +/// # unimplemented!() +/// } +/// } +/// ``` /// /// ## Example /// @@ -270,11 +209,12 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// Invalid, /// } /// -/// impl FromRequest<'_, '_> for ApiKey { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for ApiKey { /// type Error = ApiKeyError; /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { -/// let keys: Vec<_> = request.headers().get("x-api-key").collect(); +/// async fn from_request(req: &'a Request<'r>) -> request::Outcome { +/// let keys: Vec<_> = req.headers().get("x-api-key").collect(); /// match keys.len() { /// 0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)), /// 1 if is_valid(keys[0]) => Outcome::Success(ApiKey(keys[0].to_string())), @@ -316,20 +256,22 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// # Ok(User { id, is_admin: false }) /// # } /// # } -/// # impl FromRequest<'_, '_> for Database { +/// # #[rocket::async_trait] +/// # impl<'a, 'r> FromRequest<'a, 'r> for Database { /// # type Error = (); -/// # fn from_request(request: &Request<'_>) -> request::Outcome { +/// # async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// # Outcome::Success(Database) /// # } /// # } /// # /// # struct Admin { user: User } /// # -/// impl FromRequest<'_, '_> for User { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for User { /// type Error = (); /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { -/// let db = try_outcome!(request.guard::()); +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { +/// let db = try_outcome!(request.guard::().await); /// request.cookies() /// .get_private("user_id") /// .and_then(|cookie| cookie.value().parse().ok()) @@ -338,13 +280,13 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// } /// } /// -/// impl FromRequest<'_, '_> for Admin { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for Admin { /// type Error = (); /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// // This will unconditionally query the database! -/// let user = try_outcome!(request.guard::()); -/// +/// let user = try_outcome!(request.guard::().await); /// if user.is_admin { /// Outcome::Success(Admin { user }) /// } else { @@ -379,39 +321,41 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// # Ok(User { id, is_admin: false }) /// # } /// # } -/// # impl FromRequest<'_, '_> for Database { +/// # #[rocket::async_trait] +/// # impl<'a, 'r> FromRequest<'a, 'r> for Database { /// # type Error = (); -/// # fn from_request(request: &Request<'_>) -> request::Outcome { +/// # async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// # Outcome::Success(Database) /// # } /// # } /// # /// # struct Admin<'a> { user: &'a User } /// # -/// impl<'a> FromRequest<'a, '_> for &'a User { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for &'a User { /// type Error = std::convert::Infallible; /// -/// fn from_request(request: &'a Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { /// // This closure will execute at most once per request, regardless of /// // the number of times the `User` guard is executed. -/// let user_result = request.local_cache(|| { -/// let db = request.guard::().succeeded()?; +/// let user_result = request.local_cache_async(async { +/// let db = request.guard::().await.succeeded()?; /// request.cookies() /// .get_private("user_id") /// .and_then(|cookie| cookie.value().parse().ok()) /// .and_then(|id| db.get_user(id).ok()) -/// }); +/// }).await; /// /// user_result.as_ref().or_forward(()) /// } /// } /// -/// impl<'a> FromRequest<'a, '_> for Admin<'a> { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for Admin<'a> { /// type Error = std::convert::Infallible; /// -/// fn from_request(request: &'a Request<'_>) -> request::Outcome { -/// let user = try_outcome!(request.guard::<&User>()); -/// +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { +/// let user = try_outcome!(request.guard::<&User>().await); /// if user.is_admin { /// Outcome::Success(Admin { user }) /// } else { @@ -427,6 +371,7 @@ pub trait FromRequestAsync<'a, 'r>: Sized { /// /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state +#[crate::async_trait] pub trait FromRequest<'a, 'r>: Sized { /// The associated error to be returned if derivation fails. type Error: Debug; @@ -437,37 +382,32 @@ pub trait FromRequest<'a, 'r>: Sized { /// the derivation fails in an unrecoverable fashion, `Failure` is returned. /// `Forward` is returned to indicate that the request should be forwarded /// to other matching routes, if any. - fn from_request(request: &'a Request<'r>) -> Outcome; + async fn from_request(request: &'a Request<'r>) -> Outcome; } -impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequestAsync<'a, 'r> for T { - type Error = T::Error; - - fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome> { - Box::pin(async move { T::from_request(request) }) - } -} - -impl FromRequest<'_, '_> for Method { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for Method { type Error = std::convert::Infallible; - fn from_request(request: &Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { Success(request.method()) } } -impl<'a> FromRequest<'a, '_> for &'a Origin<'a> { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for &'a Origin<'a> { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { Success(request.uri()) } } -impl<'r> FromRequest<'_, 'r> for &'r Route { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for &'r Route { type Error = std::convert::Infallible; - fn from_request(request: &Request<'r>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { match request.route() { Some(route) => Success(route), None => Forward(()) @@ -475,18 +415,20 @@ impl<'r> FromRequest<'_, 'r> for &'r Route { } } -impl<'a> FromRequest<'a, '_> for Cookies<'a> { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for Cookies<'a> { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { Success(request.cookies()) } } -impl<'a> FromRequest<'a, '_> for &'a Accept { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for &'a Accept { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { match request.accept() { Some(accept) => Success(accept), None => Forward(()) @@ -494,10 +436,11 @@ impl<'a> FromRequest<'a, '_> for &'a Accept { } } -impl<'a> FromRequest<'a, '_> for &'a ContentType { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for &'a ContentType { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { match request.content_type() { Some(content_type) => Success(content_type), None => Forward(()) @@ -505,10 +448,11 @@ impl<'a> FromRequest<'a, '_> for &'a ContentType { } } -impl FromRequest<'_, '_> for SocketAddr { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr { type Error = std::convert::Infallible; - fn from_request(request: &Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { match request.remote() { Some(addr) => Success(addr), None => Forward(()) @@ -516,10 +460,12 @@ impl FromRequest<'_, '_> for SocketAddr { } } -impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Result { +impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Result { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome> { + fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome> + where 'a: 'y, 'r: 'y + { // TODO: FutureExt::map is a workaround (see rust-lang/rust#60658) use futures_util::future::FutureExt; T::from_request(request).map(|x| match x { @@ -530,10 +476,12 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Resu } } -impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Option { +impl<'a, 'r, T: FromRequest<'a, 'r> + 'a> FromRequest<'a, 'r> for Option { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'r>) -> BoxFuture<'a, Outcome> { + fn from_request<'y>(request: &'a Request<'r>) -> BoxFuture<'y, Outcome> + where 'a: 'y, 'r: 'y + { // TODO: FutureExt::map is a workaround (see rust-lang/rust#60658) use futures_util::future::FutureExt; T::from_request(request).map(|x| match x { @@ -542,4 +490,3 @@ impl<'a, 'r, T: FromRequestAsync<'a, 'r> + 'a> FromRequestAsync<'a, 'r> for Opti }).boxed() } } - diff --git a/core/lib/src/request/mod.rs b/core/lib/src/request/mod.rs index fbd45176..94a6d43a 100644 --- a/core/lib/src/request/mod.rs +++ b/core/lib/src/request/mod.rs @@ -13,7 +13,7 @@ mod tests; #[doc(hidden)] pub use rocket_codegen::{FromForm, FromFormValue}; pub use self::request::Request; -pub use self::from_request::{FromRequest, FromRequestAsync, FromRequestFuture, Outcome}; +pub use self::from_request::{FromRequest, Outcome}; pub use self::param::{FromParam, FromSegments}; pub use self::form::{FromForm, FromFormValue}; pub use self::form::{Form, LenientForm, FormItems, FormItem}; diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 3bc260d2..0e436d02 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -6,6 +6,7 @@ use std::str; use yansi::Paint; use state::{Container, Storage}; +use futures_util::future::BoxFuture; use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::request::{FromFormValue, FormItems, FormItem}; @@ -530,8 +531,9 @@ impl<'r> Request<'r> { /// let pool = request.guard::>(); /// # }); /// ``` - #[inline(always)] - pub fn guard<'a, T: FromRequest<'a, 'r>>(&'a self) -> Outcome { + pub fn guard<'z, 'a, T>(&'a self) -> BoxFuture<'z, Outcome> + where T: FromRequest<'a, 'r> + 'z, 'a: 'z, 'r: 'z + { T::from_request(self) } diff --git a/core/lib/src/request/state.rs b/core/lib/src/request/state.rs index a2f84c05..05759dd0 100644 --- a/core/lib/src/request/state.rs +++ b/core/lib/src/request/state.rs @@ -70,11 +70,12 @@ use crate::http::Status; /// # struct MyConfig{ user_val: String }; /// struct Item(String); /// -/// impl FromRequest<'_, '_> for Item { +/// #[rocket::async_trait] +/// impl<'a, 'r> FromRequest<'a, 'r> for Item { /// type Error = (); /// -/// fn from_request(request: &Request<'_>) -> request::Outcome { -/// request.guard::>() +/// async fn from_request(request: &'a Request<'r>) -> request::Outcome { +/// request.guard::>().await /// .map(|my_config| Item(my_config.user_val.clone())) /// } /// } @@ -165,11 +166,12 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> { } } -impl<'r, T: Send + Sync + 'static> FromRequest<'_, 'r> for State<'r, T> { +#[crate::async_trait] +impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> { type Error = (); #[inline(always)] - fn from_request(req: &Request<'r>) -> request::Outcome, ()> { + async fn from_request(req: &'a Request<'r>) -> request::Outcome, ()> { match req.state.managed.try_get::() { Some(state) => Outcome::Success(State(state)), None => { diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index 4a0e2bdc..b9c7f308 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -245,10 +245,11 @@ impl<'a, 'r> Flash<&'a Request<'r>> { /// /// The suggested use is through an `Option` and the `FlashMessage` type alias /// in `request`: `Option`. +#[crate::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for Flash<&'a Request<'r>> { type Error = (); - fn from_request(req: &'a Request<'r>) -> request::Outcome { + async fn from_request(req: &'a Request<'r>) -> request::Outcome { trace_!("Flash: attempting to retrieve message."); req.cookies().get(FLASH_COOKIE_NAME).ok_or(()).and_then(|cookie| { trace_!("Flash: retrieving message: {:?}", cookie); diff --git a/core/lib/src/shutdown.rs b/core/lib/src/shutdown.rs index b30853ce..463e3732 100644 --- a/core/lib/src/shutdown.rs +++ b/core/lib/src/shutdown.rs @@ -46,11 +46,12 @@ impl ShutdownHandle { } } -impl FromRequest<'_, '_> for ShutdownHandle { +#[crate::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for ShutdownHandle { type Error = std::convert::Infallible; #[inline] - fn from_request(request: &Request<'_>) -> Outcome { + async fn from_request(request: &'a Request<'r>) -> Outcome { Outcome::Success(request.state.managed.get::().0.clone()) } } diff --git a/core/lib/tests/fairing_before_head_strip-issue-546.rs b/core/lib/tests/fairing_before_head_strip-issue-546.rs index 8f297f53..cb96caac 100644 --- a/core/lib/tests/fairing_before_head_strip-issue-546.rs +++ b/core/lib/tests/fairing_before_head_strip-issue-546.rs @@ -62,7 +62,7 @@ mod fairing_before_head_strip { assert_eq!(req.method(), Method::Head); // This should be called exactly once. - let c = req.guard::>().unwrap(); + let c = req.guard::>().await.unwrap(); assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); }) })) diff --git a/core/lib/tests/local-request-content-type-issue-505.rs b/core/lib/tests/local-request-content-type-issue-505.rs index 9a0a172d..cec69aa3 100644 --- a/core/lib/tests/local-request-content-type-issue-505.rs +++ b/core/lib/tests/local-request-content-type-issue-505.rs @@ -8,10 +8,11 @@ use rocket::request::{self, FromRequest}; struct HasContentType; +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for HasContentType { type Error = (); - fn from_request(request: &'a Request<'r>) -> request::Outcome { + async fn from_request(request: &'a Request<'r>) -> request::Outcome { if request.content_type().is_some() { Success(HasContentType) } else { diff --git a/core/lib/tests/nested-fairing-attaches.rs b/core/lib/tests/nested-fairing-attaches.rs index d9d5201f..7cda19bb 100644 --- a/core/lib/tests/nested-fairing-attaches.rs +++ b/core/lib/tests/nested-fairing-attaches.rs @@ -31,7 +31,8 @@ fn rocket() -> rocket::Rocket { .attach(AdHoc::on_request("Inner", |req, _| { Box::pin(async move { if req.method() == Method::Get { - let counter = req.guard::>().unwrap(); + let counter = req.guard::>() + .await.unwrap(); counter.get.fetch_add(1, Ordering::Release); } }) diff --git a/examples/request_guard/src/main.rs b/examples/request_guard/src/main.rs index c41c54d4..520e5810 100644 --- a/examples/request_guard/src/main.rs +++ b/examples/request_guard/src/main.rs @@ -8,10 +8,11 @@ use rocket::outcome::Outcome::*; #[derive(Debug)] struct HeaderCount(usize); +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for HeaderCount { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'r>) -> request::Outcome { + async fn from_request(request: &'a Request<'r>) -> request::Outcome { Success(HeaderCount(request.headers().len())) } } diff --git a/examples/request_local_state/src/main.rs b/examples/request_local_state/src/main.rs index 6f261981..55678d27 100644 --- a/examples/request_local_state/src/main.rs +++ b/examples/request_local_state/src/main.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use rocket::outcome::Outcome::*; -use rocket::request::{self, FromRequest, FromRequestAsync, FromRequestFuture, Request, State}; +use rocket::request::{self, FromRequest, Request, State}; #[cfg(test)] mod tests; @@ -20,11 +20,12 @@ struct Guard2; struct Guard3; struct Guard4; +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { type Error = (); - fn from_request(req: &'a Request<'r>) -> request::Outcome { - let atomics = try_outcome!(req.guard::>()); + async fn from_request(req: &'a Request<'r>) -> request::Outcome { + let atomics = try_outcome!(req.guard::>().await); atomics.uncached.fetch_add(1, Ordering::Relaxed); req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); @@ -32,41 +33,38 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { } } +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { type Error = (); - fn from_request(req: &'a Request<'r>) -> request::Outcome { - try_outcome!(req.guard::()); + async fn from_request(req: &'a Request<'r>) -> request::Outcome { + try_outcome!(req.guard::().await); Success(Guard2) } } -impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard3 { +#[rocket::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for Guard3 { type Error = (); - fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()> - { - Box::pin(async move { - let atomics = try_outcome!(req.guard::>()); - atomics.uncached.fetch_add(1, Ordering::Relaxed); - req.local_cache_async(async { - atomics.cached.fetch_add(1, Ordering::Relaxed) - }).await; + async fn from_request(req: &'a Request<'r>) -> request::Outcome { + let atomics = try_outcome!(req.guard::>().await); + atomics.uncached.fetch_add(1, Ordering::Relaxed); + req.local_cache_async(async { + atomics.cached.fetch_add(1, Ordering::Relaxed) + }).await; - Success(Guard3) - }) + Success(Guard3) } } -impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard4 { +#[rocket::async_trait] +impl<'a, 'r> FromRequest<'a, 'r> for Guard4 { type Error = (); - fn from_request(req: &'a Request<'r>) -> FromRequestFuture<'a, Self, ()> - { - Box::pin(async move { - try_outcome!(Guard3::from_request(req).await); - Success(Guard4) - }) + async fn from_request(req: &'a Request<'r>) -> request::Outcome { + try_outcome!(Guard3::from_request(req).await); + Success(Guard4) } } diff --git a/examples/session/src/main.rs b/examples/session/src/main.rs index eb1c89e5..bc1007bd 100644 --- a/examples/session/src/main.rs +++ b/examples/session/src/main.rs @@ -21,10 +21,11 @@ struct Login { #[derive(Debug)] struct User(usize); +#[rocket::async_trait] impl<'a, 'r> FromRequest<'a, 'r> for User { type Error = std::convert::Infallible; - fn from_request(request: &'a Request<'r>) -> request::Outcome { + async fn from_request(request: &'a Request<'r>) -> request::Outcome { request.cookies() .get_private("user_id") .and_then(|cookie| cookie.value().parse().ok()) diff --git a/site/guide/3-overview.md b/site/guide/3-overview.md index e904d456..e915f2f4 100644 --- a/site/guide/3-overview.md +++ b/site/guide/3-overview.md @@ -227,7 +227,7 @@ synchronous equivalents inside Rocket applications. `async fn`s. Inside an `async fn`, you can `.await` `Future`s from Rocket or other libraries * Several of Rocket's traits, such as [`FromData`](../requests#body-data) and - [`FromRequestAsync`](../requests#request-guards), have methods that return + [`FromRequest`](../requests#request-guards), have methods that return `Future`s. * `Data` and `DataStream` (incoming request data) and `Response` and `Body` (outgoing response data) are based on `tokio::io::AsyncRead` instead of diff --git a/site/guide/4-requests.md b/site/guide/4-requests.md index adf15378..1c6ae186 100644 --- a/site/guide/4-requests.md +++ b/site/guide/4-requests.md @@ -412,7 +412,7 @@ imply, a request guard protects a handler from being called erroneously based on information contained in an incoming request. More specifically, a request guard is a type that represents an arbitrary validation policy. The validation policy is implemented through the [`FromRequest`] trait. Every type that implements -`FromRequest` (or the related [`FromRequestAsync`]) is a request guard. +`FromRequest` is a request guard. Request guards appear as inputs to handlers. An arbitrary number of request guards can appear as arguments in a route handler. Rocket will automatically @@ -444,7 +444,6 @@ more about request guards and implementing them, see the [`FromRequest`] documentation. [`FromRequest`]: @api/rocket/request/trait.FromRequest.html -[`FromRequestAsync`]: @api/rocket/request/trait.FromRequestAsync.html [`Cookies`]: @api/rocket/http/enum.Cookies.html ### Custom Guards