mirror of https://github.com/rwf2/Rocket.git
Introduce Managed State.
This commit is contained in:
parent
9ef65a8c91
commit
c815911705
|
@ -28,4 +28,5 @@ members = [
|
|||
"examples/hello_alt_methods",
|
||||
"examples/raw_upload",
|
||||
"examples/pastebin",
|
||||
"examples/state",
|
||||
]
|
||||
|
|
|
@ -40,13 +40,15 @@ mod test {
|
|||
use rocket::http::Header;
|
||||
|
||||
fn test_header_count<'h>(headers: Vec<Header<'static>>) {
|
||||
let rocket = rocket::ignite()
|
||||
.mount("/", routes![super::header_count]);
|
||||
|
||||
let num_headers = headers.len();
|
||||
let mut req = MockRequest::new(Get, "/");
|
||||
for header in headers {
|
||||
req = req.header(header);
|
||||
}
|
||||
|
||||
let rocket = rocket::ignite().mount("/", routes![super::header_count]);
|
||||
let mut response = req.dispatch_with(&rocket);
|
||||
|
||||
let expect = format!("Your request contained {} headers!", num_headers);
|
||||
|
|
|
@ -11,7 +11,8 @@ macro_rules! run_test {
|
|||
.mount("/", routes![super::index, super::get])
|
||||
.catch(errors![super::not_found]);
|
||||
|
||||
$test_fn($req.dispatch_with(&rocket));
|
||||
let mut req = $req;
|
||||
$test_fn(req.dispatch_with(&rocket));
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -19,7 +20,7 @@ macro_rules! run_test {
|
|||
fn test_root() {
|
||||
// Check that the redirect works.
|
||||
for method in &[Get, Head] {
|
||||
let mut req = MockRequest::new(*method, "/");
|
||||
let req = MockRequest::new(*method, "/");
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::SeeOther);
|
||||
assert!(response.body().is_none());
|
||||
|
@ -31,7 +32,7 @@ fn test_root() {
|
|||
|
||||
// Check that other request methods are not accepted (and instead caught).
|
||||
for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] {
|
||||
let mut req = MockRequest::new(*method, "/");
|
||||
let req = MockRequest::new(*method, "/");
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
|
||||
|
@ -48,7 +49,7 @@ fn test_root() {
|
|||
#[test]
|
||||
fn test_name() {
|
||||
// Check that the /hello/<name> route works.
|
||||
let mut req = MockRequest::new(Get, "/hello/Jack");
|
||||
let req = MockRequest::new(Get, "/hello/Jack");
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::Ok);
|
||||
|
||||
|
@ -66,7 +67,7 @@ fn test_name() {
|
|||
#[test]
|
||||
fn test_404() {
|
||||
// Check that the error catcher works.
|
||||
let mut req = MockRequest::new(Get, "/hello/");
|
||||
let req = MockRequest::new(Get, "/hello/");
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
|
||||
|
|
|
@ -10,14 +10,15 @@ macro_rules! run_test {
|
|||
.mount("/message", routes![super::new, super::update, super::get])
|
||||
.catch(errors![super::not_found]);
|
||||
|
||||
$test_fn($req.dispatch_with(&rocket));
|
||||
let mut req = $req;
|
||||
$test_fn(req.dispatch_with(&rocket));
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bad_get_put() {
|
||||
// Try to get a message with an ID that doesn't exist.
|
||||
let mut req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
|
||||
let req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
|
||||
|
@ -27,7 +28,7 @@ fn bad_get_put() {
|
|||
});
|
||||
|
||||
// Try to get a message with an invalid ID.
|
||||
let mut req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
|
||||
let req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
let body = response.body().unwrap().into_string().unwrap();
|
||||
|
@ -35,13 +36,13 @@ fn bad_get_put() {
|
|||
});
|
||||
|
||||
// Try to put a message without a proper body.
|
||||
let mut req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
|
||||
let req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
|
||||
run_test!(req, |response: Response| {
|
||||
assert_eq!(response.status(), Status::BadRequest);
|
||||
});
|
||||
|
||||
// Try to put a message for an ID that doesn't exist.
|
||||
let mut req = MockRequest::new(Put, "/message/80")
|
||||
let req = MockRequest::new(Put, "/message/80")
|
||||
.header(ContentType::JSON)
|
||||
.body(r#"{ "contents": "Bye bye, world!" }"#);
|
||||
|
||||
|
@ -53,13 +54,13 @@ fn bad_get_put() {
|
|||
#[test]
|
||||
fn post_get_put_get() {
|
||||
// Check that a message with ID 1 doesn't exist.
|
||||
let mut req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
|
||||
let req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
|
||||
run_test!(req, |response: Response| {
|
||||
assert_eq!(response.status(), Status::NotFound);
|
||||
});
|
||||
|
||||
// Add a new message with ID 1.
|
||||
let mut req = MockRequest::new(Post, "/message/1")
|
||||
let req = MockRequest::new(Post, "/message/1")
|
||||
.header(ContentType::JSON)
|
||||
.body(r#"{ "contents": "Hello, world!" }"#);
|
||||
|
||||
|
@ -68,7 +69,7 @@ fn post_get_put_get() {
|
|||
});
|
||||
|
||||
// Check that the message exists with the correct contents.
|
||||
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
|
||||
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::Ok);
|
||||
let body = response.body().unwrap().into_string().unwrap();
|
||||
|
@ -76,7 +77,7 @@ fn post_get_put_get() {
|
|||
});
|
||||
|
||||
// Change the message contents.
|
||||
let mut req = MockRequest::new(Put, "/message/1")
|
||||
let req = MockRequest::new(Put, "/message/1")
|
||||
.header(ContentType::JSON)
|
||||
.body(r#"{ "contents": "Bye bye, world!" }"#);
|
||||
|
||||
|
@ -85,7 +86,7 @@ fn post_get_put_get() {
|
|||
});
|
||||
|
||||
// Check that the message exists with the updated contents.
|
||||
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
|
||||
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
|
||||
run_test!(req, |mut response: Response| {
|
||||
assert_eq!(response.status(), Status::Ok);
|
||||
let body = response.body().unwrap().into_string().unwrap();
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "state"
|
||||
version = "0.0.1"
|
||||
workspace = "../../"
|
||||
|
||||
[dependencies]
|
||||
rocket = { path = "../../lib" }
|
||||
rocket_codegen = { path = "../../codegen" }
|
||||
|
||||
[dev-dependencies]
|
||||
rocket = { path = "../../lib", features = ["testing"] }
|
|
@ -0,0 +1,36 @@
|
|||
#![feature(plugin)]
|
||||
#![plugin(rocket_codegen)]
|
||||
|
||||
extern crate rocket;
|
||||
|
||||
#[cfg(test)] mod tests;
|
||||
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use rocket::State;
|
||||
use rocket::response::content;
|
||||
|
||||
struct HitCount(AtomicUsize);
|
||||
|
||||
#[get("/")]
|
||||
fn index(hit_count: State<HitCount>) -> content::HTML<String> {
|
||||
hit_count.0.fetch_add(1, Ordering::Relaxed);
|
||||
let msg = "Your visit has been recorded!";
|
||||
let count = format!("Visits: {}", count(hit_count));
|
||||
content::HTML(format!("{}<br /><br />{}", msg, count))
|
||||
}
|
||||
|
||||
#[get("/count")]
|
||||
fn count(hit_count: State<HitCount>) -> String {
|
||||
hit_count.0.load(Ordering::Relaxed).to_string()
|
||||
}
|
||||
|
||||
fn rocket() -> rocket::Rocket {
|
||||
rocket::ignite()
|
||||
.mount("/", routes![index, count])
|
||||
.manage(HitCount(AtomicUsize::new(0)))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
rocket().launch();
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
use rocket::Rocket;
|
||||
use rocket::testing::MockRequest;
|
||||
use rocket::http::Method::*;
|
||||
use rocket::http::Status;
|
||||
|
||||
fn register_hit(rocket: &Rocket) {
|
||||
let mut req = MockRequest::new(Get, "/");
|
||||
let response = req.dispatch_with(&rocket);
|
||||
assert_eq!(response.status(), Status::Ok);
|
||||
}
|
||||
|
||||
fn get_count(rocket: &Rocket) -> usize {
|
||||
let mut req = MockRequest::new(Get, "/count");
|
||||
let mut response = req.dispatch_with(&rocket);
|
||||
let body_string = response.body().and_then(|b| b.into_string()).unwrap();
|
||||
body_string.parse().unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count() {
|
||||
let rocket = super::rocket();
|
||||
|
||||
// Count should start at 0.
|
||||
assert_eq!(get_count(&rocket), 0);
|
||||
|
||||
for _ in 0..99 { register_hit(&rocket); }
|
||||
assert_eq!(get_count(&rocket), 99);
|
||||
|
||||
register_hit(&rocket);
|
||||
assert_eq!(get_count(&rocket), 100);
|
||||
}
|
||||
|
||||
// Cargo runs each test in parallel on different threads. We use all of these
|
||||
// tests below to show (and assert) that state is managed per-Rocket instance.
|
||||
#[test] fn test_count_parallel() { test_count() }
|
||||
#[test] fn test_count_parallel_2() { test_count() }
|
||||
#[test] fn test_count_parallel_3() { test_count() }
|
||||
#[test] fn test_count_parallel_4() { test_count() }
|
||||
#[test] fn test_count_parallel_5() { test_count() }
|
||||
#[test] fn test_count_parallel_6() { test_count() }
|
||||
#[test] fn test_count_parallel_7() { test_count() }
|
||||
#[test] fn test_count_parallel_8() { test_count() }
|
||||
#[test] fn test_count_parallel_9() { test_count() }
|
|
@ -21,6 +21,7 @@ url = "^1"
|
|||
hyper = { version = "^0.9.14", default-features = false }
|
||||
toml = { version = "^0.2", default-features = false }
|
||||
num_cpus = "1"
|
||||
state = "^0.2"
|
||||
# cookie = "^0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -96,6 +96,7 @@ extern crate hyper;
|
|||
extern crate url;
|
||||
extern crate toml;
|
||||
extern crate num_cpus;
|
||||
extern crate state;
|
||||
|
||||
#[cfg(test)] #[macro_use] extern crate lazy_static;
|
||||
|
||||
|
@ -123,7 +124,7 @@ mod ext;
|
|||
#[doc(inline)] pub use outcome::Outcome;
|
||||
#[doc(inline)] pub use data::Data;
|
||||
pub use router::Route;
|
||||
pub use request::Request;
|
||||
pub use request::{Request, State};
|
||||
pub use error::Error;
|
||||
pub use catcher::Catcher;
|
||||
pub use rocket::Rocket;
|
||||
|
|
|
@ -4,11 +4,13 @@ mod request;
|
|||
mod param;
|
||||
mod form;
|
||||
mod from_request;
|
||||
mod state;
|
||||
|
||||
pub use self::request::Request;
|
||||
pub use self::from_request::{FromRequest, Outcome};
|
||||
pub use self::param::{FromParam, FromSegments};
|
||||
pub use self::form::{Form, FromForm, FromFormValue, FormItems};
|
||||
pub use self::state::State;
|
||||
|
||||
/// Type alias to retrieve flash messages from a request.
|
||||
pub type FlashMessage = ::response::Flash<()>;
|
||||
|
|
|
@ -5,6 +5,8 @@ use std::fmt;
|
|||
use term_painter::Color::*;
|
||||
use term_painter::ToStyle;
|
||||
|
||||
use state::Container;
|
||||
|
||||
use error::Error;
|
||||
use super::{FromParam, FromSegments};
|
||||
|
||||
|
@ -28,6 +30,7 @@ pub struct Request<'r> {
|
|||
remote: Option<SocketAddr>,
|
||||
params: RefCell<Vec<(usize, usize)>>,
|
||||
cookies: Cookies,
|
||||
state: Option<&'r Container>,
|
||||
}
|
||||
|
||||
impl<'r> Request<'r> {
|
||||
|
@ -51,6 +54,7 @@ impl<'r> Request<'r> {
|
|||
remote: None,
|
||||
params: RefCell::new(Vec::new()),
|
||||
cookies: Cookies::new(&[]),
|
||||
state: None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -391,13 +395,25 @@ impl<'r> Request<'r> {
|
|||
Some(Segments(&path[i..j]))
|
||||
}
|
||||
|
||||
/// Get the managed state container, if it exists. For internal use only!
|
||||
#[doc(hidden)]
|
||||
pub fn get_state(&self) -> Option<&'r Container> {
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Set the state. For internal use only!
|
||||
#[doc(hidden)]
|
||||
pub fn set_state(&mut self, state: &'r Container) {
|
||||
self.state = Some(state);
|
||||
}
|
||||
|
||||
/// Convert from Hyper types into a Rocket Request.
|
||||
#[doc(hidden)]
|
||||
pub fn from_hyp(h_method: hyper::Method,
|
||||
h_headers: hyper::header::Headers,
|
||||
h_uri: hyper::RequestUri,
|
||||
h_addr: SocketAddr,
|
||||
) -> Result<Request<'static>, String> {
|
||||
) -> Result<Request<'r>, String> {
|
||||
// Get a copy of the URI for later use.
|
||||
let uri = match h_uri {
|
||||
hyper::RequestUri::AbsolutePath(s) => s,
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
use std::ops::Deref;
|
||||
|
||||
use request::{self, FromRequest, Request};
|
||||
use outcome::Outcome;
|
||||
use http::Status;
|
||||
|
||||
// TODO: Doc.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct State<'r, T: Send + Sync + 'static>(&'r T);
|
||||
|
||||
impl<'r, T: Send + Sync + 'static> State<'r, T> {
|
||||
/// Retrieve a borrow to the underyling value.
|
||||
///
|
||||
/// Using this method is typically unnecessary as `State` implements `Deref`
|
||||
/// with a `Target` of `T`. This means Rocket will automatically coerce a
|
||||
/// `State<T>` to an `&T` when the types call for it.
|
||||
pub fn inner(&self) -> &'r T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Doc.
|
||||
impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> {
|
||||
type Error = ();
|
||||
|
||||
fn from_request(req: &'a Request<'r>) -> request::Outcome<State<'r, T>, ()> {
|
||||
if let Some(state) = req.get_state() {
|
||||
match state.try_get::<T>() {
|
||||
Some(state) => Outcome::Success(State(state)),
|
||||
None => {
|
||||
error_!("Attempted to retrieve unmanaged state!");
|
||||
Outcome::Failure((Status::InternalServerError, ()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
error_!("Internal Rocket error: managed state is unset!");
|
||||
error_!("Please report this error in the Rocket GitHub issue tracker.");
|
||||
Outcome::Failure((Status::InternalServerError, ()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, T: Send + Sync + 'static> Deref for State<'r, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
self.0
|
||||
}
|
||||
}
|
|
@ -7,6 +7,8 @@ use std::io::{self, Write};
|
|||
use term_painter::Color::*;
|
||||
use term_painter::ToStyle;
|
||||
|
||||
use state::Container;
|
||||
|
||||
use {logger, handler};
|
||||
use ext::ReadExt;
|
||||
use config::{self, Config};
|
||||
|
@ -29,6 +31,7 @@ pub struct Rocket {
|
|||
router: Router,
|
||||
default_catchers: HashMap<u16, Catcher>,
|
||||
catchers: HashMap<u16, Catcher>,
|
||||
state: Container
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -175,9 +178,13 @@ impl Rocket {
|
|||
|
||||
#[doc(hidden)]
|
||||
#[inline(always)]
|
||||
pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> {
|
||||
pub fn dispatch<'s, 'r>(&'s self, request: &'r mut Request<'s>, data: Data)
|
||||
-> Response<'r> {
|
||||
info!("{}:", request);
|
||||
|
||||
// Inform the request about the state.
|
||||
request.set_state(&self.state);
|
||||
|
||||
// Do a bit of preprocessing before routing.
|
||||
self.preprocess_request(request, &data);
|
||||
|
||||
|
@ -353,6 +360,7 @@ impl Rocket {
|
|||
router: Router::new(),
|
||||
default_catchers: catcher::defaults::get(),
|
||||
catchers: catcher::defaults::get(),
|
||||
state: Container::new()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -472,6 +480,50 @@ impl Rocket {
|
|||
self
|
||||
}
|
||||
|
||||
/// Add `state` to the state managed by this instance of Rocket.
|
||||
///
|
||||
/// Managed state can be retrieved by any request handler via the
|
||||
/// [State](/rocket/struct.State.html) request guard. In particular, if a
|
||||
/// value of type `T` is managed by Rocket, adding `State<T>` to the list of
|
||||
/// arguments in a request handler instructs Rocket to retrieve the managed
|
||||
/// value.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if state of type `T` is already being managed.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #![feature(plugin)]
|
||||
/// # #![plugin(rocket_codegen)]
|
||||
/// # extern crate rocket;
|
||||
/// use rocket::State;
|
||||
///
|
||||
/// struct MyValue(usize);
|
||||
///
|
||||
/// #[get("/")]
|
||||
/// fn index(state: State<MyValue>) -> String {
|
||||
/// format!("The stateful value is: {}", state.0)
|
||||
/// }
|
||||
///
|
||||
/// fn main() {
|
||||
/// # if false { // We don't actually want to launch the server in an example.
|
||||
/// rocket::ignite()
|
||||
/// .manage(MyValue(10))
|
||||
/// # .launch()
|
||||
/// # }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn manage<T: Send + Sync + 'static>(self, state: T) -> Self {
|
||||
if !self.state.set::<T>(state) {
|
||||
error!("State for this type is already being managed!");
|
||||
panic!("Aborting due to duplicately managed state.");
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Starts the application server and begins listening for and dispatching
|
||||
/// requests to mounted routes and catchers.
|
||||
///
|
||||
|
|
|
@ -111,12 +111,12 @@ use http::{Method, Header, Cookie};
|
|||
use std::net::SocketAddr;
|
||||
|
||||
/// A type for mocking requests for testing Rocket applications.
|
||||
pub struct MockRequest {
|
||||
request: Request<'static>,
|
||||
pub struct MockRequest<'r> {
|
||||
request: Request<'r>,
|
||||
data: Data
|
||||
}
|
||||
|
||||
impl MockRequest {
|
||||
impl<'r> MockRequest<'r> {
|
||||
/// Constructs a new mocked request with the given `method` and `uri`.
|
||||
#[inline]
|
||||
pub fn new<S: AsRef<str>>(method: Method, uri: S) -> Self {
|
||||
|
@ -259,7 +259,7 @@ impl MockRequest {
|
|||
/// assert_eq!(body_str, Some("Hello, world!".to_string()));
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn dispatch_with<'r>(&'r mut self, rocket: &Rocket) -> Response<'r> {
|
||||
pub fn dispatch_with<'s>(&'s mut self, rocket: &'r Rocket) -> Response<'s> {
|
||||
let data = ::std::mem::replace(&mut self.data, Data::new(vec![]));
|
||||
rocket.dispatch(&mut self.request, data)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue