Allow implementations of on_request fairings to return a Future that borrows from self, request, and data.

This commit is contained in:
Jeb Rosen 2019-12-10 16:34:32 -08:00 committed by Sergio Benitez
parent cc3298c3e4
commit 4bb4c61528
10 changed files with 79 additions and 51 deletions

View File

@ -163,10 +163,14 @@ impl Fairing for TemplateFairing {
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
fn on_request(&self, req: &mut rocket::Request<'_>, _data: &rocket::Data) { fn on_request<'a>(&'a self, req: &'a mut rocket::Request<'_>, _data: &'a rocket::Data)
let cm = req.guard::<rocket::State<'_, ContextManager>>() -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
.expect("Template ContextManager registered in on_attach"); {
Box::pin(async move {
let cm = req.guard::<rocket::State<'_, ContextManager>>()
.expect("Template ContextManager registered in on_attach");
cm.reload_if_needed(&*self.custom_callback); cm.reload_if_needed(&*self.custom_callback);
})
} }
} }

View File

@ -47,7 +47,7 @@ const PEEK_BYTES: usize = 512;
pub struct Data { pub struct Data {
buffer: Vec<u8>, buffer: Vec<u8>,
is_complete: bool, is_complete: bool,
stream: Box<dyn AsyncRead + Unpin + Send>, stream: Box<dyn AsyncRead + Unpin + Send + Sync>,
} }
impl Data { impl Data {

View File

@ -34,7 +34,9 @@ use crate::fairing::{Fairing, Kind, Info};
/// println!("Rocket is about to launch! Exciting! Here we go..."); /// println!("Rocket is about to launch! Exciting! Here we go...");
/// })) /// }))
/// .attach(AdHoc::on_request("Put Rewriter", |req, _| { /// .attach(AdHoc::on_request("Put Rewriter", |req, _| {
/// req.set_method(Method::Put); /// Box::pin(async move {
/// req.set_method(Method::Put);
/// })
/// })); /// }));
/// ``` /// ```
pub struct AdHoc { pub struct AdHoc {
@ -48,7 +50,7 @@ enum AdHocKind {
/// An ad-hoc **launch** fairing. Called just before Rocket launches. /// An ad-hoc **launch** fairing. Called just before Rocket launches.
Launch(Mutex<Option<Box<dyn FnOnce(&Rocket) + Send + 'static>>>), Launch(Mutex<Option<Box<dyn FnOnce(&Rocket) + Send + 'static>>>),
/// An ad-hoc **request** fairing. Called when a request is received. /// An ad-hoc **request** fairing. Called when a request is received.
Request(Box<dyn Fn(&mut Request<'_>, &Data) + Send + Sync + 'static>), Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static>),
/// An ad-hoc **response** fairing. Called when a response is ready to be /// An ad-hoc **response** fairing. Called when a response is ready to be
/// sent to a client. /// sent to a client.
Response(Box<dyn for<'a, 'r> Fn(&'a Request<'r>, &'a mut Response<'r>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), Response(Box<dyn for<'a, 'r> Fn(&'a Request<'r>, &'a mut Response<'r>) -> BoxFuture<'a, ()> + Send + Sync + 'static>),
@ -101,12 +103,14 @@ impl AdHoc {
/// ///
/// // The no-op request fairing. /// // The no-op request fairing.
/// let fairing = AdHoc::on_request("Dummy", |req, data| { /// let fairing = AdHoc::on_request("Dummy", |req, data| {
/// // do something with the request and data... /// Box::pin(async move {
/// # let (_, _) = (req, data); /// // do something with the request and data...
/// # let (_, _) = (req, data);
/// })
/// }); /// });
/// ``` /// ```
pub fn on_request<F>(name: &'static str, f: F) -> AdHoc pub fn on_request<F>(name: &'static str, f: F) -> AdHoc
where F: Fn(&mut Request<'_>, &Data) + Send + Sync + 'static where F: for<'a> Fn(&'a mut Request<'_>, &'a Data) -> BoxFuture<'a, ()> + Send + Sync + 'static
{ {
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
} }
@ -164,9 +168,11 @@ impl Fairing for AdHoc {
} }
} }
fn on_request(&self, request: &mut Request<'_>, data: &Data) { fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
if let AdHocKind::Request(ref callback) = self.kind { if let AdHocKind::Request(ref callback) = self.kind {
callback(request, data) callback(request, data)
} else {
Box::pin(async { })
} }
} }

View File

@ -59,9 +59,9 @@ impl Fairings {
} }
#[inline(always)] #[inline(always)]
pub fn handle_request(&self, req: &mut Request<'_>, data: &Data) { pub async fn handle_request(&self, req: &mut Request<'_>, data: &Data) {
for &i in &self.request { for &i in &self.request {
self.all_fairings[i].on_request(req, data); self.all_fairings[i].on_request(req, data).await;
} }
} }

View File

@ -21,7 +21,7 @@
//! //!
//! ```rust //! ```rust
//! # use rocket::fairing::AdHoc; //! # use rocket::fairing::AdHoc;
//! # let req_fairing = AdHoc::on_request("Request", |_, _| ()); //! # let req_fairing = AdHoc::on_request("Request", |_, _| Box::pin(async move {}));
//! # let res_fairing = AdHoc::on_response("Response", |_, _| Box::pin(async move {})); //! # let res_fairing = AdHoc::on_response("Response", |_, _| Box::pin(async move {}));
//! let rocket = rocket::ignite() //! let rocket = rocket::ignite()
//! .attach(req_fairing) //! .attach(req_fairing)
@ -228,12 +228,14 @@ pub use self::info_kind::{Info, Kind};
/// } /// }
/// } /// }
/// ///
/// fn on_request(&self, request: &mut Request, _: &Data) { /// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// if request.method() == Method::Get { /// Box::pin(async move {
/// self.get.fetch_add(1, Ordering::Relaxed); /// if request.method() == Method::Get {
/// } else if request.method() == Method::Post { /// self.get.fetch_add(1, Ordering::Relaxed);
/// self.post.fetch_add(1, Ordering::Relaxed); /// } else if request.method() == Method::Post {
/// } /// self.post.fetch_add(1, Ordering::Relaxed);
/// }
/// })
/// } /// }
/// ///
/// fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> { /// fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
@ -293,11 +295,13 @@ pub use self::info_kind::{Info, Kind};
/// } /// }
/// ///
/// /// Stores the start time of the request in request-local state. /// /// Stores the start time of the request in request-local state.
/// fn on_request(&self, request: &mut Request, _: &Data) { /// fn on_request<'a>(&'a self, request: &'a mut Request, _: &'a Data) -> Pin<Box<dyn Future<Output=()> + Send + 'a>> {
/// // Store a `TimerStart` instead of directly storing a `SystemTime` /// Box::pin(async move {
/// // to ensure that this usage doesn't conflict with anything else /// // Store a `TimerStart` instead of directly storing a `SystemTime`
/// // that might store a `SystemTime` in request-local cache. /// // to ensure that this usage doesn't conflict with anything else
/// request.local_cache(|| TimerStart(Some(SystemTime::now()))); /// // that might store a `SystemTime` in request-local cache.
/// request.local_cache(|| TimerStart(Some(SystemTime::now())));
/// })
/// } /// }
/// ///
/// /// Adds a header to the response indicating how long the server took to /// /// Adds a header to the response indicating how long the server took to
@ -405,7 +409,9 @@ pub trait Fairing: Send + Sync + 'static {
/// ///
/// The default implementation of this method does nothing. /// The default implementation of this method does nothing.
#[allow(unused_variables)] #[allow(unused_variables)]
fn on_request(&self, request: &mut Request<'_>, data: &Data) {} fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
Box::pin(async { })
}
/// The response callback. /// The response callback.
/// ///
@ -440,7 +446,7 @@ impl<T: Fairing> Fairing for std::sync::Arc<T> {
} }
#[inline] #[inline]
fn on_request(&self, request: &mut Request<'_>, data: &Data) { fn on_request<'a>(&'a self, request: &'a mut Request<'_>, data: &'a Data) -> BoxFuture<'a, ()> {
(self as &T).on_request(request, data) (self as &T).on_request(request, data)
} }

