Retrieve managed state via a borrow: '&State<T>'.

This has the following positive effects:

  1) The lifetime retrieved through 'Deref' is now long-lived.
  2) An '&State<T>` can be created via an '&T'.
  3) '&State<T>' is shorter to type than 'State<'_, T>'.
This commit is contained in:
Sergio Benitez 2021-05-11 08:56:35 -05:00
parent f442ad93cb
commit d03a07b183
24 changed files with 193 additions and 176 deletions

View File

@ -1,4 +1,4 @@
use rocket::{Request, State, Rocket, Ignite, Sentinel};
use rocket::{Request, Rocket, Ignite, Sentinel};
use rocket::http::Status;
use rocket::request::{self, FromRequest};
@ -103,9 +103,8 @@ impl<'r> FromRequest<'r> for Metadata<'r> {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
request.guard::<State<'_, ContextManager>>().await
.succeeded()
.and_then(|cm| Some(request::Outcome::Success(Metadata(cm.inner()))))
request.rocket().state::<ContextManager>()
.map(|cm| request::Outcome::Success(Metadata(cm)))
.unwrap_or_else(|| {
error_!("Uninitialized template context: missing fairing.");
info_!("To use templates, you must attach `Template::fairing()`.");

View File

@ -49,8 +49,8 @@ fn test_reexpansion() {
macro_rules! index {
($type:ty) => {
#[get("/")]
fn index(thing: rocket::State<$type>) -> String {
format!("Thing: {}", *thing)
fn index(thing: &rocket::State<$type>) -> String {
format!("Thing: {}", thing)
}
}
}

View File

@ -666,7 +666,7 @@ crate::export! {
///
/// async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward.
/// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
/// let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
/// atomics.uncached.fetch_add(1, Ordering::Relaxed);
/// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
///

View File

@ -366,13 +366,13 @@ impl Rocket<Build> {
/// struct MyString(String);
///
/// #[get("/int")]
/// fn int(state: State<'_, MyInt>) -> String {
/// fn int(state: &State<MyInt>) -> String {
/// format!("The stateful int is: {}", state.0)
/// }
///
/// #[get("/string")]
/// fn string<'r>(state: State<'r, MyString>) -> &'r str {
/// &state.inner().0
/// fn string(state: &State<MyString>) -> &str {
/// &state.0
/// }
///
/// #[launch]

View File

@ -103,8 +103,8 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>;
/// use rocket::State;
///
/// #[get("/")]
/// fn custom_handler(state: State<Kind>) -> &'static str {
/// match *state {
/// fn custom_handler(state: &State<Kind>) -> &'static str {
/// match state.inner() {
/// Kind::Simple => "simple",
/// Kind::Intermediate => "intermediate",
/// Kind::Complex => "complex",

View File

@ -30,7 +30,7 @@ use crate::{Rocket, Ignite};
/// # use rocket::*;
/// # type Response = ();
/// #[get("/<id>")]
/// fn index(id: usize, state: State<String>) -> Response {
/// fn index(id: usize, state: &State<String>) -> Response {
/// /* ... */
/// }
///
@ -63,7 +63,7 @@ use crate::{Rocket, Ignite};
/// # use rocket::*;
/// # type Response = ();
/// # #[get("/<id>")]
/// # fn index(id: usize, state: State<String>) -> Response {
/// # fn index(id: usize, state: &State<String>) -> Response {
/// # /* ... */
/// # }
/// #
@ -92,19 +92,21 @@ use crate::{Rocket, Ignite};
/// # type Foo = ();
/// # type Bar = ();
/// #[get("/")]
/// fn f(guard: Option<State<'_, String>>) -> Either<Foo, Inner<Bar>> {
/// fn f(guard: Option<&State<String>>) -> Either<Foo, Inner<Bar>> {
/// unimplemented!()
/// }
/// ```
///
/// The directly eligible sentinel types, guard and responders, are:
///
/// * `Option<State<'_, String>>`
/// * `Option<&State<String>>`
/// * `Either<Foo, INner<Bar>>`
///
/// In addition, all embedded types are _also_ eligble. These are:
///
/// * `State<'_, String>`
/// * `&State<String>`
/// * `State<String>`
/// * `String`
/// * `Foo`
/// * `Inner<Bar>`
/// * `Bar`
@ -116,15 +118,17 @@ use crate::{Rocket, Ignite};
/// breadth-first order, is queried:
///
/// ```text
/// Option<State<'_, String>> Either<Foo, Inner<Bar>>
/// | / \
/// State<'_, String> Foo Inner<Bar>
/// |
/// Bar
/// Option<&State<String>> Either<Foo, Inner<Bar>>
/// | / \
/// &State<String> Foo Inner<Bar>
/// | |
/// State<String> Bar
/// |
/// String
/// ```
///
/// Neither `Option` nor `Either` are sentinels, so they won't be queried. In
/// the next level, `State` is a `Sentinel`, so it _is_ queried. If `Foo` is a
/// the next level, `&State` is a `Sentinel`, so it _is_ queried. If `Foo` is a
/// sentinel, it is queried as well. If `Inner` is a sentinel, it is queried,
/// and traversal stops without considering `Bar`. If `Inner` is _not_ a
/// `Sentinel`, `Bar` is considered and queried if it is a sentinel.
@ -178,10 +182,10 @@ use crate::{Rocket, Ignite};
/// ## Aliases
///
/// Embedded discovery of sentinels is syntactic in nature: an embedded sentinel
/// is only discovered if its named in the type. As such, sentinels made opaque
/// by a type alias will fail to be considered. In the example below, only
/// `Result<Foo, Bar>` will be considered, while the embedded `Foo` and `Bar`
/// will not.
/// is only discovered if its named in the type. As such, embedded sentinels
/// made opaque by a type alias will fail to be considered. In the example
/// below, only `Result<Foo, Bar>` will be considered, while the embedded `Foo`
/// and `Bar` will not.
///
/// ```rust
/// # use rocket::get;

View File

@ -1,6 +1,9 @@
use std::fmt;
use std::ops::Deref;
use std::any::type_name;
use ref_cast::RefCast;
use crate::{Phase, Rocket, Ignite, Sentinel};
use crate::request::{self, FromRequest, Request};
use crate::outcome::Outcome;
@ -8,14 +11,13 @@ use crate::http::Status;
/// Request guard to retrieve managed state.
///
/// This type can be used as a request guard to retrieve the state Rocket is
/// managing for some type `T`. This allows for the sharing of state across any
/// number of handlers. A value for the given type must previously have been
/// registered to be managed by Rocket via [`Rocket::manage()`]. The type being
/// managed must be thread safe and sendable across thread boundaries. In other
/// words, it must implement [`Send`] + [`Sync`] + `'static`.
///
/// [`Rocket::manage()`]: crate::Rocket::manage()
/// A reference `&State<T>` type is a request guard which retrieves the managed
/// state managing for some type `T`. A value for the given type must previously
/// have been registered to be managed by Rocket via [`Rocket::manage()`]. The
/// type being managed must be thread safe and sendable across thread
/// boundaries as multiple handlers in multiple threads may be accessing the
/// value at once. In other words, it must implement [`Send`] + [`Sync`] +
/// `'static`.
///
/// # Example
///
@ -33,14 +35,13 @@ use crate::http::Status;
/// }
///
/// #[get("/")]
/// fn index(state: State<'_, MyConfig>) -> String {
/// fn index(state: &State<MyConfig>) -> String {
/// format!("The config value is: {}", state.user_val)
/// }
///
/// #[get("/raw")]
/// fn raw_config_value<'r>(state: State<'r, MyConfig>) -> &'r str {
/// // use `inner()` to get a lifetime longer than `deref` gives us
/// state.inner().user_val.as_str()
/// fn raw_config_value(state: &State<MyConfig>) -> &str {
/// &state.user_val
/// }
///
/// #[launch]
@ -72,10 +73,10 @@ use crate::http::Status;
///
/// async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
/// // Using `State` as a request guard. Use `inner()` to get an `'r`.
/// let outcome = request.guard::<State<MyConfig>>().await
/// .map(|my_config| Item(&my_config.inner().user_val));
/// let outcome = request.guard::<&State<MyConfig>>().await
/// .map(|my_config| Item(&my_config.user_val));
///
/// // Or alternatively, using `Request::managed_state()`:
/// // Or alternatively, using `Rocket::state()`:
/// let outcome = request.rocket().state::<MyConfig>()
/// .map(|my_config| Item(&my_config.user_val))
/// .or_forward(());
@ -89,7 +90,7 @@ use crate::http::Status;
///
/// When unit testing your application, you may find it necessary to manually
/// construct a type of `State` to pass to your functions. To do so, use the
/// [`State::from()`] static method:
/// [`State::get()`] static method or the `From<&T>` implementation:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
@ -98,49 +99,22 @@ use crate::http::Status;
/// struct MyManagedState(usize);
///
/// #[get("/")]
/// fn handler(state: State<'_, MyManagedState>) -> String {
/// fn handler(state: &State<MyManagedState>) -> String {
/// state.0.to_string()
/// }
///
/// let mut rocket = rocket::build().manage(MyManagedState(127));
/// let state = State::from(&rocket).expect("managed `MyManagedState`");
/// let state = State::get(&rocket).expect("managed `MyManagedState`");
/// assert_eq!(handler(state), "127");
///
/// let managed = MyManagedState(77);
/// assert_eq!(handler(State::from(&managed)), "77");
/// ```
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct State<'r, T: Send + Sync + 'static>(&'r T);
impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// Retrieve a borrow to the underlying value with a lifetime of `'r`.
///
/// Using this method is typically unnecessary as `State` implements
/// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
/// automatically coerce a `State<T>` to an `&T` as required. This method
/// should only be used when a longer lifetime is required.
///
/// # Example
///
/// ```rust
/// use rocket::State;
///
/// struct MyConfig {
/// user_val: String
/// }
///
/// // Use `inner()` to get a lifetime of `'r`
/// fn handler1<'r>(config: State<'r, MyConfig>) -> &'r str {
/// &config.inner().user_val
/// }
///
/// // Use the `Deref` implementation which coerces implicitly
/// fn handler2(config: State<'_, MyConfig>) -> String {
/// config.user_val.clone()
/// }
/// ```
#[inline(always)]
pub fn inner(&self) -> &'r T {
self.0
}
#[repr(transparent)]
#[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct State<T: Send + Sync + 'static>(T);
impl<T: Send + Sync + 'static> State<T> {
/// Returns the managed state value in `rocket` for the type `T` if it is
/// being managed by `rocket`. Otherwise, returns `None`.
///
@ -157,26 +131,73 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> {
///
/// let rocket = rocket::build().manage(Managed(7));
///
/// let state: Option<State<Managed>> = State::from(&rocket);
/// let state: Option<&State<Managed>> = State::get(&rocket);
/// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7)));
///
/// let state: Option<State<Unmanaged>> = State::from(&rocket);
/// let state: Option<&State<Unmanaged>> = State::get(&rocket);
/// assert_eq!(state, None);
/// ```
#[inline(always)]
pub fn from<P: Phase>(rocket: &'r Rocket<P>) -> Option<Self> {
rocket.state().map(State)
pub fn get<P: Phase>(rocket: &Rocket<P>) -> Option<&State<T>> {
rocket.state::<T>().map(State::ref_cast)
}
/// This exists because `State::from()` would otherwise be nothing. But we
/// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is.
#[doc(hidden)]
#[inline(always)]
pub fn from(value: &T) -> &State<T> {
State::ref_cast(value)
}
/// Borrow the inner value.
///
/// Using this method is typically unnecessary as `State` implements
/// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
/// automatically coerce a `State<T>` to an `&T` as required. This method
/// should only be used when a longer lifetime is required.
///
/// # Example
///
/// ```rust
/// use rocket::State;
///
/// #[derive(Clone)]
/// struct MyConfig {
/// user_val: String
/// }
///
/// fn handler1<'r>(config: &State<MyConfig>) -> String {
/// let config = config.inner().clone();
/// config.user_val
/// }
///
/// // Use the `Deref` implementation which coerces implicitly
/// fn handler2(config: &State<MyConfig>) -> String {
/// config.user_val.clone()
/// }
/// ```
#[inline(always)]
pub fn inner(&self) -> &T {
&self.0
}
}
impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State<T> {
#[inline(always)]
fn from(reference: &'r T) -> Self {
State::ref_cast(reference)
}
}
#[crate::async_trait]
impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> {
impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State<T> {
type Error = ();
#[inline(always)]
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
match req.rocket().state::<T>() {
Some(state) => Outcome::Success(State(state)),
match State::get(req.rocket()) {
Some(state) => Outcome::Success(state),
None => {
error_!("Attempted to retrieve unmanaged state `{}`!", type_name::<T>());
Outcome::Failure((Status::InternalServerError, ()))
@ -185,7 +206,7 @@ impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> {
}
}
impl<T: Send + Sync + 'static> Sentinel for State<'_, T> {
impl<T: Send + Sync + 'static> Sentinel for &State<T> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
if rocket.state::<T>().is_none() {
let type_name = yansi::Paint::default(type_name::<T>()).bold();
@ -198,30 +219,17 @@ impl<T: Send + Sync + 'static> Sentinel for State<'_, T> {
}
}
impl<T: Send + Sync + 'static> Deref for State<'_, T> {
impl<T: Send + Sync + fmt::Display + 'static> fmt::Display for State<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl<T: Send + Sync + 'static> Deref for State<T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
self.0
}
}
impl<T: Send + Sync + 'static> Clone for State<'_, T> {
fn clone(&self) -> Self {
State(self.0)
}
}
#[cfg(test)]
mod tests {
#[test]
fn state_is_cloneable() {
struct Token(usize);
let rocket = crate::custom(crate::Config::default()).manage(Token(123));
let state = rocket.state::<Token>().unwrap();
assert_eq!(state.0, 123);
assert_eq!(state.clone().0, 123);
&self.0
}
}

View File

@ -19,7 +19,6 @@ mod fairing_before_head_strip {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::io::Cursor;
use rocket::State;
use rocket::fairing::AdHoc;
use rocket::local::blocking::Client;
use rocket::http::{Method, Status};
@ -62,7 +61,7 @@ mod fairing_before_head_strip {
assert_eq!(req.method(), Method::Head);
// This should be called exactly once.
let c = req.guard::<State<Counter>>().await.unwrap();
let c = req.rocket().state::<Counter>().unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
})
}))

View File

@ -13,7 +13,7 @@ struct Counter {
}
#[get("/")]
fn index(counter: State<'_, Counter>) -> String {
fn index(counter: &State<Counter>) -> String {
let attaches = counter.attach.load(Ordering::Relaxed);
let gets = counter.get.load(Ordering::Acquire);
format!("{}, {}", attaches, gets)
@ -29,8 +29,7 @@ fn rocket() -> Rocket<Build> {
.attach(AdHoc::on_request("Inner", |req, _| {
Box::pin(async move {
if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>()
.await.unwrap();
let counter = req.rocket().state::<Counter>().unwrap();
counter.get.fetch_add(1, Ordering::Release);
}
})

View File

@ -19,11 +19,11 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for CustomResponder<'r,
}
#[get("/unit_state")]
fn unit_state(state: State<SomeState>) -> CustomResponder<()> {
CustomResponder { responder: (), state: state.inner() }
fn unit_state(state: &State<SomeState>) -> CustomResponder<()> {
CustomResponder { responder: (), state: &*state }
}
#[get("/string_state")]
fn string_state(state: State<SomeState>) -> CustomResponder<String> {
CustomResponder { responder: "".to_string(), state: state.inner() }
fn string_state(state: &State<SomeState>) -> CustomResponder<String> {
CustomResponder { responder: "".to_string(), state: &*state }
}

View File

@ -1,10 +1,10 @@
use rocket::{*, error::ErrorKind::SentinelAborts};
#[get("/two")]
fn two_states(_one: State<u32>, _two: State<String>) {}
fn two_states(_one: &State<u32>, _two: &State<String>) {}
#[get("/one")]
fn one_state(_three: State<u8>) {}
fn one_state(_three: &State<u8>) {}
#[async_test]
async fn state_sentinel_works() {

View File

@ -14,7 +14,7 @@ struct AppConfig {
}
#[get("/")]
fn read_config(rocket_config: &Config, app_config: State<'_, AppConfig>) -> String {
fn read_config(rocket_config: &Config, app_config: &State<AppConfig>) -> String {
format!("{:#?}\n{:#?}", app_config, rocket_config)
}

View File

@ -21,19 +21,19 @@ struct Post {
}
#[post("/", data = "<post>")]
async fn create(db: State<'_, Db>, post: Json<Post>) -> Result<Created<Json<Post>>> {
async fn create(db: &State<Db>, post: Json<Post>) -> Result<Created<Json<Post>>> {
// There is no support for `RETURNING`.
sqlx::query!("INSERT INTO posts (title, text) VALUES (?, ?)", post.title, post.text)
.execute(&*db)
.execute(&**db)
.await?;
Ok(Created::new("/").body(post))
}
#[get("/")]
async fn list(db: State<'_, Db>) -> Result<Json<Vec<i64>>> {
async fn list(db: &State<Db>) -> Result<Json<Vec<i64>>> {
let ids = sqlx::query!("SELECT id FROM posts")
.fetch(&*db)
.fetch(&**db)
.map_ok(|record| record.id)
.try_collect::<Vec<_>>()
.await?;
@ -42,26 +42,26 @@ async fn list(db: State<'_, Db>) -> Result<Json<Vec<i64>>> {
}
#[get("/<id>")]
async fn read(db: State<'_, Db>, id: i64) -> Option<Json<Post>> {
async fn read(db: &State<Db>, id: i64) -> Option<Json<Post>> {
sqlx::query!("SELECT id, title, text FROM posts WHERE id = ?", id)
.fetch_one(&*db)
.fetch_one(&**db)
.map_ok(|r| Json(Post { id: Some(r.id), title: r.title, text: r.text }))
.await
.ok()
}
#[delete("/<id>")]
async fn delete(db: State<'_, Db>, id: i64) -> Result<Option<()>> {
async fn delete(db: &State<Db>, id: i64) -> Result<Option<()>> {
let result = sqlx::query!("DELETE FROM posts WHERE id = ?", id)
.execute(&*db)
.execute(&**db)
.await?;
Ok((result.rows_affected() == 1).then(|| ()))
}
#[delete("/")]
async fn destroy(db: State<'_, Db>) -> Result<()> {
sqlx::query!("DELETE FROM posts").execute(&*db).await?;
async fn destroy(db: &State<Db>) -> Result<()> {
sqlx::query!("DELETE FROM posts").execute(&**db).await?;
Ok(())
}

View File

@ -45,7 +45,7 @@ fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom<String>
#[allow(dead_code)]
#[get("/unmanaged")]
fn unmanaged(_u8: rocket::State<'_, u8>, _string: rocket::State<'_, String>) { }
fn unmanaged(_u8: &rocket::State<u8>, _string: &rocket::State<String>) { }
fn rocket() -> Rocket<Build> {
rocket::build()

View File

@ -29,7 +29,7 @@ impl Fairing for Counter {
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
#[get("/counts")]
fn counts(counts: State<'_, Counter>) -> String {
fn counts(counts: &State<Counter>) -> String {
let get_count = counts.get.load(Ordering::Relaxed);
let post_count = counts.post.load(Ordering::Relaxed);
format!("Get: {}\nPost: {}", get_count, post_count)
@ -53,7 +53,7 @@ fn hello() -> &'static str {
}
#[get("/token")]
fn token(token: State<'_, Token>) -> String {
fn token(token: &State<Token>) -> String {
format!("{}", token.0)
}

View File

@ -17,7 +17,7 @@ const HOST: &str = "http://localhost:8000";
const ID_LENGTH: usize = 3;
#[post("/", data = "<paste>")]
async fn upload(paste: Data, host: State<'_, Absolute<'_>>) -> io::Result<String> {
async fn upload(paste: Data, host: &State<Absolute<'_>>) -> io::Result<String> {
let id = PasteId::new(ID_LENGTH);
paste.open(128.kibibytes()).into_file(id.file_path()).await?;

View File

@ -11,7 +11,7 @@ type Id = usize;
// We're going to store all of the messages here. No need for a DB.
type MessageList = Mutex<Vec<String>>;
type Messages<'r> = State<'r, MessageList>;
type Messages<'r> = &'r State<MessageList>;
#[derive(Serialize, Deserialize)]
struct Message<'r> {

View File

@ -7,7 +7,7 @@ use rocket::fairing::AdHoc;
struct HitCount(AtomicUsize);
#[get("/")]
fn index(hit_count: State<'_, HitCount>) -> content::Html<String> {
fn index(hit_count: &State<HitCount>) -> content::Html<String> {
let count = hit_count.0.fetch_add(1, Ordering::Relaxed) + 1;
content::Html(format!("Your visit is recorded!<br /><br />Visits: {}", count))
}

View File

@ -6,12 +6,12 @@ struct Tx(flume::Sender<String>);
struct Rx(flume::Receiver<String>);
#[put("/push?<event>")]
fn push(event: String, tx: State<'_, Tx>) -> Result<(), Status> {
fn push(event: String, tx: &State<Tx>) -> Result<(), Status> {
tx.0.try_send(event).map_err(|_| Status::ServiceUnavailable)
}
#[get("/pop")]
fn pop(rx: State<'_, Rx>) -> Option<String> {
fn pop(rx: &State<Rx>) -> Option<String> {
rx.0.try_recv().ok()
}

View File

@ -23,7 +23,7 @@ impl<'r> FromRequest<'r> for Guard1 {
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
rocket::info_!("-- 1 --");
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache(|| {
rocket::info_!("1: populating cache!");
@ -53,7 +53,7 @@ impl<'r> FromRequest<'r> for Guard3 {
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
rocket::info_!("-- 3 --");
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await);
let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache_async(async {
rocket::info_!("3: populating cache!");
@ -77,12 +77,12 @@ impl<'r> FromRequest<'r> for Guard4 {
}
#[get("/1-2")]
fn one_two(_g1: Guard1, _g2: Guard2, state: State<'_, Atomics>) -> String {
fn one_two(_g1: Guard1, _g2: Guard2, state: &State<Atomics>) -> String {
format!("{:#?}", state)
}
#[get("/3-4")]
fn three_four(_g3: Guard3, _g4: Guard4, state: State<'_, Atomics>) -> String {
fn three_four(_g3: Guard3, _g4: Guard4, state: &State<Atomics>) -> String {
format!("{:#?}", state)
}
@ -92,7 +92,7 @@ fn all(
_g2: Guard2,
_g3: Guard3,
_g4: Guard4,
state: State<'_, Atomics>
state: &State<Atomics>
) -> String {
format!("{:#?}", state)
}

View File

@ -3,7 +3,7 @@ use rocket::fairing::AdHoc;
use rocket::tokio::sync::Barrier;
#[get("/barrier")]
async fn rendezvous(barrier: State<'_, Barrier>) -> &'static str {
async fn rendezvous(barrier: &State<Barrier>) -> &'static str {
println!("Waiting for second task...");
barrier.wait().await;
"Rendezvous reached."

View File

@ -15,7 +15,7 @@ use rocket_contrib::uuid::extern_uuid;
struct People(HashMap<extern_uuid::Uuid, &'static str>);
#[get("/people/<id>")]
fn people(id: Uuid, people: State<People>) -> Result<String, String> {
fn people(id: Uuid, people: &State<People>) -> Result<String, String> {
// Because Uuid implements the Deref trait, we use Deref coercion to convert
// rocket_contrib::uuid::Uuid to uuid::Uuid.
Ok(people.0.get(&id)

View File

@ -16,7 +16,7 @@ The process for using managed state is simple:
1. Call `manage` on the `Rocket` instance corresponding to your application
with the initial value of the state.
2. Add a `State<T>` type to any request handler, where `T` is the type of the
2. Add a `&State<T>` type to any request handler, where `T` is the type of the
value passed into `manage`.
! note: All managed state must be thread-safe.
@ -63,10 +63,10 @@ rocket::build()
### Retrieving State
State that is being managed by Rocket can be retrieved via the
[`State`](@api/rocket/struct.State.html) type: a [request
guard](../requests/#request-guards) for managed state. To use the request
guard, add a `State<T>` type to any request handler, where `T` is the type of
the managed state. For example, we can retrieve and respond with the current
[`&State`](@api/rocket/struct.State.html) type: a [request
guard](../requests/#request-guards) for managed state. To use the request guard,
add a `&State<T>` type to any request handler, where `T` is the type of the
managed state. For example, we can retrieve and respond with the current
`HitCount` in a `count` route as follows:
```rust
@ -79,13 +79,13 @@ the managed state. For example, we can retrieve and respond with the current
use rocket::State;
#[get("/count")]
fn count(hit_count: State<HitCount>) -> String {
fn count(hit_count: &State<HitCount>) -> String {
let current_count = hit_count.count.load(Ordering::Relaxed);
format!("Number of visits: {}", current_count)
}
```
You can retrieve more than one `State` type in a single route as well:
You can retrieve more than one `&State` type in a single route as well:
```rust
# #[macro_use] extern crate rocket;
@ -96,52 +96,60 @@ You can retrieve more than one `State` type in a single route as well:
# use rocket::State;
#[get("/state")]
fn state(hit_count: State<HitCount>, config: State<Config>) { /* .. */ }
fn state(hit_count: &State<HitCount>, config: &State<Config>) { /* .. */ }
```
! warning
If you request a `State<T>` for a `T` that is not `managed`, Rocket won't call
the offending route. Instead, Rocket will log an error message and return a
**500** error to the client.
If you request a `&State<T>` for a `T` that is not `managed`, Rocket will
refuse to start your application. This prevents what would have been an
unmanaged state runtime error. Unmanaged state is detected at runtime through
[_sentinels_](@api/rocket/trait.Sentinel.html), so there are limitations. If a
limitation is hit, Rocket still won't call an the offending route. Instead,
Rocket will log an error message and return a **500** error to the client.
You can find a complete example using the `HitCount` structure in the [state
example on GitHub](@example/state) and learn more about the [`manage`
method](@api/rocket/struct.Rocket.html#method.manage) and [`State`
type](@api/rocket/struct.State.html) in the API docs.
### Within Guards
# Within Guards
It can also be useful to retrieve managed state from a `FromRequest`
implementation. To do so, simply invoke `State<T>` as a guard using the
[`Request::guard()`] method.
Because `State` is itself a request guard, managed state can be retrieved from
another request guard's implementation using either [`Request::guard()`] or
[`Rocket::state()`]. In the following code example, the `Item` request guard
retrieves `MyConfig` from managed state using both methods:
```rust
# #[macro_use] extern crate rocket;
# fn main() {}
use rocket::State;
use rocket::request::{self, Request, FromRequest};
use rocket::outcome::try_outcome;
# use std::sync::atomic::{AtomicUsize, Ordering};
use rocket::outcome::IntoOutcome;
# struct MyConfig { user_val: String };
struct Item<'r>(&'r str);
# struct T;
# struct HitCount { count: AtomicUsize }
# type ErrorType = ();
#[rocket::async_trait]
impl<'r> FromRequest<'r> for T {
type Error = ErrorType;
impl<'r> FromRequest<'r> for Item<'r> {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> request::Outcome<T, Self::Error> {
let hit_count_state = try_outcome!(req.guard::<State<HitCount>>().await);
let current_count = hit_count_state.count.load(Ordering::Relaxed);
/* ... */
# request::Outcome::Success(T)
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
// Using `State` as a request guard. Use `inner()` to get an `'r`.
let outcome = request.guard::<&State<MyConfig>>().await
.map(|my_config| Item(&my_config.user_val));
// Or alternatively, using `Rocket::state()`:
let outcome = request.rocket().state::<MyConfig>()
.map(|my_config| Item(&my_config.user_val))
.or_forward(());
outcome
}
}
```
[`Request::guard()`]: @api/rocket/struct.Request.html#method.guard
[`Rocket::state()`]: @api/rocket/struct.Rocket.html#method.state
## Request-Local State

View File

@ -253,7 +253,7 @@ Because it is common to store configuration in managed state, Rocket provides an
use rocket::{State, fairing::AdHoc};
#[get("/custom")]
fn custom(config: State<'_, Config>) -> String {
fn custom(config: &State<Config>) -> String {
config.custom.get(0).cloned().unwrap_or("default".into())
}