Precisely route formats. Make 'content_type' an 'Option' in 'Request'.

This commit changes the routing algorithm. In particular, it enforces
precise matching of formats. With this change, a route with a specified
format only matches requests that have the same format specified. A
route with no format specified matches any request's format. This is
contrast to the previous behavior, where a route without a specified
format would match requests regardless of their format or whether one
was specified.

This commit also changes the following:
  * The return type of the 'content_type' method of 'Request' is now
    'Option<ContentType>'.
  * The 'ContentType' request guard forwards when the request has no
    specified ContentType.
  * The 'add_header' and 'replace_header' methods take the header
    argument generically.

Closes #120.
This commit is contained in:
Sergio Benitez 2017-02-01 03:12:24 -08:00
parent 3235e1e5e6
commit cc22836867
15 changed files with 273 additions and 132 deletions

View File

@ -73,7 +73,7 @@ impl<T: Deserialize> FromData for JSON<T> {
type Error = SerdeError; type Error = SerdeError;
fn from_data(request: &Request, data: Data) -> data::Outcome<Self, SerdeError> { fn from_data(request: &Request, data: Data) -> data::Outcome<Self, SerdeError> {
if !request.content_type().is_json() { if !request.content_type().map_or(false, |ct| ct.is_json()) {
error_!("Content-Type is not JSON."); error_!("Content-Type is not JSON.");
return Outcome::Forward(data); return Outcome::Forward(data);
} }

View File

@ -6,8 +6,7 @@ extern crate serde_json;
#[macro_use] #[macro_use]
extern crate serde_derive; extern crate serde_derive;
#[cfg(test)] #[cfg(test)] mod tests;
mod tests;
use rocket::Request; use rocket::Request;
use rocket::http::ContentType; use rocket::http::ContentType;
@ -34,13 +33,13 @@ fn hello(content_type: ContentType, name: String, age: i8) -> content::JSON<Stri
#[error(404)] #[error(404)]
fn not_found(request: &Request) -> content::HTML<String> { fn not_found(request: &Request) -> content::HTML<String> {
let html = if !request.content_type().is_json() { let html = match request.content_type() {
format!("<p>This server only supports JSON requests, not '{}'.</p>", Some(ref ct) if !ct.is_json() => {
request.content_type()) format!("<p>This server only supports JSON requests, not '{}'.</p>", ct)
} else { }
format!("<p>Sorry, '{}' is an invalid path! Try \ _ => format!("<p>Sorry, '{}' is an invalid path! Try \
/hello/&lt;name&gt;/&lt;age&gt; instead.</p>", /hello/&lt;name&gt;/&lt;age&gt; instead.</p>",
request.uri()) request.uri())
}; };
content::HTML(html) content::HTML(html)

View File

@ -28,7 +28,7 @@ fn echo_url(req: &Request, _: Data) -> Outcome<'static> {
} }
fn upload<'r>(req: &'r Request, data: Data) -> Outcome<'r> { fn upload<'r>(req: &'r Request, data: Data) -> Outcome<'r> {
if !req.content_type().is_plain() { if !req.content_type().map_or(false, |ct| ct.is_plain()) {
println!(" => Content-Type of upload must be text/plain. Ignoring."); println!(" => Content-Type of upload must be text/plain. Ignoring.");
return Outcome::failure(Status::BadRequest); return Outcome::failure(Status::BadRequest);
} }

View File

@ -99,7 +99,7 @@ impl<'a, S, E> IntoOutcome<S, (Status, E), Data> for Result<S, E> {
/// fn from_data(req: &Request, data: Data) -> data::Outcome<Self, String> { /// fn from_data(req: &Request, data: Data) -> data::Outcome<Self, String> {
/// // Ensure the content type is correct before opening the data. /// // Ensure the content type is correct before opening the data.
/// let person_ct = ContentType::new("application", "x-person"); /// let person_ct = ContentType::new("application", "x-person");
/// if req.content_type() != person_ct { /// if req.content_type() != Some(person_ct) {
/// return Outcome::Forward(data); /// return Outcome::Forward(data);
/// } /// }
/// ///

View File

@ -5,7 +5,6 @@ use std::fmt;
use http::Header; use http::Header;
use http::hyper::mime::Mime; use http::hyper::mime::Mime;
use http::ascii::{uncased_eq, UncasedAscii}; use http::ascii::{uncased_eq, UncasedAscii};
use router::Collider;
/// Representation of HTTP Content-Types. /// Representation of HTTP Content-Types.
/// ///
@ -419,13 +418,6 @@ impl Into<Header<'static>> for ContentType {
} }
} }
impl Collider for ContentType {
fn collides_with(&self, other: &ContentType) -> bool {
let collide = |a, b| a == "*" || b == "*" || a == b;
collide(&self.ttype, &other.ttype) && collide(&self.subtype, &other.subtype)
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::ContentType; use super::ContentType;

View File

@ -8,8 +8,6 @@ use std::str::Utf8Error;
use url; use url;
use router::Collider;
/// Index (start, end) into a string, to prevent borrowing. /// Index (start, end) into a string, to prevent borrowing.
type Index = (usize, usize); type Index = (usize, usize);
@ -299,26 +297,6 @@ impl<'a> fmt::Display for URI<'a> {
unsafe impl<'a> Sync for URI<'a> { /* It's safe! */ } unsafe impl<'a> Sync for URI<'a> { /* It's safe! */ }
impl<'a, 'b> Collider<URI<'b>> for URI<'a> {
fn collides_with(&self, other: &URI<'b>) -> bool {
for (seg_a, seg_b) in self.segments().zip(other.segments()) {
if seg_a.ends_with("..>") || seg_b.ends_with("..>") {
return true;
}
if !seg_a.collides_with(seg_b) {
return false;
}
}
if self.segment_count() != other.segment_count() {
return false;
}
true
}
}
/// Iterator over the segments of an absolute URI path. Skips empty segments. /// Iterator over the segments of an absolute URI path. Skips empty segments.
/// ///
/// ### Examples /// ### Examples

View File

@ -234,7 +234,7 @@ impl<'f, T: FromForm<'f>> FromData for Form<'f, T> where T::Error: Debug {
type Error = Option<String>; type Error = Option<String>;
fn from_data(request: &Request, data: Data) -> data::Outcome<Self, Self::Error> { fn from_data(request: &Request, data: Data) -> data::Outcome<Self, Self::Error> {
if !request.content_type().is_form() { if !request.content_type().map_or(false, |ct| ct.is_form()) {
warn_!("Form data does not have form content type."); warn_!("Form data does not have form content type.");
return Forward(data); return Forward(data);
} }

View File

@ -97,10 +97,8 @@ impl<S, E> IntoOutcome<S, (Status, E), ()> for Result<S, E> {
/// * **ContentType** /// * **ContentType**
/// ///
/// Extracts the [ContentType](/rocket/http/struct.ContentType.html) from /// Extracts the [ContentType](/rocket/http/struct.ContentType.html) from
/// the incoming request. If the request didn't specify a Content-Type, a /// the incoming request. If the request didn't specify a Content-Type, the
/// Content-Type of `*/*` (`Any`) is returned. /// request is forwarded.
///
/// _This implementation always returns successfully._
/// ///
/// * **SocketAddr** /// * **SocketAddr**
/// ///
@ -217,7 +215,10 @@ impl<'a, 'r> FromRequest<'a, 'r> for ContentType {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
Success(request.content_type()) match request.content_type() {
Some(content_type) => Success(content_type),
None => Forward(())
}
} }
} }

View File

@ -201,13 +201,13 @@ impl<'r> Request<'r> {
/// let mut request = Request::new(Method::Get, "/uri"); /// let mut request = Request::new(Method::Get, "/uri");
/// assert!(request.headers().is_empty()); /// assert!(request.headers().is_empty());
/// ///
/// request.add_header(ContentType::HTML.into()); /// request.add_header(ContentType::HTML);
/// assert!(request.headers().contains("Content-Type")); /// assert!(request.headers().contains("Content-Type"));
/// assert_eq!(request.headers().len(), 1); /// assert_eq!(request.headers().len(), 1);
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn add_header(&mut self, header: Header<'r>) { pub fn add_header<H: Into<Header<'r>>>(&mut self, header: H) {
self.headers.add(header); self.headers.add(header.into());
} }
/// Replaces the value of the header with `header.name` with `header.value`. /// Replaces the value of the header with `header.name` with `header.value`.
@ -222,15 +222,15 @@ impl<'r> Request<'r> {
/// let mut request = Request::new(Method::Get, "/uri"); /// let mut request = Request::new(Method::Get, "/uri");
/// assert!(request.headers().is_empty()); /// assert!(request.headers().is_empty());
/// ///
/// request.add_header(ContentType::HTML.into()); /// request.add_header(ContentType::HTML);
/// assert_eq!(request.content_type(), ContentType::HTML); /// assert_eq!(request.content_type(), Some(ContentType::HTML));
/// ///
/// request.replace_header(ContentType::JSON.into()); /// request.replace_header(ContentType::JSON);
/// assert_eq!(request.content_type(), ContentType::JSON); /// assert_eq!(request.content_type(), Some(ContentType::JSON));
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn replace_header(&mut self, header: Header<'r>) { pub fn replace_header<H: Into<Header<'r>>>(&mut self, header: H) {
self.headers.replace(header); self.headers.replace(header.into());
} }
/// Returns a borrow to the cookies in `self`. /// Returns a borrow to the cookies in `self`.
@ -262,8 +262,8 @@ impl<'r> Request<'r> {
self.cookies = cookies; self.cookies = cookies;
} }
/// Returns the Content-Type header of `self`. If the header is not present, /// Returns `Some` of the Content-Type header of `self`. If the header is
/// returns `ContentType::Any`. /// not present, returns `None`.
/// ///
/// # Example /// # Example
/// ///
@ -272,16 +272,15 @@ impl<'r> Request<'r> {
/// use rocket::http::{Method, ContentType}; /// use rocket::http::{Method, ContentType};
/// ///
/// let mut request = Request::new(Method::Get, "/uri"); /// let mut request = Request::new(Method::Get, "/uri");
/// assert_eq!(request.content_type(), ContentType::Any); /// assert_eq!(request.content_type(), None);
/// ///
/// request.replace_header(ContentType::JSON.into()); /// request.replace_header(ContentType::JSON);
/// assert_eq!(request.content_type(), ContentType::JSON); /// assert_eq!(request.content_type(), Some(ContentType::JSON));
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn content_type(&self) -> ContentType { pub fn content_type(&self) -> Option<ContentType> {
self.headers().get_one("Content-Type") self.headers().get_one("Content-Type")
.and_then(|value| value.parse().ok()) .and_then(|value| value.parse().ok())
.unwrap_or(ContentType::Any)
} }
/// Retrieves and parses into `T` the 0-indexed `n`th dynamic parameter from /// Retrieves and parses into `T` the 0-indexed `n`th dynamic parameter from
@ -458,8 +457,10 @@ impl<'r> fmt::Display for Request<'r> {
/// infrastructure. /// infrastructure.
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} {}", Green.paint(&self.method), Blue.paint(&self.uri))?; write!(f, "{} {}", Green.paint(&self.method), Blue.paint(&self.uri))?;
if self.method.supports_payload() && !self.content_type().is_any() { if let Some(content_type) = self.content_type() {
write!(f, " {}", Yellow.paint(self.content_type()))?; if self.method.supports_payload() {
write!(f, " {}", Yellow.paint(content_type))?;
}
} }
Ok(()) Ok(())

View File

@ -161,7 +161,7 @@ impl Rocket {
// field which we use to reinterpret the request's method. // field which we use to reinterpret the request's method.
let data_len = data.peek().len(); let data_len = data.peek().len();
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
let is_form = req.content_type().is_form(); let is_form = req.content_type().map_or(false, |ct| ct.is_form());
if is_form && req.method() == Method::Post && data_len >= min_len { if is_form && req.method() == Method::Post && data_len >= min_len {
let form = unsafe { let form = unsafe {
from_utf8_unchecked(&data.peek()[..min(data_len, max_len)]) from_utf8_unchecked(&data.peek()[..min(data_len, max_len)])

View File

@ -1,3 +1,9 @@
use super::Route;
use http::uri::URI;
use http::ContentType;
use request::Request;
/// The Collider trait is used to determine if two items that can be routed on /// The Collider trait is used to determine if two items that can be routed on
/// can match against a given request. That is, if two items `collide`, they /// can match against a given request. That is, if two items `collide`, they
/// will both match against _some_ request. /// will both match against _some_ request.
@ -44,18 +50,80 @@ impl<'a> Collider<str> for &'a str {
} }
} }
impl<'a, 'b> Collider<URI<'b>> for URI<'a> {
fn collides_with(&self, other: &URI<'b>) -> bool {
for (seg_a, seg_b) in self.segments().zip(other.segments()) {
if seg_a.ends_with("..>") || seg_b.ends_with("..>") {
return true;
}
if !seg_a.collides_with(seg_b) {
return false;
}
}
if self.segment_count() != other.segment_count() {
return false;
}
true
}
}
impl Collider for ContentType {
fn collides_with(&self, other: &ContentType) -> bool {
let collide = |a, b| a == "*" || b == "*" || a == b;
collide(&self.ttype, &other.ttype) && collide(&self.subtype, &other.subtype)
}
}
// This implementation is used at initialization to check if two user routes
// collide before launching. Format collisions works like this:
// * If route a specifies format, it only gets requests for that format.
// * If a route doesn't specify format, it gets requests for any format.
impl Collider for Route {
fn collides_with(&self, b: &Route) -> bool {
self.method == b.method
&& self.rank == b.rank
&& self.path.collides_with(&b.path)
&& match (self.format.as_ref(), b.format.as_ref()) {
(Some(ct_a), Some(ct_b)) => ct_a.collides_with(ct_b),
(Some(_), None) => true,
(None, Some(_)) => true,
(None, None) => true
}
}
}
// This implementation is used at runtime to check if a given request is
// intended for this Route. Format collisions works like this:
// * If route a specifies format, it only gets requests for that format.
// * If a route doesn't specify format, it gets requests for any format.
impl<'r> Collider<Request<'r>> for Route {
fn collides_with(&self, req: &Request<'r>) -> bool {
self.method == req.method()
&& req.uri().collides_with(&self.path)
// FIXME: On payload requests, check Content-Type, else Accept.
&& match (req.content_type().as_ref(), self.format.as_ref()) {
(Some(ct_a), Some(ct_b)) => ct_a.collides_with(ct_b),
(Some(_), None) => true,
(None, Some(_)) => false,
(None, None) => true
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::str::FromStr; use std::str::FromStr;
use router::Collider; use super::Collider;
use request::Request; use request::Request;
use data::Data; use data::Data;
use handler::Outcome; use handler::Outcome;
use router::route::Route; use router::route::Route;
use http::{Method, ContentType}; use http::{Method, ContentType};
use http::uri::URI; use http::uri::URI;
use http::Method::*; use http::Method::*;
type SimpleRoute = (Method, &'static str); type SimpleRoute = (Method, &'static str);
@ -199,12 +267,6 @@ mod tests {
assert!(!s_s_collide("/a/hi/<a..>", "/a/hi/")); assert!(!s_s_collide("/a/hi/<a..>", "/a/hi/"));
} }
fn ct_route(m: Method, s: &str, ct: &str) -> Route {
let mut route_a = Route::new(m, s, dummy_handler);
route_a.format = ContentType::from_str(ct).expect("Whoops!");
route_a
}
fn ct_ct_collide(ct1: &str, ct2: &str) -> bool { fn ct_ct_collide(ct1: &str, ct2: &str) -> bool {
let ct_a = ContentType::from_str(ct1).expect(ct1); let ct_a = ContentType::from_str(ct1).expect(ct1);
let ct_b = ContentType::from_str(ct2).expect(ct2); let ct_b = ContentType::from_str(ct2).expect(ct2);
@ -229,10 +291,20 @@ mod tests {
assert!(!ct_ct_collide("something/*", "random/else")); assert!(!ct_ct_collide("something/*", "random/else"));
} }
fn r_ct_ct_collide(m1: Method, ct1: &str, m2: Method, ct2: &str) -> bool { fn r_ct_ct_collide<S1, S2>(m1: Method, ct1: S1, m2: Method, ct2: S2) -> bool
let a_route = ct_route(m1, "a", ct1); where S1: Into<Option<&'static str>>, S2: Into<Option<&'static str>>
let b_route = ct_route(m2, "a", ct2); {
a_route.collides_with(&b_route) let mut route_a = Route::new(m1, "/", dummy_handler);
if let Some(ct_str) = ct1.into() {
route_a.format = Some(ct_str.parse::<ContentType>().unwrap());
}
let mut route_b = Route::new(m2, "/", dummy_handler);
if let Some(ct_str) = ct2.into() {
route_b.format = Some(ct_str.parse::<ContentType>().unwrap());
}
route_a.collides_with(&route_b)
} }
#[test] #[test]
@ -241,9 +313,56 @@ mod tests {
assert!(r_ct_ct_collide(Get, "*/json", Get, "application/json")); assert!(r_ct_ct_collide(Get, "*/json", Get, "application/json"));
assert!(r_ct_ct_collide(Get, "*/json", Get, "application/*")); assert!(r_ct_ct_collide(Get, "*/json", Get, "application/*"));
assert!(r_ct_ct_collide(Get, "text/html", Get, "text/*")); assert!(r_ct_ct_collide(Get, "text/html", Get, "text/*"));
assert!(r_ct_ct_collide(Get, "any/thing", Get, "*/*"));
assert!(r_ct_ct_collide(Get, None, Get, "text/*"));
assert!(r_ct_ct_collide(Get, None, Get, "text/html"));
assert!(r_ct_ct_collide(Get, None, Get, "*/*"));
assert!(r_ct_ct_collide(Get, "text/html", Get, None));
assert!(r_ct_ct_collide(Get, "*/*", Get, None));
assert!(r_ct_ct_collide(Get, "application/json", Get, None));
assert!(!r_ct_ct_collide(Get, "text/html", Get, "application/*")); assert!(!r_ct_ct_collide(Get, "text/html", Get, "application/*"));
assert!(!r_ct_ct_collide(Get, "application/html", Get, "text/*")); assert!(!r_ct_ct_collide(Get, "application/html", Get, "text/*"));
assert!(!r_ct_ct_collide(Get, "*/json", Get, "text/html")); assert!(!r_ct_ct_collide(Get, "*/json", Get, "text/html"));
assert!(!r_ct_ct_collide(Get, "text/html", Get, "text/css"));
}
fn req_route_collide<S1, S2>(m1: Method, ct1: S1, m2: Method, ct2: S2) -> bool
where S1: Into<Option<&'static str>>, S2: Into<Option<&'static str>>
{
let mut req = Request::new(m1, "/");
if let Some(ct_str) = ct1.into() {
req.replace_header(ct_str.parse::<ContentType>().unwrap());
}
let mut route = Route::new(m2, "/", dummy_handler);
if let Some(ct_str) = ct2.into() {
route.format = Some(ct_str.parse::<ContentType>().unwrap());
}
route.collides_with(&req)
}
#[test]
fn test_req_route_ct_collisions() {
assert!(req_route_collide(Get, "application/json", Get, "application/json"));
assert!(req_route_collide(Get, "application/json", Get, "application/*"));
assert!(req_route_collide(Get, "application/json", Get, "*/json"));
assert!(req_route_collide(Get, "text/html", Get, "text/html"));
assert!(req_route_collide(Get, "text/html", Get, "*/*"));
assert!(req_route_collide(Get, "text/html", Get, None));
assert!(req_route_collide(Get, None, Get, None));
assert!(req_route_collide(Get, "application/json", Get, None));
assert!(req_route_collide(Get, "x-custom/anything", Get, None));
assert!(!req_route_collide(Get, "application/json", Get, "text/html"));
assert!(!req_route_collide(Get, "application/json", Get, "text/*"));
assert!(!req_route_collide(Get, "application/json", Get, "*/xml"));
assert!(!req_route_collide(Get, None, Get, "text/html"));
assert!(!req_route_collide(Get, None, Get, "*/*"));
assert!(!req_route_collide(Get, None, Get, "application/json"));
} }
} }

View File

@ -1,11 +1,11 @@
mod collider; mod collider;
mod route; mod route;
pub use self::collider::Collider;
pub use self::route::Route;
use std::collections::hash_map::HashMap; use std::collections::hash_map::HashMap;
use self::collider::Collider;
pub use self::route::Route;
use request::Request; use request::Request;
use http::Method; use http::Method;
@ -206,7 +206,6 @@ mod test {
assert!(route(&router, Get, "/a/b/").is_some()); assert!(route(&router, Get, "/a/b/").is_some());
assert!(route(&router, Get, "/i/a").is_some()); assert!(route(&router, Get, "/i/a").is_some());
assert!(route(&router, Get, "/a/b/c/d/e/f").is_some()); assert!(route(&router, Get, "/a/b/c/d/e/f").is_some());
} }
#[test] #[test]

View File

@ -1,14 +1,11 @@
use std::fmt; use std::fmt;
use std::convert::From; use std::convert::From;
use super::Collider; // :D
use term_painter::ToStyle; use term_painter::ToStyle;
use term_painter::Color::*; use term_painter::Color::*;
use codegen::StaticRouteInfo; use codegen::StaticRouteInfo;
use handler::Handler; use handler::Handler;
use request::Request;
use http::{Method, ContentType}; use http::{Method, ContentType};
use http::uri::URI; use http::uri::URI;
@ -23,7 +20,7 @@ pub struct Route {
/// The rank of this route. Lower ranks have higher priorities. /// The rank of this route. Lower ranks have higher priorities.
pub rank: isize, pub rank: isize,
/// The Content-Type this route matches against. /// The Content-Type this route matches against.
pub format: ContentType, pub format: Option<ContentType>,
} }
fn default_rank(path: &str) -> isize { fn default_rank(path: &str) -> isize {
@ -45,7 +42,7 @@ impl Route {
handler: handler, handler: handler,
rank: default_rank(path.as_ref()), rank: default_rank(path.as_ref()),
path: URI::from(path.as_ref().to_string()), path: URI::from(path.as_ref().to_string()),
format: ContentType::Any, format: None,
} }
} }
@ -58,7 +55,7 @@ impl Route {
path: URI::from(path.as_ref().to_string()), path: URI::from(path.as_ref().to_string()),
handler: handler, handler: handler,
rank: rank, rank: rank,
format: ContentType::Any, format: None,
} }
} }
@ -115,11 +112,11 @@ impl fmt::Display for Route {
write!(f, " [{}]", White.paint(&self.rank))?; write!(f, " [{}]", White.paint(&self.rank))?;
} }
if !self.format.is_any() { if let Some(ref format) = self.format {
write!(f, " {}", Yellow.paint(&self.format)) write!(f, " {}", Yellow.paint(format))?;
} else {
Ok(())
} }
Ok(())
} }
} }
@ -133,7 +130,7 @@ impl fmt::Debug for Route {
impl<'a> From<&'a StaticRouteInfo> for Route { impl<'a> From<&'a StaticRouteInfo> for Route {
fn from(info: &'a StaticRouteInfo) -> Route { fn from(info: &'a StaticRouteInfo) -> Route {
let mut route = Route::new(info.method, info.path, info.handler); let mut route = Route::new(info.method, info.path, info.handler);
route.format = info.format.clone().unwrap_or(ContentType::Any); route.format = info.format.clone();
if let Some(rank) = info.rank { if let Some(rank) = info.rank {
route.rank = rank; route.rank = rank;
} }
@ -141,22 +138,3 @@ impl<'a> From<&'a StaticRouteInfo> for Route {
route route
} }
} }
impl Collider for Route {
fn collides_with(&self, b: &Route) -> bool {
self.method == b.method
&& self.rank == b.rank
&& self.format.collides_with(&b.format)
&& self.path.collides_with(&b.path)
}
}
impl<'r> Collider<Request<'r>> for Route {
fn collides_with(&self, req: &Request<'r>) -> bool {
self.method == req.method()
&& req.uri().collides_with(&self.path)
// FIXME: On payload requests, check Content-Type. On non-payload
// requests, check Accept.
&& req.content_type().collides_with(&self.format)
}
}

View File

@ -140,11 +140,30 @@ impl<'r> MockRequest<'r> {
/// let req = MockRequest::new(Get, "/").header(ContentType::JSON); /// let req = MockRequest::new(Get, "/").header(ContentType::JSON);
/// ``` /// ```
#[inline] #[inline]
pub fn header<'h, H: Into<Header<'static>>>(mut self, header: H) -> Self { pub fn header<H: Into<Header<'static>>>(mut self, header: H) -> Self {
self.request.add_header(header.into()); self.request.add_header(header.into());
self self
} }
/// Adds a header to this request without consuming `self`.
///
/// # Examples
///
/// Add the Content-Type header:
///
/// ```rust
/// use rocket::http::Method::*;
/// use rocket::testing::MockRequest;
/// use rocket::http::ContentType;
///
/// let mut req = MockRequest::new(Get, "/");
/// req.add_header(ContentType::JSON);
/// ```
#[inline]
pub fn add_header<H: Into<Header<'static>>>(&mut self, header: H) {
self.request.add_header(header.into());
}
/// Set the remote address of this request. /// Set the remote address of this request.
/// ///
/// # Examples /// # Examples
@ -164,25 +183,6 @@ impl<'r> MockRequest<'r> {
self self
} }
/// Adds a header to this request. Does not consume `self`.
///
/// # Examples
///
/// Add the Content-Type header:
///
/// ```rust
/// use rocket::http::Method::*;
/// use rocket::testing::MockRequest;
/// use rocket::http::ContentType;
///
/// let mut req = MockRequest::new(Get, "/");
/// req.add_header(ContentType::JSON);
/// ```
#[inline]
pub fn add_header<'h, H: Into<Header<'static>>>(&mut self, header: H) {
self.request.add_header(header.into());
}
/// Add a cookie to this request. /// Add a cookie to this request.
/// ///
/// # Examples /// # Examples

