mirror of https://github.com/rwf2/Rocket.git
Add 'Request::local_cache_async' for use in async request guards.
This commit is contained in:
parent
5317664893
commit
189fd65b17
|
@ -1,5 +1,6 @@
|
||||||
use std::sync::{Arc, RwLock, Mutex};
|
use std::sync::{Arc, RwLock, Mutex};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use std::future::Future;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
|
@ -564,6 +565,39 @@ impl<'r> Request<'r> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieves the cached value for type `T` from the request-local cached
|
||||||
|
/// state of `self`. If no such value has previously been cached for this
|
||||||
|
/// request, `fut` is `await`ed to produce the value which is subsequently
|
||||||
|
/// returned.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use rocket::http::Method;
|
||||||
|
/// # use rocket::Request;
|
||||||
|
/// # type User = ();
|
||||||
|
/// async fn current_user<'r>(request: &Request<'r>) -> User {
|
||||||
|
/// // Validate request for a given user, load from database, etc.
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// # Request::example(Method::Get, "/uri", |request| rocket::async_test(async {
|
||||||
|
/// let user = request.local_cache_async(async {
|
||||||
|
/// current_user(request).await
|
||||||
|
/// }).await;
|
||||||
|
/// # }));
|
||||||
|
pub async fn local_cache_async<'a, T, F>(&'a self, fut: F) -> &'a T
|
||||||
|
where F: Future<Output = T>,
|
||||||
|
T: Send + Sync + 'static
|
||||||
|
{
|
||||||
|
match self.state.cache.try_get() {
|
||||||
|
Some(s) => s,
|
||||||
|
None => {
|
||||||
|
self.state.cache.set(fut.await);
|
||||||
|
self.state.cache.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieves and parses into `T` the 0-indexed `n`th segment from the
|
/// Retrieves and parses into `T` the 0-indexed `n`th segment from the
|
||||||
/// request. Returns `None` if `n` is greater than the number of segments.
|
/// request. Returns `None` if `n` is greater than the number of segments.
|
||||||
/// Returns `Some(Err(T::Error))` if the parameter type `T` failed to be
|
/// Returns `Some(Err(T::Error))` if the parameter type `T` failed to be
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
use rocket::request::{self, Request, FromRequest, State};
|
|
||||||
use rocket::outcome::Outcome::*;
|
use rocket::outcome::Outcome::*;
|
||||||
|
use rocket::request::{self, FromRequest, FromRequestAsync, FromRequestFuture, Request, State};
|
||||||
|
|
||||||
#[cfg(test)] mod tests;
|
#[cfg(test)] mod tests;
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ struct Atomics {
|
||||||
|
|
||||||
struct Guard1;
|
struct Guard1;
|
||||||
struct Guard2;
|
struct Guard2;
|
||||||
|
struct Guard3;
|
||||||
|
struct Guard4;
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
|
impl<'a, 'r> FromRequest<'a, 'r> for Guard1 {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
@ -39,15 +41,51 @@ impl<'a, 'r> FromRequest<'a, 'r> for Guard2 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/")]
|
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard3 {
|
||||||
fn index(_g1: Guard1, _g2: Guard2) {
|
type Error = ();
|
||||||
|
|
||||||
|
fn from_request<'fut>(req: &'a Request<'r>) -> FromRequestFuture<'fut, Self, ()>
|
||||||
|
where 'a: 'fut
|
||||||
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>());
|
||||||
|
atomics.uncached.fetch_add(1, Ordering::Relaxed);
|
||||||
|
req.local_cache_async(async {
|
||||||
|
atomics.cached.fetch_add(1, Ordering::Relaxed)
|
||||||
|
}).await;
|
||||||
|
|
||||||
|
Success(Guard3)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'r> FromRequestAsync<'a, 'r> for Guard4 {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn from_request<'fut>(req: &'a Request<'r>) -> FromRequestFuture<'fut, Self, ()>
|
||||||
|
where 'a: 'fut
|
||||||
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
try_outcome!(Guard3::from_request(req).await);
|
||||||
|
Success(Guard4)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/sync")]
|
||||||
|
fn r_sync(_g1: Guard1, _g2: Guard2) {
|
||||||
|
// This exists only to run the request guards.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/async")]
|
||||||
|
async fn r_async(_g1: Guard3, _g2: Guard4) {
|
||||||
// This exists only to run the request guards.
|
// This exists only to run the request guards.
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rocket() -> rocket::Rocket {
|
fn rocket() -> rocket::Rocket {
|
||||||
rocket::ignite()
|
rocket::ignite()
|
||||||
.manage(Atomics::default())
|
.manage(Atomics::default())
|
||||||
.mount("/", routes!(index))
|
.mount("/", routes![r_sync, r_async])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|
|
@ -6,9 +6,15 @@ use rocket::local::Client;
|
||||||
#[rocket::async_test]
|
#[rocket::async_test]
|
||||||
async fn test() {
|
async fn test() {
|
||||||
let client = Client::new(rocket()).unwrap();
|
let client = Client::new(rocket()).unwrap();
|
||||||
client.get("/").dispatch().await;
|
client.get("/sync").dispatch().await;
|
||||||
|
|
||||||
let atomics = client.rocket().state::<Atomics>().unwrap();
|
let atomics = client.rocket().state::<Atomics>().unwrap();
|
||||||
assert_eq!(atomics.uncached.load(Ordering::Relaxed), 2);
|
assert_eq!(atomics.uncached.load(Ordering::Relaxed), 2);
|
||||||
assert_eq!(atomics.cached.load(Ordering::Relaxed), 1);
|
assert_eq!(atomics.cached.load(Ordering::Relaxed), 1);
|
||||||
|
|
||||||
|
client.get("/async").dispatch().await;
|
||||||
|
|
||||||
|
let atomics = client.rocket().state::<Atomics>().unwrap();
|
||||||
|
assert_eq!(atomics.uncached.load(Ordering::Relaxed), 4);
|
||||||
|
assert_eq!(atomics.cached.load(Ordering::Relaxed), 2);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue