Merge branch 'master' into uuid_support_v2

This commit is contained in:
Lori Holden 2017-01-13 19:06:30 -05:00
commit 21198bd3cf
16 changed files with 480 additions and 32 deletions

View File

@ -4,7 +4,7 @@
extern crate rocket; extern crate rocket;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::Utf8Error; use rocket::http::uri::SegmentError;
#[post("/<a>/<b..>")] #[post("/<a>/<b..>")]
fn get(a: String, b: PathBuf) -> String { fn get(a: String, b: PathBuf) -> String {
@ -12,7 +12,7 @@ fn get(a: String, b: PathBuf) -> String {
} }
#[post("/<a>/<b..>")] #[post("/<a>/<b..>")]
fn get2(a: String, b: Result<PathBuf, Utf8Error>) -> String { fn get2(a: String, b: Result<PathBuf, SegmentError>) -> String {
format!("{}/{}", a, b.unwrap().to_string_lossy()) format!("{}/{}", a, b.unwrap().to_string_lossy())
} }

View File

@ -11,7 +11,8 @@ use rocket::response::{self, Responder, content};
use rocket::http::Status; use rocket::http::Status;
use self::serde::{Serialize, Deserialize}; use self::serde::{Serialize, Deserialize};
use self::serde_json::error::Error as SerdeError;
pub use self::serde_json::error::Error as SerdeError;
/// The JSON type, which implements `FromData` and `Responder`. This type allows /// The JSON type, which implements `FromData` and `Responder`. This type allows
/// you to trivially consume and respond with JSON in your Rocket application. /// you to trivially consume and respond with JSON in your Rocket application.

View File

@ -53,6 +53,9 @@ mod uuid;
#[cfg(feature = "json")] #[cfg(feature = "json")]
pub use json::JSON; pub use json::JSON;
#[cfg(feature = "json")]
pub use json::SerdeError;
#[cfg(feature = "templates")] #[cfg(feature = "templates")]
pub use templates::Template; pub use templates::Template;

View File

@ -10,3 +10,6 @@ rocket_codegen = { path = "../../codegen" }
serde = "0.8" serde = "0.8"
serde_json = "0.8" serde_json = "0.8"
serde_derive = "0.8" serde_derive = "0.8"
[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }

View File

@ -3,7 +3,11 @@
extern crate rocket; extern crate rocket;
extern crate serde_json; extern crate serde_json;
#[macro_use] extern crate serde_derive; #[macro_use]
extern crate serde_derive;
#[cfg(test)]
mod tests;
use rocket::{Request, Error}; use rocket::{Request, Error};
use rocket::http::ContentType; use rocket::http::ContentType;
@ -34,14 +38,15 @@ fn not_found(_: Error, request: &Request) -> String {
format!("<p>This server only supports JSON requests, not '{}'.</p>", format!("<p>This server only supports JSON requests, not '{}'.</p>",
request.content_type()) request.content_type())
} else { } else {
format!("<p>Sorry, '{}' is not a valid path!</p> format!("<p>Sorry, '{}' is an invalid path! Try \
<p>Try visiting /hello/&lt;name&gt;/&lt;age&gt; instead.</p>", /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
request.uri()) request.uri())
} }
} }
fn main() { fn main() {
rocket::ignite() rocket::ignite()
.mount("/hello", routes![hello]).catch(errors![not_found]) .mount("/hello", routes![hello])
.catch(errors![not_found])
.launch(); .launch();
} }

View File

@ -0,0 +1,40 @@
use super::rocket;
use super::serde_json;
use super::Person;
use rocket::http::{ContentType, Method, Status};
use rocket::testing::MockRequest;
fn test(uri: &str, content_type: ContentType, status: Status, body: String) {
let rocket = rocket::ignite()
.mount("/hello", routes![super::hello])
.catch(errors![super::not_found]);
let mut request = MockRequest::new(Method::Get, uri).header(content_type);
let mut response = request.dispatch_with(&rocket);
assert_eq!(response.status(), status);
assert_eq!(response.body().and_then(|b| b.into_string()), Some(body));
}
#[test]
fn test_hello() {
let person = Person {
name: "Michael".to_string(),
age: 80,
};
let body = serde_json::to_string(&person).unwrap();
test("/hello/Michael/80", ContentType::JSON, Status::Ok, body);
}
#[test]
fn test_hello_invalid_content_type() {
let body = format!("<p>This server only supports JSON requests, not '{}'.</p>",
ContentType::HTML);
test("/hello/Michael/80", ContentType::HTML, Status::NotFound, body);
}
#[test]
fn test_404() {
let body = "<p>Sorry, '/unknown' is an invalid path! Try \
/hello/&lt;name&gt;/&lt;age&gt; instead.</p>";
test("/unknown", ContentType::JSON, Status::NotFound, body.to_string());
}

View File

@ -15,3 +15,6 @@ serde_json = "*"
path = "../../contrib" path = "../../contrib"
default-features = false default-features = false
features = ["handlebars_templates"] features = ["handlebars_templates"]
[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }

View File

@ -6,7 +6,9 @@ extern crate rocket;
extern crate serde_json; extern crate serde_json;
#[macro_use] extern crate serde_derive; #[macro_use] extern crate serde_derive;
use rocket::{Request}; #[cfg(test)] mod tests;
use rocket::Request;
use rocket::response::Redirect; use rocket::response::Redirect;
use rocket_contrib::Template; use rocket_contrib::Template;

View File

@ -0,0 +1,80 @@
use rocket;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::Status;
use rocket::Response;
use rocket_contrib::Template;
macro_rules! run_test {
($req:expr, $test_fn:expr) => ({
let rocket = rocket::ignite()
.mount("/", routes![super::index, super::get])
.catch(errors![super::not_found]);
$test_fn($req.dispatch_with(&rocket));
})
}
#[test]
fn test_root() {
// Check that the redirect works.
for method in &[Get, Head] {
let mut req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::SeeOther);
assert!(response.body().is_none());
let location_headers: Vec<_> = response.header_values("Location").collect();
assert_eq!(location_headers, vec!["/hello/Unknown"]);
});
}
// Check that other request methods are not accepted (and instead caught).
for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] {
let mut req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
let mut map = ::std::collections::HashMap::new();
map.insert("path", "/");
let expected = Template::render("error/404", &map).to_string();
let body_string = response.body().and_then(|body| body.into_string());
assert_eq!(body_string, Some(expected));
});
}
}
#[test]
fn test_name() {
// Check that the /hello/<name> route works.
let mut req = MockRequest::new(Get, "/hello/Jack");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
let context = super::TemplateContext {
name: "Jack".to_string(),
items: vec!["One", "Two", "Three"].iter().map(|s| s.to_string()).collect()
};
let expected = Template::render("index", &context).to_string();
let body_string = response.body().and_then(|body| body.into_string());
assert_eq!(body_string, Some(expected));
});
}
#[test]
fn test_404() {
// Check that the error catcher works.
let mut req = MockRequest::new(Get, "/hello/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
let mut map = ::std::collections::HashMap::new();
map.insert("path", "/hello/");
let expected = Template::render("error/404", &map).to_string();
let body_string = response.body().and_then(|body| body.into_string());
assert_eq!(body_string, Some(expected));
});
}

View File

@ -373,6 +373,20 @@ impl<'a> Iterator for Segments<'a> {
// } // }
} }
/// Errors which can occur when attempting to interpret a segment string as a
/// valid path segment.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum SegmentError {
/// The segment contained invalid UTF8 characters when percent decoded.
Utf8(Utf8Error),
/// The segment started with the wrapped invalid character.
BadStart(char),
/// The segment contained the wrapped invalid character.
BadChar(char),
/// The segment ended with the wrapped invalid character.
BadEnd(char),
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::URI; use super::URI;

View File