View File

@ -0,0 +1,74 @@
#![feature(plugin)]
#![plugin(rocket_codegen)]
extern crate rocket;
#[post("/", format = "application/json")]
fn specified() -> &'static str {
"specified"
}
#[post("/", rank = 2)]
fn unspecified() -> &'static str {
"unspecified"
}
#[post("/", format = "application/json")]
fn specified_json() -> &'static str {
"specified_json"
}
#[post("/", format = "text/html")]
fn specified_html() -> &'static str {
"specified_html"
}
#[cfg(feature = "testing")]
mod tests {
use super::*;
use rocket::Rocket;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::{Status, ContentType};
fn rocket() -> Rocket {
rocket::ignite()
.mount("/first", routes![specified, unspecified])
.mount("/second", routes![specified_json, specified_html])
}
macro_rules! check_dispatch {
($mount:expr, $ct:expr, $body:expr) => (
let rocket = rocket();
let mut req = MockRequest::new(Post, $mount);
let ct: Option<ContentType> = $ct;
if let Some(ct) = ct {
req.add_header(ct);
}
let mut response = req.dispatch_with(&rocket);
let body_str = response.body().and_then(|b| b.into_string());
let body: Option<&'static str> = $body;
match body {
Some(string) => assert_eq!(body_str, Some(string.to_string())),
None => assert_eq!(response.status(), Status::NotFound)
}
)
}
#[test]
fn exact_match_or_forward() {
check_dispatch!("/first", Some(ContentType::JSON), Some("specified"));
check_dispatch!("/first", None, Some("unspecified"));
check_dispatch!("/first", Some(ContentType::HTML), Some("unspecified"));
}
#[test]
fn exact_match_or_none() {
check_dispatch!("/second", Some(ContentType::JSON), Some("specified_json"));
check_dispatch!("/second", Some(ContentType::HTML), Some("specified_html"));
check_dispatch!("/second", Some(ContentType::CSV), None);
check_dispatch!("/second", None, None);
}
}