From 9d7ad109cd8a8f621cdda8d9ff7e0e652b7a02ab Mon Sep 17 00:00:00 2001 From: Artem Biryukov Date: Thu, 19 Jan 2017 04:15:30 +0300 Subject: [PATCH] Very simple CORS implementation --- Cargo.toml | 1 + contrib/src/cors.rs | 140 +++++++++++++++++++++++++++++++++++++ contrib/src/lib.rs | 4 ++ examples/cors/Cargo.toml | 16 +++++ examples/cors/src/main.rs | 22 ++++++ examples/cors/src/tests.rs | 16 +++++ 6 files changed, 199 insertions(+) create mode 100644 contrib/src/cors.rs create mode 100644 examples/cors/Cargo.toml create mode 100644 examples/cors/src/main.rs create mode 100644 examples/cors/src/tests.rs diff --git a/Cargo.toml b/Cargo.toml index a215d08c..3ff3471c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "codegen/", "contrib/", "examples/cookies", + "examples/cors", "examples/errors", "examples/extended_validation", "examples/forms", diff --git a/contrib/src/cors.rs b/contrib/src/cors.rs new file mode 100644 index 00000000..55a3280c --- /dev/null +++ b/contrib/src/cors.rs @@ -0,0 +1,140 @@ +use std::collections::HashSet; +use rocket::response::{self, Response, Responder}; +use rocket::http::Method; + +/// The CORS type, which implements `Responder`. This type allows +/// you to request resources from another domain. +/// +/// # Usage +/// +/// You can use the `CORS` type for you routes as a preflight like that: +/// +/// ```rust,ignore +/// #[route(OPTIONS, "/user")] +/// fn cors_preflight() -> PreflightCORS { +/// CORS::preflight("http://somehost.com") +/// .methods(vec![Method::Options, Methods::Get]) +/// .headers(vec!["Content-Type"]) +/// } +/// ``` +/// +/// And then you can just simply do something like this: +/// +/// ```rust,ignore +/// #[get("/user")] +/// fn user() -> CORS { +/// CORS::any("Hello I'm User!".to_string()) +/// } +/// ``` +pub struct CORS { + responder: R, + allow_origin: &'static str, + allow_credentials: bool, + expose_headers: HashSet<&'static str>, + max_age: Option, + allow_methods: HashSet, + allow_headers: HashSet<&'static str> +} + +pub type PreflightCORS = CORS<()>; + +impl PreflightCORS { + /// Consumes origin for which it will allow to use `CORS` + /// and return a basic origin `CORS` + pub fn preflight(origin: &'static str) -> PreflightCORS { + CORS::origin((), origin) + } +} + +impl<'r, R: Responder<'r>> CORS { + /// Consumes responder and returns CORS with any origin + pub fn any(responder: R) -> CORS { + CORS::origin(responder, "*") + } + + /// Consumes the responder and origin and returns basic CORS + pub fn origin(responder: R, origin: &'static str) -> CORS { + CORS { + responder: responder, + allow_origin: origin, + allow_credentials: false, + expose_headers: HashSet::new(), + max_age: None, + allow_methods: HashSet::new(), + allow_headers: HashSet::new() + } + } + + /// Consumes the CORS, set allow_credentials to + /// new value and returns changed CORS + pub fn credentials(mut self, value: bool) -> CORS { + self.allow_credentials = value; + self + } + + /// Consumes the CORS, set expose_headers to + /// passed headers and returns changed CORS + pub fn exposed_headers(mut self, headers: &[&'static str]) -> CORS { + self.expose_headers = headers.into_iter().cloned().collect(); + self + } + + /// Consumes the CORS, set max_age to + /// passed value and returns changed CORS + pub fn max_age(mut self, value: Option) -> CORS { + self.max_age = value; + self + } + + /// Consumes the CORS, set allow_methods to + /// passed methods and returns changed CORS + pub fn methods(mut self, methods: &[Method]) -> CORS { + self.allow_methods = methods.into_iter().cloned().collect(); + self + } + + /// Consumes the CORS, set allow_headers to + /// passed headers and returns changed CORS + pub fn headers(mut self, headers: &[&'static str]) -> CORS { + self.allow_headers = headers.into_iter().cloned().collect(); + self + } +} + +impl <'r, R: Responder<'r>> Responder<'r> for CORS { + fn respond(self) -> response::Result<'r> { + let mut response = Response::build_from(self.responder.respond()?) + .raw_header("Access-Control-Allow-Origin", self.allow_origin) + .finalize(); + + if self.allow_credentials { + response.set_raw_header("Access-Control-Allow-Credentials", "true"); + } else { + response.set_raw_header("Access-Control-Allow-Credentials", "false"); + } + + if !self.expose_headers.is_empty() { + let headers: Vec<_> = self.expose_headers.into_iter().collect(); + let headers = headers.join(", "); + + response.set_raw_header("Access-Control-Expose-Headers", headers); + } + + if !self.allow_methods.is_empty() { + let methods: Vec<_> = self.allow_methods + .into_iter() + .map(|m| m.as_str()) + .collect(); + let methods = methods.join(", "); + + response.set_raw_header("Access-Control-Allow-Methods", methods); + } + + if self.max_age.is_some() { + let max_age = self.max_age.unwrap(); + response.set_raw_header("Access-Control-Max-Age", max_age.to_string()); + } + + Ok(response) + } +} diff --git a/contrib/src/lib.rs b/contrib/src/lib.rs index 34b518b9..851fe878 100644 --- a/contrib/src/lib.rs +++ b/contrib/src/lib.rs @@ -40,6 +40,8 @@ #[cfg(feature = "lazy_static_macro")] extern crate lazy_static; +mod cors; + #[cfg_attr(feature = "json", macro_use)] #[cfg(feature = "json")] mod json; @@ -50,6 +52,8 @@ mod templates; #[cfg(feature = "uuid")] mod uuid; +pub use cors::{PreflightCORS, CORS}; + #[cfg(feature = "json")] pub use json::JSON; diff --git a/examples/cors/Cargo.toml b/examples/cors/Cargo.toml new file mode 100644 index 00000000..bdfae231 --- /dev/null +++ b/examples/cors/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "cors" +version = "0.0.1" +authors = ["Artem Biryukov "] +workspace = "../../" + +[dependencies] +rocket = { path = "../../lib" } +rocket_codegen = { path = "../../codegen" } + +[dependencies.rocket_contrib] +default-features = true +path = "../../contrib" + +[dev-dependencies] +rocket = { path = "../../lib", features = ["testing"] } \ No newline at end of file diff --git a/examples/cors/src/main.rs b/examples/cors/src/main.rs new file mode 100644 index 00000000..f4c3f25b --- /dev/null +++ b/examples/cors/src/main.rs @@ -0,0 +1,22 @@ +#![feature(plugin)] +#![plugin(rocket_codegen)] + +extern crate rocket; +extern crate rocket_contrib; + +use rocket::http::Method; +use rocket_contrib::{PreflightCORS, CORS}; + +#[cfg(test)] +mod tests; + +#[get("/hello")] +fn hello() -> CORS { + CORS::any("Hello there!".to_string()) +} + +fn main() { + rocket::ignite() + .mount("/", routes![cors_preflight, hello]) + .launch(); +} diff --git a/examples/cors/src/tests.rs b/examples/cors/src/tests.rs new file mode 100644 index 00000000..95d72bce --- /dev/null +++ b/examples/cors/src/tests.rs @@ -0,0 +1,16 @@ +use super::rocket; +use rocket::testing::MockRequest; +use rocket::http::Header; +use rocket::http::Method::*; + +#[test] +fn user() { + let rocket = rocket::ignite().mount("/", routes![super::hello]); + let mut req = MockRequest::new(Get, "/hello"); + let mut response = req.dispatch_with(&rocket); + + let body_str = response.body().and_then(|body| body.into_string()); + let values: Vec<_> = response.header_values("Access-Control-Allow-Origin").collect(); + assert_eq!(values, vec!["*"]); + assert_eq!(body_str, Some("Hello there!".to_string())); +}