Retrieve managed state via a borrow: '&State<T>'.

This has the following positive effects:

  1) The lifetime retrieved through 'Deref' is now long-lived.
  2) An '&State<T>` can be created via an '&T'.
  3) '&State<T>' is shorter to type than 'State<'_, T>'.
This commit is contained in:
Sergio Benitez 2021-05-11 08:56:35 -05:00
parent f442ad93cb
commit d03a07b183
24 changed files with 193 additions and 176 deletions

View File

@ -1,4 +1,4 @@
use rocket::{Request, State, Rocket, Ignite, Sentinel}; use rocket::{Request, Rocket, Ignite, Sentinel};
use rocket::http::Status; use rocket::http::Status;
use rocket::request::{self, FromRequest}; use rocket::request::{self, FromRequest};
@ -103,9 +103,8 @@ impl<'r> FromRequest<'r> for Metadata<'r> {
type Error = (); type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
request.guard::<State<'_, ContextManager>>().await request.rocket().state::<ContextManager>()
.succeeded() .map(|cm| request::Outcome::Success(Metadata(cm)))
.and_then(|cm| Some(request::Outcome::Success(Metadata(cm.inner()))))
.unwrap_or_else(|| { .unwrap_or_else(|| {
error_!("Uninitialized template context: missing fairing."); error_!("Uninitialized template context: missing fairing.");
info_!("To use templates, you must attach `Template::fairing()`."); info_!("To use templates, you must attach `Template::fairing()`.");

View File

@ -49,8 +49,8 @@ fn test_reexpansion() {
macro_rules! index { macro_rules! index {
($type:ty) => { ($type:ty) => {
#[get("/")] #[get("/")]
fn index(thing: rocket::State<$type>) -> String { fn index(thing: &rocket::State<$type>) -> String {
format!("Thing: {}", *thing) format!("Thing: {}", thing)
} }
} }
} }

View File

@ -666,7 +666,7 @@ crate::export! {
/// ///
/// async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> { /// async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
/// // Attempt to fetch the guard, passing through any error or forward. /// // Attempt to fetch the guard, passing through any error or forward.
/// let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await); /// let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
/// atomics.uncached.fetch_add(1, Ordering::Relaxed); /// atomics.uncached.fetch_add(1, Ordering::Relaxed);
/// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); /// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed));
/// ///

View File

@ -366,13 +366,13 @@ impl Rocket<Build> {
/// struct MyString(String); /// struct MyString(String);
/// ///
/// #[get("/int")] /// #[get("/int")]
/// fn int(state: State<'_, MyInt>) -> String { /// fn int(state: &State<MyInt>) -> String {
/// format!("The stateful int is: {}", state.0) /// format!("The stateful int is: {}", state.0)
/// } /// }
/// ///
/// #[get("/string")] /// #[get("/string")]
/// fn string<'r>(state: State<'r, MyString>) -> &'r str { /// fn string(state: &State<MyString>) -> &str {
/// &state.inner().0 /// &state.0
/// } /// }
/// ///
/// #[launch] /// #[launch]

View File

@ -103,8 +103,8 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>;
/// use rocket::State; /// use rocket::State;
/// ///
/// #[get("/")] /// #[get("/")]
/// fn custom_handler(state: State<Kind>) -> &'static str { /// fn custom_handler(state: &State<Kind>) -> &'static str {
/// match *state { /// match state.inner() {
/// Kind::Simple => "simple", /// Kind::Simple => "simple",
/// Kind::Intermediate => "intermediate", /// Kind::Intermediate => "intermediate",
/// Kind::Complex => "complex", /// Kind::Complex => "complex",

View File

@ -30,7 +30,7 @@ use crate::{Rocket, Ignite};
/// # use rocket::*; /// # use rocket::*;
/// # type Response = (); /// # type Response = ();
/// #[get("/<id>")] /// #[get("/<id>")]
/// fn index(id: usize, state: State<String>) -> Response { /// fn index(id: usize, state: &State<String>) -> Response {
/// /* ... */ /// /* ... */
/// } /// }
/// ///
@ -63,7 +63,7 @@ use crate::{Rocket, Ignite};
/// # use rocket::*; /// # use rocket::*;
/// # type Response = (); /// # type Response = ();
/// # #[get("/<id>")] /// # #[get("/<id>")]
/// # fn index(id: usize, state: State<String>) -> Response { /// # fn index(id: usize, state: &State<String>) -> Response {
/// # /* ... */ /// # /* ... */
/// # } /// # }
/// # /// #
@ -92,19 +92,21 @@ use crate::{Rocket, Ignite};
/// # type Foo = (); /// # type Foo = ();
/// # type Bar = (); /// # type Bar = ();
/// #[get("/")] /// #[get("/")]
/// fn f(guard: Option<State<'_, String>>) -> Either<Foo, Inner<Bar>> { /// fn f(guard: Option<&State<String>>) -> Either<Foo, Inner<Bar>> {
/// unimplemented!() /// unimplemented!()
/// } /// }
/// ``` /// ```
/// ///
/// The directly eligible sentinel types, guard and responders, are: /// The directly eligible sentinel types, guard and responders, are:
/// ///
/// * `Option<State<'_, String>>` /// * `Option<&State<String>>`
/// * `Either<Foo, INner<Bar>>` /// * `Either<Foo, INner<Bar>>`
/// ///
/// In addition, all embedded types are _also_ eligble. These are: /// In addition, all embedded types are _also_ eligble. These are:
/// ///
/// * `State<'_, String>` /// * `&State<String>`
/// * `State<String>`
/// * `String`
/// * `Foo` /// * `Foo`
/// * `Inner<Bar>` /// * `Inner<Bar>`
/// * `Bar` /// * `Bar`
@ -116,15 +118,17 @@ use crate::{Rocket, Ignite};
/// breadth-first order, is queried: /// breadth-first order, is queried:
/// ///
/// ```text /// ```text
/// Option<State<'_, String>> Either<Foo, Inner<Bar>> /// Option<&State<String>> Either<Foo, Inner<Bar>>
/// | / \ /// | / \
/// State<'_, String> Foo Inner<Bar> /// &State<String> Foo Inner<Bar>
/// | /// | |
/// Bar /// State<String> Bar
/// |
/// String
/// ``` /// ```
/// ///
/// Neither `Option` nor `Either` are sentinels, so they won't be queried. In /// Neither `Option` nor `Either` are sentinels, so they won't be queried. In
/// the next level, `State` is a `Sentinel`, so it _is_ queried. If `Foo` is a /// the next level, `&State` is a `Sentinel`, so it _is_ queried. If `Foo` is a
/// sentinel, it is queried as well. If `Inner` is a sentinel, it is queried, /// sentinel, it is queried as well. If `Inner` is a sentinel, it is queried,
/// and traversal stops without considering `Bar`. If `Inner` is _not_ a /// and traversal stops without considering `Bar`. If `Inner` is _not_ a
/// `Sentinel`, `Bar` is considered and queried if it is a sentinel. /// `Sentinel`, `Bar` is considered and queried if it is a sentinel.
@ -178,10 +182,10 @@ use crate::{Rocket, Ignite};
/// ## Aliases /// ## Aliases
/// ///
/// Embedded discovery of sentinels is syntactic in nature: an embedded sentinel /// Embedded discovery of sentinels is syntactic in nature: an embedded sentinel
/// is only discovered if its named in the type. As such, sentinels made opaque /// is only discovered if its named in the type. As such, embedded sentinels
/// by a type alias will fail to be considered. In the example below, only /// made opaque by a type alias will fail to be considered. In the example
/// `Result<Foo, Bar>` will be considered, while the embedded `Foo` and `Bar` /// below, only `Result<Foo, Bar>` will be considered, while the embedded `Foo`
/// will not. /// and `Bar` will not.
/// ///
/// ```rust /// ```rust
/// # use rocket::get; /// # use rocket::get;

View File

@ -1,6 +1,9 @@
use std::fmt;
use std::ops::Deref; use std::ops::Deref;
use std::any::type_name; use std::any::type_name;
use ref_cast::RefCast;
use crate::{Phase, Rocket, Ignite, Sentinel}; use crate::{Phase, Rocket, Ignite, Sentinel};
use crate::request::{self, FromRequest, Request}; use crate::request::{self, FromRequest, Request};
use crate::outcome::Outcome; use crate::outcome::Outcome;
@ -8,14 +11,13 @@ use crate::http::Status;
/// Request guard to retrieve managed state. /// Request guard to retrieve managed state.
/// ///
/// This type can be used as a request guard to retrieve the state Rocket is /// A reference `&State<T>` type is a request guard which retrieves the managed
/// managing for some type `T`. This allows for the sharing of state across any /// state managing for some type `T`. A value for the given type must previously
/// number of handlers. A value for the given type must previously have been /// have been registered to be managed by Rocket via [`Rocket::manage()`]. The
/// registered to be managed by Rocket via [`Rocket::manage()`]. The type being /// type being managed must be thread safe and sendable across thread
/// managed must be thread safe and sendable across thread boundaries. In other /// boundaries as multiple handlers in multiple threads may be accessing the
/// words, it must implement [`Send`] + [`Sync`] + `'static`. /// value at once. In other words, it must implement [`Send`] + [`Sync`] +
/// /// `'static`.
/// [`Rocket::manage()`]: crate::Rocket::manage()
/// ///
/// # Example /// # Example
/// ///
@ -33,14 +35,13 @@ use crate::http::Status;
/// } /// }
/// ///
/// #[get("/")] /// #[get("/")]
/// fn index(state: State<'_, MyConfig>) -> String { /// fn index(state: &State<MyConfig>) -> String {
/// format!("The config value is: {}", state.user_val) /// format!("The config value is: {}", state.user_val)
/// } /// }
/// ///
/// #[get("/raw")] /// #[get("/raw")]
/// fn raw_config_value<'r>(state: State<'r, MyConfig>) -> &'r str { /// fn raw_config_value(state: &State<MyConfig>) -> &str {
/// // use `inner()` to get a lifetime longer than `deref` gives us /// &state.user_val
/// state.inner().user_val.as_str()
/// } /// }
/// ///
/// #[launch] /// #[launch]
@ -72,10 +73,10 @@ use crate::http::Status;
/// ///
/// async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> { /// async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
/// // Using `State` as a request guard. Use `inner()` to get an `'r`. /// // Using `State` as a request guard. Use `inner()` to get an `'r`.
/// let outcome = request.guard::<State<MyConfig>>().await /// let outcome = request.guard::<&State<MyConfig>>().await
/// .map(|my_config| Item(&my_config.inner().user_val)); /// .map(|my_config| Item(&my_config.user_val));
/// ///
/// // Or alternatively, using `Request::managed_state()`: /// // Or alternatively, using `Rocket::state()`:
/// let outcome = request.rocket().state::<MyConfig>() /// let outcome = request.rocket().state::<MyConfig>()
/// .map(|my_config| Item(&my_config.user_val)) /// .map(|my_config| Item(&my_config.user_val))
/// .or_forward(()); /// .or_forward(());
@ -89,7 +90,7 @@ use crate::http::Status;
/// ///
/// When unit testing your application, you may find it necessary to manually /// When unit testing your application, you may find it necessary to manually
/// construct a type of `State` to pass to your functions. To do so, use the /// construct a type of `State` to pass to your functions. To do so, use the
/// [`State::from()`] static method: /// [`State::get()`] static method or the `From<&T>` implementation:
/// ///
/// ```rust /// ```rust
/// # #[macro_use] extern crate rocket; /// # #[macro_use] extern crate rocket;
@ -98,49 +99,22 @@ use crate::http::Status;
/// struct MyManagedState(usize); /// struct MyManagedState(usize);
/// ///
/// #[get("/")] /// #[get("/")]
/// fn handler(state: State<'_, MyManagedState>) -> String { /// fn handler(state: &State<MyManagedState>) -> String {
/// state.0.to_string() /// state.0.to_string()
/// } /// }
/// ///
/// let mut rocket = rocket::build().manage(MyManagedState(127)); /// let mut rocket = rocket::build().manage(MyManagedState(127));
/// let state = State::from(&rocket).expect("managed `MyManagedState`"); /// let state = State::get(&rocket).expect("managed `MyManagedState`");
/// assert_eq!(handler(state), "127"); /// assert_eq!(handler(state), "127");
///
/// let managed = MyManagedState(77);
/// assert_eq!(handler(State::from(&managed)), "77");
/// ``` /// ```
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)]
pub struct State<'r, T: Send + Sync + 'static>(&'r T); #[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct State<T: Send + Sync + 'static>(T);
impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// Retrieve a borrow to the underlying value with a lifetime of `'r`.
///
/// Using this method is typically unnecessary as `State` implements
/// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
/// automatically coerce a `State<T>` to an `&T` as required. This method
/// should only be used when a longer lifetime is required.
///
/// # Example
///
/// ```rust
/// use rocket::State;
///
/// struct MyConfig {
/// user_val: String
/// }
///
/// // Use `inner()` to get a lifetime of `'r`
/// fn handler1<'r>(config: State<'r, MyConfig>) -> &'r str {
/// &config.inner().user_val
/// }
///
/// // Use the `Deref` implementation which coerces implicitly
/// fn handler2(config: State<'_, MyConfig>) -> String {
/// config.user_val.clone()
/// }
/// ```
#[inline(always)]
pub fn inner(&self) -> &'r T {
self.0
}
impl<T: Send + Sync + 'static> State<T> {
/// Returns the managed state value in `rocket` for the type `T` if it is /// Returns the managed state value in `rocket` for the type `T` if it is
/// being managed by `rocket`. Otherwise, returns `None`. /// being managed by `rocket`. Otherwise, returns `None`.
/// ///
@ -157,26 +131,73 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// ///
/// let rocket = rocket::build().manage(Managed(7)); /// let rocket = rocket::build().manage(Managed(7));
/// ///
/// let state: Option<State<Managed>> = State::from(&rocket); /// let state: Option<&State<Managed>> = State::get(&rocket);
/// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7))); /// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7)));
/// ///
/// let state: Option<State<Unmanaged>> = State::from(&rocket); /// let state: Option<&State<Unmanaged>> = State::get(&rocket);
/// assert_eq!(state, None); /// assert_eq!(state, None);
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn from<P: Phase>(rocket: &'r Rocket<P>) -> Option<Self> { pub fn get<P: Phase>(rocket: &Rocket<P>) -> Option<&State<T>> {
rocket.state().map(State) rocket.state::<T>().map(State::ref_cast)
}
/// This exists because `State::from()` would otherwise be nothing. But we
/// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is.
#[doc(hidden)]
#[inline(always)]
pub fn from(value: &T) -> &State<T> {
State::ref_cast(value)
}
/// Borrow the inner value.
///
/// Using this method is typically unnecessary as `State` implements
/// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
/// automatically coerce a `State<T>` to an `&T` as required. This method
/// should only be used when a longer lifetime is required.
///
/// # Example
///
/// ```rust
/// use rocket::State;
///
/// #[derive(Clone)]
/// struct MyConfig {
/// user_val: String
/// }
///
/// fn handler1<'r>(config: &State<MyConfig>) -> String {
/// let config = config.inner().clone();
/// config.user_val
/// }
///
/// // Use the `Deref` implementation which coerces implicitly
/// fn handler2(config: &State<MyConfig>) -> String {
/// config.user_val.clone()
/// }
/// ```
#[inline(always)]
pub fn inner(&self) -> &T {
&self.0
}
}
impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State<T> {
#[inline(always)]
fn from(reference: &'r T) -> Self {
State::ref_cast(reference)
} }
} }
#[crate::async_trait] #[crate::async_trait]
impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> { impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State<T> {
type Error = (); type Error = ();
#[inline(always)] #[inline(always)]
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> { async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
match req.rocket().state::<T>() { match State::get(req.rocket()) {
Some(state) => Outcome::Success(State(state)), Some(state) => Outcome::Success(state),
None => { None => {
error_!("Attempted to retrieve unmanaged state `{}`!", type_name::<T>()); error_!("Attempted to retrieve unmanaged state `{}`!", type_name::<T>());
Outcome::Failure((Status::InternalServerError, ())) Outcome::Failure((Status::InternalServerError, ()))
@ -185,7 +206,7 @@ impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> {
} }
} }
impl<T: Send + Sync + 'static> Sentinel for State<'_, T> { impl<T: Send + Sync + 'static> Sentinel for &State<T> {
fn abort(rocket: &Rocket<Ignite>) -> bool { fn abort(rocket: &Rocket<Ignite>) -> bool {
if rocket.state::<T>().is_none() { if rocket.state::<T>().is_none() {
let type_name = yansi::Paint::default(type_name::<T>()).bold(); let type_name = yansi::Paint::default(type_name::<T>()).bold();
@ -198,30 +219,17 @@ impl<T: Send + Sync + 'static> Sentinel for State<'_, T> {
} }
} }
impl<T: Send + Sync + 'static> Deref for State<'_, T> { impl<T: Send + Sync + fmt::Display + 'static> fmt::Display for State<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl<T: Send + Sync + 'static> Deref for State<T> {
type Target = T; type Target = T;
#[inline(always)] #[inline(always)]
fn deref(&self) -> &T { fn deref(&self) -> &T {
self.0 &self.0
}
}
impl<T: Send + Sync + 'static> Clone for State<'_, T> {
fn clone(&self) -> Self {
State(self.0)
}
}
#[cfg(test)]
mod tests {
#[test]
fn state_is_cloneable() {
struct Token(usize);
let rocket = crate::custom(crate::Config::default()).manage(Token(123));
let state = rocket.state::<Token>().unwrap();
assert_eq!(state.0, 123);
assert_eq!(state.clone().0, 123);
} }
} }

View File

@ -19,7 +19,6 @@ mod fairing_before_head_strip {
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::io::Cursor; use std::io::Cursor;
use rocket::State;
use rocket::fairing::AdHoc; use rocket::fairing::AdHoc;
use rocket::local::blocking::Client; use rocket::local::blocking::Client;
use rocket::http::{Method, Status}; use rocket::http::{Method, Status};
@ -62,7 +61,7 @@ mod fairing_before_head_strip {
assert_eq!(req.method(), Method::Head); assert_eq!(req.method(), Method::Head);
// This should be called exactly once. // This should be called exactly once.
let c = req.guard::<State<Counter>>().await.unwrap(); let c = req.rocket().state::<Counter>().unwrap();
assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0);
}) })
})) }))

