diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index e1c73f73..bb3a0b03 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -1,5 +1,6 @@ use std::sync::{Arc, RwLock, Mutex}; use std::net::{IpAddr, SocketAddr}; +use std::future::Future; use std::fmt; use std::str; @@ -564,6 +565,39 @@ impl<'r> Request<'r> { }) } + /// Retrieves the cached value for type `T` from the request-local cached + /// state of `self`. If no such value has previously been cached for this + /// request, `fut` is `await`ed to produce the value which is subsequently + /// returned. + /// + /// # Example + /// + /// ```rust + /// # use rocket::http::Method; + /// # use rocket::Request; + /// # type User = (); + /// async fn current_user<'r>(request: &Request<'r>) -> User { + /// // Validate request for a given user, load from database, etc. + /// } + /// + /// # Request::example(Method::Get, "/uri", |request| rocket::async_test(async { + /// let user = request.local_cache_async(async { + /// current_user(request).await + /// }).await; + /// # })); + pub async fn local_cache_async<'a, T, F>(&'a self, fut: F) -> &'a T + where F: Future, + T: Send + Sync + 'static + { + match self.state.cache.try_get() { + Some(s) => s, + None => { + self.state.cache.set(fut.await); + self.state.cache.get() + } + } + } + /// Retrieves and parses into `T` the 0-indexed `n`th segment from the /// request. Returns `None` if `n` is greater than the number of segments. /// Returns `Some(Err(T::Error))` if the parameter type `T` failed to be diff --git a/examples/request_local_state/src/main.rs b/examples/request_local_state/src/main.rs index 93ee8bd2..936d960f 100644 --- a/examples/request_local_state/src/main.rs +++ b/examples/request_local_state/src/main.rs @@ -4,8 +4,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use rocket::request::{self, Request, FromRequest, State}; use rocket::outcome::Outcome::*; +use rocket::request::{self, FromRequest, FromRequestAsync, FromRequestFuture, Request, State}; #[cfg(test)] mod tests; @@ -17,6 +17,8 @@ struct Atomics { struct Guard1; struct Guard2; +struct Guard3; +struct Guard4; impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { type Error = (); @@ -39,15 +41,51 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { } } -#[get("/")] -fn index(_g1: Guard1, _g2: Guard2) { +impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard3 { + type Error = (); + + fn from_request<'fut>(req: &'a Request<'r>) -> FromRequestFuture<'fut, Self, ()> + where 'a: 'fut + { + Box::pin(async move { + let atomics = try_outcome!(req.guard::>()); + atomics.uncached.fetch_add(1, Ordering::Relaxed); + req.local_cache_async(async { + atomics.cached.fetch_add(1, Ordering::Relaxed) + }).await; + + Success(Guard3) + }) + } +} + +impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard4 { + type Error = (); + + fn from_request<'fut>(req: &'a Request<'r>) -> FromRequestFuture<'fut, Self, ()> + where 'a: 'fut + { + Box::pin(async move { + try_outcome!(Guard3::from_request(req).await); + Success(Guard4) + }) + } +} + +#[get("/sync")] +fn r_sync(_g1: Guard1, _g2: Guard2) { + // This exists only to run the request guards. +} + +#[get("/async")] +async fn r_async(_g1: Guard3, _g2: Guard4) { // This exists only to run the request guards. } fn rocket() -> rocket::Rocket { rocket::ignite() .manage(Atomics::default()) - .mount("/", routes!(index)) + .mount("/", routes![r_sync, r_async]) } fn main() { diff --git a/examples/request_local_state/src/tests.rs b/examples/request_local_state/src/tests.rs index 2e70dea9..e5b7046d 100644 --- a/examples/request_local_state/src/tests.rs +++ b/examples/request_local_state/src/tests.rs @@ -6,9 +6,15 @@ use rocket::local::Client; #[rocket::async_test] async fn test() { let client = Client::new(rocket()).unwrap(); - client.get("/").dispatch().await; + client.get("/sync").dispatch().await; let atomics = client.rocket().state::().unwrap(); assert_eq!(atomics.uncached.load(Ordering::Relaxed), 2); assert_eq!(atomics.cached.load(Ordering::Relaxed), 1); + + client.get("/async").dispatch().await; + + let atomics = client.rocket().state::().unwrap(); + assert_eq!(atomics.uncached.load(Ordering::Relaxed), 4); + assert_eq!(atomics.cached.load(Ordering::Relaxed), 2); }