mirror of https://github.com/rwf2/Rocket.git
Very simple CORS implementation
This commit is contained in:
parent
44e367c64c
commit
9d7ad109cd
|
@ -4,6 +4,7 @@ members = [
|
|||
"codegen/",
|
||||
"contrib/",
|
||||
"examples/cookies",
|
||||
"examples/cors",
|
||||
"examples/errors",
|
||||
"examples/extended_validation",
|
||||
"examples/forms",
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
use std::collections::HashSet;
|
||||
use rocket::response::{self, Response, Responder};
|
||||
use rocket::http::Method;
|
||||
|
||||
/// The CORS type, which implements `Responder`. This type allows
|
||||
/// you to request resources from another domain.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// You can use the `CORS` type for you routes as a preflight like that:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// #[route(OPTIONS, "/user")]
|
||||
/// fn cors_preflight() -> PreflightCORS {
|
||||
/// CORS::preflight("http://somehost.com")
|
||||
/// .methods(vec![Method::Options, Methods::Get])
|
||||
/// .headers(vec!["Content-Type"])
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// And then you can just simply do something like this:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// #[get("/user")]
|
||||
/// fn user() -> CORS<String> {
|
||||
/// CORS::any("Hello I'm User!".to_string())
|
||||
/// }
|
||||
/// ```
|
||||
pub struct CORS<R> {
|
||||
responder: R,
|
||||
allow_origin: &'static str,
|
||||
allow_credentials: bool,
|
||||
expose_headers: HashSet<&'static str>,
|
||||
max_age: Option<usize>,
|
||||
allow_methods: HashSet<Method>,
|
||||
allow_headers: HashSet<&'static str>
|
||||
}
|
||||
|
||||
pub type PreflightCORS = CORS<()>;
|
||||
|
||||
impl PreflightCORS {
|
||||
/// Consumes origin for which it will allow to use `CORS`
|
||||
/// and return a basic origin `CORS`
|
||||
pub fn preflight(origin: &'static str) -> PreflightCORS {
|
||||
CORS::origin((), origin)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, R: Responder<'r>> CORS<R> {
|
||||
/// Consumes responder and returns CORS with any origin
|
||||
pub fn any(responder: R) -> CORS<R> {
|
||||
CORS::origin(responder, "*")
|
||||
}
|
||||
|
||||
/// Consumes the responder and origin and returns basic CORS
|
||||
pub fn origin(responder: R, origin: &'static str) -> CORS<R> {
|
||||
CORS {
|
||||
responder: responder,
|
||||
allow_origin: origin,
|
||||
allow_credentials: false,
|
||||
expose_headers: HashSet::new(),
|
||||
max_age: None,
|
||||
allow_methods: HashSet::new(),
|
||||
allow_headers: HashSet::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set allow_credentials to
|
||||
/// new value and returns changed CORS
|
||||
pub fn credentials(mut self, value: bool) -> CORS<R> {
|
||||
self.allow_credentials = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set expose_headers to
|
||||
/// passed headers and returns changed CORS
|
||||
pub fn exposed_headers(mut self, headers: &[&'static str]) -> CORS<R> {
|
||||
self.expose_headers = headers.into_iter().cloned().collect();
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set max_age to
|
||||
/// passed value and returns changed CORS
|
||||
pub fn max_age(mut self, value: Option<usize>) -> CORS<R> {
|
||||
self.max_age = value;
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set allow_methods to
|
||||
/// passed methods and returns changed CORS
|
||||
pub fn methods(mut self, methods: &[Method]) -> CORS<R> {
|
||||
self.allow_methods = methods.into_iter().cloned().collect();
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the CORS, set allow_headers to
|
||||
/// passed headers and returns changed CORS
|
||||
pub fn headers(mut self, headers: &[&'static str]) -> CORS<R> {
|
||||
self.allow_headers = headers.into_iter().cloned().collect();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl <'r, R: Responder<'r>> Responder<'r> for CORS<R> {
|
||||
fn respond(self) -> response::Result<'r> {
|
||||
let mut response = Response::build_from(self.responder.respond()?)
|
||||
.raw_header("Access-Control-Allow-Origin", self.allow_origin)
|
||||
.finalize();
|
||||
|
||||
if self.allow_credentials {
|
||||
response.set_raw_header("Access-Control-Allow-Credentials", "true");
|
||||
} else {
|
||||
response.set_raw_header("Access-Control-Allow-Credentials", "false");
|
||||
}
|
||||
|
||||
if !self.expose_headers.is_empty() {
|
||||
let headers: Vec<_> = self.expose_headers.into_iter().collect();
|
||||
let headers = headers.join(", ");
|
||||
|
||||
response.set_raw_header("Access-Control-Expose-Headers", headers);
|
||||
}
|
||||
|
||||
if !self.allow_methods.is_empty() {
|
||||
let methods: Vec<_> = self.allow_methods
|
||||
.into_iter()
|
||||
.map(|m| m.as_str())
|
||||
.collect();
|
||||
let methods = methods.join(", ");
|
||||
|
||||
response.set_raw_header("Access-Control-Allow-Methods", methods);
|
||||
}
|
||||
|
||||
if self.max_age.is_some() {
|
||||
let max_age = self.max_age.unwrap();
|
||||
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
|
@ -40,6 +40,8 @@
|
|||
#[cfg(feature = "lazy_static_macro")]
|
||||
extern crate lazy_static;
|
||||
|
||||
mod cors;
|
||||
|
||||
#[cfg_attr(feature = "json", macro_use)]
|
||||
#[cfg(feature = "json")]
|
||||
mod json;
|
||||
|
@ -50,6 +52,8 @@ mod templates;
|
|||
#[cfg(feature = "uuid")]
|
||||
mod uuid;
|
||||
|
||||
pub use cors::{PreflightCORS, CORS};
|
||||
|
||||
#[cfg(feature = "json")]
|
||||
pub use json::JSON;
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "cors"
|
||||
version = "0.0.1"
|
||||
authors = ["Artem Biryukov <impowski@protonmail.ch>"]
|
||||
workspace = "../../"
|
||||
|
||||
[dependencies]
|
||||
rocket = { path = "../../lib" }
|
||||
rocket_codegen = { path = "../../codegen" }
|
||||
|
||||
[dependencies.rocket_contrib]
|
||||
default-features = true
|
||||
path = "../../contrib"
|
||||
|
||||
[dev-dependencies]
|
||||
rocket = { path = "../../lib", features = ["testing"] }
|
|
@ -0,0 +1,22 @@
|
|||
#![feature(plugin)]
|
||||
#![plugin(rocket_codegen)]
|
||||
|
||||
extern crate rocket;
|
||||
extern crate rocket_contrib;
|
||||
|
||||
use rocket::http::Method;
|
||||
use rocket_contrib::{PreflightCORS, CORS};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[get("/hello")]
|
||||
fn hello() -> CORS<String> {
|
||||
CORS::any("Hello there!".to_string())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
rocket::ignite()
|
||||
.mount("/", routes![cors_preflight, hello])
|
||||
.launch();
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
use super::rocket;
|
||||
use rocket::testing::MockRequest;
|
||||
use rocket::http::Header;
|
||||
use rocket::http::Method::*;
|
||||
|
||||
#[test]
|
||||
fn user() {
|
||||
let rocket = rocket::ignite().mount("/", routes![super::hello]);
|
||||
let mut req = MockRequest::new(Get, "/hello");
|
||||
let mut response = req.dispatch_with(&rocket);
|
||||
|
||||
let body_str = response.body().and_then(|body| body.into_string());
|
||||
let values: Vec<_> = response.header_values("Access-Control-Allow-Origin").collect();
|
||||
assert_eq!(values, vec!["*"]);
|
||||
assert_eq!(body_str, Some("Hello there!".to_string()));
|
||||
}
|
Loading…
Reference in New Issue