Use 'async_trait' for 'Fairing' trait.

Also re-exports the 'async_trait' attribute from 'rocket'.
This commit is contained in:
Sergio Benitez 2020-01-30 20:47:57 -08:00
parent a4e7972b4b
commit 48c333721c
8 changed files with 135 additions and 109 deletions

View File

@ -188,6 +188,7 @@ impl SpaceHelmet {
}
}
#[rocket::async_trait]
impl Fairing for SpaceHelmet {
fn info(&self) -> Info {
Info {
@ -196,10 +197,8 @@ impl Fairing for SpaceHelmet {
}
}
fn on_response<'a>(&'a self, _request: &'a Request<'_>, response: &'a mut Response<'_>) -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>> {
Box::pin(async move {
self.apply(response);
})
async fn on_response<'a>(&'a self, _: &'a Request<'_>, res: &'a mut Response<'_>) {
self.apply(res);
}
fn on_launch(&self, rocket: &Rocket) {

View File

@ -124,17 +124,26 @@ pub struct TemplateFairing {
pub custom_callback: Box<dyn Fn(&mut Engines) + Send + Sync + 'static>,
}
#[rocket::async_trait]
impl Fairing for TemplateFairing {
fn info(&self) -> Info {
// The on_request part of this fairing only applies in debug
// mode, so only register it in debug mode.
Info {
#[cfg(debug_assertions)]
let info = Info {
name: "Templates",
#[cfg(debug_assertions)]
kind: Kind::Attach | Kind::Request,
#[cfg(not(debug_assertions))]
};
// FIXME: We declare two `info` variables here, instead of just one with
// `cfg`s on `kind`, due to issue #63 in `async_trait`.
#[cfg(not(debug_assertions))]
let info = Info {
name: "Templates",
kind: Kind::Attach,
}
};
info
}
/// Initializes the template context. Templates will be searched for in the
@ -163,14 +172,10 @@ impl Fairing for TemplateFairing {
}
#[cfg(debug_assertions)]
fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data)
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach");
async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback);
})
cm.reload_if_needed(&*self.custom_callback);
}
}

View File

@ -1,4 +1,5 @@
#![feature(proc_macro_hygiene)]
#![allow(dead_code)]
#[macro_use] extern crate rocket;
use rocket::http::uri::Origin;

View File

@ -39,6 +39,7 @@ memchr = "2" # TODO: Use pear instead.
binascii = "0.1"
pear = "0.1"
atty = "0.2"
async-trait = "0.1"
[build-dependencies]
yansi = "0.5"

View File

