From 90d8621adfda7e27402a39285c5d976afc53ae40 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 26 Aug 2016 01:55:11 -0700 Subject: [PATCH] Major overhual: Request, ErrorHandler, ContentType. --- examples/content_types/src/main.rs | 22 ++-- examples/errors/src/main.rs | 6 +- examples/manual_routes/src/main.rs | 6 +- lib/src/catcher.rs | 21 ++-- lib/src/codegen.rs | 4 +- lib/src/content_type.rs | 189 ++++++++++++++++++++++++++--- lib/src/error.rs | 30 +---- lib/src/lib.rs | 9 +- lib/src/logger.rs | 5 +- lib/src/method.rs | 4 +- lib/src/request/from_request.rs | 2 +- lib/src/request/request.rs | 100 +++++++++++---- lib/src/rocket.rs | 76 +++++------- lib/src/router/collider.rs | 52 +++++++- lib/src/router/mod.rs | 96 ++++++++------- lib/src/router/route.rs | 51 ++++---- macros/Cargo.toml | 2 + macros/src/error_decorator.rs | 10 +- macros/src/lib.rs | 1 + macros/src/meta_item_parser.rs | 8 +- macros/src/route_decorator.rs | 83 ++++++++++--- 21 files changed, 527 insertions(+), 250 deletions(-) diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index 5cd33f1d..6ea98df3 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -4,9 +4,7 @@ extern crate rocket; extern crate serde_json; -use rocket::{Rocket, RoutingError}; -use rocket::ContentType; -use rocket::Error; +use rocket::{Rocket, Request, ContentType, Error}; #[derive(Debug, Serialize, Deserialize)] struct Person { @@ -25,23 +23,23 @@ fn hello(name: String, age: i8) -> String { } #[error(code = "404")] -fn not_found(error: RoutingError) -> String { - match error.error { - // Error::BadMethod if !error.request.content_type.is_json() => { - // format!("

This server only supports JSON requests, not '{}'.

", - // error.request.data) - // } - Error::BadMethod => { +fn not_found<'r>(error: Error, request: &'r Request<'r>) -> String { + match error { + Error::NoRoute if !request.content_type().is_json() => { + format!("

This server only supports JSON requests, not '{}'.

", + request.content_type()) + } + Error::NoRoute => { format!("

Sorry, this server but '{}' is not a valid path!

Try visiting /hello/<name>/<age> instead.

", - error.request.uri) + request.uri()) } _ => format!("

Bad Request

"), } } fn main() { - let mut rocket = Rocket::new("localhost", 8000); + let mut rocket = Rocket::new("0.0.0.0", 8000); rocket.mount("/hello", routes![hello]); rocket.catch(errors![not_found]); rocket.launch(); diff --git a/examples/errors/src/main.rs b/examples/errors/src/main.rs index b143cfcf..263cb7e1 100644 --- a/examples/errors/src/main.rs +++ b/examples/errors/src/main.rs @@ -2,7 +2,7 @@ #![plugin(rocket_macros)] extern crate rocket; -use rocket::{Rocket, RoutingError}; +use rocket::{Rocket, Error, Request}; #[route(GET, path = "/hello//")] fn hello(name: &str, age: i8) -> String { @@ -10,10 +10,10 @@ fn hello(name: &str, age: i8) -> String { } #[error(code = "404")] -fn not_found(error: RoutingError) -> String { +fn not_found<'r>(_error: Error, request: &'r Request<'r>) -> String { format!("

Sorry, but '{}' is not a valid path!

Try visiting /hello/<name>/<age> instead.

", - error.request.uri) + request.uri) } fn main() { diff --git a/examples/manual_routes/src/main.rs b/examples/manual_routes/src/main.rs index ca5c0608..00420a59 100644 --- a/examples/manual_routes/src/main.rs +++ b/examples/manual_routes/src/main.rs @@ -3,14 +3,14 @@ extern crate rocket; use rocket::{Rocket, Request, Response, Route}; use rocket::Method::*; -fn root(req: Request) -> Response { +fn root<'r>(req: &'r Request<'r>) -> Response<'r> { let name = req.get_param(0).unwrap_or("unnamed"); Response::new(format!("Hello, {}!", name)) } #[allow(dead_code)] -fn echo_url<'a>(req: Request<'a>) -> Response<'a> { - Response::new(req.get_uri().split_at(6).1) +fn echo_url<'a>(req: &'a Request<'a>) -> Response<'a> { + Response::new(req.uri().as_str().split_at(6).1) } fn main() { diff --git a/lib/src/catcher.rs b/lib/src/catcher.rs index d3cc6403..deba89af 100644 --- a/lib/src/catcher.rs +++ b/lib/src/catcher.rs @@ -1,7 +1,8 @@ -use handler::Handler; +use handler::ErrorHandler; use response::Response; -use error::RoutingError; use codegen::StaticCatchInfo; +use error::Error; +use request::Request; use std::fmt; use term_painter::ToStyle; @@ -9,7 +10,7 @@ use term_painter::Color::*; pub struct Catcher { pub code: u16, - handler: Handler, + handler: ErrorHandler, is_default: bool } @@ -18,15 +19,15 @@ pub struct Catcher { // interface here? impl Catcher { - pub fn new(code: u16, handler: Handler) -> Catcher { + pub fn new(code: u16, handler: ErrorHandler) -> Catcher { Catcher::new_with_default(code, handler, false) } - pub fn handle<'a>(&'a self, error: RoutingError<'a>) -> Response { - (self.handler)(error.request) + pub fn handle<'r>(&self, error: Error, request: &'r Request<'r>) -> Response<'r> { + (self.handler)(error, request) } - fn new_with_default(code: u16, handler: Handler, default: bool) -> Catcher { + fn new_with_default(code: u16, handler: ErrorHandler, default: bool) -> Catcher { Catcher { code: code, handler: handler, @@ -55,9 +56,10 @@ pub mod defaults { use request::Request; use response::{StatusCode, Response}; use super::Catcher; + use error::Error; use std::collections::HashMap; - pub fn not_found(_request: Request) -> Response { + pub fn not_found<'r>(_error: Error, _request: &'r Request<'r>) -> Response<'r> { Response::with_status(StatusCode::NotFound, "\ \ \ @@ -72,7 +74,8 @@ pub mod defaults { ") } - pub fn internal_error(_request: Request) -> Response { + pub fn internal_error<'r>(_error: Error, _request: &'r Request<'r>) + -> Response<'r> { Response::with_status(StatusCode::InternalServerError, "\ \ \ diff --git a/lib/src/codegen.rs b/lib/src/codegen.rs index 5b43f493..ce0b275c 100644 --- a/lib/src/codegen.rs +++ b/lib/src/codegen.rs @@ -1,4 +1,4 @@ -use ::{Method, Handler}; +use ::{Method, Handler, ErrorHandler}; use content_type::ContentType; pub struct StaticRouteInfo { @@ -10,6 +10,6 @@ pub struct StaticRouteInfo { pub struct StaticCatchInfo { pub code: u16, - pub handler: Handler + pub handler: ErrorHandler } diff --git a/lib/src/content_type.rs b/lib/src/content_type.rs index 5c7f1da4..dac1b238 100644 --- a/lib/src/content_type.rs +++ b/lib/src/content_type.rs @@ -1,14 +1,23 @@ -pub use mime::{Mime, TopLevel, SubLevel}; +pub use hyper::mime::{Mime, TopLevel, SubLevel}; use std::str::FromStr; -use mime::{Param}; +use std::borrow::Borrow; +use std::fmt; +use hyper::mime::{Param}; use self::TopLevel::{Text, Application}; use self::SubLevel::{Json, Html}; +use router::Collider; + #[derive(Debug, Clone)] pub struct ContentType(pub TopLevel, pub SubLevel, pub Option>); impl ContentType { + #[inline(always)] + pub fn new(t: TopLevel, s: SubLevel, params: Option>) -> ContentType { + ContentType(t, s, params) + } + #[inline(always)] pub fn of(t: TopLevel, s: SubLevel) -> ContentType { ContentType(t, s, None) @@ -20,17 +29,11 @@ impl ContentType { } pub fn is_json(&self) -> bool { - match *self { - ContentType(Application, Json, _) => true, - _ => false, - } + self.0 == Application && self.1 == Json } pub fn is_any(&self) -> bool { - match *self { - ContentType(TopLevel::Star, SubLevel::Star, None) => true, - _ => false, - } + self.0 == TopLevel::Star && self.1 == SubLevel::Star } pub fn is_ext(&self) -> bool { @@ -44,10 +47,7 @@ impl ContentType { } pub fn is_html(&self) -> bool { - match *self { - ContentType(Text, Html, _) => true, - _ => false, - } + self.0 == Text && self.1 == Html } } @@ -57,6 +57,13 @@ impl Into for ContentType { } } +impl> From for ContentType { + default fn from(mime: T) -> ContentType { + let mime: Mime = mime.borrow().clone(); + ContentType::from(mime) + } +} + impl From for ContentType { fn from(mime: Mime) -> ContentType { let params = match mime.2.len() { @@ -68,11 +75,153 @@ impl From for ContentType { } } -impl FromStr for ContentType { - type Err = (); - - fn from_str(raw: &str) -> Result { - let mime = Mime::from_str(raw)?; - Ok(ContentType::from(mime)) +fn is_valid_first_char(c: char) -> bool { + match c { + 'a'...'z' | 'A'...'Z' | '0'...'9' | '*' => true, + _ => false + } +} + +fn is_valid_char(c: char) -> bool { + is_valid_first_char(c) || match c { + '!' | '#' | '$' | '&' | '-' | '^' | '.' | '+' | '_' => true, + _ => false + } +} + +impl FromStr for ContentType { + type Err = &'static str; + + fn from_str(raw: &str) -> Result { + let slash = match raw.find('/') { + Some(i) => i, + None => return Err("Missing / in MIME type.") + }; + + let top_str = &raw[..slash]; + let (sub_str, rest) = match raw.find(';') { + Some(j) => (&raw[(slash + 1)..j], Some(&raw[(j + 1)..])), + None => (&raw[(slash + 1)..], None) + }; + + if top_str.len() < 1 || sub_str.len() < 1 { + return Err("Empty string.") + } + + if !is_valid_first_char(top_str.chars().next().unwrap()) + || !is_valid_first_char(sub_str.chars().next().unwrap()) { + return Err("Invalid first char.") + } + + if top_str.contains(|c| !is_valid_char(c)) + || sub_str.contains(|c| !is_valid_char(c)) { + return Err("Invalid character in string.") + } + + let (top_str, sub_str) = (&*top_str.to_lowercase(), &*sub_str.to_lowercase()); + let top_level = TopLevel::from_str(top_str).map_err(|_| "Bad TopLevel")?; + let sub_level = SubLevel::from_str(sub_str).map_err(|_| "Bad SubLevel")?; + // FIXME: Use `rest` to find params. + Ok(ContentType::new(top_level, sub_level, None)) + } +} + +impl fmt::Display for ContentType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}/{}", self.0.as_str(), self.1.as_str())?; + + self.2.as_ref().map_or(Ok(()), |params| { + for param in params.iter() { + let (ref attr, ref value) = *param; + write!(f, "; {}={}", attr, value)?; + } + + Ok(()) + }) + } +} + +impl Collider for ContentType { + fn collides_with(&self, other: &ContentType) -> bool { + self.0.collides_with(&other.0) && self.1.collides_with(&other.1) + } +} + +impl Collider for TopLevel { + fn collides_with(&self, other: &TopLevel) -> bool { + *self == TopLevel::Star + || *other == TopLevel::Star + || *self == *other + } +} + +impl Collider for SubLevel { + fn collides_with(&self, other: &SubLevel) -> bool { + *self == SubLevel::Star + || *other == SubLevel::Star + || *self == *other + } +} + +#[cfg(test)] +mod test { + use super::ContentType; + use hyper::mime::{TopLevel, SubLevel}; + use std::str::FromStr; + + + macro_rules! assert_no_parse { + ($string:expr) => ({ + let result = ContentType::from_str($string); + if !result.is_err() { + println!("{} parsed!", $string); + } + + assert!(result.is_err()); + }); + } + + macro_rules! assert_parse { + ($string:expr) => ({ + let result = ContentType::from_str($string); + assert!(result.is_ok()); + result.unwrap() + }); + ($string:expr, $top:tt/$sub:tt) => ({ + let c = assert_parse!($string); + assert_eq!(c.0, TopLevel::$top); + assert_eq!(c.1, SubLevel::$sub); + c + }) + } + + #[test] + fn test_simple() { + assert_parse!("application/json", Application/Json); + assert_parse!("*/json", Star/Json); + assert_parse!("text/html", Text/Html); + assert_parse!("TEXT/html", Text/Html); + assert_parse!("*/*", Star/Star); + assert_parse!("application/*", Application/Star); + } + + #[test] + fn test_params() { + assert_parse!("application/json; charset=utf8", Application/Json); + assert_parse!("application/*;charset=utf8;else=1", Application/Star); + assert_parse!("*/*;charset=utf8;else=1", Star/Star); + } + + #[test] + fn test_bad_parses() { + assert_no_parse!("application//json"); + assert_no_parse!("application///json"); + assert_no_parse!("/json"); + assert_no_parse!("text/"); + assert_no_parse!("text//"); + assert_no_parse!("/"); + assert_no_parse!("*/"); + assert_no_parse!("/*"); + assert_no_parse!("///"); } } diff --git a/lib/src/error.rs b/lib/src/error.rs index c860f396..69693e27 100644 --- a/lib/src/error.rs +++ b/lib/src/error.rs @@ -1,35 +1,7 @@ -use request::Request; - #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum Error { BadMethod, BadParse, - NoRoute, + NoRoute, // FIXME: Add a chain of routes attempted. NoKey } - -pub struct RoutingError<'r> { - pub error: Error, - pub request: Request<'r>, - pub chain: Option<&'r [&'r str]> -} - -impl<'a> RoutingError<'a> { - pub fn unchained(request: Request<'a>) - -> RoutingError<'a> { - RoutingError { - error: Error::NoRoute, - request: request, - chain: None, - } - } - - pub fn new(error: Error, request: Request<'a>, chain: &'a [&'a str]) - -> RoutingError<'a> { - RoutingError { - error: error, - request: request, - chain: Some(chain) - } - } -} diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 6d18836f..6e0788e7 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -23,9 +23,10 @@ mod codegen; mod catcher; pub mod handler { - use super::{Request, Response}; + use super::{Request, Response, Error}; - pub type Handler = for<'r> fn(Request<'r>) -> Response<'r>; + pub type Handler = for<'r> fn(&'r Request<'r>) -> Response<'r>; + pub type ErrorHandler = for<'r> fn(error: Error, &'r Request<'r>) -> Response<'r>; } pub use logger::RocketLogger; @@ -34,9 +35,9 @@ pub use codegen::{StaticRouteInfo, StaticCatchInfo}; pub use request::Request; pub use method::Method; pub use response::{Response, Responder}; -pub use error::{Error, RoutingError}; +pub use error::Error; pub use param::FromParam; pub use router::{Router, Route}; pub use catcher::Catcher; pub use rocket::Rocket; -pub use handler::Handler; +pub use handler::{Handler, ErrorHandler}; diff --git a/lib/src/logger.rs b/lib/src/logger.rs index bb5cf341..6fe83936 100644 --- a/lib/src/logger.rs +++ b/lib/src/logger.rs @@ -75,8 +75,9 @@ impl Log for RocketLogger { } Debug => { let loc = record.location(); - println!("{} {}:{}", Cyan.paint("-->"), loc.file(), loc.line()); - println!("{}", Cyan.paint(record.args())); + print!("\n{} ", Blue.bold().paint("-->")); + println!("{}:{}", Blue.paint(loc.file()), Blue.paint(loc.line())); + println!("{}", record.args()); } } } diff --git a/lib/src/method.rs b/lib/src/method.rs index 4e7e5890..0b727d1b 100644 --- a/lib/src/method.rs +++ b/lib/src/method.rs @@ -30,7 +30,7 @@ impl Method { HyperMethod::Trace => Some(Trace), HyperMethod::Connect => Some(Connect), HyperMethod::Patch => Some(Patch), - _ => None + HyperMethod::Extension(_) => None } } } @@ -49,7 +49,7 @@ impl FromStr for Method { "TRACE" => Ok(Trace), "CONNECT" => Ok(Connect), "PATCH" => Ok(Patch), - _ => Err(Error::BadMethod) + _ => Err(Error::BadMethod), } } } diff --git a/lib/src/request/from_request.rs b/lib/src/request/from_request.rs index 245f3f6b..d1a9608d 100644 --- a/lib/src/request/from_request.rs +++ b/lib/src/request/from_request.rs @@ -28,7 +28,7 @@ impl<'r, 'c> FromRequest<'r, 'c> for Cookies { type Error = &'static str; fn from_request(request: &'r Request<'c>) -> Result { - match request.headers.get::() { + match request.headers().get::() { // TODO: What to do about key? Some(cookie) => Ok(cookie.to_cookie_jar(&[])), None => Ok(Cookies::new(&[])) diff --git a/lib/src/request/request.rs b/lib/src/request/request.rs index 348c0155..83f845e4 100644 --- a/lib/src/request/request.rs +++ b/lib/src/request/request.rs @@ -1,38 +1,96 @@ +use std::io::{Read}; +use std::cell::RefCell; + use error::Error; use param::FromParam; use method::Method; -use request::HyperHeaders; -#[derive(Clone, Debug)] +use content_type::ContentType; +use hyper::uri::RequestUri as HyperRequestUri; +use hyper::header; +use router::URIBuf; +use router::URI; +use router::Route; + +// Hyper stuff. +use request::{HyperHeaders, HyperRequest}; + pub struct Request<'a> { - params: Option>, - pub headers: &'a HyperHeaders, // TODO: Don't make pub?.... + pub params: RefCell>>, // This also sucks. pub method: Method, - pub uri: &'a str, - pub data: &'a [u8] + pub uri: URIBuf, // FIXME: Should be URI (without Hyper). + pub data: Vec, // FIXME: Don't read this! (bad Hyper.) + headers: HyperHeaders, // This sucks. } impl<'a> Request<'a> { - pub fn new(headers: &'a HyperHeaders, method: Method, uri: &'a str, - params: Option>, data: &'a [u8]) -> Request<'a> { - Request { - headers: headers, - method: method, - params: params, - uri: uri, - data: data + pub fn get_param>(&'a self, n: usize) -> Result { + let params = self.params.borrow(); + if params.is_none() || n >= params.as_ref().unwrap().len() { + Err(Error::NoKey) + } else { + T::from_param(params.as_ref().unwrap()[n]) } } - pub fn get_uri(&self) -> &'a str { - self.uri + #[cfg(test)] + pub fn mock(method: Method, uri: &str) -> Request { + Request { + params: RefCell::new(None), + method: method, + uri: URIBuf::from(uri), + data: vec![], + headers: HyperHeaders::new() + } } - pub fn get_param>(&'a self, n: usize) -> Result { - if self.params.is_none() || n >= self.params.as_ref().unwrap().len() { - Err(Error::NoKey) - } else { - T::from_param(self.params.as_ref().unwrap()[n]) + + // FIXME: Get rid of Hyper. + #[inline(always)] + pub fn headers(&self) -> &HyperHeaders { + &self.headers + } + + pub fn content_type(&self) -> ContentType { + let hyp_ct = self.headers().get::(); + hyp_ct.map_or(ContentType::any(), |ct| ContentType::from(&ct.0)) + } + + pub fn uri(&'a self) -> URI<'a> { + self.uri.as_uri() + } + + pub fn set_params(&'a self, route: &Route) { + *self.params.borrow_mut() = Some(route.get_params(self.uri.as_uri())) + } + + #[cfg(test)] + pub fn set_content_type(&mut self, ct: ContentType) { + let hyper_ct = header::ContentType(ct.into()); + self.headers.set::(hyper_ct) + } + +} + +impl<'a, 'h, 'k> From> for Request<'a> { + fn from(hyper_req: HyperRequest<'h, 'k>) -> Request<'a> { + let (_, h_method, h_headers, h_uri, _, mut h_body) = hyper_req.deconstruct(); + + let uri = match h_uri { + HyperRequestUri::AbsolutePath(s) => URIBuf::from(s), + _ => panic!("Can only accept absolute paths!") + }; + + // FIXME: GRRR. + let mut data = vec![]; + h_body.read_to_end(&mut data).unwrap(); + + Request { + params: RefCell::new(None), + method: Method::from_hyp(&h_method).unwrap(), + uri: uri, + data: data, + headers: h_headers, } } } diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index b2265fa7..7765c809 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -28,20 +28,13 @@ fn uri_is_absolute(uri: &HyperRequestUri) -> bool { } } -fn unwrap_absolute_path(uri: &HyperRequestUri) -> &str { - match *uri { - HyperRequestUri::AbsolutePath(ref s) => s.as_str(), - _ => panic!("Can only accept absolute paths!") - } -} - fn method_is_valid(method: &HyperMethod) -> bool { Method::from_hyp(method).is_some() } impl HyperHandler for Rocket { - fn handle<'a, 'k>(&'a self, req: HyperRequest<'a, 'k>, - mut res: FreshHyperResponse<'a>) { + fn handle<'h, 'k>(&self, req: HyperRequest<'h, 'k>, + mut res: FreshHyperResponse<'h>) { info!("{:?} '{}':", Green.paint(&req.method), Blue.paint(&req.uri)); let finalize = |mut req: HyperRequest, _res: FreshHyperResponse| { @@ -68,51 +61,38 @@ impl HyperHandler for Rocket { } impl Rocket { - fn dispatch<'h, 'k>(&self, mut req: HyperRequest<'h, 'k>, + fn dispatch<'h, 'k>(&self, hyper_req: HyperRequest<'h, 'k>, res: FreshHyperResponse<'h>) { - // We read all of the contents now because we have to do it at some - // point thanks to Hyper. FIXME: Simple DOS attack here. - let mut buf = vec![]; - let _ = req.read_to_end(&mut buf); + let req = Request::from(hyper_req); + let route = self.router.route(&req); + if let Some(route) = route { + // Retrieve and set the requests parameters. + req.set_params(&route); - // Extract the method, uri, and try to find a route. - let method = Method::from_hyp(&req.method).unwrap(); - let uri = unwrap_absolute_path(&req.uri); - let route = self.router.route(method, uri); + // Here's the magic: dispatch the request to the handler. + let outcome = (route.handler)(&req).respond(res); + info_!("{} {}", White.paint("Outcome:"), outcome); - // A closure which we call when we know there is no route. - let handle_not_found = |response: FreshHyperResponse| { - error_!("Dispatch failed. Returning 404."); - - let request = Request::new(&req.headers, method, uri, None, &buf); - let catcher = self.catchers.get(&404).unwrap(); - catcher.handle(RoutingError::unchained(request)).respond(response); - }; - - // No route found. Handle the not_found error and return. - if route.is_none() { + // // TODO: keep trying lower ranked routes before dispatching a not + // // found error. + // outcome.map_forward(|res| { + // error_!("No further matching routes."); + // // TODO: Have some way to know why this was failed forward. Use that + // // instead of always using an unchained error. + // self.handle_not_found(req, res); + // }); + } else { error_!("No matching routes."); - return handle_not_found(res); + return self.handle_not_found(&req, res); } + } - // Okay, we've got a route. Unwrap it, generate a request, and dispatch. - let route = route.unwrap(); - let params = route.get_params(uri); - let request = Request::new(&req.headers, method, uri, Some(params), &buf); - - // TODO: Paint these magenta. - trace_!("Dispatching request."); - let outcome = (route.handler)(request).respond(res); - - // TODO: keep trying lower ranked routes before dispatching a not found - // error. - info_!("{} {}", White.paint("Outcome:"), outcome); - outcome.map_forward(|res| { - error_!("No further matching routes."); - // TODO: Have some way to know why this was failed forward. Use that - // instead of always using an unchained error. - handle_not_found(res); - }); + // A closure which we call when we know there is no route. + fn handle_not_found<'r>(&self, request: &'r Request<'r>, + response: FreshHyperResponse) { + error_!("Dispatch failed. Returning 404."); + let catcher = self.catchers.get(&404).unwrap(); + catcher.handle(Error::NoRoute, request).respond(response); } pub fn new(address: &'static str, port: isize) -> Rocket { diff --git a/lib/src/router/collider.rs b/lib/src/router/collider.rs index a94d4110..3ba97027 100644 --- a/lib/src/router/collider.rs +++ b/lib/src/router/collider.rs @@ -45,10 +45,12 @@ mod tests { use Method; use Method::*; use {Request, Response}; + use content_type::{ContentType, TopLevel, SubLevel}; + use std::str::FromStr; type SimpleRoute = (Method, &'static str); - fn dummy_handler(_req: Request) -> Response<'static> { + fn dummy_handler(_req: &Request) -> Response<'static> { Response::empty() } @@ -183,4 +185,52 @@ mod tests { assert!(!s_s_collide("/a/", "/b/")); assert!(!s_s_collide("/a/", "/b/")); } + + fn ct_route(m: Method, s: &str, ct: &str) -> Route { + let mut route_a = Route::new(m, s, dummy_handler); + route_a.content_type = ContentType::from_str(ct).expect("Whoops!"); + route_a + } + + fn ct_ct_collide(ct1: &str, ct2: &str) -> bool { + let ct_a = ContentType::from_str(ct1).expect(ct1); + let ct_b = ContentType::from_str(ct2).expect(ct2); + ct_a.collides_with(&ct_b) + } + + #[test] + fn test_content_type_colliions() { + assert!(ct_ct_collide("application/json", "application/json")); + assert!(ct_ct_collide("*/json", "application/json")); + assert!(ct_ct_collide("*/*", "application/json")); + assert!(ct_ct_collide("application/*", "application/json")); + assert!(ct_ct_collide("application/*", "*/json")); + assert!(ct_ct_collide("something/random", "something/random")); + + assert!(!ct_ct_collide("text/*", "application/*")); + assert!(!ct_ct_collide("*/text", "*/json")); + assert!(!ct_ct_collide("*/text", "application/test")); + assert!(!ct_ct_collide("something/random", "something_else/random")); + assert!(!ct_ct_collide("something/random", "*/else")); + assert!(!ct_ct_collide("*/random", "*/else")); + assert!(!ct_ct_collide("something/*", "random/else")); + } + + fn r_ct_ct_collide(m1: Method, ct1: &str, m2: Method, ct2: &str) -> bool { + let a_route = ct_route(m1, "a", ct1); + let b_route = ct_route(m2, "a", ct2); + a_route.collides_with(&b_route) + } + + #[test] + fn test_route_content_type_colliions() { + assert!(r_ct_ct_collide(Get, "application/json", Get, "application/json")); + assert!(r_ct_ct_collide(Get, "*/json", Get, "application/json")); + assert!(r_ct_ct_collide(Get, "*/json", Get, "application/*")); + assert!(r_ct_ct_collide(Get, "text/html", Get, "text/*")); + + assert!(!r_ct_ct_collide(Get, "text/html", Get, "application/*")); + assert!(!r_ct_ct_collide(Get, "application/html", Get, "text/*")); + assert!(!r_ct_ct_collide(Get, "*/json", Get, "text/html")); + } } diff --git a/lib/src/router/mod.rs b/lib/src/router/mod.rs index 702d8374..9000aff3 100644 --- a/lib/src/router/mod.rs +++ b/lib/src/router/mod.rs @@ -8,6 +8,7 @@ pub use self::route::Route; use std::collections::hash_map::HashMap; use method::Method; +use request::Request; type Selector = (Method, usize); @@ -31,13 +32,15 @@ impl Router { // `Route` structure is inflexible. Have it be an associated type. // FIXME: Figure out a way to get more than one route, i.e., to correctly // handle ranking. - pub fn route<'b>(&'b self, method: Method, uri: &str) -> Option<&'b Route> { - let mut matched_route: Option<&Route> = None; + // TODO: Should the Selector include the content-type? If it does, can't + // warn the user that a match was found for the wrong content-type. It + // doesn't, can, but this method is slower. + pub fn route<'b>(&'b self, req: &Request) -> Option<&'b Route> { + let num_segments = req.uri.segment_count(); - let path = URI::new(uri); - let num_segments = path.segment_count(); - if let Some(routes) = self.routes.get(&(method, num_segments)) { - for route in routes.iter().filter(|r| r.collides_with(uri)) { + let mut matched_route: Option<&Route> = None; + if let Some(routes) = self.routes.get(&(req.method, num_segments)) { + for route in routes.iter().filter(|r| r.collides_with(req)) { info_!("Matched: {}", route); if let Some(existing_route) = matched_route { if route.rank > existing_route.rank { @@ -71,11 +74,13 @@ impl Router { #[cfg(test)] mod test { - use Method::*; + use method::Method; + use method::Method::*; use super::{Router, Route}; use {Response, Request}; + use super::URI; - fn dummy_handler(_req: Request) -> Response<'static> { + fn dummy_handler(_req: &Request) -> Response<'static> { Response::empty() } @@ -132,65 +137,70 @@ mod test { assert!(!router.has_collisions()); } + fn route<'a>(router: &'a Router, method: Method, uri: &str) -> Option<&'a Route> { + let request = Request::mock(method, uri); + router.route(&request) + } + #[test] fn test_ok_routing() { let router = router_with_routes(&["/hello"]); - assert!(router.route(Get, "/hello").is_some()); + assert!(route(&router, Get, "/hello").is_some()); let router = router_with_routes(&["/"]); - assert!(router.route(Get, "/hello").is_some()); - assert!(router.route(Get, "/hi").is_some()); - assert!(router.route(Get, "/bobbbbbbbbbby").is_some()); - assert!(router.route(Get, "/dsfhjasdf").is_some()); + assert!(route(&router, Get, "/hello").is_some()); + assert!(route(&router, Get, "/hi").is_some()); + assert!(route(&router, Get, "/bobbbbbbbbbby").is_some()); + assert!(route(&router, Get, "/dsfhjasdf").is_some()); let router = router_with_routes(&["//"]); - assert!(router.route(Get, "/hello/hi").is_some()); - assert!(router.route(Get, "/a/b/").is_some()); - assert!(router.route(Get, "/i/a").is_some()); - assert!(router.route(Get, "/jdlk/asdij").is_some()); + assert!(route(&router, Get, "/hello/hi").is_some()); + assert!(route(&router, Get, "/a/b/").is_some()); + assert!(route(&router, Get, "/i/a").is_some()); + assert!(route(&router, Get, "/jdlk/asdij").is_some()); let mut router = Router::new(); router.add(Route::new(Put, "/hello".to_string(), dummy_handler)); router.add(Route::new(Post, "/hello".to_string(), dummy_handler)); router.add(Route::new(Delete, "/hello".to_string(), dummy_handler)); - assert!(router.route(Put, "/hello").is_some()); - assert!(router.route(Post, "/hello").is_some()); - assert!(router.route(Delete, "/hello").is_some()); + assert!(route(&router, Put, "/hello").is_some()); + assert!(route(&router, Post, "/hello").is_some()); + assert!(route(&router, Delete, "/hello").is_some()); } #[test] fn test_err_routing() { let router = router_with_routes(&["/hello"]); - assert!(router.route(Put, "/hello").is_none()); - assert!(router.route(Post, "/hello").is_none()); - assert!(router.route(Options, "/hello").is_none()); - assert!(router.route(Get, "/hell").is_none()); - assert!(router.route(Get, "/hi").is_none()); - assert!(router.route(Get, "/hello/there").is_none()); - assert!(router.route(Get, "/hello/i").is_none()); - assert!(router.route(Get, "/hillo").is_none()); + assert!(route(&router, Put, "/hello").is_none()); + assert!(route(&router, Post, "/hello").is_none()); + assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Get, "/hell").is_none()); + assert!(route(&router, Get, "/hi").is_none()); + assert!(route(&router, Get, "/hello/there").is_none()); + assert!(route(&router, Get, "/hello/i").is_none()); + assert!(route(&router, Get, "/hillo").is_none()); let router = router_with_routes(&["/"]); - assert!(router.route(Put, "/hello").is_none()); - assert!(router.route(Post, "/hello").is_none()); - assert!(router.route(Options, "/hello").is_none()); - assert!(router.route(Get, "/hello/there").is_none()); - assert!(router.route(Get, "/hello/i").is_none()); + assert!(route(&router, Put, "/hello").is_none()); + assert!(route(&router, Post, "/hello").is_none()); + assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Get, "/hello/there").is_none()); + assert!(route(&router, Get, "/hello/i").is_none()); let router = router_with_routes(&["//"]); - assert!(router.route(Get, "/a/b/c").is_none()); - assert!(router.route(Get, "/a").is_none()); - assert!(router.route(Get, "/a/").is_none()); - assert!(router.route(Get, "/a/b/c/d").is_none()); - assert!(router.route(Put, "/hello/hi").is_none()); - assert!(router.route(Put, "/a/b").is_none()); - assert!(router.route(Put, "/a/b").is_none()); + assert!(route(&router, Get, "/a/b/c").is_none()); + assert!(route(&router, Get, "/a").is_none()); + assert!(route(&router, Get, "/a/").is_none()); + assert!(route(&router, Get, "/a/b/c/d").is_none()); + assert!(route(&router, Put, "/hello/hi").is_none()); + assert!(route(&router, Put, "/a/b").is_none()); + assert!(route(&router, Put, "/a/b").is_none()); } macro_rules! assert_ranked_routes { ($routes:expr, $to:expr, $want:expr) => ({ let router = router_with_routes($routes); - let route_path = router.route(Get, $to).unwrap().path.as_str(); + let route_path = route(&router, Get, $to).unwrap().path.as_str(); assert_eq!(route_path as &str, $want as &str); }) } @@ -211,8 +221,8 @@ mod test { } fn match_params(router: &Router, path: &str, expected: &[&str]) -> bool { - router.route(Get, path).map_or(false, |route| { - let params = route.get_params(path); + route(router, Get, path).map_or(false, |route| { + let params = route.get_params(URI::new(path)); if params.len() != expected.len() { return false; } diff --git a/lib/src/router/route.rs b/lib/src/router/route.rs index 7c2016bb..232c0edd 100644 --- a/lib/src/router/route.rs +++ b/lib/src/router/route.rs @@ -7,6 +7,7 @@ use term_painter::Color::*; use std::fmt; use std::convert::From; +use request::Request; pub struct Route { pub method: Method, @@ -17,17 +18,6 @@ pub struct Route { } impl Route { - pub fn full(rank: isize, m: Method, path: S, handler: Handler, t: ContentType) - -> Route where S: AsRef { - Route { - method: m, - path: URIBuf::from(path.as_ref()), - handler: handler, - rank: rank, - content_type: t, - } - } - pub fn ranked(rank: isize, m: Method, path: S, handler: Handler) -> Route where S: AsRef { Route { @@ -39,8 +29,7 @@ impl Route { } } - pub fn new(m: Method, path: S, handler: Handler) - -> Route where S: AsRef { + pub fn new(m: Method, path: S, handler: Handler) -> Route where S: AsRef { Route { method: m, handler: handler, @@ -57,9 +46,9 @@ impl Route { // FIXME: Decide whether a component has to be fully variable or not. That // is, whether you can have: /ab/ or even /:/ // TODO: Don't return a Vec...take in an &mut [&'a str] (no alloc!) - pub fn get_params<'a>(&self, uri: &'a str) -> Vec<&'a str> { + pub fn get_params<'a>(&self, uri: URI<'a>) -> Vec<&'a str> { let route_components = self.path.segments(); - let uri_components = URI::new(uri).segments(); + let uri_components = uri.segments(); let mut result = Vec::with_capacity(self.path.segment_count()); for (route_seg, uri_seg) in route_components.zip(uri_components) { @@ -74,25 +63,31 @@ impl Route { impl fmt::Display for Route { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", Green.paint(&self.method), Blue.paint(&self.path)) + write!(f, "{} {} ", Green.paint(&self.method), Blue.paint(&self.path))?; + + if !self.content_type.is_any() { + write!(f, "{}", Yellow.paint(&self.content_type)) + } else { + Ok(()) + } } } impl<'a> From<&'a StaticRouteInfo> for Route { fn from(info: &'a StaticRouteInfo) -> Route { - Route::new(info.method, info.path, info.handler) + let mut route = Route::new(info.method, info.path, info.handler); + route.content_type = info.content_type.clone(); + route } } impl Collider for Route { fn collides_with(&self, b: &Route) -> bool { - if self.path.segment_count() != b.path.segment_count() - || self.method != b.method - || self.rank != b.rank { - return false; - } - - self.path.collides_with(&b.path) + self.path.segment_count() == b.path.segment_count() + && self.method == b.method + && self.rank == b.rank + && self.content_type.collides_with(&b.content_type) + && self.path.collides_with(&b.path) } } @@ -108,3 +103,11 @@ impl Collider for Route { other.collides_with(self) } } + +impl<'r> Collider> for Route { + fn collides_with(&self, req: &Request) -> bool { + self.method == req.method + && req.uri.collides_with(&self.path) + && req.content_type().collides_with(&self.content_type) + } +} diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 0511dc2c..d84f59c2 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -8,3 +8,5 @@ plugin = true [dependencies] rocket = { path = "../lib/" } +log = "*" +env_logger = "*" diff --git a/macros/src/error_decorator.rs b/macros/src/error_decorator.rs index 545da59c..b6bdb971 100644 --- a/macros/src/error_decorator.rs +++ b/macros/src/error_decorator.rs @@ -7,9 +7,6 @@ use syntax::ast::{MetaItem}; use syntax::ext::base::{Annotatable, ExtCtxt}; use syntax::print::pprust::{item_to_string}; -#[allow(dead_code)] -const DEBUG: bool = true; - #[derive(Debug)] struct Params { code: KVSpanned, @@ -58,10 +55,11 @@ pub fn error_decorator(ecx: &mut ExtCtxt, sp: Span, meta_item: &MetaItem, let catch_fn_name = prepend_ident(CATCH_FN_PREFIX, &item.ident); let catch_code = error_params.code.node; let catch_fn_item = quote_item!(ecx, - fn $catch_fn_name<'rocket>(_req: rocket::Request<'rocket>) - -> rocket::Response<'rocket> { + fn $catch_fn_name<'rocket>(err: ::rocket::Error, + req: &'rocket ::rocket::Request<'rocket>) + -> ::rocket::Response<'rocket> { // TODO: Figure out what type signature of catcher should be. - let result = $fn_name(RoutingError::unchained(_req)); + let result = $fn_name(err, req); rocket::Response::with_raw_status($catch_code, result) } ).unwrap(); diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 30567cd2..698def25 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -7,6 +7,7 @@ extern crate syntax_ext; extern crate rustc; extern crate rustc_plugin; extern crate rocket; +extern crate env_logger; #[macro_use] mod utils; mod routes_macro; diff --git a/macros/src/meta_item_parser.rs b/macros/src/meta_item_parser.rs index 56afa68b..2185743b 100644 --- a/macros/src/meta_item_parser.rs +++ b/macros/src/meta_item_parser.rs @@ -8,9 +8,6 @@ use syntax::ptr::P; use utils::*; use rocket::{Method, ContentType}; -#[allow(dead_code)] -const DEBUG: bool = true; - pub struct MetaItemParser<'a, 'c: 'a> { attr_name: &'a str, ctxt: &'a ExtCtxt<'c>, @@ -122,7 +119,7 @@ pub struct RouteParams { pub method: Spanned, pub path: KVSpanned, pub form: Option>, - pub content_type: Option>, + pub content_type: KVSpanned, } pub trait RouteDecoratorExt { @@ -219,8 +216,7 @@ impl<'a, 'c> RouteDecoratorExt for MetaItemParser<'a, 'c> { self.ctxt.span_err(data.v_span, &msg); None } - }); - + }).unwrap_or(KVSpanned::dummy(ContentType::any())); RouteParams { method: method, diff --git a/macros/src/route_decorator.rs b/macros/src/route_decorator.rs index ac400427..6466511d 100644 --- a/macros/src/route_decorator.rs +++ b/macros/src/route_decorator.rs @@ -11,7 +11,8 @@ use syntax::ptr::P; use syntax::print::pprust::{item_to_string, stmt_to_string}; use syntax::parse::token::{self, str_to_ident}; -use rocket::Method; +use rocket::{Method, ContentType}; +use rocket::content_type::{TopLevel, SubLevel}; pub fn extract_params_from_kv<'a>(parser: &MetaItemParser, params: &'a KVSpanned) -> Vec> { @@ -93,11 +94,12 @@ fn get_form_stmt(ecx: &ExtCtxt, fn_args: &mut Vec, // The actual code we'll be inserting. quote_stmt!(ecx, let $param_ident: $param_ty = - if let Ok(form_string) = ::std::str::from_utf8(_req.data) { + if let Ok(form_string) = ::std::str::from_utf8(_req.data.as_slice()) { match ::rocket::form::FromForm::from_form_string(form_string) { Ok(v) => v, Err(_) => { - debug!("\t=> Form failed to parse."); + // TODO: + // debug!("\t=> Form failed to parse."); return ::rocket::Response::not_found(); } } @@ -107,9 +109,9 @@ fn get_form_stmt(ecx: &ExtCtxt, fn_args: &mut Vec, ) } -// Is there a better way to do this? I need something with ToTokens for the -// quote_expr macro that builds the route struct. I tried using -// str_to_ident("rocket::Method::Options"), but this seems to miss the context, +// TODO: Is there a better way to do this? I need something with ToTokens for +// the quote_expr macro that builds the route struct. I tried using +// str_to_ident("::rocket::Method::Options"), but this seems to miss the context, // and you get an 'ident not found' on compile. I also tried using the path expr // builder from ASTBuilder: same thing. fn method_variant_to_expr(ecx: &ExtCtxt, method: Method) -> P { @@ -126,10 +128,61 @@ fn method_variant_to_expr(ecx: &ExtCtxt, method: Method) -> P { } } +// Same here. +fn top_level_to_expr(ecx: &ExtCtxt, level: &TopLevel) -> P { + use rocket::content_type::TopLevel::*; + match *level { + Star => quote_expr!(ecx, ::rocket::content_type::TopLevel::Star), + Text => quote_expr!(ecx, ::rocket::content_type::TopLevel::Text), + Image => quote_expr!(ecx, ::rocket::content_type::TopLevel::Image), + Audio => quote_expr!(ecx, ::rocket::content_type::TopLevel::Audio), + Video => quote_expr!(ecx, ::rocket::content_type::TopLevel::Video), + Application => quote_expr!(ecx, ::rocket::content_type::TopLevel::Application), + Multipart => quote_expr!(ecx, ::rocket::content_type::TopLevel::Multipart), + Message => quote_expr!(ecx, ::rocket::content_type::TopLevel::Message), + Model => quote_expr!(ecx, ::rocket::content_type::TopLevel::Model), + Ext(ref s) => quote_expr!(ecx, ::rocket::content_type::TopLevel::Ext($s)), + } +} + +// Same here. +fn sub_level_to_expr(ecx: &ExtCtxt, level: &SubLevel) -> P { + use rocket::content_type::SubLevel::*; + match *level { + Star => quote_expr!(ecx, ::rocket::content_type::SubLevel::Star), + Plain => quote_expr!(ecx, ::rocket::content_type::SubLevel::Plain), + Html => quote_expr!(ecx, ::rocket::content_type::SubLevel::Html), + Xml => quote_expr!(ecx, ::rocket::content_type::SubLevel::Xml), + Javascript => quote_expr!(ecx, ::rocket::content_type::SubLevel::Javascript), + Css => quote_expr!(ecx, ::rocket::content_type::SubLevel::Css), + EventStream => quote_expr!(ecx, ::rocket::content_type::SubLevel::EventStream), + Json => quote_expr!(ecx, ::rocket::content_type::SubLevel::Json), + WwwFormUrlEncoded => + quote_expr!(ecx, ::rocket::content_type::SubLevel::WwwFormUrlEncoded), + Msgpack => quote_expr!(ecx, ::rocket::content_type::SubLevel::Msgpack), + OctetStream => + quote_expr!(ecx, ::rocket::content_type::SubLevel::OctetStream), + FormData => quote_expr!(ecx, ::rocket::content_type::SubLevel::FormData), + Png => quote_expr!(ecx, ::rocket::content_type::SubLevel::Png), + Gif => quote_expr!(ecx, ::rocket::content_type::SubLevel::Gif), + Bmp => quote_expr!(ecx, ::rocket::content_type::SubLevel::Bmp), + Jpeg => quote_expr!(ecx, ::rocket::content_type::SubLevel::Jpeg), + Ext(ref s) => quote_expr!(ecx, ::rocket::content_type::SubLevel::Ext($s)), + } +} + +fn content_type_to_expr(ecx: &ExtCtxt, content_type: &ContentType) -> P { + let top_level = top_level_to_expr(ecx, &content_type.0); + let sub_level = sub_level_to_expr(ecx, &content_type.1); + quote_expr!(ecx, ::rocket::ContentType($top_level, $sub_level, None)) +} + // FIXME: Compilation fails when parameters have the same name as the function! pub fn route_decorator(known_method: Option>, ecx: &mut ExtCtxt, sp: Span, meta_item: &MetaItem, annotated: &Annotatable, push: &mut FnMut(Annotatable)) { + ::rocket::logger::init(::rocket::logger::Level::Debug); + // Get the encompassing item and function declaration for the annotated func. let parser = MetaItemParser::new(ecx, meta_item, annotated, &sp); let (item, fn_decl) = (parser.expect_item(), parser.expect_fn_decl()); @@ -186,14 +239,14 @@ pub fn route_decorator(known_method: Option>, ecx: &mut ExtCtxt, ).unwrap() }; - debug!("Param FN: {:?}", stmt_to_string(¶m_fn_item)); + debug!("Param FN: {}", stmt_to_string(¶m_fn_item)); fn_param_exprs.push(param_fn_item); } let route_fn_name = prepend_ident(ROUTE_FN_PREFIX, &item.ident); let fn_name = item.ident; let route_fn_item = quote_item!(ecx, - fn $route_fn_name<'rocket>(_req: ::rocket::Request<'rocket>) + fn $route_fn_name<'rocket>(_req: &'rocket ::rocket::Request<'rocket>) -> ::rocket::Response<'rocket> { $form_stmt $fn_param_exprs @@ -208,19 +261,21 @@ pub fn route_decorator(known_method: Option>, ecx: &mut ExtCtxt, let struct_name = prepend_ident(ROUTE_STRUCT_PREFIX, &item.ident); let path = &route.path.node; let method = method_variant_to_expr(ecx, route.method.node); - push(Annotatable::Item(quote_item!(ecx, + let content_type = content_type_to_expr(ecx, &route.content_type.node); + + let static_item = quote_item!(ecx, #[allow(non_upper_case_globals)] pub static $struct_name: ::rocket::StaticRouteInfo = ::rocket::StaticRouteInfo { method: $method, path: $path, handler: $route_fn_name, - content_type: ::rocket::ContentType( - ::rocket::content_type::TopLevel::Star, - ::rocket::content_type::SubLevel::Star, - None) + content_type: $content_type, }; - ).unwrap())); + ).unwrap(); + + debug!("Emitting static: {}", item_to_string(&static_item)); + push(Annotatable::Item(static_item)); }