Introduce Managed State.

This commit is contained in:
Sergio Benitez 2017-01-20 19:31:46 -08:00
parent 9ef65a8c91
commit c815911705
14 changed files with 239 additions and 23 deletions

View File

@ -28,4 +28,5 @@ members = [
"examples/hello_alt_methods",
"examples/raw_upload",
"examples/pastebin",
"examples/state",
]

View File

@ -40,13 +40,15 @@ mod test {
use rocket::http::Header;
fn test_header_count<'h>(headers: Vec<Header<'static>>) {
let rocket = rocket::ignite()
.mount("/", routes![super::header_count]);
let num_headers = headers.len();
let mut req = MockRequest::new(Get, "/");
for header in headers {
req = req.header(header);
}
let rocket = rocket::ignite().mount("/", routes![super::header_count]);
let mut response = req.dispatch_with(&rocket);
let expect = format!("Your request contained {} headers!", num_headers);

View File

@ -11,7 +11,8 @@ macro_rules! run_test {
.mount("/", routes![super::index, super::get])
.catch(errors![super::not_found]);
$test_fn($req.dispatch_with(&rocket));
let mut req = $req;
$test_fn(req.dispatch_with(&rocket));
})
}
@ -19,7 +20,7 @@ macro_rules! run_test {
fn test_root() {
// Check that the redirect works.
for method in &[Get, Head] {
let mut req = MockRequest::new(*method, "/");
let req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::SeeOther);
assert!(response.body().is_none());
@ -31,7 +32,7 @@ fn test_root() {
// Check that other request methods are not accepted (and instead caught).
for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] {
let mut req = MockRequest::new(*method, "/");
let req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
@ -48,7 +49,7 @@ fn test_root() {
#[test]
fn test_name() {
// Check that the /hello/<name> route works.
let mut req = MockRequest::new(Get, "/hello/Jack");
let req = MockRequest::new(Get, "/hello/Jack");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
@ -66,7 +67,7 @@ fn test_name() {
#[test]
fn test_404() {
// Check that the error catcher works.
let mut req = MockRequest::new(Get, "/hello/");
let req = MockRequest::new(Get, "/hello/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);

View File

@ -10,14 +10,15 @@ macro_rules! run_test {
.mount("/message", routes![super::new, super::update, super::get])
.catch(errors![super::not_found]);
$test_fn($req.dispatch_with(&rocket));
let mut req = $req;
$test_fn(req.dispatch_with(&rocket));
})
}
#[test]
fn bad_get_put() {
// Try to get a message with an ID that doesn't exist.
let mut req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
@ -27,7 +28,7 @@ fn bad_get_put() {
});
// Try to get a message with an invalid ID.
let mut req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
let body = response.body().unwrap().into_string().unwrap();
@ -35,13 +36,13 @@ fn bad_get_put() {
});
// Try to put a message without a proper body.
let mut req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
let req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
run_test!(req, |response: Response| {
assert_eq!(response.status(), Status::BadRequest);
});
// Try to put a message for an ID that doesn't exist.
let mut req = MockRequest::new(Put, "/message/80")
let req = MockRequest::new(Put, "/message/80")
.header(ContentType::JSON)
.body(r#"{ "contents": "Bye bye, world!" }"#);
@ -53,13 +54,13 @@ fn bad_get_put() {
#[test]
fn post_get_put_get() {
// Check that a message with ID 1 doesn't exist.
let mut req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
run_test!(req, |response: Response| {
assert_eq!(response.status(), Status::NotFound);
});
// Add a new message with ID 1.
let mut req = MockRequest::new(Post, "/message/1")
let req = MockRequest::new(Post, "/message/1")
.header(ContentType::JSON)
.body(r#"{ "contents": "Hello, world!" }"#);
@ -68,7 +69,7 @@ fn post_get_put_get() {
});
// Check that the message exists with the correct contents.
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
let body = response.body().unwrap().into_string().unwrap();
@ -76,7 +77,7 @@ fn post_get_put_get() {
});
// Change the message contents.
let mut req = MockRequest::new(Put, "/message/1")
let req = MockRequest::new(Put, "/message/1")
.header(ContentType::JSON)
.body(r#"{ "contents": "Bye bye, world!" }"#);
@ -85,7 +86,7 @@ fn post_get_put_get() {
});
// Check that the message exists with the updated contents.
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
let body = response.body().unwrap().into_string().unwrap();

11
examples/state/Cargo.toml Normal file
View File

@ -0,0 +1,11 @@
[package]
name = "state"
version = "0.0.1"
workspace = "../../"
[dependencies]
rocket = { path = "../../lib" }
rocket_codegen = { path = "../../codegen" }
[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }

View File

@ -0,0 +1,36 @@
#![feature(plugin)]
#![plugin(rocket_codegen)]
extern crate rocket;
#[cfg(test)] mod tests;
use std::sync::atomic::{AtomicUsize, Ordering};
use rocket::State;
use rocket::response::content;
struct HitCount(AtomicUsize);
#[get("/")]
fn index(hit_count: State<HitCount>) -> content::HTML<String> {
hit_count.0.fetch_add(1, Ordering::Relaxed);
let msg = "Your visit has been recorded!";
let count = format!("Visits: {}", count(hit_count));
content::HTML(format!("{}<br /><br />{}", msg, count))
}
#[get("/count")]
fn count(hit_count: State<HitCount>) -> String {
hit_count.0.load(Ordering::Relaxed).to_string()
}
fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", routes![index, count])
.manage(HitCount(AtomicUsize::new(0)))
}
fn main() {
rocket().launch();
}

View File

@ -0,0 +1,43 @@
use rocket::Rocket;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::Status;
fn register_hit(rocket: &Rocket) {
let mut req = MockRequest::new(Get, "/");
let response = req.dispatch_with(&rocket);
assert_eq!(response.status(), Status::Ok);
}
fn get_count(rocket: &Rocket) -> usize {
let mut req = MockRequest::new(Get, "/count");
let mut response = req.dispatch_with(&rocket);
let body_string = response.body().and_then(|b| b.into_string()).unwrap();
body_string.parse().unwrap()
}
#[test]
fn test_count() {
let rocket = super::rocket();
// Count should start at 0.
assert_eq!(get_count(&rocket), 0);
for _ in 0..99 { register_hit(&rocket); }
assert_eq!(get_count(&rocket), 99);
register_hit(&rocket);
assert_eq!(get_count(&rocket), 100);
}
// Cargo runs each test in parallel on different threads. We use all of these
// tests below to show (and assert) that state is managed per-Rocket instance.
#[test] fn test_count_parallel() { test_count() }
#[test] fn test_count_parallel_2() { test_count() }
#[test] fn test_count_parallel_3() { test_count() }
#[test] fn test_count_parallel_4() { test_count() }
#[test] fn test_count_parallel_5() { test_count() }
#[test] fn test_count_parallel_6() { test_count() }
#[test] fn test_count_parallel_7() { test_count() }
#[test] fn test_count_parallel_8() { test_count() }
#[test] fn test_count_parallel_9() { test_count() }

View File

@ -21,6 +21,7 @@ url = "^1"
hyper = { version = "^0.9.14", default-features = false }
toml = { version = "^0.2", default-features = false }
num_cpus = "1"
state = "^0.2"
# cookie = "^0.3"
[dev-dependencies]

View File

@ -96,6 +96,7 @@ extern crate hyper;
extern crate url;
extern crate toml;
extern crate num_cpus;
extern crate state;
#[cfg(test)] #[macro_use] extern crate lazy_static;
@ -123,7 +124,7 @@ mod ext;
#[doc(inline)] pub use outcome::Outcome;
#[doc(inline)] pub use data::Data;
pub use router::Route;
pub use request::Request;
pub use request::{Request, State};
pub use error::Error;
pub use catcher::Catcher;
pub use rocket::Rocket;

View File

@ -4,11 +4,13 @@ mod request;
mod param;
mod form;
mod from_request;
mod state;
pub use self::request::Request;
pub use self::from_request::{FromRequest, Outcome};
pub use self::param::{FromParam, FromSegments};
pub use self::form::{Form, FromForm, FromFormValue, FormItems};
pub use self::state::State;
/// Type alias to retrieve flash messages from a request.
pub type FlashMessage = ::response::Flash<()>;

View File

@ -5,6 +5,8 @@ use std::fmt;
use term_painter::Color::*;
use term_painter::ToStyle;
use state::Container;
use error::Error;
use super::{FromParam, FromSegments};
@ -28,6 +30,7 @@ pub struct Request<'r> {
remote: Option<SocketAddr>,
params: RefCell<Vec<(usize, usize)>>,
cookies: Cookies,
state: Option<&'r Container>,
}
impl<'r> Request<'r> {
@ -51,6 +54,7 @@ impl<'r> Request<'r> {
remote: None,
params: RefCell::new(Vec::new()),
cookies: Cookies::new(&[]),
state: None
}
}
@ -391,13 +395,25 @@ impl<'r> Request<'r> {
Some(Segments(&path[i..j]))
}
/// Get the managed state container, if it exists. For internal use only!
#[doc(hidden)]
pub fn get_state(&self) -> Option<&'r Container> {
self.state
}
/// Set the state. For internal use only!
#[doc(hidden)]
pub fn set_state(&mut self, state: &'r Container) {
self.state = Some(state);
}
/// Convert from Hyper types into a Rocket Request.
#[doc(hidden)]
pub fn from_hyp(h_method: hyper::Method,
h_headers: hyper::header::Headers,
h_uri: hyper::RequestUri,
h_addr: SocketAddr,
) -> Result<Request<'static>, String> {
) -> Result<Request<'r>, String> {
// Get a copy of the URI for later use.
let uri = match h_uri {
hyper::RequestUri::AbsolutePath(s) => s,

49
lib/src/request/state.rs Normal file
View File

@ -0,0 +1,49 @@
use std::ops::Deref;
use request::{self, FromRequest, Request};
use outcome::Outcome;
use http::Status;
// TODO: Doc.
#[derive(Debug, PartialEq, Eq)]
pub struct State<'r, T: Send + Sync + 'static>(&'r T);
impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// Retrieve a borrow to the underyling value.
///
/// Using this method is typically unnecessary as `State` implements `Deref`
/// with a `Target` of `T`. This means Rocket will automatically coerce a
/// `State<T>` to an `&T` when the types call for it.
pub fn inner(&self) -> &'r T {
self.0
}
}
// TODO: Doc.
impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> {
type Error = ();
fn from_request(req: &'a Request<'r>) -> request::Outcome<State<'r, T>, ()> {
if let Some(state) = req.get_state() {
match state.try_get::<T>() {
Some(state) => Outcome::Success(State(state)),
None => {
error_!("Attempted to retrieve unmanaged state!");
Outcome::Failure((Status::InternalServerError, ()))
}
}
} else {
error_!("Internal Rocket error: managed state is unset!");
error_!("Please report this error in the Rocket GitHub issue tracker.");
Outcome::Failure((Status::InternalServerError, ()))
}
}
}
impl<'r, T: Send + Sync + 'static> Deref for State<'r, T> {
type Target = T;
fn deref(&self) -> &T {
self.0
}
}

View File

@ -7,6 +7,8 @@ use std::io::{self, Write};
use term_painter::Color::*;
use term_painter::ToStyle;
use state::Container;
use {logger, handler};
use ext::ReadExt;
use config::{self, Config};
@ -29,6 +31,7 @@ pub struct Rocket {
router: Router,
default_catchers: HashMap<u16, Catcher>,
catchers: HashMap<u16, Catcher>,
state: Container
}
#[doc(hidden)]
@ -175,9 +178,13 @@ impl Rocket {
#[doc(hidden)]
#[inline(always)]
pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> {
pub fn dispatch<'s, 'r>(&'s self, request: &'r mut Request<'s>, data: Data)
-> Response<'r> {
info!("{}:", request);
// Inform the request about the state.
request.set_state(&self.state);
// Do a bit of preprocessing before routing.
self.preprocess_request(request, &data);
@ -353,6 +360,7 @@ impl Rocket {
router: Router::new(),
default_catchers: catcher::defaults::get(),
catchers: catcher::defaults::get(),
state: Container::new()
}
}
@ -472,6 +480,50 @@ impl Rocket {
self
}
/// Add `state` to the state managed by this instance of Rocket.
///
/// Managed state can be retrieved by any request handler via the
/// [State](/rocket/struct.State.html) request guard. In particular, if a
/// value of type `T` is managed by Rocket, adding `State<T>` to the list of
/// arguments in a request handler instructs Rocket to retrieve the managed
/// value.
///
/// # Panics
///
/// Panics if state of type `T` is already being managed.
///
/// # Example
///
/// ```rust
/// # #![feature(plugin)]
/// # #![plugin(rocket_codegen)]
/// # extern crate rocket;
/// use rocket::State;
///
/// struct MyValue(usize);
///
/// #[get("/")]
/// fn index(state: State<MyValue>) -> String {
/// format!("The stateful value is: {}", state.0)
/// }
///
/// fn main() {
/// # if false { // We don't actually want to launch the server in an example.
/// rocket::ignite()
/// .manage(MyValue(10))
/// # .launch()
/// # }
/// }
/// ```
pub fn manage<T: Send + Sync + 'static>(self, state: T) -> Self {
if !self.state.set::<T>(state) {
error!("State for this type is already being managed!");
panic!("Aborting due to duplicately managed state.");
}
self
}
/// Starts the application server and begins listening for and dispatching
/// requests to mounted routes and catchers.
///

View File

@ -111,12 +111,12 @@ use http::{Method, Header, Cookie};
use std::net::SocketAddr;
/// A type for mocking requests for testing Rocket applications.
pub struct MockRequest {
request: Request<'static>,
pub struct MockRequest<'r> {
request: Request<'r>,
data: Data
}
impl MockRequest {
impl<'r> MockRequest<'r> {
/// Constructs a new mocked request with the given `method` and `uri`.
#[inline]
pub fn new<S: AsRef<str>>(method: Method, uri: S) -> Self {
@ -259,7 +259,7 @@ impl MockRequest {
/// assert_eq!(body_str, Some("Hello, world!".to_string()));
/// # }
/// ```
pub fn dispatch_with<'r>(&'r mut self, rocket: &Rocket) -> Response<'r> {
pub fn dispatch_with<'s>(&'s mut self, rocket: &'r Rocket) -> Response<'s> {
let data = ::std::mem::replace(&mut self.data, Data::new(vec![]));
rocket.dispatch(&mut self.request, data)
}