mirror of https://github.com/rwf2/Rocket.git
Use 'async_trait' for 'Fairing' trait.
Also re-exports the 'async_trait' attribute from 'rocket'.
This commit is contained in:
parent
a4e7972b4b
commit
48c333721c
|
@ -188,6 +188,7 @@ impl SpaceHelmet {
|
|||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for SpaceHelmet {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
|
@ -196,10 +197,8 @@ impl Fairing for SpaceHelmet {
|
|||
}
|
||||
}
|
||||
|
||||
fn on_response<'a>(&'a self, _request: &'a Request<'_>, response: &'a mut Response<'_>) -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
self.apply(response);
|
||||
})
|
||||
async fn on_response<'a>(&'a self, _: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
self.apply(res);
|
||||
}
|
||||
|
||||
fn on_launch(&self, rocket: &Rocket) {
|
||||
|
|
|
@ -124,17 +124,26 @@ pub struct TemplateFairing {
|
|||
pub custom_callback: Box<dyn Fn(&mut Engines) + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for TemplateFairing {
|
||||
fn info(&self) -> Info {
|
||||
// The on_request part of this fairing only applies in debug
|
||||
// mode, so only register it in debug mode.
|
||||
Info {
|
||||
#[cfg(debug_assertions)]
|
||||
let info = Info {
|
||||
name: "Templates",
|
||||
#[cfg(debug_assertions)]
|
||||
kind: Kind::Attach | Kind::Request,
|
||||
#[cfg(not(debug_assertions))]
|
||||
};
|
||||
|
||||
// FIXME: We declare two `info` variables here, instead of just one with
|
||||
// `cfg`s on `kind`, due to issue #63 in `async_trait`.
|
||||
#[cfg(not(debug_assertions))]
|
||||
let info = Info {
|
||||
name: "Templates",
|
||||
kind: Kind::Attach,
|
||||
}
|
||||
};
|
||||
|
||||
info
|
||||
}
|
||||
|
||||
/// Initializes the template context. Templates will be searched for in the
|
||||
|
@ -163,14 +172,10 @@ impl Fairing for TemplateFairing {
|
|||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data)
|
||||
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
|
||||
{
|
||||
Box::pin(async move {
|
||||
let cm = req.guard::<rocket::State<'_, ContextManager>>()
|
||||
.expect("Template ContextManager registered in on_attach");
|
||||
async fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) {
|
||||
let cm = req.guard::<rocket::State<'_, ContextManager>>()
|
||||
.expect("Template ContextManager registered in on_attach");
|
||||
|
||||
cm.reload_if_needed(&*self.custom_callback);
|
||||
})
|
||||
cm.reload_if_needed(&*self.custom_callback);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#![feature(proc_macro_hygiene)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[macro_use] extern crate rocket;
|
||||
use rocket::http::uri::Origin;
|
||||
|
|
|
@ -39,6 +39,7 @@ memchr = "2" # TODO: Use pear instead.
|
|||
binascii = "0.1"
|
||||
pear = "0.1"
|
||||
atty = "0.2"
|
||||
async-trait = "0.1"
|
||||
|
||||
[build-dependencies]
|
||||
yansi = "0.5"
|
||||
|
|
|
@ -68,8 +68,8 @@ impl AdHoc {
|
|||
/// // The no-op attach fairing.
|
||||
/// let fairing = AdHoc::on_attach("No-Op", |rocket| Ok(rocket));
|
||||
/// ```
|
||||
pub fn on_attach<F>(name: &'static str, f: F) -> AdHoc
|
||||
where F: FnOnce(Rocket) -> Result<Rocket, Rocket> + Send + 'static
|
||||
pub fn on_attach<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
|
||||
where F: FnOnce(Rocket) -> Result<Rocket, Rocket>
|
||||
{
|
||||
AdHoc { name, kind: AdHocKind::Attach(Mutex::new(Some(Box::new(f)))) }
|
||||
}
|
||||
|
@ -87,8 +87,8 @@ impl AdHoc {
|
|||
/// println!("Launching in T-3..2..1..");
|
||||
/// });
|
||||
/// ```
|
||||
pub fn on_launch<F>(name: &'static str, f: F) -> AdHoc
|
||||
where F: FnOnce(&Rocket) + Send + 'static
|
||||
pub fn on_launch<F: Send + 'static>(name: &'static str, f: F) -> AdHoc
|
||||
where F: FnOnce(&Rocket)
|
||||
{
|
||||
AdHoc { name, kind: AdHocKind::Launch(Mutex::new(Some(Box::new(f)))) }
|
||||
}
|
||||
|
@ -110,8 +110,8 @@ impl AdHoc {
|
|||
/// })
|
||||
/// });
|
||||
/// ```
|
||||
pub fn on_request<F>(name: &'static str, f: F) -> AdHoc
|
||||
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static
|
||||
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
|
||||
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()>
|
||||
{
|
||||
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
|
||||
}
|
||||
|
@ -133,13 +133,14 @@ impl AdHoc {
|
|||
/// })
|
||||
/// });
|
||||
/// ```
|
||||
pub fn on_response<F>(name: &'static str, f: F) -> AdHoc
|
||||
where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static
|
||||
pub fn on_response<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
|
||||
where F: for<'a> Fn(&'a Request<'_>, &'a mut Response<'_>) -> BoxFuture<'a, ()>
|
||||
{
|
||||
AdHoc { name, kind: AdHocKind::Response(Box::new(f)) }
|
||||
}
|
||||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl Fairing for AdHoc {
|
||||
fn info(&self) -> Info {
|
||||
let kind = match self.kind {
|
||||
|
@ -170,19 +171,15 @@ impl Fairing for AdHoc {
|
|||
}
|
||||
}
|
||||
|
||||
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
|
||||
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
|
||||
if let AdHocKind::Request(ref callback) = self.kind {
|
||||
callback(request, data)
|
||||
} else {
|
||||
Box::pin(async { })
|
||||
callback(req, data).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
|
||||
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
if let AdHocKind::Response(ref callback) = self.kind {
|
||||
callback(request, response)
|
||||
} else {
|
||||
Box::pin(async { })
|
||||
callback(req, res).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,8 +47,6 @@
|
|||
//! of other `Fairings` are not jeopardized. For instance, unless it is made
|
||||
//! abundantly clear, a fairing should not rewrite every request.
|
||||
|
||||
use futures_util::future::BoxFuture;
|
||||
|
||||
use crate::{Rocket, Request, Response, Data};
|
||||
|
||||
mod fairings;
|
||||
|
@ -138,7 +136,7 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// to the request; these issues are better handled via [request guards] or
|
||||
/// via response callbacks. Any modifications to a request are persisted and
|
||||
/// can potentially alter how a request is routed.
|
||||
///=
|
||||
///
|
||||
/// * **Response (`on_response`)**
|
||||
///
|
||||
/// A response callback, represented by the [`Fairing::on_response()`]
|
||||
|
@ -192,6 +190,45 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// these bounds _do not_ prohibit a `Fairing` from holding state: the state
|
||||
/// need simply be thread-safe and statically available or heap allocated.
|
||||
///
|
||||
/// ## Async Trait
|
||||
///
|
||||
/// [`Fairing`] is an _async_ trait. Implementations of `Fairing` must be
|
||||
/// decorated with an attribute of `#[rocket::async_trait]`:
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::{Rocket, Request, Data, Response};
|
||||
/// use rocket::fairing::{Fairing, Info, Kind};
|
||||
///
|
||||
/// # struct MyType;
|
||||
/// #[rocket::async_trait]
|
||||
/// impl Fairing for MyType {
|
||||
/// fn info(&self) -> Info {
|
||||
/// /* ... */
|
||||
/// # unimplemented!()
|
||||
/// }
|
||||
///
|
||||
/// fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
|
||||
/// /* ... */
|
||||
/// # unimplemented!()
|
||||
/// }
|
||||
///
|
||||
/// fn on_launch(&self, rocket: &Rocket) {
|
||||
/// /* ... */
|
||||
/// # unimplemented!()
|
||||
/// }
|
||||
///
|
||||
/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
|
||||
/// /* ... */
|
||||
/// # unimplemented!()
|
||||
/// }
|
||||
///
|
||||
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
/// /* ... */
|
||||
/// # unimplemented!()
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// Imagine that we want to record the number of `GET` and `POST` requests that
|
||||
|
@ -220,6 +257,7 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// post: AtomicUsize,
|
||||
/// }
|
||||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl Fairing for Counter {
|
||||
/// fn info(&self) -> Info {
|
||||
/// Info {
|
||||
|
@ -228,33 +266,29 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// }
|
||||
/// }
|
||||
///
|
||||
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
|
||||
/// Box::pin(async move {
|
||||
/// if request.method() == Method::Get {
|
||||
/// self.get.fetch_add(1, Ordering::Relaxed);
|
||||
/// } else if request.method() == Method::Post {
|
||||
/// self.post.fetch_add(1, Ordering::Relaxed);
|
||||
/// }
|
||||
/// })
|
||||
/// async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, _: &'a Data) {
|
||||
/// if req.method() == Method::Get {
|
||||
/// self.get.fetch_add(1, Ordering::Relaxed);
|
||||
/// } else if req.method() == Method::Post {
|
||||
/// self.post.fetch_add(1, Ordering::Relaxed);
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
|
||||
/// Box::pin(async move {
|
||||
/// // Don't change a successful user's response, ever.
|
||||
/// if response.status() != Status::NotFound {
|
||||
/// return
|
||||
/// }
|
||||
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
/// // Don't change a successful user's response, ever.
|
||||
/// if res.status() != Status::NotFound {
|
||||
/// return
|
||||
/// }
|
||||
///
|
||||
/// if request.method() == Method::Get && request.uri().path() == "/counts" {
|
||||
/// let get_count = self.get.load(Ordering::Relaxed);
|
||||
/// let post_count = self.post.load(Ordering::Relaxed);
|
||||
/// if req.method() == Method::Get && req.uri().path() == "/counts" {
|
||||
/// let get_count = self.get.load(Ordering::Relaxed);
|
||||
/// let post_count = self.post.load(Ordering::Relaxed);
|
||||
///
|
||||
/// let body = format!("Get: {}\nPost: {}", get_count, post_count);
|
||||
/// response.set_status(Status::Ok);
|
||||
/// response.set_header(ContentType::Plain);
|
||||
/// response.set_sized_body(Cursor::new(body));
|
||||
/// }
|
||||
/// })
|
||||
/// let body = format!("Get: {}\nPost: {}", get_count, post_count);
|
||||
/// res.set_status(Status::Ok);
|
||||
/// res.set_header(ContentType::Plain);
|
||||
/// res.set_sized_body(Cursor::new(body));
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
|
@ -286,6 +320,7 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// #[derive(Copy, Clone)]
|
||||
/// struct TimerStart(Option<SystemTime>);
|
||||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl Fairing for RequestTimer {
|
||||
/// fn info(&self) -> Info {
|
||||
/// Info {
|
||||
|
@ -295,25 +330,21 @@ pub use self::info_kind::{Info, Kind};
|
|||
/// }
|
||||
///
|
||||
/// /// Stores the start time of the request in request-local state.
|
||||
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
|
||||
/// Box::pin(async move {
|
||||
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
|
||||
/// // to ensure that this usage doesn't conflict with anything else
|
||||
/// // that might store a `SystemTime` in request-local cache.
|
||||
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
|
||||
/// })
|
||||
/// async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
|
||||
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
|
||||
/// // to ensure that this usage doesn't conflict with anything else
|
||||
/// // that might store a `SystemTime` in request-local cache.
|
||||
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
|
||||
/// }
|
||||
///
|
||||
/// /// Adds a header to the response indicating how long the server took to
|
||||
/// /// process the request.
|
||||
/// fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
|
||||
/// Box::pin(async move {
|
||||
/// let start_time = request.local_cache(|| TimerStart(None));
|
||||
/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) {
|
||||
/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64;
|
||||
/// response.set_raw_header("X-Response-Time", format!("{} ms", ms));
|
||||
/// }
|
||||
/// })
|
||||
/// async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
/// let start_time = req.local_cache(|| TimerStart(None));
|
||||
/// if let Some(Ok(duration)) = start_time.0.map(|st| st.elapsed()) {
|
||||
/// let ms = duration.as_secs() * 1000 + duration.subsec_millis() as u64;
|
||||
/// res.set_raw_header("X-Response-Time", format!("{} ms", ms));
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
|
@ -336,6 +367,7 @@ pub use self::info_kind::{Info, Kind};
|
|||
///
|
||||
/// [request-local state]: https://rocket.rs/v0.5/guide/state/#request-local-state
|
||||
|
||||
#[crate::async_trait]
|
||||
pub trait Fairing: Send + Sync + 'static {
|
||||
/// Returns an [`Info`] structure containing the `name` and [`Kind`] of this
|
||||
/// fairing. The `name` can be any arbitrary string. `Kind` must be an `or`d
|
||||
|
@ -409,9 +441,7 @@ pub trait Fairing: Send + Sync + 'static {
|
|||
///
|
||||
/// The default implementation of this method does nothing.
|
||||
#[allow(unused_variables)]
|
||||
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
|
||||
Box::pin(async { })
|
||||
}
|
||||
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {}
|
||||
|
||||
/// The response callback.
|
||||
///
|
||||
|
@ -424,11 +454,10 @@ pub trait Fairing: Send + Sync + 'static {
|
|||
///
|
||||
/// The default implementation of this method does nothing.
|
||||
#[allow(unused_variables)]
|
||||
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
|
||||
Box::pin(async { })
|
||||
}
|
||||
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {}
|
||||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<T: Fairing> Fairing for std::sync::Arc<T> {
|
||||
#[inline]
|
||||
fn info(&self) -> Info {
|
||||
|
@ -446,12 +475,12 @@ impl<T: Fairing> Fairing for std::sync::Arc<T> {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
|
||||
(self as &T).on_request(request, data)
|
||||
async fn on_request<'a>(&'a self, req: &'a mut Request<'_>, data: &'a Data) {
|
||||
(self as &T).on_request(req, data).await;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>) -> BoxFuture<'a, ()> {
|
||||
(self as &T).on_response(request, response)
|
||||
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
(self as &T).on_response(req, res).await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -94,6 +94,7 @@
|
|||
|
||||
#[allow(unused_imports)] #[macro_use] extern crate rocket_codegen;
|
||||
pub use rocket_codegen::*;
|
||||
pub use async_trait::*;
|
||||
|
||||
#[macro_use] extern crate log;
|
||||
#[macro_use] extern crate pear;
|
||||
|
|
|
@ -19,6 +19,7 @@ struct Counter {
|
|||
post: AtomicUsize,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for Counter {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
|
@ -27,36 +28,28 @@ impl Fairing for Counter {
|
|||
}
|
||||
}
|
||||
|
||||
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data)
|
||||
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
|
||||
{
|
||||
Box::pin(async move {
|
||||
if request.method() == Method::Get {
|
||||
self.get.fetch_add(1, Ordering::Relaxed);
|
||||
} else if request.method() == Method::Post {
|
||||
self.post.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
})
|
||||
async fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) {
|
||||
if request.method() == Method::Get {
|
||||
self.get.fetch_add(1, Ordering::Relaxed);
|
||||
} else if request.method() == Method::Post {
|
||||
self.post.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_response<'a>(&'a self, request: &'a Request<'_>, response: &'a mut Response<'_>)
|
||||
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
|
||||
{
|
||||
Box::pin(async move {
|
||||
if response.status() != Status::NotFound {
|
||||
return
|
||||
}
|
||||
async fn on_response<'a>(&'a self, req: &'a Request<'_>, res: &'a mut Response<'_>) {
|
||||
if res.status() != Status::NotFound {
|
||||
return
|
||||
}
|
||||
|
||||
if request.method() == Method::Get && request.uri().path() == "/counts" {
|
||||
let get_count = self.get.load(Ordering::Relaxed);
|
||||
let post_count = self.post.load(Ordering::Relaxed);
|
||||
if req.method() == Method::Get && req.uri().path() == "/counts" {
|
||||
let get_count = self.get.load(Ordering::Relaxed);
|
||||
let post_count = self.post.load(Ordering::Relaxed);
|
||||
|
||||
let body = format!("Get: {}\nPost: {}", get_count, post_count);
|
||||
response.set_status(Status::Ok);
|
||||
response.set_header(ContentType::Plain);
|
||||
response.set_sized_body(Cursor::new(body)).await;
|
||||
}
|
||||
})
|
||||
let body = format!("Get: {}\nPost: {}", get_count, post_count);
|
||||
res.set_status(Status::Ok);
|
||||
res.set_header(ContentType::Plain);
|
||||
res.set_sized_body(Cursor::new(body)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue