Very simple CORS implementation

This commit is contained in:
Artem Biryukov 2017-01-19 04:15:30 +03:00
parent 44e367c64c
commit 9d7ad109cd
No known key found for this signature in database
GPG Key ID: AC9644A627DDAA9D
6 changed files with 199 additions and 0 deletions

View File

@ -4,6 +4,7 @@ members = [
"codegen/", "codegen/",
"contrib/", "contrib/",
"examples/cookies", "examples/cookies",
"examples/cors",
"examples/errors", "examples/errors",
"examples/extended_validation", "examples/extended_validation",
"examples/forms", "examples/forms",

140
contrib/src/cors.rs Normal file
View File

@ -0,0 +1,140 @@
use std::collections::HashSet;
use rocket::response::{self, Response, Responder};
use rocket::http::Method;
/// The CORS type, which implements `Responder`. This type allows
/// you to request resources from another domain.
///
/// # Usage
///
/// You can use the `CORS` type for you routes as a preflight like that:
///
/// ```rust,ignore
/// #[route(OPTIONS, "/user")]
/// fn cors_preflight() -> PreflightCORS {
/// CORS::preflight("http://somehost.com")
/// .methods(vec![Method::Options, Methods::Get])
/// .headers(vec!["Content-Type"])
/// }
/// ```
///
/// And then you can just simply do something like this:
///
/// ```rust,ignore
/// #[get("/user")]
/// fn user() -> CORS<String> {
/// CORS::any("Hello I'm User!".to_string())
/// }
/// ```
pub struct CORS<R> {
responder: R,
allow_origin: &'static str,
allow_credentials: bool,
expose_headers: HashSet<&'static str>,
max_age: Option<usize>,
allow_methods: HashSet<Method>,
allow_headers: HashSet<&'static str>
}
pub type PreflightCORS = CORS<()>;
impl PreflightCORS {
/// Consumes origin for which it will allow to use `CORS`
/// and return a basic origin `CORS`
pub fn preflight(origin: &'static str) -> PreflightCORS {
CORS::origin((), origin)
}
}
impl<'r, R: Responder<'r>> CORS<R> {
/// Consumes responder and returns CORS with any origin
pub fn any(responder: R) -> CORS<R> {
CORS::origin(responder, "*")
}
/// Consumes the responder and origin and returns basic CORS
pub fn origin(responder: R, origin: &'static str) -> CORS<R> {
CORS {
responder: responder,
allow_origin: origin,
allow_credentials: false,
expose_headers: HashSet::new(),
max_age: None,
allow_methods: HashSet::new(),
allow_headers: HashSet::new()
}
}
/// Consumes the CORS, set allow_credentials to
/// new value and returns changed CORS
pub fn credentials(mut self, value: bool) -> CORS<R> {
self.allow_credentials = value;
self
}
/// Consumes the CORS, set expose_headers to
/// passed headers and returns changed CORS
pub fn exposed_headers(mut self, headers: &[&'static str]) -> CORS<R> {
self.expose_headers = headers.into_iter().cloned().collect();
self
}
/// Consumes the CORS, set max_age to
/// passed value and returns changed CORS
pub fn max_age(mut self, value: Option<usize>) -> CORS<R> {
self.max_age = value;
self
}
/// Consumes the CORS, set allow_methods to
/// passed methods and returns changed CORS
pub fn methods(mut self, methods: &[Method]) -> CORS<R> {
self.allow_methods = methods.into_iter().cloned().collect();
self
}
/// Consumes the CORS, set allow_headers to
/// passed headers and returns changed CORS
pub fn headers(mut self, headers: &[&'static str]) -> CORS<R> {
self.allow_headers = headers.into_iter().cloned().collect();
self
}
}
impl <'r, R: Responder<'r>> Responder<'r> for CORS<R> {
fn respond(self) -> response::Result<'r> {
let mut response = Response::build_from(self.responder.respond()?)
.raw_header("Access-Control-Allow-Origin", self.allow_origin)
.finalize();
if self.allow_credentials {
response.set_raw_header("Access-Control-Allow-Credentials", "true");
} else {
response.set_raw_header("Access-Control-Allow-Credentials", "false");
}
if !self.expose_headers.is_empty() {
let headers: Vec<_> = self.expose_headers.into_iter().collect();
let headers = headers.join(", ");
response.set_raw_header("Access-Control-Expose-Headers", headers);
}
if !self.allow_methods.is_empty() {
let methods: Vec<_> = self.allow_methods
.into_iter()
.map(|m| m.as_str())
.collect();
let methods = methods.join(", ");
response.set_raw_header("Access-Control-Allow-Methods", methods);
}
if self.max_age.is_some() {
let max_age = self.max_age.unwrap();
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
}
Ok(response)
}
}

View File

@ -40,6 +40,8 @@
#[cfg(feature = "lazy_static_macro")] #[cfg(feature = "lazy_static_macro")]
extern crate lazy_static; extern crate lazy_static;
mod cors;
#[cfg_attr(feature = "json", macro_use)] #[cfg_attr(feature = "json", macro_use)]
#[cfg(feature = "json")] #[cfg(feature = "json")]
mod json; mod json;
@ -50,6 +52,8 @@ mod templates;
#[cfg(feature = "uuid")] #[cfg(feature = "uuid")]
mod uuid; mod uuid;
pub use cors::{PreflightCORS, CORS};
#[cfg(feature = "json")] #[cfg(feature = "json")]
pub use json::JSON; pub use json::JSON;

16
examples/cors/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "cors"
version = "0.0.1"
authors = ["Artem Biryukov <impowski@protonmail.ch>"]
workspace = "../../"
[dependencies]
rocket = { path = "../../lib" }
rocket_codegen = { path = "../../codegen" }
[dependencies.rocket_contrib]
default-features = true
path = "../../contrib"
[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }

22
examples/cors/src/main.rs Normal file
View File

@ -0,0 +1,22 @@
#![feature(plugin)]
#![plugin(rocket_codegen)]
extern crate rocket;
extern crate rocket_contrib;
use rocket::http::Method;
use rocket_contrib::{PreflightCORS, CORS};
#[cfg(test)]
mod tests;
#[get("/hello")]
fn hello() -> CORS<String> {
CORS::any("Hello there!".to_string())
}
fn main() {
rocket::ignite()
.mount("/", routes![cors_preflight, hello])
.launch();
}

View File

@ -0,0 +1,16 @@
use super::rocket;
use rocket::testing::MockRequest;
use rocket::http::Header;
use rocket::http::Method::*;
#[test]
fn user() {
let rocket = rocket::ignite().mount("/", routes![super::hello]);
let mut req = MockRequest::new(Get, "/hello");
let mut response = req.dispatch_with(&rocket);
let body_str = response.body().and_then(|body| body.into_string());
let values: Vec<_> = response.header_values("Access-Control-Allow-Origin").collect();
assert_eq!(values, vec!["*"]);
assert_eq!(body_str, Some("Hello there!".to_string()));
}