View File

@ -208,7 +208,7 @@ impl Rocket {
self.preprocess_request(request, &data); self.preprocess_request(request, &data);
// Run the request fairings. // Run the request fairings.
self.fairings.handle_request(request, &data); self.fairings.handle_request(request, &data).await;
// Remember if the request is a `HEAD` request for later body stripping. // Remember if the request is a `HEAD` request for later body stripping.
let was_head_request = request.method() == Method::Head; let was_head_request = request.method() == Method::Head;

View File

@ -19,7 +19,7 @@ fn index(mut cookies: Cookies) -> &'static str {
mod tests { mod tests {
use super::*; use super::*;
use rocket::local::Client; use rocket::local::blocking::Client;
use rocket::fairing::AdHoc; use rocket::fairing::AdHoc;
#[test] #[test]
@ -27,9 +27,9 @@ mod tests {
let rocket = rocket::ignite() let rocket = rocket::ignite()
.mount("/", routes![index]) .mount("/", routes![index])
.register(catchers![not_found]) .register(catchers![not_found])
.attach(AdHoc::on_request("Add Fairing Cookie", |req, _| { .attach(AdHoc::on_request("Add Fairing Cookie", |req, _| Box::pin(async move {
req.cookies().add(Cookie::new("fairing", "hi")); req.cookies().add(Cookie::new("fairing", "hi"));
})); })));
let client = Client::new(rocket).unwrap(); let client = Client::new(rocket).unwrap();

View File

@ -31,7 +31,9 @@ mod fairing_before_head_strip {
let rocket = rocket::ignite() let rocket = rocket::ignite()
.mount("/", routes![head]) .mount("/", routes![head])
.attach(AdHoc::on_request("Check HEAD", |req, _| { .attach(AdHoc::on_request("Check HEAD", |req, _| {
assert_eq!(req.method(), Method::Head); Box::pin(async move {
assert_eq!(req.method(), Method::Head);
})
})) }))
.attach(AdHoc::on_response("Check HEAD 2", |req, res| { .attach(AdHoc::on_response("Check HEAD 2", |req, res| {
Box::pin(async move { Box::pin(async move {
@ -56,11 +58,13 @@ mod fairing_before_head_strip {
.mount("/", routes![auto]) .mount("/", routes![auto])
.manage(counter) .manage(counter)
.attach(AdHoc::on_request("Check HEAD + Count", |req, _| { .attach(AdHoc::on_request("Check HEAD + Count", |req, _| {
assert_eq!(req.method(), Method::Head); Box::pin(async move {
assert_eq!(req.method(), Method::Head);
// This should be called exactly once. // This should be called exactly once.
let c = req.guard::<State<Counter>>().unwrap(); let c = req.guard::<State<Counter>>().unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
})
})) }))
.attach(AdHoc::on_response("Check GET", |req, res| { .attach(AdHoc::on_response("Check GET", |req, res| {
Box::pin(async move { Box::pin(async move {

View File

@ -29,10 +29,12 @@ fn rocket() -> rocket::Rocket {
counter.attach.fetch_add(1, Ordering::Relaxed); counter.attach.fetch_add(1, Ordering::Relaxed);
let rocket = rocket.manage(counter) let rocket = rocket.manage(counter)
.attach(AdHoc::on_request("Inner", |req, _| { .attach(AdHoc::on_request("Inner", |req, _| {
if req.method() == Method::Get { Box::pin(async move {
let counter = req.guard::<State<'_, Counter>>().unwrap(); if req.method() == Method::Get {
counter.get.fetch_add(1, Ordering::Release); let counter = req.guard::<State<'_, Counter>>().unwrap();
} counter.get.fetch_add(1, Ordering::Release);
}
})
})); }));
Ok(rocket) Ok(rocket)

View File

@ -27,12 +27,16 @@ impl Fairing for Counter {
} }
} }
fn on_request(&self, request: &mut Request<'_>, _: &Data) { fn on_request<'a>(&'a self, request: &'a mut Request<'_>, _: &'a Data)
if request.method() == Method::Get { -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>>
self.get.fetch_add(1, Ordering::Relaxed); {
} else if request.method() == Method::Post { Box::pin(async move {
self.post.fetch_add(1, Ordering::Relaxed); if request.method() == Method::Get {
} self.get.fetch_add(1, Ordering::Relaxed);
} else if request.method() == Method::Post {
self.post.fetch_add(1, Ordering::Relaxed);
}
})
} }
fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>) fn on_response<'a, 'r>(&'a self, request: &'a Request<'r>, response: &'a mut Response<'r>)
@ -79,11 +83,13 @@ fn rocket() -> rocket::Rocket {
println!("Rocket is about to launch!"); println!("Rocket is about to launch!");
})) }))
.attach(AdHoc::on_request("PUT Rewriter", |req, _| { .attach(AdHoc::on_request("PUT Rewriter", |req, _| {
println!(" => Incoming request: {}", req); Box::pin(async move {
if req.uri().path() == "/" { println!(" => Incoming request: {}", req);
println!(" => Changing method to `PUT`."); if req.uri().path() == "/" {
req.set_method(Method::Put); println!(" => Changing method to `PUT`.");
} req.set_method(Method::Put);
}
})
})) }))
.attach(AdHoc::on_response("Response Rewriter", |req, res| { .attach(AdHoc::on_response("Response Rewriter", |req, res| {
Box::pin(async move { Box::pin(async move {