Use 'async_trait' for 'Fairing' trait.

Also re-exports the 'async_trait' attribute from 'rocket'.
This commit is contained in:
Sergio Benitez 2020-01-30 20:47:57 -08:00
parent a4e7972b4b
commit 48c333721c
8 changed files with 135 additions and 109 deletions

View File

@ -188,6 +188,7 @@ impl SpaceHelmet {
} }
} }
#[rocket::async_trait]
impl Fairing for SpaceHelmet { impl Fairing for SpaceHelmet {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
@ -196,10 +197,8 @@ impl Fairing for SpaceHelmet {
} }
} }
fn on_response<'a>(&'a self, _request: &'a Request<'_>, response: &'a mut Response<'_>) -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>> { async fn on_response<'a>(&'a self, _: &'a Request<'_>, res: &'a mut Response<'_>) {
Box::pin(async move { self.apply(res);
self.apply(response);
})
} }
fn on_launch(&self, rocket: &Rocket) { fn on_launch(&self, rocket: &Rocket) {

View File

@ -124,17 +124,26 @@ pub struct TemplateFairing {
pub custom_callback: Box<dyn Fn(&mut Engines) + Send + Sync + 'static>, pub custom_callback: Box<dyn Fn(&mut Engines) + Send + Sync + 'static>,
} }
#[rocket::async_trait]
impl Fairing for TemplateFairing { impl Fairing for TemplateFairing {
fn info(&self) -> Info { fn info(&self) -> Info {
// The on_request part of this fairing only applies in debug // The on_request part of this fairing only applies in debug
// mode, so only register it in debug mode. // mode, so only register it in debug mode.
Info {
name: "Templates",
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let info = Info {
name: "Templates",
kind: Kind::Attach | Kind::Request, kind: Kind::Attach | Kind::Request,
};
// FIXME: We declare two `info` variables here, instead of just one with
// `cfg`s on `kind`, due to issue #63 in `async_trait`.
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
let info = Info {
name: "Templates",
kind: Kind::Attach, kind: Kind::Attach,
} };
info
} }
/// Initializes the template context. Templates will be searched for in the /// Initializes the template context. Templates will be searched for in the
@ -163,14 +172,10 @@ impl Fairing for TemplateFairing {
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) {
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
let cm = req.guard::<rocket::State<'_, ContextManager>>() let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach"); .expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback); cm.reload_if_needed(&*self.custom_callback);
})
} }
} }

View File

@ -1,4 +1,5 @@
#![feature(proc_macro_hygiene)] #![feature(proc_macro_hygiene)]
#![allow(dead_code)]
#[macro_use] extern crate rocket; #[macro_use] extern crate rocket;
use rocket::http::uri::Origin; use rocket::http::uri::Origin;

View File

@ -39,6 +39,7 @@ memchr = "2" # TODO: Use pear instead.
binascii = "0.1" binascii = "0.1"
pear = "0.1" pear = "0.1"
atty = "0.2" atty = "0.2"
async-trait = "0.1"
[build-dependencies] [build-dependencies]
yansi = "0.5" yansi = "0.5"

View File