@ -68,8 +68,8 @@ impl AdHoc {
/// // The no-op attach fairing.
/// let fairing = AdHoc::on_attach("No-Op", |rocket| Ok(rocket));
/// ```
pub fn on_attach<F>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(Rocket) -> Result<Rocket, Rocket> + Send + 'static
pub fn on_attach<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(Rocket) -> Result<Rocket, Rocket>
{
AdHoc { name, kind: AdHocKind::Attach(Mutex::new(Some(Box::new(f)))) }
}
@ -87,8 +87,8 @@ impl AdHoc {
/// println!("Launching in T-3..2..1..");
/// });
/// ```
pub fn on_launch<F>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(&Rocket) + Send + 'static
pub fn on_launch<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
where F: FnOnce(&Rocket)
{
AdHoc { name, kind: AdHocKind::Launch(Mutex::new(Some(Box::new(f)))) }
}
@ -110,8 +110,8 @@ impl AdHoc {
/// })
/// });
/// ```
pub fn on_request<F>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()>
{
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
}
@ -133,13 +133,14 @@ impl AdHoc {
/// })
/// });
/// ```
pub fn on_response<F>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static
pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()>
{
AdHoc { name, kind: AdHocKind::Response(Box::new(f)) }
}
}
#[crate::async_trait]
impl Fairing for AdHoc {
fn info(&self) -> Info {
let kind = match self.kind {
@ -170,19 +171,15 @@ impl Fairing for AdHoc {
}
}
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
if let AdHocKind::Request(ref callback) = self.kind {
callback(request, data)
} else {
Box::pin(async { })
callback(req, data).await;
}
}
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
if let AdHocKind::Response(ref callback) = self.kind {
callback(request, response)
} else {
Box::pin(async { })
callback(req, res).await;
}
}
}

View File

@ -47,8 +47,6 @@
//! of other `Fairings` are not jeopardized. For instance, unless it is made
//! abundantly clear, a fairing should not rewrite every request.
use futures_util::future::BoxFuture;
use crate::{Rocket, Request, Response, Data};
mod fairings;
@ -138,7 +136,7 @@ pub use self::info_kind::{Info, Kind};
/// to the request; these issues are better handled via [request guards] or
/// via response callbacks. Any modifications to a request are persisted and
/// can potentially alter how a request is routed.
///=
///
/// * **Response (`on_response`)**
///
/// A response callback, represented by the [`Fairing::on_response()`]
@ -192,6 +190,45 @@ pub use self::info_kind::{Info, Kind};
/// these bounds _do not_ prohibit a `Fairing` from holding state: the state
/// need simply be thread-safe and statically available or heap allocated.
///
/// ## Async Trait
///
/// [`Fairing`] is an _async_ trait. Implementations of `Fairing` must be
/// decorated with an attribute of `#[rocket::async_trait]`:
///
/// ```rust
/// use rocket::{Rocket, Request, Data, Response};
/// use rocket::fairing::{Fairing, Info, Kind};
///
/// # struct MyType;
/// #[rocket::async_trait]
/// impl Fairing for MyType {
/// fn info(&self) -> Info {
/// /* ... */
/// # unimplemented!()
/// }
///
/// fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
/// /* ... */
/// # unimplemented!()
/// }
///
/// fn on_launch(&self, rocket: &Rocket) {
/// /* ... */
/// # unimplemented!()
/// }
///
/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
/// /* ... */
/// # unimplemented!()
/// }
///
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// /* ... */
/// # unimplemented!()
/// }
/// }
/// ```
///
/// ## Example
///
/// Imagine that we want to record the number of `GET` and `POST` requests that
@ -220,6 +257,7 @@ pub use self::info_kind::{Info, Kind};
/// post: AtomicUsize,
/// }
///
/// #[rocket::async_trait]
/// impl Fairing for Counter {
/// fn info(&self) -> Info {
/// Info {
@ -228,33 +266,29 @@ pub use self::info_kind::{Info, Kind};
/// }
/// }
///
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// if request.method() == Method::Get {
/// self.get.fetch_add(1, Ordering::Relaxed);
/// } else if request.method() == Method::Post {
/// self.post.fetch_add(1, Ordering::Relaxed);
/// }
/// })
/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, _: &'a Data) {
/// if req.method() == Method::Get {
/// self.get.fetch_add(1, Ordering::Relaxed);
/// } else if req.method() == Method::Post {
/// self.post.fetch_add(1, Ordering::Relaxed);
/// }
/// }
///
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// // Don't change a successful user's response, ever.
/// if response.status() != Status::NotFound {
/// return
/// }
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// // Don't change a successful user's response, ever.
/// if res.status() != Status::NotFound {
/// return
/// }
///
/// if request.method() == Method::Get && request.uri().path() == "/counts" {
/// let get_count = self.get.load(Ordering::Relaxed);
/// let post_count = self.post.load(Ordering::Relaxed);
/// if req.method() == Method::Get && req.uri().path() == "/counts" {
/// let get_count = self.get.load(Ordering::Relaxed);
/// let post_count = self.post.load(Ordering::Relaxed);
///
/// let body = format!("Get: {}\nPost: {}", get_count, post_count);
/// response.set_status(Status::Ok);
/// response.set_header(ContentType::Plain);
/// response.set_sized_body(Cursor::new(body));
/// }
/// })
/// let body = format!("Get: {}\nPost: {}", get_count, post_count);
/// res.set_status(Status::Ok);
/// res.set_header(ContentType::Plain);
/// res.set_sized_body(Cursor::new(body));
/// }
/// }
/// }
/// ```
@ -286,6 +320,7 @@ pub use self::info_kind::{Info, Kind};
/// #[derive(Copy, Clone)]
/// struct TimerStart(Option<SystemTime>);
///
/// #[rocket::async_trait]
/// impl Fairing for RequestTimer {
/// fn info(&self) -> Info {
/// Info {
@ -295,25 +330,21 @@ pub use self::info_kind::{Info, Kind};
/// }
///
/// /// Stores the start time of the request in request-local state.
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // to ensure that this usage doesn't conflict with anything else
/// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// })
/// async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // to ensure that this usage doesn't conflict with anything else
/// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// }
///
/// /// Adds a header to the response indicating how long the server took to
/// /// process the request.
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// let start_time = request.local_cache(|| TimerStart(None));
/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) {
/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64;
/// response.set_raw_header("X-Response-Time", format!("{} ms", ms));
/// }
/// })
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
/// let start_time = req.local_cache(|| TimerStart(None));
/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) {
/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64;
/// res.set_raw_header("X-Response-Time", format!("{} ms", ms));
/// }
/// }
/// }
///
@ -336,6 +367,7 @@ pub use self::info_kind::{Info, Kind};
///
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
#[crate::async_trait]
pub trait Fairing: Send + Sync + 'static {
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this
/// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d
@ -409,9 +441,7 @@ pub trait Fairing: Send + Sync + 'static {
///
/// The default implementation of this method does nothing.
#[allow(unused_variables)]
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
Box::pin(async { })
}
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {}
/// The response callback.
///
@ -424,11 +454,10 @@ pub trait Fairing: Send + Sync + 'static {
///
/// The default implementation of this method does nothing.
#[allow(unused_variables)]
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
Box::pin(async { })
}
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {}
}
#[crate::async_trait]
impl<T: Fairing> Fairing for std::sync::Arc<T> {
#[inline]
fn info(&self) -> Info {
@ -446,12 +475,12 @@ impl<T: Fairing> Fairing for std::sync::Arc<T> {
}
#[inline]
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
(self as &T).on_request(request, data)
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
(self as &T).on_request(req, data).await;
}
#[inline]
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
(self as &T).on_response(request, response)
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
(self as &T).on_response(req, res).await;
}
}

