diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index 6ea98df3..ab9f6f88 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -4,7 +4,8 @@ extern crate rocket; extern crate serde_json; -use rocket::{Rocket, Request, ContentType, Error}; +use rocket::{Rocket, Request, Error}; +use rocket::response::JSON; #[derive(Debug, Serialize, Deserialize)] struct Person { @@ -13,24 +14,24 @@ struct Person { } #[GET(path = "//", content = "application/json")] -fn hello(name: String, age: i8) -> String { +fn hello(name: String, age: i8) -> JSON { let person = Person { name: name, age: age, }; - serde_json::to_string(&person).unwrap() + JSON(serde_json::to_string(&person).unwrap()) } #[error(code = "404")] fn not_found<'r>(error: Error, request: &'r Request<'r>) -> String { match error { - Error::NoRoute if !request.content_type().is_json() => { + Error::BadMethod if !request.content_type().is_json() => { format!("

This server only supports JSON requests, not '{}'.

", request.content_type()) } Error::NoRoute => { - format!("

Sorry, this server but '{}' is not a valid path!

+ format!("

Sorry, '{}' is not a valid path!

Try visiting /hello/<name>/<age> instead.

", request.uri()) } @@ -40,7 +41,6 @@ fn not_found<'r>(error: Error, request: &'r Request<'r>) -> String { fn main() { let mut rocket = Rocket::new("0.0.0.0", 8000); - rocket.mount("/hello", routes![hello]); - rocket.catch(errors![not_found]); + rocket.mount("/hello", routes![hello]).catch(errors![not_found]); rocket.launch(); } diff --git a/lib/src/content_type.rs b/lib/src/content_type.rs index dac1b238..140107ea 100644 --- a/lib/src/content_type.rs +++ b/lib/src/content_type.rs @@ -1,9 +1,9 @@ -pub use hyper::mime::{Mime, TopLevel, SubLevel}; +pub use response::mime::{Mime, TopLevel, SubLevel}; +use response::mime::{Param}; use std::str::FromStr; use std::borrow::Borrow; use std::fmt; -use hyper::mime::{Param}; use self::TopLevel::{Text, Application}; use self::SubLevel::{Json, Html}; diff --git a/lib/src/error.rs b/lib/src/error.rs index 69693e27..760fd713 100644 --- a/lib/src/error.rs +++ b/lib/src/error.rs @@ -3,5 +3,6 @@ pub enum Error { BadMethod, BadParse, NoRoute, // FIXME: Add a chain of routes attempted. + Internal, NoKey } diff --git a/lib/src/form.rs b/lib/src/form.rs index d2000dd9..fbb43c0a 100644 --- a/lib/src/form.rs +++ b/lib/src/form.rs @@ -1,6 +1,6 @@ use std::str::FromStr; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; -use url::{self}; +use url; use error::Error; diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 6e0788e7..238cd1f7 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -29,7 +29,7 @@ pub mod handler { pub type ErrorHandler = for<'r> fn(error: Error, &'r Request<'r>) -> Response<'r>; } -pub use logger::RocketLogger; +pub use logger::{RocketLogger, LoggingLevel}; pub use content_type::ContentType; pub use codegen::{StaticRouteInfo, StaticCatchInfo}; pub use request::Request; diff --git a/lib/src/logger.rs b/lib/src/logger.rs index 6fe83936..0b981e30 100644 --- a/lib/src/logger.rs +++ b/lib/src/logger.rs @@ -2,9 +2,10 @@ use log::{self, Log, LogLevel, LogRecord, LogMetadata}; use term_painter::Color::*; use term_painter::ToStyle; -pub struct RocketLogger(Level); +pub struct RocketLogger(LoggingLevel); -pub enum Level { +#[derive(PartialEq)] +pub enum LoggingLevel { /// Only shows errors and warning. Critical, /// Shows everything except debug and trace information. @@ -13,13 +14,13 @@ pub enum Level { Debug, } -impl Level { +impl LoggingLevel { #[inline(always)] fn max_log_level(&self) -> LogLevel { match *self { - Level::Critical => LogLevel::Warn, - Level::Normal => LogLevel::Info, - Level::Debug => LogLevel::Trace + LoggingLevel::Critical => LogLevel::Warn, + LoggingLevel::Normal => LogLevel::Info, + LoggingLevel::Debug => LogLevel::Trace } } } @@ -55,7 +56,7 @@ impl Log for RocketLogger { } // In Rocket, we abuse target with value "_" to indicate indentation. - if record.target() == "_" { + if record.target() == "_" && self.0 != LoggingLevel::Critical { print!(" {} ", White.paint("=>")); } @@ -83,7 +84,7 @@ impl Log for RocketLogger { } } -pub fn init(level: Level) { +pub fn init(level: LoggingLevel) { let result = log::set_logger(|max_log_level| { max_log_level.set(level.max_log_level().to_log_level_filter()); Box::new(RocketLogger(level)) diff --git a/lib/src/param.rs b/lib/src/param.rs index 772ec829..f0b01a33 100644 --- a/lib/src/param.rs +++ b/lib/src/param.rs @@ -1,5 +1,6 @@ use std::str::FromStr; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}; +use url; use error::Error; @@ -13,6 +14,13 @@ impl<'a> FromParam<'a> for &'a str { } } +impl<'a> FromParam<'a> for String { + fn from_param(p: &'a str) -> Result { + let decoder = url::percent_encoding::percent_decode(p.as_bytes()); + decoder.decode_utf8().map_err(|_| Error::BadParse).map(|s| s.into_owned()) + } +} + macro_rules! impl_with_fromstr { ($($T:ident),+) => ($( impl<'a> FromParam<'a> for $T { @@ -24,5 +32,5 @@ macro_rules! impl_with_fromstr { } impl_with_fromstr!(f32, f64, isize, i8, i16, i32, i64, usize, u8, u16, u32, u64, - bool, String, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, + bool, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr); diff --git a/lib/src/request/request.rs b/lib/src/request/request.rs index 83f845e4..e9ed1cb8 100644 --- a/lib/src/request/request.rs +++ b/lib/src/request/request.rs @@ -1,5 +1,9 @@ use std::io::{Read}; use std::cell::RefCell; +use std::fmt; + +use term_painter::Color::*; +use term_painter::ToStyle; use error::Error; use param::FromParam; @@ -16,10 +20,10 @@ use router::Route; use request::{HyperHeaders, HyperRequest}; pub struct Request<'a> { - pub params: RefCell>>, // This also sucks. pub method: Method, pub uri: URIBuf, // FIXME: Should be URI (without Hyper). pub data: Vec, // FIXME: Don't read this! (bad Hyper.) + params: RefCell>>, // This also sucks. headers: HyperHeaders, // This sucks. } @@ -33,7 +37,6 @@ impl<'a> Request<'a> { } } - #[cfg(test)] pub fn mock(method: Method, uri: &str) -> Request { Request { params: RefCell::new(None), @@ -44,7 +47,6 @@ impl<'a> Request<'a> { } } - // FIXME: Get rid of Hyper. #[inline(always)] pub fn headers(&self) -> &HyperHeaders { @@ -70,27 +72,38 @@ impl<'a> Request<'a> { self.headers.set::(hyper_ct) } -} - -impl<'a, 'h, 'k> From> for Request<'a> { - fn from(hyper_req: HyperRequest<'h, 'k>) -> Request<'a> { + pub fn from_hyp<'h, 'k>(hyper_req: HyperRequest<'h, 'k>) + -> Result, String> { let (_, h_method, h_headers, h_uri, _, mut h_body) = hyper_req.deconstruct(); let uri = match h_uri { HyperRequestUri::AbsolutePath(s) => URIBuf::from(s), - _ => panic!("Can only accept absolute paths!") + _ => return Err(format!("Bad URI: {}", h_uri)) + }; + + let method = match Method::from_hyp(&h_method) { + Some(m) => m, + _ => return Err(format!("Bad method: {}", h_method)) }; // FIXME: GRRR. let mut data = vec![]; h_body.read_to_end(&mut data).unwrap(); - Request { + let request = Request { params: RefCell::new(None), - method: Method::from_hyp(&h_method).unwrap(), + method: method, uri: uri, data: data, headers: h_headers, - } + }; + + Ok(request) + } +} + +impl<'r> fmt::Display for Request<'r> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} {}", Green.paint(&self.method), Blue.paint(&self.uri)) } } diff --git a/lib/src/response/data_type.rs b/lib/src/response/data_type.rs new file mode 100644 index 00000000..c96ecb2c --- /dev/null +++ b/lib/src/response/data_type.rs @@ -0,0 +1,19 @@ +use response::{header, Responder, FreshHyperResponse, Outcome}; +use response::mime::{Mime, TopLevel, SubLevel}; + +macro_rules! impl_data_type_responder { + ($name:ident: $top:ident/$sub:ident) => ( + pub struct $name(pub T); + + impl Responder for $name { + fn respond<'b>(&mut self, mut res: FreshHyperResponse<'b>) -> Outcome<'b> { + let mime = Mime(TopLevel::$top, SubLevel::$sub, vec![]); + res.headers_mut().set(header::ContentType(mime)); + self.0.respond(res) + } + }) +} + +impl_data_type_responder!(JSON: Application/Json); +impl_data_type_responder!(HTML: Text/Html); +impl_data_type_responder!(Plain: Text/Plain); diff --git a/lib/src/response/mod.rs b/lib/src/response/mod.rs index a83b0dd9..268266c5 100644 --- a/lib/src/response/mod.rs +++ b/lib/src/response/mod.rs @@ -4,6 +4,7 @@ mod redirect; mod with_status; mod outcome; mod cookied; +mod data_type; pub use hyper::server::Response as HyperResponse; pub use hyper::net::Fresh as HyperFresh; @@ -11,6 +12,7 @@ pub use hyper::status::StatusCode; pub use hyper::header; pub use hyper::mime; +pub use self::data_type::*; pub use self::responder::Responder; pub use self::empty::{Empty, Forward}; pub use self::redirect::Redirect; diff --git a/lib/src/response/responder.rs b/lib/src/response/responder.rs index d05fc4b6..5f0b2f94 100644 --- a/lib/src/response/responder.rs +++ b/lib/src/response/responder.rs @@ -14,8 +14,11 @@ pub trait Responder { impl<'a> Responder for &'a str { fn respond<'b>(&mut self, mut res: FreshHyperResponse<'b>) -> Outcome<'b> { - let mime = Mime(TopLevel::Text, SubLevel::Html, vec![]); - res.headers_mut().set(header::ContentType(mime)); + if res.headers().get::().is_none() { + let mime = Mime(TopLevel::Text, SubLevel::Plain, vec![]); + res.headers_mut().set(header::ContentType(mime)); + } + res.send(self.as_bytes()).unwrap(); Outcome::Complete } @@ -23,8 +26,10 @@ impl<'a> Responder for &'a str { impl Responder for String { fn respond<'a>(&mut self, mut res: FreshHyperResponse<'a>) -> Outcome<'a> { - let mime = Mime(TopLevel::Text, SubLevel::Html, vec![]); - res.headers_mut().set(header::ContentType(mime)); + if res.headers().get::().is_none() { + let mime = Mime(TopLevel::Text, SubLevel::Html, vec![]); + res.headers_mut().set(header::ContentType(mime)); + } res.send(self.as_bytes()).unwrap(); Outcome::Complete } diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index 7765c809..a26a0de1 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -3,14 +3,11 @@ use response::FreshHyperResponse; use request::HyperRequest; use catcher; -use std::io::Read; use std::collections::HashMap; use term_painter::Color::*; use term_painter::ToStyle; -use hyper::uri::RequestUri as HyperRequestUri; -use hyper::method::Method as HyperMethod; use hyper::server::Server as HyperServer; use hyper::server::Handler as HyperHandler; @@ -19,91 +16,80 @@ pub struct Rocket { port: isize, router: Router, catchers: HashMap, -} - -fn uri_is_absolute(uri: &HyperRequestUri) -> bool { - match *uri { - HyperRequestUri::AbsolutePath(_) => true, - _ => false - } -} - -fn method_is_valid(method: &HyperMethod) -> bool { - Method::from_hyp(method).is_some() + log_set: bool, } impl HyperHandler for Rocket { fn handle<'h, 'k>(&self, req: HyperRequest<'h, 'k>, mut res: FreshHyperResponse<'h>) { - info!("{:?} '{}':", Green.paint(&req.method), Blue.paint(&req.uri)); - - let finalize = |mut req: HyperRequest, _res: FreshHyperResponse| { - let mut buf = vec![]; - // FIXME: Simple DOS attack here. Working around Hyper bug. - let _ = req.read_to_end(&mut buf); - }; - - if !uri_is_absolute(&req.uri) { - error_!("Internal failure. Bad URI."); - debug_!("Debug: {}", req.uri); - return finalize(req, res); - } - - if !method_is_valid(&req.method) { - error_!("Internal failure. Bad method."); - debug_!("Method: {}", req.method); - return finalize(req, res); - } - res.headers_mut().set(response::header::Server("rocket".to_string())); self.dispatch(req, res) } } impl Rocket { - fn dispatch<'h, 'k>(&self, hyper_req: HyperRequest<'h, 'k>, + fn dispatch<'h, 'k>(&self, hyp_req: HyperRequest<'h, 'k>, res: FreshHyperResponse<'h>) { - let req = Request::from(hyper_req); - let route = self.router.route(&req); - if let Some(route) = route { + // Get a copy of the URI for later use. + let uri = hyp_req.uri.to_string(); + + // Try to create a Rocket request from the hyper request. + let request = match Request::from_hyp(hyp_req) { + Ok(req) => req, + Err(reason) => { + let mock_request = Request::mock(Method::Get, uri.as_str()); + return self.handle_internal_error(reason, &mock_request, res); + } + }; + + info!("{}:", request); + let route = self.router.route(&request); + if let Some(ref route) = route { // Retrieve and set the requests parameters. - req.set_params(&route); + request.set_params(route); // Here's the magic: dispatch the request to the handler. - let outcome = (route.handler)(&req).respond(res); + let outcome = (route.handler)(&request).respond(res); info_!("{} {}", White.paint("Outcome:"), outcome); - // // TODO: keep trying lower ranked routes before dispatching a not - // // found error. - // outcome.map_forward(|res| { - // error_!("No further matching routes."); - // // TODO: Have some way to know why this was failed forward. Use that - // // instead of always using an unchained error. - // self.handle_not_found(req, res); - // }); + // TODO: keep trying lower ranked routes before dispatching a not + // found error. + outcome.map_forward(|res| { + error_!("No further matching routes."); + // TODO: Have some way to know why this was failed forward. Use that + // instead of always using an unchained error. + self.handle_not_found(&request, res); + }); } else { error_!("No matching routes."); - return self.handle_not_found(&req, res); + self.handle_not_found(&request, res); } } - // A closure which we call when we know there is no route. + // Call on internal server error. + fn handle_internal_error<'r>(&self, reason: String, request: &'r Request<'r>, + response: FreshHyperResponse) { + error!("Internal server error."); + debug!("{}", reason); + let catcher = self.catchers.get(&500).unwrap(); + catcher.handle(Error::Internal, request).respond(response); + } + + // Call when no route was found. fn handle_not_found<'r>(&self, request: &'r Request<'r>, response: FreshHyperResponse) { - error_!("Dispatch failed. Returning 404."); + error_!("{} dispatch failed: 404.", request); let catcher = self.catchers.get(&404).unwrap(); catcher.handle(Error::NoRoute, request).respond(response); } pub fn new(address: &'static str, port: isize) -> Rocket { - // FIXME: Allow user to override level/disable logging. - logger::init(logger::Level::Normal); - Rocket { address: address, port: port, router: Router::new(), catchers: catcher::defaults::get(), + log_set: false, } } @@ -138,11 +124,24 @@ impl Rocket { self } - pub fn launch(self) { + pub fn log(&mut self, level: LoggingLevel) { + if self.log_set { + warn!("Log level already set! Not overriding."); + } else { + logger::init(level); + self.log_set = true; + } + } + + pub fn launch(mut self) { if self.router.has_collisions() { warn!("Route collisions detected!"); } + if !self.log_set { + self.log(LoggingLevel::Normal) + } + let full_addr = format!("{}:{}", self.address, self.port); info!("🚀 {} {}...", White.paint("Rocket has launched from"), White.bold().paint(&full_addr)); diff --git a/macros/src/route_decorator.rs b/macros/src/route_decorator.rs index 6466511d..07c4252e 100644 --- a/macros/src/route_decorator.rs +++ b/macros/src/route_decorator.rs @@ -181,7 +181,7 @@ fn content_type_to_expr(ecx: &ExtCtxt, content_type: &ContentType) -> P { pub fn route_decorator(known_method: Option>, ecx: &mut ExtCtxt, sp: Span, meta_item: &MetaItem, annotated: &Annotatable, push: &mut FnMut(Annotatable)) { - ::rocket::logger::init(::rocket::logger::Level::Debug); + ::rocket::logger::init(::rocket::LoggingLevel::Debug); // Get the encompassing item and function declaration for the annotated func. let parser = MetaItemParser::new(ecx, meta_item, annotated, &sp);