From 6fd0503ceaf51a92d19095815d7546f056296897 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 12 Jan 2017 23:07:01 -0800 Subject: [PATCH 1/6] Expose SerdeError. --- contrib/src/json/mod.rs | 3 ++- contrib/src/lib.rs | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/contrib/src/json/mod.rs b/contrib/src/json/mod.rs index 5d94aae3..e5a83954 100644 --- a/contrib/src/json/mod.rs +++ b/contrib/src/json/mod.rs @@ -11,7 +11,8 @@ use rocket::response::{self, Responder, content}; use rocket::http::Status; use self::serde::{Serialize, Deserialize}; -use self::serde_json::error::Error as SerdeError; + +pub use self::serde_json::error::Error as SerdeError; /// The JSON type, which implements `FromData` and `Responder`. This type allows /// you to trivially consume and respond with JSON in your Rocket application. diff --git a/contrib/src/lib.rs b/contrib/src/lib.rs index 59dd96b6..b8600862 100644 --- a/contrib/src/lib.rs +++ b/contrib/src/lib.rs @@ -48,5 +48,8 @@ mod templates; #[cfg(feature = "json")] pub use json::JSON; +#[cfg(feature = "json")] +pub use json::SerdeError; + #[cfg(feature = "templates")] pub use templates::Template; From dec585dbd4e3a03ab0189d2c1aac669f867b5ea9 Mon Sep 17 00:00:00 2001 From: Seth Lopez Date: Thu, 29 Dec 2016 21:53:48 -0500 Subject: [PATCH 2/6] Add tests for content_types example. --- examples/content_types/Cargo.toml | 3 +++ examples/content_types/src/main.rs | 15 +++++++---- examples/content_types/src/tests.rs | 40 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 examples/content_types/src/tests.rs diff --git a/examples/content_types/Cargo.toml b/examples/content_types/Cargo.toml index 973be5bd..cd4388f1 100644 --- a/examples/content_types/Cargo.toml +++ b/examples/content_types/Cargo.toml @@ -10,3 +10,6 @@ rocket_codegen = { path = "../../codegen" } serde = "0.8" serde_json = "0.8" serde_derive = "0.8" + +[dev-dependencies] +rocket = { path = "../../lib", features = ["testing"] } diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index d33f789b..84b2b0ca 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -3,7 +3,11 @@ extern crate rocket; extern crate serde_json; -#[macro_use] extern crate serde_derive; +#[macro_use] +extern crate serde_derive; + +#[cfg(test)] +mod tests; use rocket::{Request, Error}; use rocket::http::ContentType; @@ -34,14 +38,15 @@ fn not_found(_: Error, request: &Request) -> String { format!("

This server only supports JSON requests, not '{}'.

", request.content_type()) } else { - format!("

Sorry, '{}' is not a valid path!

-

Try visiting /hello/<name>/<age> instead.

", - request.uri()) + format!("

Sorry, '{}' is an invalid path! Try \ + /hello/<name>/<age> instead.

", + request.uri()) } } fn main() { rocket::ignite() - .mount("/hello", routes![hello]).catch(errors![not_found]) + .mount("/hello", routes![hello]) + .catch(errors![not_found]) .launch(); } diff --git a/examples/content_types/src/tests.rs b/examples/content_types/src/tests.rs new file mode 100644 index 00000000..b7b0f254 --- /dev/null +++ b/examples/content_types/src/tests.rs @@ -0,0 +1,40 @@ +use super::rocket; +use super::serde_json; +use super::Person; +use rocket::http::{ContentType, Method, Status}; +use rocket::testing::MockRequest; + +fn test(uri: &str, content_type: ContentType, status: Status, body: String) { + let rocket = rocket::ignite() + .mount("/hello", routes![super::hello]) + .catch(errors![super::not_found]); + let mut request = MockRequest::new(Method::Get, uri).header(content_type); + let mut response = request.dispatch_with(&rocket); + + assert_eq!(response.status(), status); + assert_eq!(response.body().and_then(|b| b.into_string()), Some(body)); +} + +#[test] +fn test_hello() { + let person = Person { + name: "Michael".to_string(), + age: 80, + }; + let body = serde_json::to_string(&person).unwrap(); + test("/hello/Michael/80", ContentType::JSON, Status::Ok, body); +} + +#[test] +fn test_hello_invalid_content_type() { + let body = format!("

This server only supports JSON requests, not '{}'.

", + ContentType::HTML); + test("/hello/Michael/80", ContentType::HTML, Status::NotFound, body); +} + +#[test] +fn test_404() { + let body = "

Sorry, '/unknown' is an invalid path! Try \ + /hello/<name>/<age> instead.

"; + test("/unknown", ContentType::JSON, Status::NotFound, body.to_string()); +} From 99a17b42aea0663256524ec1f8caac8c72c08de4 Mon Sep 17 00:00:00 2001 From: FliegendeWurst <2012gdwu@web.de> Date: Tue, 3 Jan 2017 12:12:27 +0100 Subject: [PATCH 3/6] Add tests for handlebars_templates example. --- examples/handlebars_templates/Cargo.toml | 3 + examples/handlebars_templates/src/main.rs | 4 +- examples/handlebars_templates/src/tests.rs | 82 ++++++++++++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 examples/handlebars_templates/src/tests.rs diff --git a/examples/handlebars_templates/Cargo.toml b/examples/handlebars_templates/Cargo.toml index d11180a3..0c32a3d7 100644 --- a/examples/handlebars_templates/Cargo.toml +++ b/examples/handlebars_templates/Cargo.toml @@ -15,3 +15,6 @@ serde_json = "*" path = "../../contrib" default-features = false features = ["handlebars_templates"] + +[dev-dependencies] +rocket = { path = "../../lib", features = ["testing"] } diff --git a/examples/handlebars_templates/src/main.rs b/examples/handlebars_templates/src/main.rs index 5e29e6b1..2eaa7d34 100644 --- a/examples/handlebars_templates/src/main.rs +++ b/examples/handlebars_templates/src/main.rs @@ -6,7 +6,9 @@ extern crate rocket; extern crate serde_json; #[macro_use] extern crate serde_derive; -use rocket::{Request}; +#[cfg(test)] mod tests; + +use rocket::Request; use rocket::response::Redirect; use rocket_contrib::Template; diff --git a/examples/handlebars_templates/src/tests.rs b/examples/handlebars_templates/src/tests.rs new file mode 100644 index 00000000..89e2bfab --- /dev/null +++ b/examples/handlebars_templates/src/tests.rs @@ -0,0 +1,82 @@ +use rocket; +use rocket::testing::MockRequest; +use rocket::http::Method::*; +use rocket::http::Status; +use rocket::Response; +use rocket_contrib::Template; + +macro_rules! run_test { + ($req:expr, $test_fn:expr) => ({ + let rocket = rocket::ignite() + .mount("/", routes![super::index, super::get]) + .catch(errors![super::not_found]); + + $test_fn($req.dispatch_with(&rocket)); + }) +} + +#[test] +fn test_root() { + // Check that the redirect works. + for method in &[Get, Head] { + let mut req = MockRequest::new(*method, "/"); + run_test!(req, |mut response: Response| { + assert_eq!(response.status(), Status::SeeOther); + + assert!(response.body().is_none()); + + let location_headers: Vec<_> = response.header_values("Location").collect(); + + assert_eq!(location_headers, vec!["/hello/Unknown"]); + }); + } + + // Check that other request methods are not accepted (and instead caught). + for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] { + let mut req = MockRequest::new(*method, "/"); + run_test!(req, |mut response: Response| { + assert_eq!(response.status(), Status::NotFound); + + let mut map = ::std::collections::HashMap::new(); + map.insert("path", "/"); + let expected = Template::render("error/404", &map).to_string(); + + let body_string = response.body().and_then(|body| body.into_string()); + assert_eq!(body_string, Some(expected)); + }); + } +} + +#[test] +fn test_name() { + // Check that the /hello/ route works. + let mut req = MockRequest::new(Get, "/hello/Jack"); + run_test!(req, |mut response: Response| { + assert_eq!(response.status(), Status::Ok); + + let context = super::TemplateContext { + name: "Jack".to_string(), + items: vec!["One", "Two", "Three"].iter().map(|s| s.to_string()).collect() + }; + let expected = Template::render("index", &context).to_string(); + + let body_string = response.body().and_then(|body| body.into_string()); + assert_eq!(body_string, Some(expected)); + }); +} + +#[test] +fn test_404() { + // Check that the error catcher works. + let mut req = MockRequest::new(Get, "/hello/"); + run_test!(req, |mut response: Response| { + assert_eq!(response.status(), Status::NotFound); + + let mut map = ::std::collections::HashMap::new(); + map.insert("path", "/hello/"); + let expected = Template::render("error/404", &map).to_string(); + + let body_string = response.body().and_then(|body| body.into_string()); + assert_eq!(body_string, Some(expected)); + }); +} From 725191d3c3b8e01041246c11f378ef9badfa8c32 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 13 Jan 2017 00:22:16 -0800 Subject: [PATCH 4/6] Adjust spacing in handlebars_templates example. --- examples/handlebars_templates/src/tests.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/handlebars_templates/src/tests.rs b/examples/handlebars_templates/src/tests.rs index 89e2bfab..9e385afd 100644 --- a/examples/handlebars_templates/src/tests.rs +++ b/examples/handlebars_templates/src/tests.rs @@ -22,21 +22,19 @@ fn test_root() { let mut req = MockRequest::new(*method, "/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::SeeOther); - assert!(response.body().is_none()); - + let location_headers: Vec<_> = response.header_values("Location").collect(); - assert_eq!(location_headers, vec!["/hello/Unknown"]); }); } - + // Check that other request methods are not accepted (and instead caught). for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] { let mut req = MockRequest::new(*method, "/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); - + let mut map = ::std::collections::HashMap::new(); map.insert("path", "/"); let expected = Template::render("error/404", &map).to_string(); @@ -53,13 +51,13 @@ fn test_name() { let mut req = MockRequest::new(Get, "/hello/Jack"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::Ok); - + let context = super::TemplateContext { name: "Jack".to_string(), items: vec!["One", "Two", "Three"].iter().map(|s| s.to_string()).collect() }; - let expected = Template::render("index", &context).to_string(); + let expected = Template::render("index", &context).to_string(); let body_string = response.body().and_then(|body| body.into_string()); assert_eq!(body_string, Some(expected)); }); @@ -71,7 +69,7 @@ fn test_404() { let mut req = MockRequest::new(Get, "/hello/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); - + let mut map = ::std::collections::HashMap::new(); map.insert("path", "/hello/"); let expected = Template::render("error/404", &map).to_string(); From 41aecc3e7ff1a14f4519bb22b71a82f4ec2ef286 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 13 Jan 2017 07:50:51 -0800 Subject: [PATCH 5/6] Expose the remote address via `remote()` in `Request`. This commit also includes the following changes: * `FromRequest` for `SocketAddr` implemented: extracts remote address. * All built-in `FromRequest` implementations are documented. * Request preprocessing overrides remote IP with value from X-Real-IP header. * `MockRequest` allows setting the remote address with `remote()`. Resolves #38. --- lib/src/request/from_request.rs | 71 +++++++++++++++++++++++++++ lib/src/request/request.rs | 59 ++++++++++++++++++++-- lib/src/rocket.rs | 40 +++++++++++---- lib/src/testing.rs | 40 +++++++++++++++ lib/tests/remote-rewrite.rs | 87 +++++++++++++++++++++++++++++++++ 5 files changed, 283 insertions(+), 14 deletions(-) create mode 100644 lib/tests/remote-rewrite.rs diff --git a/lib/src/request/from_request.rs b/lib/src/request/from_request.rs index 3253bc19..16ac5cba 100644 --- a/lib/src/request/from_request.rs +++ b/lib/src/request/from_request.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::net::SocketAddr; use outcome::{self, IntoOutcome}; use request::Request; @@ -65,6 +66,65 @@ impl IntoOutcome for Result { /// matching request. Note that users can request an `Option` to catch /// `Forward`s. /// +/// # Provided Implementations +/// +/// Rocket implements `FromRequest` for several built-in types. Their behavior +/// is documented here. +/// +/// * **URI** +/// +/// Extracts the [URI](/rocket/http/uri/struct.URI.html) from the incoming +/// request. +/// +/// _This implementation always returns successfully._ +/// +/// * **Method** +/// +/// Extracts the [Method](/rocket/http/enum.Method.html) from the incoming +/// request. +/// +/// _This implementation always returns successfully._ +/// +/// * **&Cookies** +/// +/// Returns a borrow to the [Cookies](/rocket/http/type.Cookies.html) in the +/// incoming request. Note that `Cookies` implements internal mutability, so +/// a handle to `&Cookies` allows you to get _and_ set cookies in the +/// request. +/// +/// _This implementation always returns successfully._ +/// +/// * **ContentType** +/// +/// Extracts the [ContentType](/rocket/http/struct.ContentType.html) from +/// the incoming request. If the request didn't specify a Content-Type, a +/// Content-Type of `*/*` (`Any`) is returned. +/// +/// _This implementation always returns successfully._ +/// +/// * **SocketAddr** +/// +/// Extracts the remote address of the incoming request as a `SocketAddr`. +/// If the remote address is not known, the request is forwarded. +/// +/// _This implementation always returns successfully._ +/// +/// * **Option<T>** _where_ **T: FromRequest** +/// +/// The type `T` is derived from the incoming request using `T`'s +/// `FromRequest` implementation. If the derivation is a `Success`, the +/// dervived value is returned in `Some`. Otherwise, a `None` is returned. +/// +/// _This implementation always returns successfully._ +/// +/// * **Result<T, T::Error>** _where_ **T: FromRequest** +/// +/// The type `T` is derived from the incoming request using `T`'s +/// `FromRequest` implementation. If derivation is a `Success`, the value is +/// returned in `Ok`. If the derivation is a `Failure`, the error value is +/// returned in `Err`. If the derivation is a `Forward`, the request is +/// forwarded. +/// /// # Example /// /// Imagine you're running an authenticated API service that requires that some @@ -161,6 +221,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for ContentType { } } +impl<'a, 'r> FromRequest<'a, 'r> for SocketAddr { + type Error = (); + + fn from_request(request: &'a Request<'r>) -> Outcome { + match request.remote() { + Some(addr) => Success(addr), + None => Forward(()) + } + } +} + impl<'a, 'r, T: FromRequest<'a, 'r>> FromRequest<'a, 'r> for Result { type Error = (); diff --git a/lib/src/request/request.rs b/lib/src/request/request.rs index f6868d11..b5982930 100644 --- a/lib/src/request/request.rs +++ b/lib/src/request/request.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::net::SocketAddr; use std::fmt; use term_painter::Color::*; @@ -24,6 +25,7 @@ pub struct Request<'r> { method: Method, uri: URI<'r>, headers: HeaderMap<'r>, + remote: Option, params: RefCell>, cookies: Cookies, } @@ -46,6 +48,7 @@ impl<'r> Request<'r> { method: method, uri: uri.into(), headers: HeaderMap::new(), + remote: None, params: RefCell::new(Vec::new()), cookies: Cookies::new(&[]), } @@ -123,6 +126,49 @@ impl<'r> Request<'r> { self.params = RefCell::new(Vec::new()); } + /// Returns the address of the remote connection that initiated this + /// request if the address is known. If the address is not known, `None` is + /// returned. + /// + /// # Example + /// + /// ```rust + /// use rocket::Request; + /// use rocket::http::Method; + /// + /// let request = Request::new(Method::Get, "/uri"); + /// assert!(request.remote().is_none()); + /// ``` + #[inline(always)] + pub fn remote(&self) -> Option { + self.remote + } + + /// Sets the remote address of `self` to `address`. + /// + /// # Example + /// + /// Set the remote address to be 127.0.0.1:8000: + /// + /// ```rust + /// use rocket::Request; + /// use rocket::http::Method; + /// use std::net::{SocketAddr, IpAddr, Ipv4Addr}; + /// + /// let mut request = Request::new(Method::Get, "/uri"); + /// + /// let (ip, port) = (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8000); + /// let localhost = SocketAddr::new(ip, port); + /// request.set_remote(localhost); + /// + /// assert_eq!(request.remote(), Some(localhost)); + /// ``` + #[doc(hidden)] + #[inline(always)] + pub fn set_remote(&mut self, address: SocketAddr) { + self.remote = Some(address); + } + /// Returns a `HeaderMap` of all of the headers in `self`. /// /// # Example @@ -185,8 +231,8 @@ impl<'r> Request<'r> { /// Returns a borrow to the cookies in `self`. /// - /// Note that `Cookie` implements internal mutability, so this method allows - /// you to get _and_ set cookies in `self`. + /// Note that `Cookies` implements internal mutability, so this method + /// allows you to get _and_ set cookies in `self`. /// /// # Example /// @@ -274,6 +320,7 @@ impl<'r> Request<'r> { /// Set `self`'s parameters given that the route used to reach this request /// was `route`. This should only be used internally by `Rocket` as improper /// use may result in out of bounds indexing. + /// TODO: Figure out the mount path from here. #[doc(hidden)] #[inline(always)] pub fn set_params(&self, route: &Route) { @@ -348,8 +395,9 @@ impl<'r> Request<'r> { #[doc(hidden)] pub fn from_hyp(h_method: hyper::Method, h_headers: hyper::header::Headers, - h_uri: hyper::RequestUri) - -> Result, String> { + h_uri: hyper::RequestUri, + h_addr: SocketAddr, + ) -> Result, String> { // Get a copy of the URI for later use. let uri = match h_uri { hyper::RequestUri::AbsolutePath(s) => s, @@ -376,6 +424,9 @@ impl<'r> Request<'r> { request.add_header(header); } + // Set the remote address. + request.set_remote(h_addr); + Ok(request) } } diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index cf871fd9..17d22c81 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::str::from_utf8_unchecked; use std::cmp::min; +use std::net::SocketAddr; use std::io::{self, Write}; use term_painter::Color::*; @@ -41,11 +42,11 @@ impl hyper::Handler for Rocket { hyp_req: hyper::Request<'h, 'k>, res: hyper::FreshResponse<'h>) { // Get all of the information from Hyper. - let (_, h_method, h_headers, h_uri, _, h_body) = hyp_req.deconstruct(); + let (h_addr, h_method, h_headers, h_uri, _, h_body) = hyp_req.deconstruct(); // Convert the Hyper request into a Rocket request. - let mut request = match Request::from_hyp(h_method, h_headers, h_uri) { - Ok(request) => request, + let mut req = match Request::from_hyp(h_method, h_headers, h_uri, h_addr) { + Ok(req) => req, Err(e) => { error!("Bad incoming request: {}", e); let dummy = Request::new(Method::Get, URI::new("")); @@ -59,13 +60,13 @@ impl hyper::Handler for Rocket { Ok(data) => data, Err(reason) => { error_!("Bad data in request: {}", reason); - let r = self.handle_error(Status::InternalServerError, &request); + let r = self.handle_error(Status::InternalServerError, &req); return self.issue_response(r, res); } }; // Dispatch the request to get a response, then write that response out. - let response = self.dispatch(&mut request, data); + let response = self.dispatch(&mut req, data); self.issue_response(response, res) } } @@ -132,15 +133,33 @@ impl Rocket { } } - /// Preprocess the request for Rocket-specific things. At this time, we're - /// only checking for _method in forms. Keep this in-sync with derive_form - /// when preprocessing form fields. + /// Preprocess the request for Rocket things. Currently, this means: + /// + /// * Rewriting the method in the request if _method form field exists. + /// * Rewriting the remote IP if the 'X-Real-IP' header is set. + /// + /// Keep this in-sync with derive_form when preprocessing form fields. fn preprocess_request(&self, req: &mut Request, data: &Data) { + // Rewrite the remote IP address. The request must already have an + // address associated with it to do this since we need to know the port. + if let Some(current) = req.remote() { + let ip = req.headers() + .get_one("X-Real-IP") + .and_then(|ip_str| ip_str.parse().map_err(|_| { + warn_!("The 'X-Real-IP' header is malformed: {}", ip_str) + }).ok()); + + if let Some(ip) = ip { + req.set_remote(SocketAddr::new(ip, current.port())); + } + } + // Check if this is a form and if the form contains the special _method // field which we use to reinterpret the request's method. let data_len = data.peek().len(); let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); - if req.method() == Method::Post && req.content_type().is_form() && data_len >= min_len { + let is_form = req.content_type().is_form(); + if is_form && req.method() == Method::Post && data_len >= min_len { let form = unsafe { from_utf8_unchecked(&data.peek()[..min(data_len, max_len)]) }; @@ -157,6 +176,8 @@ impl Rocket { #[doc(hidden)] #[inline(always)] pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> { + info!("{}:", request); + // Do a bit of preprocessing before routing. self.preprocess_request(request, &data); @@ -207,7 +228,6 @@ impl Rocket { pub fn route<'r>(&self, request: &'r Request, mut data: Data) -> handler::Outcome<'r> { // Go through the list of matching routes until we fail or succeed. - info!("{}:", request); let matches = self.router.route(request); for route in matches { // Retrieve and set the requests parameters. diff --git a/lib/src/testing.rs b/lib/src/testing.rs index 1fb55d66..134f6ed5 100644 --- a/lib/src/testing.rs +++ b/lib/src/testing.rs @@ -108,6 +108,8 @@ use ::{Rocket, Request, Response, Data}; use http::{Method, Header, Cookie}; +use std::net::SocketAddr; + /// A type for mocking requests for testing Rocket applications. pub struct MockRequest { request: Request<'static>, @@ -143,6 +145,44 @@ impl MockRequest { self } + /// Set the remote address of this request. + /// + /// # Examples + /// + /// Set the remote address to "8.8.8.8:80": + /// + /// ```rust + /// use rocket::http::Method::*; + /// use rocket::testing::MockRequest; + /// + /// let address = "8.8.8.8:80".parse().unwrap(); + /// let req = MockRequest::new(Get, "/").remote(address); + /// ``` + #[inline] + pub fn remote(mut self, address: SocketAddr) -> Self { + self.request.set_remote(address); + self + } + + /// Adds a header to this request. Does not consume `self`. + /// + /// # Examples + /// + /// Add the Content-Type header: + /// + /// ```rust + /// use rocket::http::Method::*; + /// use rocket::testing::MockRequest; + /// use rocket::http::ContentType; + /// + /// let mut req = MockRequest::new(Get, "/"); + /// req.add_header(ContentType::JSON); + /// ``` + #[inline] + pub fn add_header<'h, H: Into>>(&mut self, header: H) { + self.request.add_header(header.into()); + } + /// Add a cookie to this request. /// /// # Examples diff --git a/lib/tests/remote-rewrite.rs b/lib/tests/remote-rewrite.rs new file mode 100644 index 00000000..56c99402 --- /dev/null +++ b/lib/tests/remote-rewrite.rs @@ -0,0 +1,87 @@ +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] + +extern crate rocket; + +use std::net::SocketAddr; + +#[get("/")] +fn get_ip(remote: SocketAddr) -> String { + remote.to_string() +} + +#[cfg(feature = "testing")] +mod remote_rewrite_tests { + use super::*; + use rocket::testing::MockRequest; + use rocket::http::Method::*; + use rocket::http::{Header, Status}; + + use std::net::SocketAddr; + + const KNOWN_IP: &'static str = "127.0.0.1:8000"; + + fn check_ip(header: Option>, ip: Option) { + let address: SocketAddr = KNOWN_IP.parse().unwrap(); + let port = address.port(); + + let rocket = rocket::ignite().mount("/", routes![get_ip]); + let mut req = MockRequest::new(Get, "/").remote(address); + if let Some(header) = header { + req.add_header(header); + } + + let mut response = req.dispatch_with(&rocket); + assert_eq!(response.status(), Status::Ok); + let body_str = response.body().and_then(|b| b.into_string()); + match ip { + Some(ip) => assert_eq!(body_str, Some(format!("{}:{}", ip, port))), + None => assert_eq!(body_str, Some(KNOWN_IP.into())) + } + } + + #[test] + fn x_real_ip_rewrites() { + let ip = "8.8.8.8"; + check_ip(Some(Header::new("X-Real-IP", ip)), Some(ip.to_string())); + + let ip = "129.120.111.200"; + check_ip(Some(Header::new("X-Real-IP", ip)), Some(ip.to_string())); + } + + #[test] + fn x_real_ip_rewrites_ipv6() { + let ip = "2001:db8:0:1:1:1:1:1"; + check_ip(Some(Header::new("X-Real-IP", ip)), Some(format!("[{}]", ip))); + + let ip = "2001:db8::2:1"; + check_ip(Some(Header::new("X-Real-IP", ip)), Some(format!("[{}]", ip))); + } + + #[test] + fn uncased_header_rewrites() { + let ip = "8.8.8.8"; + check_ip(Some(Header::new("x-REAL-ip", ip)), Some(ip.to_string())); + + let ip = "1.2.3.4"; + check_ip(Some(Header::new("x-real-ip", ip)), Some(ip.to_string())); + } + + #[test] + fn no_header_no_rewrite() { + check_ip(Some(Header::new("real-ip", "?")), None); + check_ip(None, None); + } + + #[test] + fn bad_header_doesnt_rewrite() { + let ip = "092348092348"; + check_ip(Some(Header::new("X-Real-IP", ip)), None); + + let ip = "1200:100000:0120129"; + check_ip(Some(Header::new("X-Real-IP", ip)), None); + + let ip = "192.168.1.900"; + check_ip(Some(Header::new("X-Real-IP", ip)), None); + } +} From 4bc5c20a45757be6b33e5fe9530e759d6c5bb032 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 13 Jan 2017 13:25:33 -0800 Subject: [PATCH 6/6] Fix security checks in `PathBuf::FromSegments`. In #134, @tunz discovered that Rocket does not properly prevent path traversal or local file inclusion attacks. The issue is caused by a failure to check for some dangerous characters after decoding. In this case, the path separator '/' was left as-is after decoding. As such, an attacker could construct a path with containing any number of `..%2f..` sequences to traverse the file system. This commit resolves the issue by ensuring that the decoded segment does not contains any `/` characters. It further hardens the `FromSegments` implementation by checking for additional risky characters: ':', '>', '<' as the last character, and '\' on Windows. This is in addition to the already present checks for '.' and '*' as the first character. The behavior for a failing check has also changed. Previously, Rocket would skip segments that contained illegal characters. In this commit, the implementation instead return an error. The `Error` type of the `PathBuf::FromSegment` implementations was changed to a new `SegmentError` type that indicates the condition that failed. Closes #134. --- codegen/tests/run-pass/segments.rs | 4 +-- lib/src/http/uri.rs | 14 +++++++++ lib/src/request/param.rs | 46 ++++++++++++++++++++++++------ 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/codegen/tests/run-pass/segments.rs b/codegen/tests/run-pass/segments.rs index fb22fd49..8e7ec45a 100644 --- a/codegen/tests/run-pass/segments.rs +++ b/codegen/tests/run-pass/segments.rs @@ -4,7 +4,7 @@ extern crate rocket; use std::path::PathBuf; -use std::str::Utf8Error; +use rocket::http::uri::SegmentError; #[post("//")] fn get(a: String, b: PathBuf) -> String { @@ -12,7 +12,7 @@ fn get(a: String, b: PathBuf) -> String { } #[post("//")] -fn get2(a: String, b: Result) -> String { +fn get2(a: String, b: Result) -> String { format!("{}/{}", a, b.unwrap().to_string_lossy()) } diff --git a/lib/src/http/uri.rs b/lib/src/http/uri.rs index fe01b854..f873efe3 100644 --- a/lib/src/http/uri.rs +++ b/lib/src/http/uri.rs @@ -373,6 +373,20 @@ impl<'a> Iterator for Segments<'a> { // } } +/// Errors which can occur when attempting to interpret a segment string as a +/// valid path segment. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum SegmentError { + /// The segment contained invalid UTF8 characters when percent decoded. + Utf8(Utf8Error), + /// The segment started with the wrapped invalid character. + BadStart(char), + /// The segment contained the wrapped invalid character. + BadChar(char), + /// The segment ended with the wrapped invalid character. + BadEnd(char), +} + #[cfg(test)] mod tests { use super::URI; diff --git a/lib/src/request/param.rs b/lib/src/request/param.rs index da85f3b6..385dc4a8 100644 --- a/lib/src/request/param.rs +++ b/lib/src/request/param.rs @@ -1,9 +1,9 @@ -use std::str::{Utf8Error, FromStr}; +use std::str::FromStr; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; use std::path::PathBuf; use std::fmt::Debug; -use http::uri::{URI, Segments}; +use http::uri::{URI, Segments, SegmentError}; /// Trait to convert a dynamic path segment string to a concrete value. /// @@ -274,6 +274,7 @@ pub trait FromSegments<'a>: Sized { impl<'a> FromSegments<'a> for Segments<'a> { type Error = (); + fn from_segments(segments: Segments<'a>) -> Result, ()> { Ok(segments) } @@ -281,19 +282,46 @@ impl<'a> FromSegments<'a> for Segments<'a> { /// Creates a `PathBuf` from a `Segments` iterator. The returned `PathBuf` is /// percent-decoded. If a segment is equal to "..", the previous segment (if -/// any) is skipped. For security purposes, any other segments that begin with -/// "*" or "." are ignored. If a percent-decoded segment results in invalid -/// UTF8, an `Err` is returned. +/// any) is skipped. +/// +/// For security purposes, if a segment meets any of the following conditions, +/// an `Err` is returned indicating the condition met: +/// +/// * Decoded segment starts with any of: `.`, `*` +/// * Decoded segment ends with any of: `:`, `>`, `<` +/// * Decoded segment contains any of: `/` +/// * On Windows, decoded segment contains any of: '\' +/// * Percent-encoding results in invalid UTF8. +/// +/// As a result of these conditions, a `PathBuf` derived via `FromSegments` is +/// safe to interpolate within, or use as a suffix of, a path without additional +/// checks. impl<'a> FromSegments<'a> for PathBuf { - type Error = Utf8Error; + type Error = SegmentError; - fn from_segments(segments: Segments<'a>) -> Result { + fn from_segments(segments: Segments<'a>) -> Result { let mut buf = PathBuf::new(); for segment in segments { - let decoded = URI::percent_decode(segment.as_bytes())?; + let decoded = URI::percent_decode(segment.as_bytes()) + .map_err(|e| SegmentError::Utf8(e))?; + if decoded == ".." { buf.pop(); - } else if !(decoded.starts_with('.') || decoded.starts_with('*')) { + } else if decoded.starts_with('.') { + return Err(SegmentError::BadStart('.')) + } else if decoded.starts_with('*') { + return Err(SegmentError::BadStart('*')) + } else if decoded.ends_with(':') { + return Err(SegmentError::BadEnd(':')) + } else if decoded.ends_with('>') { + return Err(SegmentError::BadEnd('>')) + } else if decoded.ends_with('<') { + return Err(SegmentError::BadEnd('<')) + } else if decoded.contains('/') { + return Err(SegmentError::BadChar('/')) + } else if cfg!(windows) && decoded.contains('\\') { + return Err(SegmentError::BadChar('\\')) + } else { buf.push(&*decoded) } }