diff --git a/contrib/lib/src/templates/metadata.rs b/contrib/lib/src/templates/metadata.rs index 80590ea5..3cb81696 100644 --- a/contrib/lib/src/templates/metadata.rs +++ b/contrib/lib/src/templates/metadata.rs @@ -1,4 +1,4 @@ -use rocket::{Request, State, Rocket, Ignite, Sentinel}; +use rocket::{Request, Rocket, Ignite, Sentinel}; use rocket::http::Status; use rocket::request::{self, FromRequest}; @@ -103,9 +103,8 @@ impl<'r> FromRequest<'r> for Metadata<'r> { type Error = (); async fn from_request(request: &'r Request<'_>) -> request::Outcome { - request.guard::>().await - .succeeded() - .and_then(|cm| Some(request::Outcome::Success(Metadata(cm.inner())))) + request.rocket().state::() + .map(|cm| request::Outcome::Success(Metadata(cm))) .unwrap_or_else(|| { error_!("Uninitialized template context: missing fairing."); info_!("To use templates, you must attach `Template::fairing()`."); diff --git a/core/codegen/tests/expansion.rs b/core/codegen/tests/expansion.rs index bb130472..012af88d 100644 --- a/core/codegen/tests/expansion.rs +++ b/core/codegen/tests/expansion.rs @@ -49,8 +49,8 @@ fn test_reexpansion() { macro_rules! index { ($type:ty) => { #[get("/")] - fn index(thing: rocket::State<$type>) -> String { - format!("Thing: {}", *thing) + fn index(thing: &rocket::State<$type>) -> String { + format!("Thing: {}", thing) } } } diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index d050ec98..3aa7b8d7 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -666,7 +666,7 @@ crate::export! { /// /// async fn from_request(req: &'r Request<'_>) -> request::Outcome { /// // Attempt to fetch the guard, passing through any error or forward. - /// let atomics = try_outcome!(req.guard::>().await); + /// let atomics = try_outcome!(req.guard::<&State>().await); /// atomics.uncached.fetch_add(1, Ordering::Relaxed); /// req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); /// diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index bd12c71f..7bae39b2 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -366,13 +366,13 @@ impl Rocket { /// struct MyString(String); /// /// #[get("/int")] - /// fn int(state: State<'_, MyInt>) -> String { + /// fn int(state: &State) -> String { /// format!("The stateful int is: {}", state.0) /// } /// /// #[get("/string")] - /// fn string<'r>(state: State<'r, MyString>) -> &'r str { - /// &state.inner().0 + /// fn string(state: &State) -> &str { + /// &state.0 /// } /// /// #[launch] diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index 9a5b4833..1d8c33d4 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -103,8 +103,8 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>; /// use rocket::State; /// /// #[get("/")] -/// fn custom_handler(state: State) -> &'static str { -/// match *state { +/// fn custom_handler(state: &State) -> &'static str { +/// match state.inner() { /// Kind::Simple => "simple", /// Kind::Intermediate => "intermediate", /// Kind::Complex => "complex", diff --git a/core/lib/src/sentinel.rs b/core/lib/src/sentinel.rs index 4182aac1..d23d6dc9 100644 --- a/core/lib/src/sentinel.rs +++ b/core/lib/src/sentinel.rs @@ -30,7 +30,7 @@ use crate::{Rocket, Ignite}; /// # use rocket::*; /// # type Response = (); /// #[get("/")] -/// fn index(id: usize, state: State) -> Response { +/// fn index(id: usize, state: &State) -> Response { /// /* ... */ /// } /// @@ -63,7 +63,7 @@ use crate::{Rocket, Ignite}; /// # use rocket::*; /// # type Response = (); /// # #[get("/")] -/// # fn index(id: usize, state: State) -> Response { +/// # fn index(id: usize, state: &State) -> Response { /// # /* ... */ /// # } /// # @@ -92,19 +92,21 @@ use crate::{Rocket, Ignite}; /// # type Foo = (); /// # type Bar = (); /// #[get("/")] -/// fn f(guard: Option>) -> Either> { +/// fn f(guard: Option<&State>) -> Either> { /// unimplemented!() /// } /// ``` /// /// The directly eligible sentinel types, guard and responders, are: /// -/// * `Option>` +/// * `Option<&State>` /// * `Either>` /// /// In addition, all embedded types are _also_ eligble. These are: /// -/// * `State<'_, String>` +/// * `&State` +/// * `State` +/// * `String` /// * `Foo` /// * `Inner` /// * `Bar` @@ -116,15 +118,17 @@ use crate::{Rocket, Ignite}; /// breadth-first order, is queried: /// /// ```text -/// Option> Either> -/// | / \ -/// State<'_, String> Foo Inner -/// | -/// Bar +/// Option<&State> Either> +/// | / \ +/// &State Foo Inner +/// | | +/// State Bar +/// | +/// String /// ``` /// /// Neither `Option` nor `Either` are sentinels, so they won't be queried. In -/// the next level, `State` is a `Sentinel`, so it _is_ queried. If `Foo` is a +/// the next level, `&State` is a `Sentinel`, so it _is_ queried. If `Foo` is a /// sentinel, it is queried as well. If `Inner` is a sentinel, it is queried, /// and traversal stops without considering `Bar`. If `Inner` is _not_ a /// `Sentinel`, `Bar` is considered and queried if it is a sentinel. @@ -178,10 +182,10 @@ use crate::{Rocket, Ignite}; /// ## Aliases /// /// Embedded discovery of sentinels is syntactic in nature: an embedded sentinel -/// is only discovered if its named in the type. As such, sentinels made opaque -/// by a type alias will fail to be considered. In the example below, only -/// `Result` will be considered, while the embedded `Foo` and `Bar` -/// will not. +/// is only discovered if its named in the type. As such, embedded sentinels +/// made opaque by a type alias will fail to be considered. In the example +/// below, only `Result` will be considered, while the embedded `Foo` +/// and `Bar` will not. /// /// ```rust /// # use rocket::get; diff --git a/core/lib/src/state.rs b/core/lib/src/state.rs index aa5233d8..17f4b238 100644 --- a/core/lib/src/state.rs +++ b/core/lib/src/state.rs @@ -1,6 +1,9 @@ +use std::fmt; use std::ops::Deref; use std::any::type_name; +use ref_cast::RefCast; + use crate::{Phase, Rocket, Ignite, Sentinel}; use crate::request::{self, FromRequest, Request}; use crate::outcome::Outcome; @@ -8,14 +11,13 @@ use crate::http::Status; /// Request guard to retrieve managed state. /// -/// This type can be used as a request guard to retrieve the state Rocket is -/// managing for some type `T`. This allows for the sharing of state across any -/// number of handlers. A value for the given type must previously have been -/// registered to be managed by Rocket via [`Rocket::manage()`]. The type being -/// managed must be thread safe and sendable across thread boundaries. In other -/// words, it must implement [`Send`] + [`Sync`] + `'static`. -/// -/// [`Rocket::manage()`]: crate::Rocket::manage() +/// A reference `&State` type is a request guard which retrieves the managed +/// state managing for some type `T`. A value for the given type must previously +/// have been registered to be managed by Rocket via [`Rocket::manage()`]. The +/// type being managed must be thread safe and sendable across thread +/// boundaries as multiple handlers in multiple threads may be accessing the +/// value at once. In other words, it must implement [`Send`] + [`Sync`] + +/// `'static`. /// /// # Example /// @@ -33,14 +35,13 @@ use crate::http::Status; /// } /// /// #[get("/")] -/// fn index(state: State<'_, MyConfig>) -> String { +/// fn index(state: &State) -> String { /// format!("The config value is: {}", state.user_val) /// } /// /// #[get("/raw")] -/// fn raw_config_value<'r>(state: State<'r, MyConfig>) -> &'r str { -/// // use `inner()` to get a lifetime longer than `deref` gives us -/// state.inner().user_val.as_str() +/// fn raw_config_value(state: &State) -> &str { +/// &state.user_val /// } /// /// #[launch] @@ -72,10 +73,10 @@ use crate::http::Status; /// /// async fn from_request(request: &'r Request<'_>) -> request::Outcome { /// // Using `State` as a request guard. Use `inner()` to get an `'r`. -/// let outcome = request.guard::>().await -/// .map(|my_config| Item(&my_config.inner().user_val)); +/// let outcome = request.guard::<&State>().await +/// .map(|my_config| Item(&my_config.user_val)); /// -/// // Or alternatively, using `Request::managed_state()`: +/// // Or alternatively, using `Rocket::state()`: /// let outcome = request.rocket().state::() /// .map(|my_config| Item(&my_config.user_val)) /// .or_forward(()); @@ -89,7 +90,7 @@ use crate::http::Status; /// /// When unit testing your application, you may find it necessary to manually /// construct a type of `State` to pass to your functions. To do so, use the -/// [`State::from()`] static method: +/// [`State::get()`] static method or the `From<&T>` implementation: /// /// ```rust /// # #[macro_use] extern crate rocket; @@ -98,49 +99,22 @@ use crate::http::Status; /// struct MyManagedState(usize); /// /// #[get("/")] -/// fn handler(state: State<'_, MyManagedState>) -> String { +/// fn handler(state: &State) -> String { /// state.0.to_string() /// } /// /// let mut rocket = rocket::build().manage(MyManagedState(127)); -/// let state = State::from(&rocket).expect("managed `MyManagedState`"); +/// let state = State::get(&rocket).expect("managed `MyManagedState`"); /// assert_eq!(handler(state), "127"); +/// +/// let managed = MyManagedState(77); +/// assert_eq!(handler(State::from(&managed)), "77"); /// ``` -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct State<'r, T: Send + Sync + 'static>(&'r T); - -impl<'r, T: Send + Sync + 'static> State<'r, T> { - /// Retrieve a borrow to the underlying value with a lifetime of `'r`. - /// - /// Using this method is typically unnecessary as `State` implements - /// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will - /// automatically coerce a `State` to an `&T` as required. This method - /// should only be used when a longer lifetime is required. - /// - /// # Example - /// - /// ```rust - /// use rocket::State; - /// - /// struct MyConfig { - /// user_val: String - /// } - /// - /// // Use `inner()` to get a lifetime of `'r` - /// fn handler1<'r>(config: State<'r, MyConfig>) -> &'r str { - /// &config.inner().user_val - /// } - /// - /// // Use the `Deref` implementation which coerces implicitly - /// fn handler2(config: State<'_, MyConfig>) -> String { - /// config.user_val.clone() - /// } - /// ``` - #[inline(always)] - pub fn inner(&self) -> &'r T { - self.0 - } +#[repr(transparent)] +#[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct State(T); +impl State { /// Returns the managed state value in `rocket` for the type `T` if it is /// being managed by `rocket`. Otherwise, returns `None`. /// @@ -157,26 +131,73 @@ impl<'r, T: Send + Sync + 'static> State<'r, T> { /// /// let rocket = rocket::build().manage(Managed(7)); /// - /// let state: Option> = State::from(&rocket); + /// let state: Option<&State> = State::get(&rocket); /// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7))); /// - /// let state: Option> = State::from(&rocket); + /// let state: Option<&State> = State::get(&rocket); /// assert_eq!(state, None); /// ``` #[inline(always)] - pub fn from(rocket: &'r Rocket

) -> Option { - rocket.state().map(State) + pub fn get(rocket: &Rocket

) -> Option<&State> { + rocket.state::().map(State::ref_cast) + } + + /// This exists because `State::from()` would otherwise be nothing. But we + /// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is. + #[doc(hidden)] + #[inline(always)] + pub fn from(value: &T) -> &State { + State::ref_cast(value) + } + + /// Borrow the inner value. + /// + /// Using this method is typically unnecessary as `State` implements + /// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will + /// automatically coerce a `State` to an `&T` as required. This method + /// should only be used when a longer lifetime is required. + /// + /// # Example + /// + /// ```rust + /// use rocket::State; + /// + /// #[derive(Clone)] + /// struct MyConfig { + /// user_val: String + /// } + /// + /// fn handler1<'r>(config: &State) -> String { + /// let config = config.inner().clone(); + /// config.user_val + /// } + /// + /// // Use the `Deref` implementation which coerces implicitly + /// fn handler2(config: &State) -> String { + /// config.user_val.clone() + /// } + /// ``` + #[inline(always)] + pub fn inner(&self) -> &T { + &self.0 + } +} + +impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State { + #[inline(always)] + fn from(reference: &'r T) -> Self { + State::ref_cast(reference) } } #[crate::async_trait] -impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> { +impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State { type Error = (); #[inline(always)] async fn from_request(req: &'r Request<'_>) -> request::Outcome { - match req.rocket().state::() { - Some(state) => Outcome::Success(State(state)), + match State::get(req.rocket()) { + Some(state) => Outcome::Success(state), None => { error_!("Attempted to retrieve unmanaged state `{}`!", type_name::()); Outcome::Failure((Status::InternalServerError, ())) @@ -185,7 +206,7 @@ impl<'r, T: Send + Sync + 'static> FromRequest<'r> for State<'r, T> { } } -impl Sentinel for State<'_, T> { +impl Sentinel for &State { fn abort(rocket: &Rocket) -> bool { if rocket.state::().is_none() { let type_name = yansi::Paint::default(type_name::()).bold(); @@ -198,30 +219,17 @@ impl Sentinel for State<'_, T> { } } -impl Deref for State<'_, T> { +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for State { type Target = T; #[inline(always)] fn deref(&self) -> &T { - self.0 - } -} - -impl Clone for State<'_, T> { - fn clone(&self) -> Self { - State(self.0) - } -} - -#[cfg(test)] -mod tests { - #[test] - fn state_is_cloneable() { - struct Token(usize); - - let rocket = crate::custom(crate::Config::default()).manage(Token(123)); - let state = rocket.state::().unwrap(); - assert_eq!(state.0, 123); - assert_eq!(state.clone().0, 123); + &self.0 } } diff --git a/core/lib/tests/fairing_before_head_strip-issue-546.rs b/core/lib/tests/fairing_before_head_strip-issue-546.rs index 368098b3..16424716 100644 --- a/core/lib/tests/fairing_before_head_strip-issue-546.rs +++ b/core/lib/tests/fairing_before_head_strip-issue-546.rs @@ -19,7 +19,6 @@ mod fairing_before_head_strip { use std::sync::atomic::{AtomicUsize, Ordering}; use std::io::Cursor; - use rocket::State; use rocket::fairing::AdHoc; use rocket::local::blocking::Client; use rocket::http::{Method, Status}; @@ -62,7 +61,7 @@ mod fairing_before_head_strip { assert_eq!(req.method(), Method::Head); // This should be called exactly once. - let c = req.guard::>().await.unwrap(); + let c = req.rocket().state::().unwrap(); assert_eq!(c.0.fetch_add(1, Ordering::SeqCst), 0); }) })) diff --git a/core/lib/tests/nested-fairing-attaches.rs b/core/lib/tests/nested-fairing-attaches.rs index 3fa8364f..e2da0558 100644 --- a/core/lib/tests/nested-fairing-attaches.rs +++ b/core/lib/tests/nested-fairing-attaches.rs @@ -13,7 +13,7 @@ struct Counter { } #[get("/")] -fn index(counter: State<'_, Counter>) -> String { +fn index(counter: &State) -> String { let attaches = counter.attach.load(Ordering::Relaxed); let gets = counter.get.load(Ordering::Acquire); format!("{}, {}", attaches, gets) @@ -29,8 +29,7 @@ fn rocket() -> Rocket { .attach(AdHoc::on_request("Inner", |req, _| { Box::pin(async move { if req.method() == Method::Get { - let counter = req.guard::>() - .await.unwrap(); + let counter = req.rocket().state::().unwrap(); counter.get.fetch_add(1, Ordering::Release); } }) diff --git a/core/lib/tests/responder_lifetime-issue-345.rs b/core/lib/tests/responder_lifetime-issue-345.rs index aef958e8..ce7305f4 100644 --- a/core/lib/tests/responder_lifetime-issue-345.rs +++ b/core/lib/tests/responder_lifetime-issue-345.rs @@ -19,11 +19,11 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for CustomResponder<'r, } #[get("/unit_state")] -fn unit_state(state: State) -> CustomResponder<()> { - CustomResponder { responder: (), state: state.inner() } +fn unit_state(state: &State) -> CustomResponder<()> { + CustomResponder { responder: (), state: &*state } } #[get("/string_state")] -fn string_state(state: State) -> CustomResponder { - CustomResponder { responder: "".to_string(), state: state.inner() } +fn string_state(state: &State) -> CustomResponder { + CustomResponder { responder: "".to_string(), state: &*state } } diff --git a/core/lib/tests/sentinel.rs b/core/lib/tests/sentinel.rs index d0ac364a..073cbf01 100644 --- a/core/lib/tests/sentinel.rs +++ b/core/lib/tests/sentinel.rs @@ -1,10 +1,10 @@ use rocket::{*, error::ErrorKind::SentinelAborts}; #[get("/two")] -fn two_states(_one: State, _two: State) {} +fn two_states(_one: &State, _two: &State) {} #[get("/one")] -fn one_state(_three: State) {} +fn one_state(_three: &State) {} #[async_test] async fn state_sentinel_works() { diff --git a/examples/config/src/main.rs b/examples/config/src/main.rs index 810337e6..b10d7963 100644 --- a/examples/config/src/main.rs +++ b/examples/config/src/main.rs @@ -14,7 +14,7 @@ struct AppConfig { } #[get("/")] -fn read_config(rocket_config: &Config, app_config: State<'_, AppConfig>) -> String { +fn read_config(rocket_config: &Config, app_config: &State) -> String { format!("{:#?}\n{:#?}", app_config, rocket_config) } diff --git a/examples/databases/src/sqlx.rs b/examples/databases/src/sqlx.rs index 350ba3cd..1f1bf572 100644 --- a/examples/databases/src/sqlx.rs +++ b/examples/databases/src/sqlx.rs @@ -21,19 +21,19 @@ struct Post { } #[post("/", data = "")] -async fn create(db: State<'_, Db>, post: Json) -> Result>> { +async fn create(db: &State, post: Json) -> Result>> { // There is no support for `RETURNING`. sqlx::query!("INSERT INTO posts (title, text) VALUES (?, ?)", post.title, post.text) - .execute(&*db) + .execute(&**db) .await?; Ok(Created::new("/").body(post)) } #[get("/")] -async fn list(db: State<'_, Db>) -> Result>> { +async fn list(db: &State) -> Result>> { let ids = sqlx::query!("SELECT id FROM posts") - .fetch(&*db) + .fetch(&**db) .map_ok(|record| record.id) .try_collect::>() .await?; @@ -42,26 +42,26 @@ async fn list(db: State<'_, Db>) -> Result>> { } #[get("/")] -async fn read(db: State<'_, Db>, id: i64) -> Option> { +async fn read(db: &State, id: i64) -> Option> { sqlx::query!("SELECT id, title, text FROM posts WHERE id = ?", id) - .fetch_one(&*db) + .fetch_one(&**db) .map_ok(|r| Json(Post { id: Some(r.id), title: r.title, text: r.text })) .await .ok() } #[delete("/")] -async fn delete(db: State<'_, Db>, id: i64) -> Result> { +async fn delete(db: &State, id: i64) -> Result> { let result = sqlx::query!("DELETE FROM posts WHERE id = ?", id) - .execute(&*db) + .execute(&**db) .await?; Ok((result.rows_affected() == 1).then(|| ())) } #[delete("/")] -async fn destroy(db: State<'_, Db>) -> Result<()> { - sqlx::query!("DELETE FROM posts").execute(&*db).await?; +async fn destroy(db: &State) -> Result<()> { + sqlx::query!("DELETE FROM posts").execute(&**db).await?; Ok(()) } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index a82b73f3..b31cfe38 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -45,7 +45,7 @@ fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom #[allow(dead_code)] #[get("/unmanaged")] -fn unmanaged(_u8: rocket::State<'_, u8>, _string: rocket::State<'_, String>) { } +fn unmanaged(_u8: &rocket::State, _string: &rocket::State) { } fn rocket() -> Rocket { rocket::build() diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index 2b769e7a..48ce44da 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -29,7 +29,7 @@ impl Fairing for Counter { async fn on_ignite(&self, rocket: Rocket) -> fairing::Result { #[get("/counts")] - fn counts(counts: State<'_, Counter>) -> String { + fn counts(counts: &State) -> String { let get_count = counts.get.load(Ordering::Relaxed); let post_count = counts.post.load(Ordering::Relaxed); format!("Get: {}\nPost: {}", get_count, post_count) @@ -53,7 +53,7 @@ fn hello() -> &'static str { } #[get("/token")] -fn token(token: State<'_, Token>) -> String { +fn token(token: &State) -> String { format!("{}", token.0) } diff --git a/examples/pastebin/src/main.rs b/examples/pastebin/src/main.rs index 996bf9c2..ac883e9c 100644 --- a/examples/pastebin/src/main.rs +++ b/examples/pastebin/src/main.rs @@ -17,7 +17,7 @@ const HOST: &str = "http://localhost:8000"; const ID_LENGTH: usize = 3; #[post("/", data = "")] -async fn upload(paste: Data, host: State<'_, Absolute<'_>>) -> io::Result { +async fn upload(paste: Data, host: &State>) -> io::Result { let id = PasteId::new(ID_LENGTH); paste.open(128.kibibytes()).into_file(id.file_path()).await?; diff --git a/examples/serialization/src/json.rs b/examples/serialization/src/json.rs index 640bd719..eda61879 100644 --- a/examples/serialization/src/json.rs +++ b/examples/serialization/src/json.rs @@ -11,7 +11,7 @@ type Id = usize; // We're going to store all of the messages here. No need for a DB. type MessageList = Mutex>; -type Messages<'r> = State<'r, MessageList>; +type Messages<'r> = &'r State; #[derive(Serialize, Deserialize)] struct Message<'r> { diff --git a/examples/state/src/managed_hit_count.rs b/examples/state/src/managed_hit_count.rs index 7b4c59db..6138eead 100644 --- a/examples/state/src/managed_hit_count.rs +++ b/examples/state/src/managed_hit_count.rs @@ -7,7 +7,7 @@ use rocket::fairing::AdHoc; struct HitCount(AtomicUsize); #[get("/")] -fn index(hit_count: State<'_, HitCount>) -> content::Html { +fn index(hit_count: &State) -> content::Html { let count = hit_count.0.fetch_add(1, Ordering::Relaxed) + 1; content::Html(format!("Your visit is recorded!

Visits: {}", count)) } diff --git a/examples/state/src/managed_queue.rs b/examples/state/src/managed_queue.rs index 4dcf919e..2cafff28 100644 --- a/examples/state/src/managed_queue.rs +++ b/examples/state/src/managed_queue.rs @@ -6,12 +6,12 @@ struct Tx(flume::Sender); struct Rx(flume::Receiver); #[put("/push?")] -fn push(event: String, tx: State<'_, Tx>) -> Result<(), Status> { +fn push(event: String, tx: &State) -> Result<(), Status> { tx.0.try_send(event).map_err(|_| Status::ServiceUnavailable) } #[get("/pop")] -fn pop(rx: State<'_, Rx>) -> Option { +fn pop(rx: &State) -> Option { rx.0.try_recv().ok() } diff --git a/examples/state/src/request_local.rs b/examples/state/src/request_local.rs index afe2254a..e6e1ed29 100644 --- a/examples/state/src/request_local.rs +++ b/examples/state/src/request_local.rs @@ -23,7 +23,7 @@ impl<'r> FromRequest<'r> for Guard1 { async fn from_request(req: &'r Request<'_>) -> request::Outcome { rocket::info_!("-- 1 --"); - let atomics = try_outcome!(req.guard::>().await); + let atomics = try_outcome!(req.guard::<&State>().await); atomics.uncached.fetch_add(1, Ordering::Relaxed); req.local_cache(|| { rocket::info_!("1: populating cache!"); @@ -53,7 +53,7 @@ impl<'r> FromRequest<'r> for Guard3 { async fn from_request(req: &'r Request<'_>) -> request::Outcome { rocket::info_!("-- 3 --"); - let atomics = try_outcome!(req.guard::>().await); + let atomics = try_outcome!(req.guard::<&State>().await); atomics.uncached.fetch_add(1, Ordering::Relaxed); req.local_cache_async(async { rocket::info_!("3: populating cache!"); @@ -77,12 +77,12 @@ impl<'r> FromRequest<'r> for Guard4 { } #[get("/1-2")] -fn one_two(_g1: Guard1, _g2: Guard2, state: State<'_, Atomics>) -> String { +fn one_two(_g1: Guard1, _g2: Guard2, state: &State) -> String { format!("{:#?}", state) } #[get("/3-4")] -fn three_four(_g3: Guard3, _g4: Guard4, state: State<'_, Atomics>) -> String { +fn three_four(_g3: Guard3, _g4: Guard4, state: &State) -> String { format!("{:#?}", state) } @@ -92,7 +92,7 @@ fn all( _g2: Guard2, _g3: Guard3, _g4: Guard4, - state: State<'_, Atomics> + state: &State ) -> String { format!("{:#?}", state) } diff --git a/examples/testing/src/async_required.rs b/examples/testing/src/async_required.rs index 7c9660f4..4dbcc3cd 100644 --- a/examples/testing/src/async_required.rs +++ b/examples/testing/src/async_required.rs @@ -3,7 +3,7 @@ use rocket::fairing::AdHoc; use rocket::tokio::sync::Barrier; #[get("/barrier")] -async fn rendezvous(barrier: State<'_, Barrier>) -> &'static str { +async fn rendezvous(barrier: &State) -> &'static str { println!("Waiting for second task..."); barrier.wait().await; "Rendezvous reached." diff --git a/examples/uuid/src/main.rs b/examples/uuid/src/main.rs index fc7e0f54..b999805b 100644 --- a/examples/uuid/src/main.rs +++ b/examples/uuid/src/main.rs @@ -15,7 +15,7 @@ use rocket_contrib::uuid::extern_uuid; struct People(HashMap); #[get("/people/")] -fn people(id: Uuid, people: State) -> Result { +fn people(id: Uuid, people: &State) -> Result { // Because Uuid implements the Deref trait, we use Deref coercion to convert // rocket_contrib::uuid::Uuid to uuid::Uuid. Ok(people.0.get(&id) diff --git a/site/guide/6-state.md b/site/guide/6-state.md index 7f25b1c5..72df1225 100644 --- a/site/guide/6-state.md +++ b/site/guide/6-state.md @@ -16,7 +16,7 @@ The process for using managed state is simple: 1. Call `manage` on the `Rocket` instance corresponding to your application with the initial value of the state. - 2. Add a `State` type to any request handler, where `T` is the type of the + 2. Add a `&State` type to any request handler, where `T` is the type of the value passed into `manage`. ! note: All managed state must be thread-safe. @@ -63,10 +63,10 @@ rocket::build() ### Retrieving State State that is being managed by Rocket can be retrieved via the -[`State`](@api/rocket/struct.State.html) type: a [request -guard](../requests/#request-guards) for managed state. To use the request -guard, add a `State` type to any request handler, where `T` is the type of -the managed state. For example, we can retrieve and respond with the current +[`&State`](@api/rocket/struct.State.html) type: a [request +guard](../requests/#request-guards) for managed state. To use the request guard, +add a `&State` type to any request handler, where `T` is the type of the +managed state. For example, we can retrieve and respond with the current `HitCount` in a `count` route as follows: ```rust @@ -79,13 +79,13 @@ the managed state. For example, we can retrieve and respond with the current use rocket::State; #[get("/count")] -fn count(hit_count: State) -> String { +fn count(hit_count: &State) -> String { let current_count = hit_count.count.load(Ordering::Relaxed); format!("Number of visits: {}", current_count) } ``` -You can retrieve more than one `State` type in a single route as well: +You can retrieve more than one `&State` type in a single route as well: ```rust # #[macro_use] extern crate rocket; @@ -96,52 +96,60 @@ You can retrieve more than one `State` type in a single route as well: # use rocket::State; #[get("/state")] -fn state(hit_count: State, config: State) { /* .. */ } +fn state(hit_count: &State, config: &State) { /* .. */ } ``` ! warning - If you request a `State` for a `T` that is not `managed`, Rocket won't call - the offending route. Instead, Rocket will log an error message and return a - **500** error to the client. + If you request a `&State` for a `T` that is not `managed`, Rocket will + refuse to start your application. This prevents what would have been an + unmanaged state runtime error. Unmanaged state is detected at runtime through + [_sentinels_](@api/rocket/trait.Sentinel.html), so there are limitations. If a + limitation is hit, Rocket still won't call an the offending route. Instead, + Rocket will log an error message and return a **500** error to the client. You can find a complete example using the `HitCount` structure in the [state example on GitHub](@example/state) and learn more about the [`manage` method](@api/rocket/struct.Rocket.html#method.manage) and [`State` type](@api/rocket/struct.State.html) in the API docs. -### Within Guards +# Within Guards -It can also be useful to retrieve managed state from a `FromRequest` -implementation. To do so, simply invoke `State` as a guard using the -[`Request::guard()`] method. +Because `State` is itself a request guard, managed state can be retrieved from +another request guard's implementation using either [`Request::guard()`] or +[`Rocket::state()`]. In the following code example, the `Item` request guard +retrieves `MyConfig` from managed state using both methods: ```rust -# #[macro_use] extern crate rocket; -# fn main() {} - use rocket::State; use rocket::request::{self, Request, FromRequest}; -use rocket::outcome::try_outcome; -# use std::sync::atomic::{AtomicUsize, Ordering}; +use rocket::outcome::IntoOutcome; + +# struct MyConfig { user_val: String }; +struct Item<'r>(&'r str); -# struct T; -# struct HitCount { count: AtomicUsize } -# type ErrorType = (); #[rocket::async_trait] -impl<'r> FromRequest<'r> for T { - type Error = ErrorType; +impl<'r> FromRequest<'r> for Item<'r> { + type Error = (); - async fn from_request(req: &'r Request<'_>) -> request::Outcome { - let hit_count_state = try_outcome!(req.guard::>().await); - let current_count = hit_count_state.count.load(Ordering::Relaxed); - /* ... */ - # request::Outcome::Success(T) + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + // Using `State` as a request guard. Use `inner()` to get an `'r`. + let outcome = request.guard::<&State>().await + .map(|my_config| Item(&my_config.user_val)); + + // Or alternatively, using `Rocket::state()`: + let outcome = request.rocket().state::() + .map(|my_config| Item(&my_config.user_val)) + .or_forward(()); + + outcome } } ``` + [`Request::guard()`]: @api/rocket/struct.Request.html#method.guard +[`Rocket::state()`]: @api/rocket/struct.Rocket.html#method.state ## Request-Local State diff --git a/site/guide/9-configuration.md b/site/guide/9-configuration.md index 99a143dc..dac53657 100644 --- a/site/guide/9-configuration.md +++ b/site/guide/9-configuration.md @@ -253,7 +253,7 @@ Because it is common to store configuration in managed state, Rocket provides an use rocket::{State, fairing::AdHoc}; #[get("/custom")] -fn custom(config: State<'_, Config>) -> String { +fn custom(config: &State) -> String { config.custom.get(0).cloned().unwrap_or("default".into()) }