From 4bb4c6152856db98b93d28e7483755a1b482cb6e Mon Sep 17 00:00:00 2001 From: Jeb Rosen Date: Tue, 10 Dec 2019 16:34:32 -0800 Subject: [PATCH] Allow implementations of on_request fairings to return a Future that borrows from self, request, and data. --- contrib/lib/src/templates/fairing.rs | 12 ++++--- core/lib/src/data/data.rs | 2 +- core/lib/src/fairing/ad_hoc.rs | 18 ++++++---- core/lib/src/fairing/fairings.rs | 4 +-- core/lib/src/fairing/mod.rs | 34 +++++++++++-------- core/lib/src/rocket.rs | 2 +- core/lib/tests/catcher-cookies-1213.rs | 6 ++-- .../fairing_before_head_strip-issue-546.rs | 14 +++++--- core/lib/tests/nested-fairing-attaches.rs | 10 +++--- examples/fairings/src/main.rs | 28 +++++++++------ 10 files changed, 79 insertions(+), 51 deletions(-) diff --git a/contrib/lib/src/templates/fairing.rs b/contrib/lib/src/templates/fairing.rs index 007776e1..e47f874d 100644 --- a/contrib/lib/src/templates/fairing.rs +++ b/contrib/lib/src/templates/fairing.rs @@ -163,10 +163,14 @@ impl Fairing for TemplateFairing { } #[cfg(debug_assertions)] - fn on_request(&self, req: &mut rocket::Request<'_>, _data: &rocket::Data) { - let cm = req.guard::>() - .expect("Template ContextManager registered in on_attach"); + fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data) + -> std::pin::Pin + Send + 'a>> + { + Box::pin(async move { + let cm = req.guard::>() + .expect("Template ContextManager registered in on_attach"); - cm.reload_if_needed(&*self.custom_callback); + cm.reload_if_needed(&*self.custom_callback); + }) } } diff --git a/core/lib/src/data/data.rs b/core/lib/src/data/data.rs index eeeebbaf..00a5ff78 100644 --- a/core/lib/src/data/data.rs +++ b/core/lib/src/data/data.rs @@ -47,7 +47,7 @@ const PEEK_BYTES: usize = 512; pub struct Data { buffer: Vec, is_complete: bool, - stream: Box, + stream: Box, } impl Data { diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index 493dad81..b1ae3472 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -34,7 +34,9 @@ use crate::fairing::{Fairing, Kind, Info}; /// println!("Rocket is about to launch! Exciting! Here we go..."); /// })) /// .attach(AdHoc::on_request("Put Rewriter", |req, _| { -/// req.set_method(Method::Put); +/// Box::pin(async move { +/// req.set_method(Method::Put); +/// }) /// })); /// ``` pub struct AdHoc { @@ -48,7 +50,7 @@ enum AdHocKind { /// An ad-hoc **launch** fairing. Called just before Rocket launches. Launch(Mutex>>), /// An ad-hoc **request** fairing. Called when a request is received. - Request(Box, &Data) + Send + Sync + 'static>), + Request(Box Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static>), /// An ad-hoc **response** fairing. Called when a response is ready to be /// sent to a client. Response(Box Fn(&'a Request<'r>, &'a mut Response<'r>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), @@ -101,12 +103,14 @@ impl AdHoc { /// /// // The no-op request fairing. /// let fairing = AdHoc::on_request("Dummy", |req, data| { - /// // do something with the request and data... - /// # let (_, _) = (req, data); + /// Box::pin(async move { + /// // do something with the request and data... + /// # let (_, _) = (req, data); + /// }) /// }); /// ``` pub fn on_request(name: &'static str, f: F) -> AdHoc - where F: Fn(&mut Request<'_>, &Data) + Send + Sync + 'static + where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } } @@ -164,9 +168,11 @@ impl Fairing for AdHoc { } } - fn on_request(&self, request: &mut Request<'_>, data: &Data) { + fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { if let AdHocKind::Request(ref callback) = self.kind { callback(request, data) + } else { + Box::pin(async { }) } } diff --git a/core/lib/src/fairing/fairings.rs b/core/lib/src/fairing/fairings.rs index 5e3ae720..3444c357 100644 --- a/core/lib/src/fairing/fairings.rs +++ b/core/lib/src/fairing/fairings.rs @@ -59,9 +59,9 @@ impl Fairings { } #[inline(always)] - pub fn handle_request(&self, req: &mut Request<'_>, data: &Data) { + pub async fn handle_request(&self, req: &mut Request<'_>, data: &Data) { for &i in &self.request { - self.all_fairings[i].on_request(req, data); + self.all_fairings[i].on_request(req, data).await; } } diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index 3f958f01..2b3e2510 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -21,7 +21,7 @@ //! //! ```rust //! # use rocket::fairing::AdHoc; -//! # let req_fairing = AdHoc::on_request("Request", |_, _| ()); +//! # let req_fairing = AdHoc::on_request("Request", |_, _| Box::pin(async move {})); //! # let res_fairing = AdHoc::on_response("Response", |_, _| Box::pin(async move {})); //! let rocket = rocket::ignite() //! .attach(req_fairing) @@ -228,12 +228,14 @@ pub use self::info_kind::{Info, Kind}; /// } /// } /// -/// fn on_request(&self, request: &mut Request, _: &Data) { -/// if request.method() == Method::Get { -/// self.get.fetch_add(1, Ordering::Relaxed); -/// } else if request.method() == Method::Post { -/// self.post.fetch_add(1, Ordering::Relaxed); -/// } +/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin + Send + 'a>> { +/// Box::pin(async move { +/// if request.method() == Method::Get { +/// self.get.fetch_add(1, Ordering::Relaxed); +/// } else if request.method() == Method::Post { +/// self.post.fetch_add(1, Ordering::Relaxed); +/// } +/// }) /// } /// /// fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) -> Pin + Send + 'a>> { @@ -293,11 +295,13 @@ pub use self::info_kind::{Info, Kind}; /// } /// /// /// Stores the start time of the request in request-local state. -/// fn on_request(&self, request: &mut Request, _: &Data) { -/// // Store a `TimerStart` instead of directly storing a `SystemTime` -/// // to ensure that this usage doesn't conflict with anything else -/// // that might store a `SystemTime` in request-local cache. -/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); +/// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin + Send + 'a>> { +/// Box::pin(async move { +/// // Store a `TimerStart` instead of directly storing a `SystemTime` +/// // to ensure that this usage doesn't conflict with anything else +/// // that might store a `SystemTime` in request-local cache. +/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); +/// }) /// } /// /// /// Adds a header to the response indicating how long the server took to @@ -405,7 +409,9 @@ pub trait Fairing: Send + Sync + 'static { /// /// The default implementation of this method does nothing. #[allow(unused_variables)] - fn on_request(&self, request: &mut Request<'_>, data: &Data) {} + fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { + Box::pin(async { }) + } /// The response callback. /// @@ -440,7 +446,7 @@ impl Fairing for std::sync::Arc { } #[inline] - fn on_request(&self, request: &mut Request<'_>, data: &Data) { + fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> { (self as &T).on_request(request, data) } diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 581384b1..dd9203a1 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -208,7 +208,7 @@ impl Rocket { self.preprocess_request(request, &data); // Run the request fairings. - self.fairings.handle_request(request, &data); + self.fairings.handle_request(request, &data).await; // Remember if the request is a `HEAD` request for later body stripping. let was_head_request = request.method() == Method::Head; diff --git a/core/lib/tests/catcher-cookies-1213.rs b/core/lib/tests/catcher-cookies-1213.rs index f9f66ce3..e5fe28e2 100644 --- a/core/lib/tests/catcher-cookies-1213.rs +++ b/core/lib/tests/catcher-cookies-1213.rs @@ -19,7 +19,7 @@ fn index(mut cookies: Cookies) -> &'static str { mod tests { use super::*; - use rocket::local::Client; + use rocket::local::blocking::Client; use rocket::fairing::AdHoc; #[test] @@ -27,9 +27,9 @@ mod tests { let rocket = rocket::ignite() .mount("/", routes![index]) .register(catchers![not_found]) - .attach(AdHoc::on_request("Add Fairing Cookie", |req, _| { + .attach(AdHoc::on_request("Add Fairing Cookie", |req, _| Box::pin(async move { req.cookies().add(Cookie::new("fairing", "hi")); - })); + }))); let client = Client::new(rocket).unwrap(); diff --git a/core/lib/tests/fairing_before_head_strip-issue-546.rs b/core/lib/tests/fairing_before_head_strip-issue-546.rs index 7a2170b8..8f297f53 100644 --- a/core/lib/tests/fairing_before_head_strip-issue-546.rs +++ b/core/lib/tests/fairing_before_head_strip-issue-546.rs @@ -31,7 +31,9 @@ mod fairing_before_head_strip { let rocket = rocket::ignite() .mount("/", routes![head]) .attach(AdHoc::on_request("Check HEAD", |req, _| { - assert_eq!(req.method(), Method::Head); + Box::pin(async move { + assert_eq!(req.method(), Method::Head); + }) })) .attach(AdHoc::on_response("Check HEAD 2", |req, res| { Box::pin(async move { @@ -56,11 +58,13 @@ mod fairing_before_head_strip { .mount("/", routes![auto]) .manage(counter) .attach(AdHoc::on_request("Check HEAD + Count", |req, _| { - assert_eq!(req.method(), Method::Head); + Box::pin(async move { + assert_eq!(req.method(), Method::Head); - // This should be called exactly once. - let c = req.guard::>().unwrap(); - assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); + // This should be called exactly once. + let c = req.guard::>().unwrap(); + assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); + }) })) .attach(AdHoc::on_response("Check GET", |req, res| { Box::pin(async move { diff --git a/core/lib/tests/nested-fairing-attaches.rs b/core/lib/tests/nested-fairing-attaches.rs index 456b3a9b..d9d5201f 100644 --- a/core/lib/tests/nested-fairing-attaches.rs +++ b/core/lib/tests/nested-fairing-attaches.rs @@ -29,10 +29,12 @@ fn rocket() -> rocket::Rocket { counter.attach.fetch_add(1, Ordering::Relaxed); let rocket = rocket.manage(counter) .attach(AdHoc::on_request("Inner", |req, _| { - if req.method() == Method::Get { - let counter = req.guard::>().unwrap(); - counter.get.fetch_add(1, Ordering::Release); - } + Box::pin(async move { + if req.method() == Method::Get { + let counter = req.guard::>().unwrap(); + counter.get.fetch_add(1, Ordering::Release); + } + }) })); Ok(rocket) diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index e5cc0f37..a4cc0513 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -27,12 +27,16 @@ impl Fairing for Counter { } } - fn on_request(&self, request: &mut Request<'_>, _: &Data) { - if request.method() == Method::Get { - self.get.fetch_add(1, Ordering::Relaxed); - } else if request.method() == Method::Post { - self.post.fetch_add(1, Ordering::Relaxed); - } + fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data) + -> std::pin::Pin + Send + 'a>> + { + Box::pin(async move { + if request.method() == Method::Get { + self.get.fetch_add(1, Ordering::Relaxed); + } else if request.method() == Method::Post { + self.post.fetch_add(1, Ordering::Relaxed); + } + }) } fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) @@ -79,11 +83,13 @@ fn rocket() -> rocket::Rocket { println!("Rocket is about to launch!"); })) .attach(AdHoc::on_request("PUT Rewriter", |req, _| { - println!(" => Incoming request: {}", req); - if req.uri().path() == "/" { - println!(" => Changing method to `PUT`."); - req.set_method(Method::Put); - } + Box::pin(async move { + println!(" => Incoming request: {}", req); + if req.uri().path() == "/" { + println!(" => Changing method to `PUT`."); + req.set_method(Method::Put); + } + }) })) .attach(AdHoc::on_response("Response Rewriter", |req, res| { Box::pin(async move {