From 95a8a51b76e66f858c43d2757c1fec90125702cc Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Mon, 8 Aug 2016 03:10:23 -0700 Subject: [PATCH] Added FromRequest and modified macro to use it: any parameters not declared by the user in the attributes will automatically be retrieved using FromRequest. --- examples/cookies/README.md | 16 ------- examples/cookies/src/main.rs | 7 ++- examples/cookies/src/static_files.rs | 19 -------- lib/src/request/from_request.rs | 46 +++++++++++++++++++ lib/src/request/mod.rs | 6 +++ lib/src/{ => request}/request.rs | 2 - macros/src/route_decorator.rs | 67 ++++++++++++++++++---------- macros/src/utils.rs | 30 +++++++++++++ 8 files changed, 129 insertions(+), 64 deletions(-) delete mode 100644 examples/cookies/README.md delete mode 100644 examples/cookies/src/static_files.rs create mode 100644 lib/src/request/from_request.rs create mode 100644 lib/src/request/mod.rs rename lib/src/{ => request}/request.rs (94%) diff --git a/examples/cookies/README.md b/examples/cookies/README.md deleted file mode 100644 index f284443b..00000000 --- a/examples/cookies/README.md +++ /dev/null @@ -1,16 +0,0 @@ -Rocket Todo Example -=================== - -Before running this example, you'll need to ensure there's a database file -present. You can do this with Diesel. - -Running migration with Diesel ------------------------------ - -Just run the following commands in your shell: - -``` -cargo install diesel_cli # installs the diesel CLI tools -DATABASE_URL=db/db.sql diesel migration run # create db/db.sql -``` - diff --git a/examples/cookies/src/main.rs b/examples/cookies/src/main.rs index 93e30f93..9b4d5e9e 100644 --- a/examples/cookies/src/main.rs +++ b/examples/cookies/src/main.rs @@ -6,10 +6,9 @@ extern crate lazy_static; extern crate rocket; extern crate tera; -mod static_files; - use rocket::Rocket; use rocket::response::{Cookied, Redirect}; +use rocket::Method; lazy_static!(static ref TERA: tera::Tera = tera::Tera::new("templates/**/*");); @@ -31,13 +30,13 @@ fn submit(message: Message) -> Cookied { } #[route(GET, path = "/")] -fn index() -> tera::TeraResult { +fn index(method: Method) -> tera::TeraResult { + println!("Method is: {}", method); TERA.render("index.html", ctxt(None)) } fn main() { let mut rocket = Rocket::new("127.0.0.1", 8000); - rocket.mount("/", static_files::routes()); rocket.mount("/", routes![submit, index]); rocket.launch(); } diff --git a/examples/cookies/src/static_files.rs b/examples/cookies/src/static_files.rs deleted file mode 100644 index 3333147e..00000000 --- a/examples/cookies/src/static_files.rs +++ /dev/null @@ -1,19 +0,0 @@ -use std::fs::File; -use std::io; -use rocket; - -#[route(GET, path = "//")] -fn all_level_one(top: &str, file: &str) -> io::Result { - let file = format!("static/{}/{}", top, file); - File::open(file) -} - -#[route(GET, path = "/")] -fn all(file: &str) -> io::Result { - let file = format!("static/{}", file); - File::open(file) -} - -pub fn routes() -> Vec { - routes![all_level_one, all] -} diff --git a/lib/src/request/from_request.rs b/lib/src/request/from_request.rs new file mode 100644 index 00000000..42e523f1 --- /dev/null +++ b/lib/src/request/from_request.rs @@ -0,0 +1,46 @@ +use request::*; +use method::Method; +use std::fmt::Debug; + +pub trait FromRequest<'r, 'c>: Sized { + type Error: Debug; + + fn from_request(request: &'r Request<'c>) -> Result; +} + +impl<'r, 'c> FromRequest<'r, 'c> for &'r Request<'c> { + type Error = (); + + fn from_request(request: &'r Request<'c>) -> Result { + Ok(request) + } +} + +impl<'r, 'c> FromRequest<'r, 'c> for Method { + type Error = &'static str; + + fn from_request(request: &'r Request<'c>) -> Result { + Ok(request.method) + } +} + +impl<'r, 'c, T: FromRequest<'r, 'c>> FromRequest<'r, 'c> for Option { + type Error = (); + + fn from_request(request: &'r Request<'c>) -> Result { + let opt = match T::from_request(request) { + Ok(v) => Some(v), + Err(_) => None + }; + + Ok(opt) + } +} + +impl<'r, 'c, T: FromRequest<'r, 'c>> FromRequest<'r, 'c> for Result { + type Error = (); + + fn from_request(request: &'r Request<'c>) -> Result { + Ok(T::from_request(request)) + } +} diff --git a/lib/src/request/mod.rs b/lib/src/request/mod.rs new file mode 100644 index 00000000..6477e7eb --- /dev/null +++ b/lib/src/request/mod.rs @@ -0,0 +1,6 @@ +mod request; +mod from_request; + +pub use hyper::server::Request as HyperRequest; +pub use self::request::Request; +pub use self::from_request::FromRequest; diff --git a/lib/src/request.rs b/lib/src/request/request.rs similarity index 94% rename from lib/src/request.rs rename to lib/src/request/request.rs index 17d0ed7e..de40a23a 100644 --- a/lib/src/request.rs +++ b/lib/src/request/request.rs @@ -2,8 +2,6 @@ use error::Error; use param::FromParam; use method::Method; -pub use hyper::server::Request as HyperRequest; - #[derive(Clone, Debug)] pub struct Request<'a> { params: Option>, diff --git a/macros/src/route_decorator.rs b/macros/src/route_decorator.rs index fc2ba4f3..5aa208cc 100644 --- a/macros/src/route_decorator.rs +++ b/macros/src/route_decorator.rs @@ -163,8 +163,12 @@ pub fn extract_params_from_kv<'a>(ecx: &ExtCtxt, params: &'a KVSpanned) }) } +// Analyzes the declared parameters against the function declaration. Returns +// two vectors. The first is the set of parameters declared by the user, and +// the second is the set of parameters not declared by the user. fn get_fn_params<'a, T: Iterator>>(ecx: &ExtCtxt, - declared_params: T, fn_decl: &Spanned<&FnDecl>) -> Vec { + declared_params: T, fn_decl: &Spanned<&FnDecl>) + -> Vec { debug!("FUNCTION: {:?}", fn_decl); // First, check that all of the parameters are unique. @@ -179,20 +183,20 @@ fn get_fn_params<'a, T: Iterator>>(ecx: &ExtCtxt, } } + let mut user_params = vec![]; + // Ensure every param in the function declaration was declared by the user. - let mut result = vec![]; for arg in &fn_decl.node.inputs { let name = arg.pat.expect_ident(ecx, "Expected identifier."); - if seen.remove(&*name.to_string()).is_none() { - let msg = format!("'{}' appears in the function declaration \ - but does not appear as a parameter in the attribute.", name); - ecx.span_err(arg.pat.span, msg.as_str()); + let arg = SimpleArg::new(name, arg.ty.clone(), arg.pat.span); + if seen.remove(&*name.to_string()).is_some() { + user_params.push(UserParam::new(arg, true)); + } else { + user_params.push(UserParam::new(arg, false)); } - - result.push(SimpleArg::new(name, arg.ty.clone(), arg.pat.span)); } - // Ensure every declared parameter is in the function declaration. + // Emit an error on every attribute param that didn't match in fn params. for item in seen.values() { let msg = format!("'{}' was declared in the attribute...", item.node); ecx.span_err(item.span, msg.as_str()); @@ -200,10 +204,10 @@ fn get_fn_params<'a, T: Iterator>>(ecx: &ExtCtxt, declaration."); } - result + user_params } -fn get_form_stmt(ecx: &ExtCtxt, fn_args: &mut Vec, +fn get_form_stmt(ecx: &ExtCtxt, fn_args: &mut Vec, form_params: &[Spanned<&str>]) -> Option { if form_params.len() < 1 { return None; @@ -281,35 +285,52 @@ pub fn route_decorator(ecx: &mut ExtCtxt, sp: Span, meta_item: &MetaItem, // Ensure the params match the function declaration and return the params. let all_params = path_params.iter().chain(form_params.iter()); - let mut fn_params = get_fn_params(ecx, all_params, &fn_decl); + let mut user_params = get_fn_params(ecx, all_params, &fn_decl); // Create a comma seperated list (token tree) of the function parameters // We pass this in to the user's function that we're wrapping. - let fn_param_idents = token_separate(ecx, &fn_params, token::Comma); + let fn_param_idents = token_separate(ecx, &user_params, token::Comma); // Generate the statements that will attempt to parse forms during run-time. // Calling this function also remove the form parameter from fn_params. - let form_stmt = get_form_stmt(ecx, &mut fn_params, &form_params); + let form_stmt = get_form_stmt(ecx, &mut user_params, &form_params); form_stmt.as_ref().map(|s| debug!("Form stmt: {:?}", stmt_to_string(s))); // Generate the statements that will attempt to parse the paramaters during // run-time. let mut fn_param_exprs = vec![]; - for (i, param) in fn_params.iter().enumerate() { - let param_ident = str_to_ident(param.as_str()); - let param_ty = ¶m.ty; - let param_fn_item = quote_stmt!(ecx, - let $param_ident: $param_ty = match _req.get_param($i) { - Ok(v) => v, - Err(_) => return ::rocket::Response::forward() + for (i, param) in user_params.iter().enumerate() { + let ident = str_to_ident(param.as_str()); + let ty = ¶m.ty; + let param_fn_item = + if param.declared { + quote_stmt!(ecx, + let $ident: $ty = match _req.get_param($i) { + Ok(v) => v, + Err(_) => return ::rocket::Response::forward() + }; + ).unwrap() + } else { + quote_stmt!(ecx, + let $ident: $ty = match + <$ty as ::rocket::request::FromRequest>::from_request(&_req) { + Ok(v) => v, + Err(e) => { + // TODO: Add $ident and $ty to the string. + // TODO: Add some kind of loggin facility in Rocket + // to get the formatting right (IE, so it idents + // correctly). + println!("Failed to parse: {:?}", e); + return ::rocket::Response::forward(); + } + }; + ).unwrap() }; - ).unwrap(); debug!("Param FN: {:?}", stmt_to_string(¶m_fn_item)); fn_param_exprs.push(param_fn_item); } - debug!("Final Params: {:?}", fn_params); let route_fn_name = prepend_ident(ROUTE_FN_PREFIX, &item.ident); let fn_name = item.ident; let route_fn_item = quote_item!(ecx, diff --git a/macros/src/utils.rs b/macros/src/utils.rs index 55c1d953..75670fea 100644 --- a/macros/src/utils.rs +++ b/macros/src/utils.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use syntax::parse::{token}; use syntax::parse::token::Token; use syntax::tokenstream::TokenTree; @@ -240,3 +242,31 @@ impl ToTokens for SimpleArg { } } +pub struct UserParam { + pub arg: SimpleArg, + pub declared: bool +} + +impl UserParam { + pub fn new(arg: SimpleArg, declared: bool) -> UserParam { + UserParam { + arg: arg, + declared: declared + } + } +} + +impl Deref for UserParam { + type Target = SimpleArg; + + fn deref(&self) -> &SimpleArg { + &self.arg + } +} + +impl ToTokens for UserParam { + fn to_tokens(&self, cx: &ExtCtxt) -> Vec { + self.arg.to_tokens(cx) + } +} +