View File

@ -13,7 +13,7 @@ struct Counter {
} }
#[get("/")] #[get("/")]
fn index(counter: State<'_, Counter>) -> String { fn index(counter: &State<Counter>) -> String {
let attaches = counter.attach.load(Ordering::Relaxed); let attaches = counter.attach.load(Ordering::Relaxed);
let gets = counter.get.load(Ordering::Acquire); let gets = counter.get.load(Ordering::Acquire);
format!("{}, {}", attaches, gets) format!("{}, {}", attaches, gets)
@ -29,8 +29,7 @@ fn rocket() -> Rocket<Build> {
.attach(AdHoc::on_request("Inner", |req, _| { .attach(AdHoc::on_request("Inner", |req, _| {
Box::pin(async move { Box::pin(async move {
if req.method() == Method::Get { if req.method() == Method::Get {
let counter = req.guard::<State<'_, Counter>>() let counter = req.rocket().state::<Counter>().unwrap();
.await.unwrap();
counter.get.fetch_add(1, Ordering::Release); counter.get.fetch_add(1, Ordering::Release);
} }
}) })

View File

@ -19,11 +19,11 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for CustomResponder<'r,
} }
#[get("/unit_state")] #[get("/unit_state")]
fn unit_state(state: State<SomeState>) -> CustomResponder<()> { fn unit_state(state: &State<SomeState>) -> CustomResponder<()> {
CustomResponder { responder: (), state: state.inner() } CustomResponder { responder: (), state: &*state }
} }
#[get("/string_state")] #[get("/string_state")]
fn string_state(state: State<SomeState>) -> CustomResponder<String> { fn string_state(state: &State<SomeState>) -> CustomResponder<String> {
CustomResponder { responder: "".to_string(), state: state.inner() } CustomResponder { responder: "".to_string(), state: &*state }
} }

View File

@ -1,10 +1,10 @@
use rocket::{*, error::ErrorKind::SentinelAborts}; use rocket::{*, error::ErrorKind::SentinelAborts};
#[get("/two")] #[get("/two")]
fn two_states(_one: State<u32>, _two: State<String>) {} fn two_states(_one: &State<u32>, _two: &State<String>) {}
#[get("/one")] #[get("/one")]
fn one_state(_three: State<u8>) {} fn one_state(_three: &State<u8>) {}
#[async_test] #[async_test]
async fn state_sentinel_works() { async fn state_sentinel_works() {

View File

@ -14,7 +14,7 @@ struct AppConfig {
} }
#[get("/")] #[get("/")]
fn read_config(rocket_config: &Config, app_config: State<'_, AppConfig>) -> String { fn read_config(rocket_config: &Config, app_config: &State<AppConfig>) -> String {
format!("{:#?}\n{:#?}", app_config, rocket_config) format!("{:#?}\n{:#?}", app_config, rocket_config)
} }

View File

@ -21,19 +21,19 @@ struct Post {
} }
#[post("/", data = "<post>")] #[post("/", data = "<post>")]
async fn create(db: State<'_, Db>, post: Json<Post>) -> Result<Created<Json<Post>>> { async fn create(db: &State<Db>, post: Json<Post>) -> Result<Created<Json<Post>>> {
// There is no support for `RETURNING`. // There is no support for `RETURNING`.
sqlx::query!("INSERT INTO posts (title, text) VALUES (?, ?)", post.title, post.text) sqlx::query!("INSERT INTO posts (title, text) VALUES (?, ?)", post.title, post.text)
.execute(&*db) .execute(&**db)
.await?; .await?;
Ok(Created::new("/").body(post)) Ok(Created::new("/").body(post))
} }
#[get("/")] #[get("/")]
async fn list(db: State<'_, Db>) -> Result<Json<Vec<i64>>> { async fn list(db: &State<Db>) -> Result<Json<Vec<i64>>> {
let ids = sqlx::query!("SELECT id FROM posts") let ids = sqlx::query!("SELECT id FROM posts")
.fetch(&*db) .fetch(&**db)
.map_ok(|record| record.id) .map_ok(|record| record.id)
.try_collect::<Vec<_>>() .try_collect::<Vec<_>>()
.await?; .await?;
@ -42,26 +42,26 @@ async fn list(db: State<'_, Db>) -> Result<Json<Vec<i64>>> {
} }
#[get("/<id>")] #[get("/<id>")]
async fn read(db: State<'_, Db>, id: i64) -> Option<Json<Post>> { async fn read(db: &State<Db>, id: i64) -> Option<Json<Post>> {
sqlx::query!("SELECT id, title, text FROM posts WHERE id = ?", id) sqlx::query!("SELECT id, title, text FROM posts WHERE id = ?", id)
.fetch_one(&*db) .fetch_one(&**db)
.map_ok(|r| Json(Post { id: Some(r.id), title: r.title, text: r.text })) .map_ok(|r| Json(Post { id: Some(r.id), title: r.title, text: r.text }))
.await .await
.ok() .ok()
} }
#[delete("/<id>")] #[delete("/<id>")]
async fn delete(db: State<'_, Db>, id: i64) -> Result<Option<()>> { async fn delete(db: &State<Db>, id: i64) -> Result<Option<()>> {
let result = sqlx::query!("DELETE FROM posts WHERE id = ?", id) let result = sqlx::query!("DELETE FROM posts WHERE id = ?", id)
.execute(&*db) .execute(&**db)
.await?; .await?;
Ok((result.rows_affected() == 1).then(|| ())) Ok((result.rows_affected() == 1).then(|| ()))
} }
#[delete("/")] #[delete("/")]
async fn destroy(db: State<'_, Db>) -> Result<()> { async fn destroy(db: &State<Db>) -> Result<()> {
sqlx::query!("DELETE FROM posts").execute(&*db).await?; sqlx::query!("DELETE FROM posts").execute(&**db).await?;
Ok(()) Ok(())
} }

View File

@ -45,7 +45,7 @@ fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom<String>
#[allow(dead_code)] #[allow(dead_code)]
#[get("/unmanaged")] #[get("/unmanaged")]
fn unmanaged(_u8: rocket::State<'_, u8>, _string: rocket::State<'_, String>) { } fn unmanaged(_u8: &rocket::State<u8>, _string: &rocket::State<String>) { }
fn rocket() -> Rocket<Build> { fn rocket() -> Rocket<Build> {
rocket::build() rocket::build()

View File

@ -29,7 +29,7 @@ impl Fairing for Counter {
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result { async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
#[get("/counts")] #[get("/counts")]
fn counts(counts: State<'_, Counter>) -> String { fn counts(counts: &State<Counter>) -> String {
let get_count = counts.get.load(Ordering::Relaxed); let get_count = counts.get.load(Ordering::Relaxed);
let post_count = counts.post.load(Ordering::Relaxed); let post_count = counts.post.load(Ordering::Relaxed);
format!("Get: {}\nPost: {}", get_count, post_count) format!("Get: {}\nPost: {}", get_count, post_count)
@ -53,7 +53,7 @@ fn hello() -> &'static str {
} }
#[get("/token")] #[get("/token")]
fn token(token: State<'_, Token>) -> String { fn token(token: &State<Token>) -> String {
format!("{}", token.0) format!("{}", token.0)
} }

View File

@ -17,7 +17,7 @@ const HOST: &str = "http://localhost:8000";
const ID_LENGTH: usize = 3; const ID_LENGTH: usize = 3;
#[post("/", data = "<paste>")] #[post("/", data = "<paste>")]
async fn upload(paste: Data, host: State<'_, Absolute<'_>>) -> io::Result<String> { async fn upload(paste: Data, host: &State<Absolute<'_>>) -> io::Result<String> {
let id = PasteId::new(ID_LENGTH); let id = PasteId::new(ID_LENGTH);
paste.open(128.kibibytes()).into_file(id.file_path()).await?; paste.open(128.kibibytes()).into_file(id.file_path()).await?;

View File

@ -11,7 +11,7 @@ type Id = usize;
// We're going to store all of the messages here. No need for a DB. // We're going to store all of the messages here. No need for a DB.
type MessageList = Mutex<Vec<String>>; type MessageList = Mutex<Vec<String>>;
type Messages<'r> = State<'r, MessageList>; type Messages<'r> = &'r State<MessageList>;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Message<'r> { struct Message<'r> {

View File

@ -7,7 +7,7 @@ use rocket::fairing::AdHoc;
struct HitCount(AtomicUsize); struct HitCount(AtomicUsize);
#[get("/")] #[get("/")]
fn index(hit_count: State<'_, HitCount>) -> content::Html<String> { fn index(hit_count: &State<HitCount>) -> content::Html<String> {
let count = hit_count.0.fetch_add(1, Ordering::Relaxed) + 1; let count = hit_count.0.fetch_add(1, Ordering::Relaxed) + 1;
content::Html(format!("Your visit is recorded!<br /><br />Visits: {}", count)) content::Html(format!("Your visit is recorded!<br /><br />Visits: {}", count))
} }

View File

@ -6,12 +6,12 @@ struct Tx(flume::Sender<String>);
struct Rx(flume::Receiver<String>); struct Rx(flume::Receiver<String>);
#[put("/push?<event>")] #[put("/push?<event>")]
fn push(event: String, tx: State<'_, Tx>) -> Result<(), Status> { fn push(event: String, tx: &State<Tx>) -> Result<(), Status> {
tx.0.try_send(event).map_err(|_| Status::ServiceUnavailable) tx.0.try_send(event).map_err(|_| Status::ServiceUnavailable)
} }
#[get("/pop")] #[get("/pop")]
fn pop(rx: State<'_, Rx>) -> Option<String> { fn pop(rx: &State<Rx>) -> Option<String> {
rx.0.try_recv().ok() rx.0.try_recv().ok()
} }

View File

@ -23,7 +23,7 @@ impl<'r> FromRequest<'r> for Guard1 {
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> { async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
rocket::info_!("-- 1 --"); rocket::info_!("-- 1 --");
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await); let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed); atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache(|| { req.local_cache(|| {
rocket::info_!("1: populating cache!"); rocket::info_!("1: populating cache!");
@ -53,7 +53,7 @@ impl<'r> FromRequest<'r> for Guard3 {
async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> { async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
rocket::info_!("-- 3 --"); rocket::info_!("-- 3 --");
let atomics = try_outcome!(req.guard::<State<'_, Atomics>>().await); let atomics = try_outcome!(req.guard::<&State<Atomics>>().await);
atomics.uncached.fetch_add(1, Ordering::Relaxed); atomics.uncached.fetch_add(1, Ordering::Relaxed);
req.local_cache_async(async { req.local_cache_async(async {
rocket::info_!("3: populating cache!"); rocket::info_!("3: populating cache!");
@ -77,12 +77,12 @@ impl<'r> FromRequest<'r> for Guard4 {
} }
#[get("/1-2")] #[get("/1-2")]
fn one_two(_g1: Guard1, _g2: Guard2, state: State<'_, Atomics>) -> String { fn one_two(_g1: Guard1, _g2: Guard2, state: &State<Atomics>) -> String {
format!("{:#?}", state) format!("{:#?}", state)
} }
#[get("/3-4")] #[get("/3-4")]
fn three_four(_g3: Guard3, _g4: Guard4, state: State<'_, Atomics>) -> String { fn three_four(_g3: Guard3, _g4: Guard4, state: &State<Atomics>) -> String {
format!("{:#?}", state) format!("{:#?}", state)
} }
@ -92,7 +92,7 @@ fn all(
_g2: Guard2, _g2: Guard2,
_g3: Guard3, _g3: Guard3,
_g4: Guard4, _g4: Guard4,
state: State<'_, Atomics> state: &State<Atomics>
) -> String { ) -> String {
format!("{:#?}", state) format!("{:#?}", state)
} }

View File

@ -3,7 +3,7 @@ use rocket::fairing::AdHoc;
use rocket::tokio::sync::Barrier; use rocket::tokio::sync::Barrier;
#[get("/barrier")] #[get("/barrier")]
async fn rendezvous(barrier: State<'_, Barrier>) -> &'static str { async fn rendezvous(barrier: &State<Barrier>) -> &'static str {
println!("Waiting for second task..."); println!("Waiting for second task...");
barrier.wait().await; barrier.wait().await;
"Rendezvous reached." "Rendezvous reached."

View File

@ -15,7 +15,7 @@ use rocket_contrib::uuid::extern_uuid;
struct People(HashMap<extern_uuid::Uuid, &'static str>); struct People(HashMap<extern_uuid::Uuid, &'static str>);
#[get("/people/<id>")] #[get("/people/<id>")]
fn people(id: Uuid, people: State<People>) -> Result<String, String> { fn people(id: Uuid, people: &State<People>) -> Result<String, String> {
// Because Uuid implements the Deref trait, we use Deref coercion to convert // Because Uuid implements the Deref trait, we use Deref coercion to convert
// rocket_contrib::uuid::Uuid to uuid::Uuid. // rocket_contrib::uuid::Uuid to uuid::Uuid.
Ok(people.0.get(&id) Ok(people.0.get(&id)

View File

@ -16,7 +16,7 @@ The process for using managed state is simple:
1. Call `manage` on the `Rocket` instance corresponding to your application 1. Call `manage` on the `Rocket` instance corresponding to your application
with the initial value of the state. with the initial value of the state.
2. Add a `State<T>` type to any request handler, where `T` is the type of the 2. Add a `&State<T>` type to any request handler, where `T` is the type of the
value passed into `manage`. value passed into `manage`.
! note: All managed state must be thread-safe. ! note: All managed state must be thread-safe.
@ -63,10 +63,10 @@ rocket::build()
### Retrieving State ### Retrieving State
State that is being managed by Rocket can be retrieved via the State that is being managed by Rocket can be retrieved via the
[`State`](@api/rocket/struct.State.html) type: a [request [`&State`](@api/rocket/struct.State.html) type: a [request
guard](../requests/#request-guards) for managed state. To use the request guard](../requests/#request-guards) for managed state. To use the request guard,
guard, add a `State<T>` type to any request handler, where `T` is the type of add a `&State<T>` type to any request handler, where `T` is the type of the
the managed state. For example, we can retrieve and respond with the current managed state. For example, we can retrieve and respond with the current
`HitCount` in a `count` route as follows: `HitCount` in a `count` route as follows:
```rust ```rust
@ -79,13 +79,13 @@ the managed state. For example, we can retrieve and respond with the current
use rocket::State; use rocket::State;
#[get("/count")] #[get("/count")]
fn count(hit_count: State<HitCount>) -> String { fn count(hit_count: &State<HitCount>) -> String {
let current_count = hit_count.count.load(Ordering::Relaxed); let current_count = hit_count.count.load(Ordering::Relaxed);
format!("Number of visits: {}", current_count) format!("Number of visits: {}", current_count)
} }
``` ```
You can retrieve more than one `State` type in a single route as well: You can retrieve more than one `&State` type in a single route as well:
```rust ```rust
# #[macro_use] extern crate rocket; # #[macro_use] extern crate rocket;
@ -96,52 +96,60 @@ You can retrieve more than one `State` type in a single route as well:
# use rocket::State; # use rocket::State;
#[get("/state")] #[get("/state")]
fn state(hit_count: State<HitCount>, config: State<Config>) { /* .. */ } fn state(hit_count: &State<HitCount>, config: &State<Config>) { /* .. */ }
``` ```
! warning ! warning
If you request a `State<T>` for a `T` that is not `managed`, Rocket won't call If you request a `&State<T>` for a `T` that is not `managed`, Rocket will
the offending route. Instead, Rocket will log an error message and return a refuse to start your application. This prevents what would have been an
**500** error to the client. unmanaged state runtime error. Unmanaged state is detected at runtime through
[_sentinels_](@api/rocket/trait.Sentinel.html), so there are limitations. If a
limitation is hit, Rocket still won't call an the offending route. Instead,
Rocket will log an error message and return a **500** error to the client.
You can find a complete example using the `HitCount` structure in the [state You can find a complete example using the `HitCount` structure in the [state
example on GitHub](@example/state) and learn more about the [`manage` example on GitHub](@example/state) and learn more about the [`manage`
method](@api/rocket/struct.Rocket.html#method.manage) and [`State` method](@api/rocket/struct.Rocket.html#method.manage) and [`State`
type](@api/rocket/struct.State.html) in the API docs. type](@api/rocket/struct.State.html) in the API docs.
### Within Guards # Within Guards
It can also be useful to retrieve managed state from a `FromRequest` Because `State` is itself a request guard, managed state can be retrieved from
implementation. To do so, simply invoke `State<T>` as a guard using the another request guard's implementation using either [`Request::guard()`] or
[`Request::guard()`] method. [`Rocket::state()`]. In the following code example, the `Item` request guard
retrieves `MyConfig` from managed state using both methods:
```rust ```rust
# #[macro_use] extern crate rocket;
# fn main() {}
use rocket::State; use rocket::State;
use rocket::request::{self, Request, FromRequest}; use rocket::request::{self, Request, FromRequest};
use rocket::outcome::try_outcome; use rocket::outcome::IntoOutcome;
# use std::sync::atomic::{AtomicUsize, Ordering};
# struct MyConfig { user_val: String };
struct Item<'r>(&'r str);
# struct T;
# struct HitCount { count: AtomicUsize }
# type ErrorType = ();
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for T { impl<'r> FromRequest<'r> for Item<'r> {
type Error = ErrorType; type Error = ();
async fn from_request(req: &'r Request<'_>) -> request::Outcome<T, Self::Error> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
let hit_count_state = try_outcome!(req.guard::<State<HitCount>>().await); // Using `State` as a request guard. Use `inner()` to get an `'r`.
let current_count = hit_count_state.count.load(Ordering::Relaxed); let outcome = request.guard::<&State<MyConfig>>().await
/* ... */ .map(|my_config| Item(&my_config.user_val));
# request::Outcome::Success(T)
// Or alternatively, using `Rocket::state()`:
let outcome = request.rocket().state::<MyConfig>()
.map(|my_config| Item(&my_config.user_val))
.or_forward(());
outcome
} }
} }
``` ```
[`Request::guard()`]: @api/rocket/struct.Request.html#method.guard [`Request::guard()`]: @api/rocket/struct.Request.html#method.guard
[`Rocket::state()`]: @api/rocket/struct.Rocket.html#method.state
## Request-Local State ## Request-Local State

View File

@ -253,7 +253,7 @@ Because it is common to store configuration in managed state, Rocket provides an
use rocket::{State, fairing::AdHoc}; use rocket::{State, fairing::AdHoc};
#[get("/custom")] #[get("/custom")]
fn custom(config: State<'_, Config>) -> String { fn custom(config: &State<Config>) -> String {
config.custom.get(0).cloned().unwrap_or("default".into()) config.custom.get(0).cloned().unwrap_or("default".into())
} }