@ -68,8 +68,8 @@ impl AdHoc {
/// // The no-op attach fairing. /// // The no-op attach fairing.
/// let fairing = AdHoc::on_attach("No-Op", |rocket| Ok(rocket)); /// let fairing = AdHoc::on_attach("No-Op", |rocket| Ok(rocket));
/// ``` /// ```
pub fn on_attach<F>(name: &'static str, f: F) -> AdHoc pub fn on_attach<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(Rocket) -> Result<Rocket, Rocket> + Send + 'static where F: FnOnce(Rocket) -> Result<Rocket, Rocket>
{ {
AdHoc { name, kind: AdHocKind::Attach(Mutex::new(Some(Box::new(f)))) } AdHoc { name, kind: AdHocKind::Attach(Mutex::new(Some(Box::new(f)))) }
} }
@ -87,8 +87,8 @@ impl AdHoc {
/// println!("Launching in T-3..2..1.."); /// println!("Launching in T-3..2..1..");
/// }); /// });
/// ``` /// ```
pub fn on_launch<F>(name: &'static str, f: F) -> AdHoc pub fn on_launch<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(&Rocket) + Send + 'static where F: FnOnce(&Rocket)
{ {
AdHoc { name, kind: AdHocKind::Launch(Mutex::new(Some(Box::new(f)))) } AdHoc { name, kind: AdHocKind::Launch(Mutex::new(Some(Box::new(f)))) }
} }
@ -110,8 +110,8 @@ impl AdHoc {
/// }) /// })
/// }); /// });
/// ``` /// ```
pub fn on_request<F>(name: &'static str, f: F) -> AdHoc pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()>
{ {
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
} }
@ -133,13 +133,14 @@ impl AdHoc {
/// }) /// })
/// }); /// });
/// ``` /// ```
pub fn on_response<F>(name: &'static str, f: F) -> AdHoc pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()>
{ {
AdHoc { name, kind: AdHocKind::Response(Box::new(f)) } AdHoc { name, kind: AdHocKind::Response(Box::new(f)) }
} }
} }
#[crate::async_trait]
impl Fairing for AdHoc { impl Fairing for AdHoc {
fn info(&self) -> Info { fn info(&self) -> Info {
let kind = match self.kind { let kind = match self.kind {
@ -170,19 +171,15 @@ impl Fairing for AdHoc {
} }
} }
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
if let AdHocKind::Request(ref callback) = self.kind { if let AdHocKind::Request(ref callback) = self.kind {
callback(request, data) callback(req, data).await;
} else {
Box::pin(async { })
} }
} }
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
if let AdHocKind::Response(ref callback) = self.kind { if let AdHocKind::Response(ref callback) = self.kind {
callback(request, response) callback(req, res).await;
} else {
Box::pin(async { })
} }
} }
} }

View File

@ -47,8 +47,6 @@
//! of other `Fairings` are not jeopardized. For instance, unless it is made //! of other `Fairings` are not jeopardized. For instance, unless it is made
//! abundantly clear, a fairing should not rewrite every request. //! abundantly clear, a fairing should not rewrite every request.
use futures_util::future::BoxFuture;
use crate::{Rocket, Request, Response, Data}; use crate::{Rocket, Request, Response, Data};
mod fairings; mod fairings;
@ -138,7 +136,7 @@ pub use self::info_kind::{Info, Kind};
/// to the request; these issues are better handled via [request guards] or /// to the request; these issues are better handled via [request guards] or
/// via response callbacks. Any modifications to a request are persisted and /// via response callbacks. Any modifications to a request are persisted and
/// can potentially alter how a request is routed. /// can potentially alter how a request is routed.
///= ///
/// * **Response (`on_response`)** /// * **Response (`on_response`)**
/// ///
/// A response callback, represented by the [`Fairing::on_response()`] /// A response callback, represented by the [`Fairing::on_response()`]
@ -192,6 +190,45 @@ pub use self::info_kind::{Info, Kind};
/// these bounds _do not_ prohibit a `Fairing` from holding state: the state /// these bounds _do not_ prohibit a `Fairing` from holding state: the state
/// need simply be thread-safe and statically available or heap allocated. /// need simply be thread-safe and statically available or heap allocated.
/// ///
/// ## Async Trait
///
/// [`Fairing`] is an _async_ trait. Implementations of `Fairing` must be
/// decorated with an attribute of `#[rocket::async_trait]`:
///
/// ```rust
/// use rocket::{Rocket, Request, Data, Response};
/// use rocket::fairing::{Fairing, Info, Kind};
///
/// # struct MyType;
/// #[rocket::async_trait]
/// impl Fairing for MyType {
/// fn info(&self) -> Info {
/// /* ... */
/// # unimplemented!()
/// }
///
/// fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
/// /* ... */
/// # unimplemented!()
/// }
///
/// fn on_launch(&self, rocket: &Rocket) {
/// /* ... */
/// # unimplemented!()
/// }
///
/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
/// /* ... */
/// # unimplemented!()
/// }
///
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// /* ... */
/// # unimplemented!()
/// }
/// }
/// ```
///
/// ## Example /// ## Example
/// ///
/// Imagine that we want to record the number of `GET` and `POST` requests that /// Imagine that we want to record the number of `GET` and `POST` requests that
@ -220,6 +257,7 @@ pub use self::info_kind::{Info, Kind};
/// post: AtomicUsize, /// post: AtomicUsize,
/// } /// }
/// ///
/// #[rocket::async_trait]
/// impl Fairing for Counter { /// impl Fairing for Counter {
/// fn info(&self) -> Info { /// fn info(&self) -> Info {
/// Info { /// Info {
@ -228,33 +266,29 @@ pub use self::info_kind::{Info, Kind};
/// } /// }
/// } /// }
/// ///
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> { /// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, _: &'a Data) {
/// Box::pin(async move { /// if req.method() == Method::Get {
/// if request.method() == Method::Get {
/// self.get.fetch_add(1, Ordering::Relaxed); /// self.get.fetch_add(1, Ordering::Relaxed);
/// } else if request.method() == Method::Post { /// } else if req.method() == Method::Post {
/// self.post.fetch_add(1, Ordering::Relaxed); /// self.post.fetch_add(1, Ordering::Relaxed);
/// } /// }
/// })
/// } /// }
/// ///
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> { /// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// Box::pin(async move {
/// // Don't change a successful user's response, ever. /// // Don't change a successful user's response, ever.
/// if response.status() != Status::NotFound { /// if res.status() != Status::NotFound {
/// return /// return
/// } /// }
/// ///
/// if request.method() == Method::Get && request.uri().path() == "/counts" { /// if req.method() == Method::Get && req.uri().path() == "/counts" {
/// let get_count = self.get.load(Ordering::Relaxed); /// let get_count = self.get.load(Ordering::Relaxed);
/// let post_count = self.post.load(Ordering::Relaxed); /// let post_count = self.post.load(Ordering::Relaxed);
/// ///
/// let body = format!("Get: {}\nPost: {}", get_count, post_count); /// let body = format!("Get: {}\nPost: {}", get_count, post_count);
/// response.set_status(Status::Ok); /// res.set_status(Status::Ok);
/// response.set_header(ContentType::Plain); /// res.set_header(ContentType::Plain);
/// response.set_sized_body(Cursor::new(body)); /// res.set_sized_body(Cursor::new(body));
/// } /// }
/// })
/// } /// }
/// } /// }
/// ``` /// ```
@ -286,6 +320,7 @@ pub use self::info_kind::{Info, Kind};
/// #[derive(Copy, Clone)] /// #[derive(Copy, Clone)]
/// struct TimerStart(Option<SystemTime>); /// struct TimerStart(Option<SystemTime>);
/// ///
/// #[rocket::async_trait]
/// impl Fairing for RequestTimer { /// impl Fairing for RequestTimer {
/// fn info(&self) -> Info { /// fn info(&self) -> Info {
/// Info { /// Info {
@ -295,25 +330,21 @@ pub use self::info_kind::{Info, Kind};
/// } /// }
/// ///
/// /// Stores the start time of the request in request-local state. /// /// Stores the start time of the request in request-local state.
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> { /// async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
/// Box::pin(async move {
/// // Store a `TimerStart` instead of directly storing a `SystemTime` /// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // to ensure that this usage doesn't conflict with anything else /// // to ensure that this usage doesn't conflict with anything else
/// // that might store a `SystemTime` in request-local cache. /// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); /// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// })
/// } /// }
/// ///
/// /// Adds a header to the response indicating how long the server took to /// /// Adds a header to the response indicating how long the server took to
/// /// process the request. /// /// process the request.
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> { /// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// Box::pin(async move { /// let start_time = req.local_cache(|| TimerStart(None));
/// let start_time = request.local_cache(|| TimerStart(None));
/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) { /// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) {
/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64; /// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64;
/// response.set_raw_header("X-Response-Time", format!("{} ms", ms)); /// res.set_raw_header("X-Response-Time", format!("{} ms", ms));
/// } /// }
/// })
/// } /// }
/// } /// }
/// ///
@ -336,6 +367,7 @@ pub use self::info_kind::{Info, Kind};
/// ///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state /// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait]
pub trait Fairing: Send + Sync + 'static { pub trait Fairing: Send + Sync + 'static {
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this /// Returns an [`Info`] structure containing the `name` and [`Kind`] of this
/// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d /// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d
@ -409,9 +441,7 @@ pub trait Fairing: Send + Sync + 'static {
/// ///
/// The default implementation of this method does nothing. /// The default implementation of this method does nothing.
#[allow(unused_variables)] #[allow(unused_variables)]
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {}
Box::pin(async { })
}
/// The response callback. /// The response callback.
/// ///
@ -424,11 +454,10 @@ pub trait Fairing: Send + Sync + 'static {
/// ///
/// The default implementation of this method does nothing. /// The default implementation of this method does nothing.
#[allow(unused_variables)] #[allow(unused_variables)]
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {}
Box::pin(async { })
}
} }
#[crate::async_trait]
impl<T: Fairing> Fairing for std::sync::Arc<T> { impl<T: Fairing> Fairing for std::sync::Arc<T> {
#[inline] #[inline]
fn info(&self) -> Info { fn info(&self) -> Info {
@ -446,12 +475,12 @@ impl<T: Fairing> Fairing for std::sync::Arc<T> {
} }
#[inline] #[inline]
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
(self as &T).on_request(request, data) (self as &T).on_request(req, data).await;
} }
#[inline] #[inline]
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> { async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
(self as &T).on_response(request, response) (self as &T).on_response(req, res).await;
} }
} }

View File

@ -94,6 +94,7 @@
#[allow(unused_imports)] #[macro_use] extern crate rocket_codegen; #[allow(unused_imports)] #[macro_use] extern crate rocket_codegen;
pub use rocket_codegen::*; pub use rocket_codegen::*;
pub use async_trait::*;
#[macro_use] extern crate log; #[macro_use] extern crate log;
#[macro_use] extern crate pear; #[macro_use] extern crate pear;

View File

@ -19,6 +19,7 @@ struct Counter {
post: AtomicUsize, post: AtomicUsize,
} }
#[rocket::async_trait]
impl Fairing for Counter { impl Fairing for Counter {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
@ -27,36 +28,28 @@ impl Fairing for Counter {
} }
} }
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
if request.method() == Method::Get { if request.method() == Method::Get {
self.get.fetch_add(1, Ordering::Relaxed); self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post { } else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed); self.post.fetch_add(1, Ordering::Relaxed);
} }
})
} }
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>> if res.status() != Status::NotFound {
{
Box::pin(async move {
if response.status() != Status::NotFound {
return return
} }
if request.method() == Method::Get && request.uri().path() == "/counts" { if req.method() == Method::Get && req.uri().path() == "/counts" {
let get_count = self.get.load(Ordering::Relaxed); let get_count = self.get.load(Ordering::Relaxed);
let post_count = self.post.load(Ordering::Relaxed); let post_count = self.post.load(Ordering::Relaxed);
let body = format!("Get: {}\nPost: {}", get_count, post_count); let body = format!("Get: {}\nPost: {}", get_count, post_count);
response.set_status(Status::Ok); res.set_status(Status::Ok);
response.set_header(ContentType::Plain); res.set_header(ContentType::Plain);
response.set_sized_body(Cursor::new(body)).await; res.set_sized_body(Cursor::new(body)).await;
} }
})
} }
} }