From 48c333721cd5d6a042adc4ab8fb3dd8ad1d6d9d7 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 30 Jan 2020 20:47:57 -0800 Subject: [PATCH] Use 'async_trait' for 'Fairing' trait. Also re-exports the 'async_trait' attribute from 'rocket'. --- contrib/lib/src/helmet/helmet.rs | 7 +- contrib/lib/src/templates/fairing.rs | 29 +++--- core/codegen/tests/async-routes.rs | 1 + core/lib/Cargo.toml | 1 + core/lib/src/fairing/ad_hoc.rs | 29 +++--- core/lib/src/fairing/mod.rs | 131 ++++++++++++++++----------- core/lib/src/lib.rs | 1 + examples/fairings/src/main.rs | 45 ++++----- 8 files changed, 135 insertions(+), 109 deletions(-) diff --git a/contrib/lib/src/helmet/helmet.rs b/contrib/lib/src/helmet/helmet.rs index dd54bf0d..8c5da983 100644 --- a/contrib/lib/src/helmet/helmet.rs +++ b/contrib/lib/src/helmet/helmet.rs @@ -188,6 +188,7 @@ impl SpaceHelmet { } } +#[rocket::async_trait] impl Fairing for SpaceHelmet { fn info(&self) -> Info { Info { @@ -196,10 +197,8 @@ impl Fairing for SpaceHelmet { } } - fn on_response<'a>(&'a self, _request: &'a Request<'_>, response: &'a mut Response<'_>) -> std::pin::Pin + Send + 'a>> { - Box::pin(async move { - self.apply(response); - }) + async fn on_response<'a>(&'a self, _: &'a Request<'_>, res: &'a mut Response<'_>) { + self.apply(res); } fn on_launch(&self, rocket: &Rocket) { diff --git a/contrib/lib/src/templates/fairing.rs b/contrib/lib/src/templates/fairing.rs index e47f874d..4bc4bb36 100644 --- a/contrib/lib/src/templates/fairing.rs +++ b/contrib/lib/src/templates/fairing.rs @@ -124,17 +124,26 @@ pub struct TemplateFairing { pub custom_callback: Box, } +#[rocket::async_trait] impl Fairing for TemplateFairing { fn info(&self) -> Info { // The on_request part of this fairing only applies in debug // mode, so only register it in debug mode. - Info { + #[cfg(debug_assertions)] + let info = Info { name: "Templates", - #[cfg(debug_assertions)] kind: Kind::Attach | Kind::Request, - #[cfg(not(debug_assertions))] + }; + + // FIXME: We declare two `info` variables here, instead of just one with + // `cfg`s on `kind`, due to issue #63 in `async_trait`. + #[cfg(not(debug_assertions))] + let info = Info { + name: "Templates", kind: Kind::Attach, - } + }; + + info } /// Initializes the template context. Templates will be searched for in the @@ -163,14 +172,10 @@ impl Fairing for TemplateFairing { } #[cfg(debug_assertions)] - fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) - -> std::pin::Pin + Send + 'a>> - { - Box::pin(async move { - let cm = req.guard::>() - .expect("Template ContextManager registered in on_attach"); + async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) { + let cm = req.guard::>() + .expect("Template ContextManager registered in on_attach"); - cm.reload_if_needed(&*self.custom_callback); - }) + cm.reload_if_needed(&*self.custom_callback); } } diff --git a/core/codegen/tests/async-routes.rs b/core/codegen/tests/async-routes.rs index 37186950..0385a274 100644 --- a/core/codegen/tests/async-routes.rs +++ b/core/codegen/tests/async-routes.rs @@ -1,4 +1,5 @@ #![feature(proc_macro_hygiene)] +#![allow(dead_code)] #[macro_use] extern crate rocket; use rocket::http::uri::Origin; diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 85f19ae9..639d6bca 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -39,6 +39,7 @@ memchr = "2" # TODO: Use pear instead. binascii = "0.1" pear = "0.1" atty = "0.2" +async-trait = "0.1" [build-dependencies] yansi = "0.5" diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index fd63c1bd..6139f643 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -68,8 +68,8 @@ impl AdHoc { /// // The no-op attach fairing. /// let fairing = AdHoc::on_attach("No-Op", |rocket| Ok(rocket)); /// ``` - pub fn on_attach(name: &'static str, f: F) -> AdHoc - where F: FnOnce(Rocket) -> Result + Send + 'static + pub fn on_attach(name: &'static str, f: F) -> AdHoc + where F: FnOnce(Rocket) -> Result { AdHoc { name, kind: AdHocKind::Attach(Mutex::new(Some(Box::new(f)))) } } @@ -87,8 +87,8 @@ impl AdHoc { /// println!("Launching in T-3..2..1.."); /// }); /// ``` - pub fn on_launch(name: &'static str, f: F) -> AdHoc - where F: FnOnce(&Rocket) + Send + 'static + pub fn on_launch(name: &'static str, f: F) -> AdHoc + where F: FnOnce(&Rocket) { AdHoc { name, kind: AdHocKind::Launch(Mutex::new(Some(Box::new(f)))) } } @@ -110,8 +110,8 @@ impl AdHoc { /// }) /// }); /// ``` - pub fn on_request(name: &'static str, f: F) -> AdHoc - where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static + pub fn on_request(name: &'static str, f: F) -> AdHoc + where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } } @@ -133,13 +133,14 @@ impl AdHoc { /// }) /// }); /// ``` - pub fn on_response(name: &'static str, f: F) -> AdHoc - where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static + pub fn on_response(name: &'static str, f: F) -> AdHoc + where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()> { AdHoc { name, kind: AdHocKind::Response(Box::new(f)) } } } +#[crate::async_trait] impl Fairing for AdHoc { fn info(&self) -> Info { let kind = match self.kind { @@ -170,19 +171,15 @@ impl Fairing for AdHoc { } } - fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { + async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) { if let AdHocKind::Request(ref callback) = self.kind { - callback(request, data) - } else { - Box::pin(async { }) + callback(req, data).await; } } - fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { + async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { if let AdHocKind::Response(ref callback) = self.kind { - callback(request, response) - } else { - Box::pin(async { }) + callback(req, res).await; } } } diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index 84e961bf..0e1e57ef 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -47,8 +47,6 @@ //! of other `Fairings` are not jeopardized. For instance, unless it is made //! abundantly clear, a fairing should not rewrite every request. -use futures_util::future::BoxFuture; - use crate::{Rocket, Request, Response, Data}; mod fairings; @@ -138,7 +136,7 @@ pub use self::info_kind::{Info, Kind}; /// to the request; these issues are better handled via [request guards] or /// via response callbacks. Any modifications to a request are persisted and /// can potentially alter how a request is routed. -///= +/// /// * **Response (`on_response`)** /// /// A response callback, represented by the [`Fairing::on_response()`] @@ -192,6 +190,45 @@ pub use self::info_kind::{Info, Kind}; /// these bounds _do not_ prohibit a `Fairing` from holding state: the state /// need simply be thread-safe and statically available or heap allocated. /// +/// ## Async Trait +/// +/// [`Fairing`] is an _async_ trait. Implementations of `Fairing` must be +/// decorated with an attribute of `#[rocket::async_trait]`: +/// +/// ```rust +/// use rocket::{Rocket, Request, Data, Response}; +/// use rocket::fairing::{Fairing, Info, Kind}; +/// +/// # struct MyType; +/// #[rocket::async_trait] +/// impl Fairing for MyType { +/// fn info(&self) -> Info { +/// /* ... */ +/// # unimplemented!() +/// } +/// +/// fn on_attach(&self, rocket: Rocket) -> Result { +/// /* ... */ +/// # unimplemented!() +/// } +/// +/// fn on_launch(&self, rocket: &Rocket) { +/// /* ... */ +/// # unimplemented!() +/// } +/// +/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) { +/// /* ... */ +/// # unimplemented!() +/// } +/// +/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { +/// /* ... */ +/// # unimplemented!() +/// } +/// } +/// ``` +/// /// ## Example /// /// Imagine that we want to record the number of `GET` and `POST` requests that @@ -220,6 +257,7 @@ pub use self::info_kind::{Info, Kind}; /// post: AtomicUsize, /// } /// +/// #[rocket::async_trait] /// impl Fairing for Counter { /// fn info(&self) -> Info { /// Info { @@ -228,33 +266,29 @@ pub use self::info_kind::{Info, Kind}; /// } /// } /// -/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin + Send + 'a>> { -/// Box::pin(async move { -/// if request.method() == Method::Get { -/// self.get.fetch_add(1, Ordering::Relaxed); -/// } else if request.method() == Method::Post { -/// self.post.fetch_add(1, Ordering::Relaxed); -/// } -/// }) +/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, _: &'a Data) { +/// if req.method() == Method::Get { +/// self.get.fetch_add(1, Ordering::Relaxed); +/// } else if req.method() == Method::Post { +/// self.post.fetch_add(1, Ordering::Relaxed); +/// } /// } /// -/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin + Send + 'a>> { -/// Box::pin(async move { -/// // Don't change a successful user's response, ever. -/// if response.status() != Status::NotFound { -/// return -/// } +/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { +/// // Don't change a successful user's response, ever. +/// if res.status() != Status::NotFound { +/// return +/// } /// -/// if request.method() == Method::Get && request.uri().path() == "/counts" { -/// let get_count = self.get.load(Ordering::Relaxed); -/// let post_count = self.post.load(Ordering::Relaxed); +/// if req.method() == Method::Get && req.uri().path() == "/counts" { +/// let get_count = self.get.load(Ordering::Relaxed); +/// let post_count = self.post.load(Ordering::Relaxed); /// -/// let body = format!("Get: {}\nPost: {}", get_count, post_count); -/// response.set_status(Status::Ok); -/// response.set_header(ContentType::Plain); -/// response.set_sized_body(Cursor::new(body)); -/// } -/// }) +/// let body = format!("Get: {}\nPost: {}", get_count, post_count); +/// res.set_status(Status::Ok); +/// res.set_header(ContentType::Plain); +/// res.set_sized_body(Cursor::new(body)); +/// } /// } /// } /// ``` @@ -286,6 +320,7 @@ pub use self::info_kind::{Info, Kind}; /// #[derive(Copy, Clone)] /// struct TimerStart(Option); /// +/// #[rocket::async_trait] /// impl Fairing for RequestTimer { /// fn info(&self) -> Info { /// Info { @@ -295,25 +330,21 @@ pub use self::info_kind::{Info, Kind}; /// } /// /// /// Stores the start time of the request in request-local state. -/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin + Send + 'a>> { -/// Box::pin(async move { -/// // Store a `TimerStart` instead of directly storing a `SystemTime` -/// // to ensure that this usage doesn't conflict with anything else -/// // that might store a `SystemTime` in request-local cache. -/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); -/// }) +/// async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) { +/// // Store a `TimerStart` instead of directly storing a `SystemTime` +/// // to ensure that this usage doesn't conflict with anything else +/// // that might store a `SystemTime` in request-local cache. +/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); /// } /// /// /// Adds a header to the response indicating how long the server took to /// /// process the request. -/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin + Send + 'a>> { -/// Box::pin(async move { -/// let start_time = request.local_cache(|| TimerStart(None)); -/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) { -/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64; -/// response.set_raw_header("X-Response-Time", format!("{} ms", ms)); -/// } -/// }) +/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { +/// let start_time = req.local_cache(|| TimerStart(None)); +/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) { +/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64; +/// res.set_raw_header("X-Response-Time", format!("{} ms", ms)); +/// } /// } /// } /// @@ -336,6 +367,7 @@ pub use self::info_kind::{Info, Kind}; /// /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state +#[crate::async_trait] pub trait Fairing: Send + Sync + 'static { /// Returns an [`Info`] structure containing the `name` and [`Kind`] of this /// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d @@ -409,9 +441,7 @@ pub trait Fairing: Send + Sync + 'static { /// /// The default implementation of this method does nothing. #[allow(unused_variables)] - fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { - Box::pin(async { }) - } + async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {} /// The response callback. /// @@ -424,11 +454,10 @@ pub trait Fairing: Send + Sync + 'static { /// /// The default implementation of this method does nothing. #[allow(unused_variables)] - fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { - Box::pin(async { }) - } + async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {} } +#[crate::async_trait] impl Fairing for std::sync::Arc { #[inline] fn info(&self) -> Info { @@ -446,12 +475,12 @@ impl Fairing for std::sync::Arc { } #[inline] - fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { - (self as &T).on_request(request, data) + async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) { + (self as &T).on_request(req, data).await; } #[inline] - fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { - (self as &T).on_response(request, response) + async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { + (self as &T).on_response(req, res).await; } } diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index bdddab5f..61798419 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -94,6 +94,7 @@ #[allow(unused_imports)] #[macro_use] extern crate rocket_codegen; pub use rocket_codegen::*; +pub use async_trait::*; #[macro_use] extern crate log; #[macro_use] extern crate pear; diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index b6398b8f..a1842863 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -19,6 +19,7 @@ struct Counter { post: AtomicUsize, } +#[rocket::async_trait] impl Fairing for Counter { fn info(&self) -> Info { Info { @@ -27,36 +28,28 @@ impl Fairing for Counter { } } - fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) - -> std::pin::Pin + Send + 'a>> - { - Box::pin(async move { - if request.method() == Method::Get { - self.get.fetch_add(1, Ordering::Relaxed); - } else if request.method() == Method::Post { - self.post.fetch_add(1, Ordering::Relaxed); - } - }) + async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) { + if request.method() == Method::Get { + self.get.fetch_add(1, Ordering::Relaxed); + } else if request.method() == Method::Post { + self.post.fetch_add(1, Ordering::Relaxed); + } } - fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) - -> std::pin::Pin + Send + 'a>> - { - Box::pin(async move { - if response.status() != Status::NotFound { - return - } + async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) { + if res.status() != Status::NotFound { + return + } - if request.method() == Method::Get && request.uri().path() == "/counts" { - let get_count = self.get.load(Ordering::Relaxed); - let post_count = self.post.load(Ordering::Relaxed); + if req.method() == Method::Get && req.uri().path() == "/counts" { + let get_count = self.get.load(Ordering::Relaxed); + let post_count = self.post.load(Ordering::Relaxed); - let body = format!("Get: {}\nPost: {}", get_count, post_count); - response.set_status(Status::Ok); - response.set_header(ContentType::Plain); - response.set_sized_body(Cursor::new(body)).await; - } - }) + let body = format!("Get: {}\nPost: {}", get_count, post_count); + res.set_status(Status::Ok); + res.set_header(ContentType::Plain); + res.set_sized_body(Cursor::new(body)).await; + } } }