mirror of https://github.com/rwf2/Rocket.git
Very simple CORS implementation
This commit is contained in:
parent
44e367c64c
commit
9d7ad109cd
|
@ -4,6 +4,7 @@ members = [
|
||||||
"codegen/",
|
"codegen/",
|
||||||
"contrib/",
|
"contrib/",
|
||||||
"examples/cookies",
|
"examples/cookies",
|
||||||
|
"examples/cors",
|
||||||
"examples/errors",
|
"examples/errors",
|
||||||
"examples/extended_validation",
|
"examples/extended_validation",
|
||||||
"examples/forms",
|
"examples/forms",
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use rocket::response::{self, Response, Responder};
|
||||||
|
use rocket::http::Method;
|
||||||
|
|
||||||
|
/// The CORS type, which implements `Responder`. This type allows
|
||||||
|
/// you to request resources from another domain.
|
||||||
|
///
|
||||||
|
/// # Usage
|
||||||
|
///
|
||||||
|
/// You can use the `CORS` type for you routes as a preflight like that:
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// #[route(OPTIONS, "/user")]
|
||||||
|
/// fn cors_preflight() -> PreflightCORS {
|
||||||
|
/// CORS::preflight("http://somehost.com")
|
||||||
|
/// .methods(vec![Method::Options, Methods::Get])
|
||||||
|
/// .headers(vec!["Content-Type"])
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// And then you can just simply do something like this:
|
||||||
|
///
|
||||||
|
/// ```rust,ignore
|
||||||
|
/// #[get("/user")]
|
||||||
|
/// fn user() -> CORS<String> {
|
||||||
|
/// CORS::any("Hello I'm User!".to_string())
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub struct CORS<R> {
|
||||||
|
responder: R,
|
||||||
|
allow_origin: &'static str,
|
||||||
|
allow_credentials: bool,
|
||||||
|
expose_headers: HashSet<&'static str>,
|
||||||
|
max_age: Option<usize>,
|
||||||
|
allow_methods: HashSet<Method>,
|
||||||
|
allow_headers: HashSet<&'static str>
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type PreflightCORS = CORS<()>;
|
||||||
|
|
||||||
|
impl PreflightCORS {
|
||||||
|
/// Consumes origin for which it will allow to use `CORS`
|
||||||
|
/// and return a basic origin `CORS`
|
||||||
|
pub fn preflight(origin: &'static str) -> PreflightCORS {
|
||||||
|
CORS::origin((), origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r, R: Responder<'r>> CORS<R> {
|
||||||
|
/// Consumes responder and returns CORS with any origin
|
||||||
|
pub fn any(responder: R) -> CORS<R> {
|
||||||
|
CORS::origin(responder, "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the responder and origin and returns basic CORS
|
||||||
|
pub fn origin(responder: R, origin: &'static str) -> CORS<R> {
|
||||||
|
CORS {
|
||||||
|
responder: responder,
|
||||||
|
allow_origin: origin,
|
||||||
|
allow_credentials: false,
|
||||||
|
expose_headers: HashSet::new(),
|
||||||
|
max_age: None,
|
||||||
|
allow_methods: HashSet::new(),
|
||||||
|
allow_headers: HashSet::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the CORS, set allow_credentials to
|
||||||
|
/// new value and returns changed CORS
|
||||||
|
pub fn credentials(mut self, value: bool) -> CORS<R> {
|
||||||
|
self.allow_credentials = value;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the CORS, set expose_headers to
|
||||||
|
/// passed headers and returns changed CORS
|
||||||
|
pub fn exposed_headers(mut self, headers: &[&'static str]) -> CORS<R> {
|
||||||
|
self.expose_headers = headers.into_iter().cloned().collect();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the CORS, set max_age to
|
||||||
|
/// passed value and returns changed CORS
|
||||||
|
pub fn max_age(mut self, value: Option<usize>) -> CORS<R> {
|
||||||
|
self.max_age = value;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the CORS, set allow_methods to
|
||||||
|
/// passed methods and returns changed CORS
|
||||||
|
pub fn methods(mut self, methods: &[Method]) -> CORS<R> {
|
||||||
|
self.allow_methods = methods.into_iter().cloned().collect();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the CORS, set allow_headers to
|
||||||
|
/// passed headers and returns changed CORS
|
||||||
|
pub fn headers(mut self, headers: &[&'static str]) -> CORS<R> {
|
||||||
|
self.allow_headers = headers.into_iter().cloned().collect();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'r, R: Responder<'r>> Responder<'r> for CORS<R> {
|
||||||
|
fn respond(self) -> response::Result<'r> {
|
||||||
|
let mut response = Response::build_from(self.responder.respond()?)
|
||||||
|
.raw_header("Access-Control-Allow-Origin", self.allow_origin)
|
||||||
|
.finalize();
|
||||||
|
|
||||||
|
if self.allow_credentials {
|
||||||
|
response.set_raw_header("Access-Control-Allow-Credentials", "true");
|
||||||
|
} else {
|
||||||
|
response.set_raw_header("Access-Control-Allow-Credentials", "false");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.expose_headers.is_empty() {
|
||||||
|
let headers: Vec<_> = self.expose_headers.into_iter().collect();
|
||||||
|
let headers = headers.join(", ");
|
||||||
|
|
||||||
|
response.set_raw_header("Access-Control-Expose-Headers", headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.allow_methods.is_empty() {
|
||||||
|
let methods: Vec<_> = self.allow_methods
|
||||||
|
.into_iter()
|
||||||
|
.map(|m| m.as_str())
|
||||||
|
.collect();
|
||||||
|
let methods = methods.join(", ");
|
||||||
|
|
||||||
|
response.set_raw_header("Access-Control-Allow-Methods", methods);
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.max_age.is_some() {
|
||||||
|
let max_age = self.max_age.unwrap();
|
||||||
|
response.set_raw_header("Access-Control-Max-Age", max_age.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,6 +40,8 @@
|
||||||
#[cfg(feature = "lazy_static_macro")]
|
#[cfg(feature = "lazy_static_macro")]
|
||||||
extern crate lazy_static;
|
extern crate lazy_static;
|
||||||
|
|
||||||
|
mod cors;
|
||||||
|
|
||||||
#[cfg_attr(feature = "json", macro_use)]
|
#[cfg_attr(feature = "json", macro_use)]
|
||||||
#[cfg(feature = "json")]
|
#[cfg(feature = "json")]
|
||||||
mod json;
|
mod json;
|
||||||
|
@ -50,6 +52,8 @@ mod templates;
|
||||||
#[cfg(feature = "uuid")]
|
#[cfg(feature = "uuid")]
|
||||||
mod uuid;
|
mod uuid;
|
||||||
|
|
||||||
|
pub use cors::{PreflightCORS, CORS};
|
||||||
|
|
||||||
#[cfg(feature = "json")]
|
#[cfg(feature = "json")]
|
||||||
pub use json::JSON;
|
pub use json::JSON;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
[package]
|
||||||
|
name = "cors"
|
||||||
|
version = "0.0.1"
|
||||||
|
authors = ["Artem Biryukov <impowski@protonmail.ch>"]
|
||||||
|
workspace = "../../"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
rocket = { path = "../../lib" }
|
||||||
|
rocket_codegen = { path = "../../codegen" }
|
||||||
|
|
||||||
|
[dependencies.rocket_contrib]
|
||||||
|
default-features = true
|
||||||
|
path = "../../contrib"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rocket = { path = "../../lib", features = ["testing"] }
|
|
@ -0,0 +1,22 @@
|
||||||
|
#![feature(plugin)]
|
||||||
|
#![plugin(rocket_codegen)]
|
||||||
|
|
||||||
|
extern crate rocket;
|
||||||
|
extern crate rocket_contrib;
|
||||||
|
|
||||||
|
use rocket::http::Method;
|
||||||
|
use rocket_contrib::{PreflightCORS, CORS};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
#[get("/hello")]
|
||||||
|
fn hello() -> CORS<String> {
|
||||||
|
CORS::any("Hello there!".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
rocket::ignite()
|
||||||
|
.mount("/", routes![cors_preflight, hello])
|
||||||
|
.launch();
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
use super::rocket;
|
||||||
|
use rocket::testing::MockRequest;
|
||||||
|
use rocket::http::Header;
|
||||||
|
use rocket::http::Method::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user() {
|
||||||
|
let rocket = rocket::ignite().mount("/", routes![super::hello]);
|
||||||
|
let mut req = MockRequest::new(Get, "/hello");
|
||||||
|
let mut response = req.dispatch_with(&rocket);
|
||||||
|
|
||||||
|
let body_str = response.body().and_then(|body| body.into_string());
|
||||||
|
let values: Vec<_> = response.header_values("Access-Control-Allow-Origin").collect();
|
||||||
|
assert_eq!(values, vec!["*"]);
|
||||||
|
assert_eq!(body_str, Some("Hello there!".to_string()));
|
||||||
|
}
|
Loading…
Reference in New Issue