@ -1,4 +1,5 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::net::SocketAddr;
use outcome::{self, IntoOutcome}; use outcome::{self, IntoOutcome};
use request::Request; use request::Request;
@ -65,6 +66,65 @@ impl<S, E> IntoOutcome<S, (Status, E), ()> for Result<S, E> {
/// matching request. Note that users can request an `Option<S>` to catch /// matching request. Note that users can request an `Option<S>` to catch
/// `Forward`s. /// `Forward`s.
/// ///
/// # Provided Implementations
///
/// Rocket implements `FromRequest` for several built-in types. Their behavior
/// is documented here.
///
/// * **URI**
///
/// Extracts the [URI](/rocket/http/uri/struct.URI.html) from the incoming
/// request.
///
/// _This implementation always returns successfully._
///
/// * **Method**
///
/// Extracts the [Method](/rocket/http/enum.Method.html) from the incoming
/// request.
///
/// _This implementation always returns successfully._
///
/// * **&Cookies**
///
/// Returns a borrow to the [Cookies](/rocket/http/type.Cookies.html) in the
/// incoming request. Note that `Cookies` implements internal mutability, so
/// a handle to `&Cookies` allows you to get _and_ set cookies in the
/// request.
///
/// _This implementation always returns successfully._
///
/// * **ContentType**
///
/// Extracts the [ContentType](/rocket/http/struct.ContentType.html) from
/// the incoming request. If the request didn't specify a Content-Type, a
/// Content-Type of `*/*` (`Any`) is returned.
///
/// _This implementation always returns successfully._
///
/// * **SocketAddr**
///
/// Extracts the remote address of the incoming request as a `SocketAddr`.
/// If the remote address is not known, the request is forwarded.
///
/// _This implementation always returns successfully._
///
/// * **Option&lt;T>** _where_ **T: FromRequest**
///
/// The type `T` is derived from the incoming request using `T`'s
/// `FromRequest` implementation. If the derivation is a `Success`, the
/// dervived value is returned in `Some`. Otherwise, a `None` is returned.
///
/// _This implementation always returns successfully._
///
/// * **Result&lt;T, T::Error>** _where_ **T: FromRequest**
///
/// The type `T` is derived from the incoming request using `T`'s
/// `FromRequest` implementation. If derivation is a `Success`, the value is
/// returned in `Ok`. If the derivation is a `Failure`, the error value is
/// returned in `Err`. If the derivation is a `Forward`, the request is
/// forwarded.
///
/// # Example /// # Example
/// ///
/// Imagine you're running an authenticated API service that requires that some /// Imagine you're running an authenticated API service that requires that some
@ -161,6 +221,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for ContentType {
} }
} }
impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr {
type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
match request.remote() {
Some(addr) => Success(addr),
None => Forward(())
}
}
}
impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequest<'a, 'r> for Result<T, T::Error> { impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequest<'a, 'r> for Result<T, T::Error> {
type Error = (); type Error = ();

View File

@ -1,9 +1,9 @@
use std::str::{Utf8Error, FromStr}; use std::str::FromStr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use std::fmt::Debug; use std::fmt::Debug;
use http::uri::{URI, Segments}; use http::uri::{URI, Segments, SegmentError};
/// Trait to convert a dynamic path segment string to a concrete value. /// Trait to convert a dynamic path segment string to a concrete value.
/// ///
@ -274,6 +274,7 @@ pub trait FromSegments<'a>: Sized {
impl<'a> FromSegments<'a> for Segments<'a> { impl<'a> FromSegments<'a> for Segments<'a> {
type Error = (); type Error = ();
fn from_segments(segments: Segments<'a>) -> Result<Segments<'a>, ()> { fn from_segments(segments: Segments<'a>) -> Result<Segments<'a>, ()> {
Ok(segments) Ok(segments)
} }
@ -281,19 +282,46 @@ impl<'a> FromSegments<'a> for Segments<'a> {
/// Creates a `PathBuf` from a `Segments` iterator. The returned `PathBuf` is /// Creates a `PathBuf` from a `Segments` iterator. The returned `PathBuf` is
/// percent-decoded. If a segment is equal to "..", the previous segment (if /// percent-decoded. If a segment is equal to "..", the previous segment (if
/// any) is skipped. For security purposes, any other segments that begin with /// any) is skipped.
/// "*" or "." are ignored. If a percent-decoded segment results in invalid ///
/// UTF8, an `Err` is returned. /// For security purposes, if a segment meets any of the following conditions,
/// an `Err` is returned indicating the condition met:
///
/// * Decoded segment starts with any of: `.`, `*`
/// * Decoded segment ends with any of: `:`, `>`, `<`
/// * Decoded segment contains any of: `/`
/// * On Windows, decoded segment contains any of: '\'
/// * Percent-encoding results in invalid UTF8.
///
/// As a result of these conditions, a `PathBuf` derived via `FromSegments` is
/// safe to interpolate within, or use as a suffix of, a path without additional
/// checks.
impl<'a> FromSegments<'a> for PathBuf { impl<'a> FromSegments<'a> for PathBuf {
type Error = Utf8Error; type Error = SegmentError;
fn from_segments(segments: Segments<'a>) -> Result<PathBuf, Utf8Error> { fn from_segments(segments: Segments<'a>) -> Result<PathBuf, SegmentError> {
let mut buf = PathBuf::new(); let mut buf = PathBuf::new();
for segment in segments { for segment in segments {
let decoded = URI::percent_decode(segment.as_bytes())?; let decoded = URI::percent_decode(segment.as_bytes())
.map_err(|e| SegmentError::Utf8(e))?;
if decoded == ".." { if decoded == ".." {
buf.pop(); buf.pop();
} else if !(decoded.starts_with('.') || decoded.starts_with('*')) { } else if decoded.starts_with('.') {
return Err(SegmentError::BadStart('.'))
} else if decoded.starts_with('*') {
return Err(SegmentError::BadStart('*'))
} else if decoded.ends_with(':') {
return Err(SegmentError::BadEnd(':'))
} else if decoded.ends_with('>') {
return Err(SegmentError::BadEnd('>'))
} else if decoded.ends_with('<') {
return Err(SegmentError::BadEnd('<'))
} else if decoded.contains('/') {
return Err(SegmentError::BadChar('/'))
} else if cfg!(windows) && decoded.contains('\\') {
return Err(SegmentError::BadChar('\\'))
} else {
buf.push(&*decoded) buf.push(&*decoded)
} }
} }

View File

@ -1,4 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::net::SocketAddr;
use std::fmt; use std::fmt;
use term_painter::Color::*; use term_painter::Color::*;
@ -24,6 +25,7 @@ pub struct Request<'r> {
method: Method, method: Method,
uri: URI<'r>, uri: URI<'r>,
headers: HeaderMap<'r>, headers: HeaderMap<'r>,
remote: Option<SocketAddr>,
params: RefCell<Vec<(usize, usize)>>, params: RefCell<Vec<(usize, usize)>>,
cookies: Cookies, cookies: Cookies,
} }
@ -46,6 +48,7 @@ impl<'r> Request<'r> {
method: method, method: method,
uri: uri.into(), uri: uri.into(),
headers: HeaderMap::new(), headers: HeaderMap::new(),
remote: None,
params: RefCell::new(Vec::new()), params: RefCell::new(Vec::new()),
cookies: Cookies::new(&[]), cookies: Cookies::new(&[]),
} }
@ -123,6 +126,49 @@ impl<'r> Request<'r> {
self.params = RefCell::new(Vec::new()); self.params = RefCell::new(Vec::new());
} }
/// Returns the address of the remote connection that initiated this
/// request if the address is known. If the address is not known, `None` is
/// returned.
///
/// # Example
///
/// ```rust
/// use rocket::Request;
/// use rocket::http::Method;
///
/// let request = Request::new(Method::Get, "/uri");
/// assert!(request.remote().is_none());
/// ```
#[inline(always)]
pub fn remote(&self) -> Option<SocketAddr> {
self.remote
}
/// Sets the remote address of `self` to `address`.
///
/// # Example
///
/// Set the remote address to be 127.0.0.1:8000:
///
/// ```rust
/// use rocket::Request;
/// use rocket::http::Method;
/// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
///
/// let mut request = Request::new(Method::Get, "/uri");
///
/// let (ip, port) = (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8000);
/// let localhost = SocketAddr::new(ip, port);
/// request.set_remote(localhost);
///
/// assert_eq!(request.remote(), Some(localhost));
/// ```
#[doc(hidden)]
#[inline(always)]
pub fn set_remote(&mut self, address: SocketAddr) {
self.remote = Some(address);
}
/// Returns a `HeaderMap` of all of the headers in `self`. /// Returns a `HeaderMap` of all of the headers in `self`.
/// ///
/// # Example /// # Example
@ -185,8 +231,8 @@ impl<'r> Request<'r> {
/// Returns a borrow to the cookies in `self`. /// Returns a borrow to the cookies in `self`.
/// ///
/// Note that `Cookie` implements internal mutability, so this method allows /// Note that `Cookies` implements internal mutability, so this method
/// you to get _and_ set cookies in `self`. /// allows you to get _and_ set cookies in `self`.
/// ///
/// # Example /// # Example
/// ///
@ -274,6 +320,7 @@ impl<'r> Request<'r> {
/// Set `self`'s parameters given that the route used to reach this request /// Set `self`'s parameters given that the route used to reach this request
/// was `route`. This should only be used internally by `Rocket` as improper /// was `route`. This should only be used internally by `Rocket` as improper
/// use may result in out of bounds indexing. /// use may result in out of bounds indexing.
/// TODO: Figure out the mount path from here.
#[doc(hidden)] #[doc(hidden)]
#[inline(always)] #[inline(always)]
pub fn set_params(&self, route: &Route) { pub fn set_params(&self, route: &Route) {
@ -348,8 +395,9 @@ impl<'r> Request<'r> {
#[doc(hidden)] #[doc(hidden)]
pub fn from_hyp(h_method: hyper::Method, pub fn from_hyp(h_method: hyper::Method,
h_headers: hyper::header::Headers, h_headers: hyper::header::Headers,
h_uri: hyper::RequestUri) h_uri: hyper::RequestUri,
-> Result<Request<'static>, String> { h_addr: SocketAddr,
) -> Result<Request<'static>, String> {
// Get a copy of the URI for later use. // Get a copy of the URI for later use.
let uri = match h_uri { let uri = match h_uri {
hyper::RequestUri::AbsolutePath(s) => s, hyper::RequestUri::AbsolutePath(s) => s,
@ -376,6 +424,9 @@ impl<'r> Request<'r> {
request.add_header(header); request.add_header(header);
} }
// Set the remote address.
request.set_remote(h_addr);
Ok(request) Ok(request)
} }
} }

View File

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::str::from_utf8_unchecked; use std::str::from_utf8_unchecked;
use std::cmp::min; use std::cmp::min;
use std::net::SocketAddr;
use std::io::{self, Write}; use std::io::{self, Write};
use term_painter::Color::*; use term_painter::Color::*;
@ -41,11 +42,11 @@ impl hyper::Handler for Rocket {
hyp_req: hyper::Request<'h, 'k>, hyp_req: hyper::Request<'h, 'k>,
res: hyper::FreshResponse<'h>) { res: hyper::FreshResponse<'h>) {
// Get all of the information from Hyper. // Get all of the information from Hyper.
let (_, h_method, h_headers, h_uri, _, h_body) = hyp_req.deconstruct(); let (h_addr, h_method, h_headers, h_uri, _, h_body) = hyp_req.deconstruct();
// Convert the Hyper request into a Rocket request. // Convert the Hyper request into a Rocket request.
let mut request = match Request::from_hyp(h_method, h_headers, h_uri) { let mut req = match Request::from_hyp(h_method, h_headers, h_uri, h_addr) {
Ok(request) => request, Ok(req) => req,
Err(e) => { Err(e) => {
error!("Bad incoming request: {}", e); error!("Bad incoming request: {}", e);
let dummy = Request::new(Method::Get, URI::new("<unknown>")); let dummy = Request::new(Method::Get, URI::new("<unknown>"));
@ -59,13 +60,13 @@ impl hyper::Handler for Rocket {
Ok(data) => data, Ok(data) => data,
Err(reason) => { Err(reason) => {
error_!("Bad data in request: {}", reason); error_!("Bad data in request: {}", reason);
let r = self.handle_error(Status::InternalServerError, &request); let r = self.handle_error(Status::InternalServerError, &req);
return self.issue_response(r, res); return self.issue_response(r, res);
} }
}; };
// Dispatch the request to get a response, then write that response out. // Dispatch the request to get a response, then write that response out.
let response = self.dispatch(&mut request, data); let response = self.dispatch(&mut req, data);
self.issue_response(response, res) self.issue_response(response, res)
} }
} }
@ -132,15 +133,33 @@ impl Rocket {
} }
} }
/// Preprocess the request for Rocket-specific things. At this time, we're /// Preprocess the request for Rocket things. Currently, this means:
/// only checking for _method in forms. Keep this in-sync with derive_form ///
/// when preprocessing form fields. /// * Rewriting the method in the request if _method form field exists.
/// * Rewriting the remote IP if the 'X-Real-IP' header is set.
///
/// Keep this in-sync with derive_form when preprocessing form fields.
fn preprocess_request(&self, req: &mut Request, data: &Data) { fn preprocess_request(&self, req: &mut Request, data: &Data) {
// Rewrite the remote IP address. The request must already have an
// address associated with it to do this since we need to know the port.
if let Some(current) = req.remote() {
let ip = req.headers()
.get_one("X-Real-IP")
.and_then(|ip_str| ip_str.parse().map_err(|_| {
warn_!("The 'X-Real-IP' header is malformed: {}", ip_str)
}).ok());
if let Some(ip) = ip {
req.set_remote(SocketAddr::new(ip, current.port()));
}
}
// Check if this is a form and if the form contains the special _method // Check if this is a form and if the form contains the special _method
// field which we use to reinterpret the request's method. // field which we use to reinterpret the request's method.
let data_len = data.peek().len(); let data_len = data.peek().len();
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
if req.method() == Method::Post && req.content_type().is_form() && data_len >= min_len { let is_form = req.content_type().is_form();
if is_form && req.method() == Method::Post && data_len >= min_len {
let form = unsafe { let form = unsafe {
from_utf8_unchecked(&data.peek()[..min(data_len, max_len)]) from_utf8_unchecked(&data.peek()[..min(data_len, max_len)])
}; };
@ -157,6 +176,8 @@ impl Rocket {
#[doc(hidden)] #[doc(hidden)]
#[inline(always)] #[inline(always)]
pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> { pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> {
info!("{}:", request);
// Do a bit of preprocessing before routing. // Do a bit of preprocessing before routing.
self.preprocess_request(request, &data); self.preprocess_request(request, &data);
@ -207,7 +228,6 @@ impl Rocket {
pub fn route<'r>(&self, request: &'r Request, mut data: Data) pub fn route<'r>(&self, request: &'r Request, mut data: Data)
-> handler::Outcome<'r> { -> handler::Outcome<'r> {
// Go through the list of matching routes until we fail or succeed. // Go through the list of matching routes until we fail or succeed.
info!("{}:", request);
let matches = self.router.route(request); let matches = self.router.route(request);
for route in matches { for route in matches {
// Retrieve and set the requests parameters. // Retrieve and set the requests parameters.

View File

@ -108,6 +108,8 @@
use ::{Rocket, Request, Response, Data}; use ::{Rocket, Request, Response, Data};
use http::{Method, Header, Cookie}; use http::{Method, Header, Cookie};
use std::net::SocketAddr;
/// A type for mocking requests for testing Rocket applications. /// A type for mocking requests for testing Rocket applications.
pub struct MockRequest { pub struct MockRequest {
request: Request<'static>, request: Request<'static>,
@ -143,6 +145,44 @@ impl MockRequest {
self self
} }
/// Set the remote address of this request.
///
/// # Examples
///
/// Set the remote address to "8.8.8.8:80":
///
/// ```rust
/// use rocket::http::Method::*;
/// use rocket::testing::MockRequest;
///
/// let address = "8.8.8.8:80".parse().unwrap();
/// let req = MockRequest::new(Get, "/").remote(address);
/// ```
#[inline]
pub fn remote(mut self, address: SocketAddr) -> Self {
self.request.set_remote(address);
self
}
/// Adds a header to this request. Does not consume `self`.
///
/// # Examples
///
/// Add the Content-Type header:
///
/// ```rust
/// use rocket::http::Method::*;
/// use rocket::testing::MockRequest;
/// use rocket::http::ContentType;
///
/// let mut req = MockRequest::new(Get, "/");
/// req.add_header(ContentType::JSON);
/// ```
#[inline]
pub fn add_header<'h, H: Into<Header<'static>>>(&mut self, header: H) {
self.request.add_header(header.into());
}
/// Add a cookie to this request. /// Add a cookie to this request.
/// ///
/// # Examples /// # Examples

View File

@ -0,0 +1,87 @@
#![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)]
extern crate rocket;
use std::net::SocketAddr;
#[get("/")]
fn get_ip(remote: SocketAddr) -> String {
remote.to_string()
}
#[cfg(feature = "testing")]
mod remote_rewrite_tests {
use super::*;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::{Header, Status};
use std::net::SocketAddr;
const KNOWN_IP: &'static str = "127.0.0.1:8000";
fn check_ip(header: Option<Header<'static>>, ip: Option<String>) {
let address: SocketAddr = KNOWN_IP.parse().unwrap();
let port = address.port();
let rocket = rocket::ignite().mount("/", routes![get_ip]);
let mut req = MockRequest::new(Get, "/").remote(address);
if let Some(header) = header {
req.add_header(header);
}
let mut response = req.dispatch_with(&rocket);
assert_eq!(response.status(), Status::Ok);
let body_str = response.body().and_then(|b| b.into_string());
match ip {
Some(ip) => assert_eq!(body_str, Some(format!("{}:{}", ip, port))),
None => assert_eq!(body_str, Some(KNOWN_IP.into()))
}
}
#[test]
fn x_real_ip_rewrites() {
let ip = "8.8.8.8";
check_ip(Some(Header::new("X-Real-IP", ip)), Some(ip.to_string()));
let ip = "129.120.111.200";
check_ip(Some(Header::new("X-Real-IP", ip)), Some(ip.to_string()));
}
#[test]
fn x_real_ip_rewrites_ipv6() {
let ip = "2001:db8:0:1:1:1:1:1";
check_ip(Some(Header::new("X-Real-IP", ip)), Some(format!("[{}]", ip)));
let ip = "2001:db8::2:1";
check_ip(Some(Header::new("X-Real-IP", ip)), Some(format!("[{}]", ip)));
}
#[test]
fn uncased_header_rewrites() {
let ip = "8.8.8.8";
check_ip(Some(Header::new("x-REAL-ip", ip)), Some(ip.to_string()));
let ip = "1.2.3.4";
check_ip(Some(Header::new("x-real-ip", ip)), Some(ip.to_string()));
}
#[test]
fn no_header_no_rewrite() {
check_ip(Some(Header::new("real-ip", "?")), None);
check_ip(None, None);
}
#[test]
fn bad_header_doesnt_rewrite() {
let ip = "092348092348";
check_ip(Some(Header::new("X-Real-IP", ip)), None);
let ip = "1200:100000:0120129";
check_ip(Some(Header::new("X-Real-IP", ip)), None);
let ip = "192.168.1.900";
check_ip(Some(Header::new("X-Real-IP", ip)), None);
}
}