View File

@ -94,6 +94,7 @@
#[allow(unused_imports)] #[macro_use] extern crate rocket_codegen;
pub use rocket_codegen::*;
pub use async_trait::*;
#[macro_use] extern crate log;
#[macro_use] extern crate pear;

View File

@ -19,6 +19,7 @@ struct Counter {
post: AtomicUsize,
}
#[rocket::async_trait]
impl Fairing for Counter {
fn info(&self) -> Info {
Info {
@ -27,36 +28,28 @@ impl Fairing for Counter {
}
}
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data)
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
if request.method() == Method::Get {
self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed);
}
})
async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
if request.method() == Method::Get {
self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed);
}
}
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>)
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
if response.status() != Status::NotFound {
return
}
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
if res.status() != Status::NotFound {
return
}
if request.method() == Method::Get && request.uri().path() == "/counts" {
let get_count = self.get.load(Ordering::Relaxed);
let post_count = self.post.load(Ordering::Relaxed);
if req.method() == Method::Get && req.uri().path() == "/counts" {
let get_count = self.get.load(Ordering::Relaxed);
let post_count = self.post.load(Ordering::Relaxed);
let body = format!("Get: {}\nPost: {}", get_count, post_count);
response.set_status(Status::Ok);
response.set_header(ContentType::Plain);
response.set_sized_body(Cursor::new(body)).await;
}
})
let body = format!("Get: {}\nPost: {}", get_count, post_count);
res.set_status(Status::Ok);
res.set_header(ContentType::Plain);
res.set_sized_body(Cursor::new(body)).await;
}
}
}