Allow implementations of on_request fairings to return a Future that borrows from self, request, and data.

This commit is contained in:
Jeb Rosen 2019-12-10 16:34:32 -08:00 committed by Sergio Benitez
parent cc3298c3e4
commit 4bb4c61528
10 changed files with 79 additions and 51 deletions

View File

@ -163,10 +163,14 @@ impl Fairing for TemplateFairing {
}
#[cfg(debug_assertions)]
fn on_request(&self, req: &mut rocket::Request<'_>, _data: &rocket::Data) {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach");
fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data)
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback);
cm.reload_if_needed(&*self.custom_callback);
})
}
}

View File

@ -47,7 +47,7 @@ const PEEK_BYTES: usize = 512;
pub struct Data {
buffer: Vec<u8>,
is_complete: bool,
stream: Box<dyn AsyncRead + Unpin + Send>,
stream: Box<dyn AsyncRead + Unpin + Send + Sync>,
}
impl Data {

View File

@ -34,7 +34,9 @@ use crate::fairing::{Fairing, Kind, Info};
/// println!("Rocket is about to launch! Exciting! Here we go...");
/// }))
/// .attach(AdHoc::on_request("Put Rewriter", |req, _| {
/// req.set_method(Method::Put);
/// Box::pin(async move {
/// req.set_method(Method::Put);
/// })
/// }));
/// ```
pub struct AdHoc {
@ -48,7 +50,7 @@ enum AdHocKind {
/// An ad-hoc **launch** fairing. Called just before Rocket launches.
Launch(Mutex<Option<Box<dyn FnOnce(&Rocket) + Send + 'static>>>),
/// An ad-hoc **request** fairing. Called when a request is received.
Request(Box<dyn Fn(&mut Request<'_>, &Data) + Send + Sync + 'static>),
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static>),
/// An ad-hoc **response** fairing. Called when a response is ready to be
/// sent to a client.
Response(Box<dyn for<'a, 'r> Fn(&'a Request<'r>, &'a mut Response<'r>) -> BoxFuture<'a, ()> + Send + Sync + 'static>),
@ -101,12 +103,14 @@ impl AdHoc {
///
/// // The no-op request fairing.
/// let fairing = AdHoc::on_request("Dummy", |req, data| {
/// // do something with the request and data...
/// # let (_, _) = (req, data);
/// Box::pin(async move {
/// // do something with the request and data...
/// # let (_, _) = (req, data);
/// })
/// });
/// ```
pub fn on_request<F>(name: &'static str, f: F) -> AdHoc
where F: Fn(&mut Request<'_>, &Data) + Send + Sync + 'static
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static
{
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
}
@ -164,9 +168,11 @@ impl Fairing for AdHoc {
}
}
fn on_request(&self, request: &mut Request<'_>, data: &Data) {
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
if let AdHocKind::Request(ref callback) = self.kind {
callback(request, data)
} else {
Box::pin(async { })
}
}

View File

@ -59,9 +59,9 @@ impl Fairings {
}
#[inline(always)]
pub fn handle_request(&self, req: &mut Request<'_>, data: &Data) {
pub async fn handle_request(&self, req: &mut Request<'_>, data: &Data) {
for &i in &self.request {
self.all_fairings[i].on_request(req, data);
self.all_fairings[i].on_request(req, data).await;
}
}

View File

@ -21,7 +21,7 @@
//!
//! ```rust
//! # use rocket::fairing::AdHoc;
//! # let req_fairing = AdHoc::on_request("Request", |_, _| ());
//! # let req_fairing = AdHoc::on_request("Request", |_, _| Box::pin(async move {}));
//! # let res_fairing = AdHoc::on_response("Response", |_, _| Box::pin(async move {}));
//! let rocket = rocket::ignite()
//! .attach(req_fairing)
@ -228,12 +228,14 @@ pub use self::info_kind::{Info, Kind};
/// }
/// }
///
/// fn on_request(&self, request: &mut Request, _: &Data) {
/// if request.method() == Method::Get {
/// self.get.fetch_add(1, Ordering::Relaxed);
/// } else if request.method() == Method::Post {
/// self.post.fetch_add(1, Ordering::Relaxed);
/// }
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// if request.method() == Method::Get {
/// self.get.fetch_add(1, Ordering::Relaxed);
/// } else if request.method() == Method::Post {
/// self.post.fetch_add(1, Ordering::Relaxed);
/// }
/// })
/// }
///
/// fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
@ -293,11 +295,13 @@ pub use self::info_kind::{Info, Kind};
/// }
///
/// /// Stores the start time of the request in request-local state.
/// fn on_request(&self, request: &mut Request, _: &Data) {
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // to ensure that this usage doesn't conflict with anything else
/// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// Box::pin(async move {
/// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // to ensure that this usage doesn't conflict with anything else
/// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// })
/// }
///
/// /// Adds a header to the response indicating how long the server took to
@ -405,7 +409,9 @@ pub trait Fairing: Send + Sync + 'static {
///
/// The default implementation of this method does nothing.
#[allow(unused_variables)]
fn on_request(&self, request: &mut Request<'_>, data: &Data) {}
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
Box::pin(async { })
}
/// The response callback.
///
@ -440,7 +446,7 @@ impl<T: Fairing> Fairing for std::sync::Arc<T> {
}
#[inline]
fn on_request(&self, request: &mut Request<'_>, data: &Data) {
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
(self as &T).on_request(request, data)
}

View File

@ -208,7 +208,7 @@ impl Rocket {
self.preprocess_request(request, &data);
// Run the request fairings.
self.fairings.handle_request(request, &data);
self.fairings.handle_request(request, &data).await;
// Remember if the request is a `HEAD` request for later body stripping.
let was_head_request = request.method() == Method::Head;

View File

@ -19,7 +19,7 @@ fn index(mut cookies: Cookies) -> &'static str {
mod tests {
use super::*;
use rocket::local::Client;
use rocket::local::blocking::Client;
use rocket::fairing::AdHoc;
#[test]
@ -27,9 +27,9 @@ mod tests {
let rocket = rocket::ignite()
.mount("/", routes![index])
.register(catchers![not_found])
.attach(AdHoc::on_request("Add Fairing Cookie", |req, _| {
.attach(AdHoc::on_request("Add Fairing Cookie", |req, _| Box::pin(async move {
req.cookies().add(Cookie::new("fairing", "hi"));
}));
})));
let client = Client::new(rocket).unwrap();

View File

@ -31,7 +31,9 @@ mod fairing_before_head_strip {
let rocket = rocket::ignite()
.mount("/", routes![head])
.attach(AdHoc::on_request("Check HEAD", |req, _| {
assert_eq!(req.method(), Method::Head);
Box::pin(async move {
assert_eq!(req.method(), Method::Head);
})
}))
.attach(AdHoc::on_response("Check HEAD 2", |req, res| {
Box::pin(async move {
@ -56,11 +58,13 @@ mod fairing_before_head_strip {
.mount("/", routes![auto])
.manage(counter)
.attach(AdHoc::on_request("Check HEAD + Count", |req, _| {
assert_eq!(req.method(), Method::Head);
Box::pin(async move {
assert_eq!(req.method(), Method::Head);
// This should be called exactly once.
let c = req.guard::<State<Counter>>().unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
// This should be called exactly once.
let c = req.guard::<State<Counter>>().unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
})
}))
.attach(AdHoc::on_response("Check GET", |req, res| {
Box::pin(async move {

View File

@ -29,10 +29,12 @@ fn rocket() -> rocket::Rocket {
counter.attach.fetch_add(1, Ordering::Relaxed);
let rocket = rocket.manage(counter)
.attach(AdHoc::on_request("Inner", |req, _| {
if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>().unwrap();
counter.get.fetch_add(1, Ordering::Release);
}
Box::pin(async move {
if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>().unwrap();
counter.get.fetch_add(1, Ordering::Release);
}
})
}));
Ok(rocket)

View File

@ -27,12 +27,16 @@ impl Fairing for Counter {
}
}
fn on_request(&self, request: &mut Request<'_>, _: &Data) {
if request.method() == Method::Get {
self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed);
}
fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data)
-> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
{
Box::pin(async move {
if request.method() == Method::Get {
self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed);
}
})
}
fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>)
@ -79,11 +83,13 @@ fn rocket() -> rocket::Rocket {
println!("Rocket is about to launch!");
}))
.attach(AdHoc::on_request("PUT Rewriter", |req, _| {
println!(" => Incoming request: {}", req);
if req.uri().path() == "/" {
println!(" => Changing method to `PUT`.");
req.set_method(Method::Put);
}
Box::pin(async move {
println!(" => Incoming request: {}", req);
if req.uri().path() == "/" {
println!(" => Changing method to `PUT`.");
req.set_method(Method::Put);
}
})
}))
.attach(AdHoc::on_response("Response Rewriter", |req, res| {
Box::pin(async move {