From 0f7c075bfef2ac0d75767906594a3df2969a8de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesper=20Steen=20M=C3=B8ller?= Date: Thu, 26 Sep 2024 11:42:05 +0200 Subject: [PATCH] Allow accepting headers by implementing FromRequest --- core/http/src/header/header.rs | 8 +- core/lib/Cargo.toml | 1 + core/lib/src/request/from_request_headers.rs | 103 +++++++++++++++++++ core/lib/src/request/mod.rs | 1 + core/lib/tests/typed-headers.rs | 56 +++++++++- 5 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 core/lib/src/request/from_request_headers.rs diff --git a/core/http/src/header/header.rs b/core/http/src/header/header.rs index af03a9c6..888303cb 100644 --- a/core/http/src/header/header.rs +++ b/core/http/src/header/header.rs @@ -839,7 +839,7 @@ impl Extend for HeaderValueDestination { } } -macro_rules! import_typed_headers { +macro_rules! from_typed_header { ($($name:ident),*) => ($( pub use headers::$name; @@ -854,7 +854,7 @@ macro_rules! import_typed_headers { )*) } -macro_rules! import_generic_typed_headers { +macro_rules! generic_from_typed_header { ($($name:ident<$bound:ident>),*) => ($( pub use headers::$name; @@ -879,7 +879,7 @@ macro_rules! import_generic_typed_headers { // * Location, // Location header, defined in RFC7231 // * SetCookie, // Set-Cookie header, defined RFC6265 -import_typed_headers! { +from_typed_header! { AcceptRanges, // Accept-Ranges header, defined in RFC7233 AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS @@ -926,7 +926,7 @@ import_typed_headers! { Vary // Vary header, defined in RFC7231 } -import_generic_typed_headers! { +generic_from_typed_header! { Authorization, // Authorization header, defined in RFC7235 ProxyAuthorization // Proxy-Authorization header, defined in RFC7235 } diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 0ce5f3a8..e0270774 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -49,6 +49,7 @@ http = "1" bytes = "1.4" hyper = { version = "1.1", default-features = false, features = ["http1", "server"] } hyper-util = { version = "0.1.3", default-features = false, features = ["http1", "server", "tokio"] } +headers = "0.4.0" # Non-optional, core dependencies from here on out. yansi = { version = "1.0.1", features = ["detect-tty"] } diff --git a/core/lib/src/request/from_request_headers.rs b/core/lib/src/request/from_request_headers.rs new file mode 100644 index 00000000..1e9c7c62 --- /dev/null +++ b/core/lib/src/request/from_request_headers.rs @@ -0,0 +1,103 @@ +use crate::{outcome::IntoOutcome, Request}; +use super::FromRequest; + +use headers::{Header as HHeader, HeaderValue as HHeaderValue}; +use rocket_http::Status; + +macro_rules! typed_headers_from_request { + ($($name:ident),*) => ($( + pub use crate::http::$name; + + #[rocket::async_trait] + impl<'r> FromRequest<'r> for $name { + type Error = headers::Error; + async fn from_request(req: &'r Request<'_>) -> crate::request::Outcome { + req.headers().get($name::name().as_str()).next().or_forward(Status::NotFound) + .and_then(|h| HHeaderValue::from_str(h).or_error(Status::BadRequest)) + .map_error(|(s, _)| (s, headers::Error::invalid())) + .and_then(|h| $name::decode(&mut std::iter::once(&h)).or_forward(Status::BadRequest)) + } + } + )*) +} + +macro_rules! generic_typed_headers_from_request { +($($name:ident<$bound:ident>),*) => ($( + pub use crate::http::$name; + + #[rocket::async_trait] + impl<'r, T1: 'static + $bound> FromRequest<'r> for $name { + type Error = headers::Error; + async fn from_request(req: &'r Request<'_>) -> crate::request::Outcome { + req.headers().get($name::::name().as_str()).next().or_forward(Status::NotFound) + .and_then(|h| HHeaderValue::from_str(h).or_error(Status::BadRequest)) + .map_error(|(s, _)| (s, headers::Error::invalid())) + .and_then(|h| $name::decode(&mut std::iter::once(&h)).or_forward(Status::BadRequest)) + } + } +)*) +} + +// The following headers from 'headers' 0.4 are not imported, since they are +// provided by other Rocket features. + +// * ContentType, // Content-Type header, defined in RFC7231 +// * Cookie, // Cookie header, defined in RFC6265 +// * Host, // The Host header. +// * Location, // Location header, defined in RFC7231 +// * SetCookie, // Set-Cookie header, defined RFC6265 + +typed_headers_from_request! { + AcceptRanges, // Accept-Ranges header, defined in RFC7233 + AccessControlAllowCredentials, // Access-Control-Allow-Credentials header, part of CORS + AccessControlAllowHeaders, // Access-Control-Allow-Headers header, part of CORS + AccessControlAllowMethods, // Access-Control-Allow-Methods header, part of CORS + AccessControlAllowOrigin, // The Access-Control-Allow-Origin response header, part of CORS + AccessControlExposeHeaders, // Access-Control-Expose-Headers header, part of CORS + AccessControlMaxAge, // Access-Control-Max-Age header, part of CORS + AccessControlRequestHeaders, // Access-Control-Request-Headers header, part of CORS + AccessControlRequestMethod, // Access-Control-Request-Method header, part of CORS + Age, // Age header, defined in RFC7234 + Allow, // Allow header, defined in RFC7231 + CacheControl, // Cache-Control header, defined in RFC7234 with extensions in RFC8246 + Connection, // Connection header, defined in RFC7230 + ContentDisposition, // A Content-Disposition header, (re)defined in RFC6266. + ContentEncoding, // Content-Encoding header, defined in RFC7231 + ContentLength, // Content-Length header, defined in RFC7230 + ContentLocation, // Content-Location header, defined in RFC7231 + ContentRange, // Content-Range, described in RFC7233 + Date, // Date header, defined in RFC7231 + ETag, // ETag header, defined in RFC7232 + Expect, // The Expect header. + Expires, // Expires header, defined in RFC7234 + IfMatch, // If-Match header, defined in RFC7232 + IfModifiedSince, // If-Modified-Since header, defined in RFC7232 + IfNoneMatch, // If-None-Match header, defined in RFC7232 + IfRange, // If-Range header, defined in RFC7233 + IfUnmodifiedSince, // If-Unmodified-Since header, defined in RFC7232 + LastModified, // Last-Modified header, defined in RFC7232 + Origin, // The Origin header. + Pragma, // The Pragma header defined by HTTP/1.0. + Range, // Range header, defined in RFC7233 + Referer, // Referer header, defined in RFC7231 + ReferrerPolicy, // Referrer-Policy header, part of Referrer Policy + RetryAfter, // The Retry-After header. + SecWebsocketAccept, // The Sec-Websocket-Accept header. + SecWebsocketKey, // The Sec-Websocket-Key header. + SecWebsocketVersion, // The Sec-Websocket-Version header. + Server, // Server header, defined in RFC7231 + StrictTransportSecurity, // StrictTransportSecurity header, defined in RFC6797 + Te, // TE header, defined in RFC7230 + TransferEncoding, // Transfer-Encoding header, defined in RFC7230 + Upgrade, // Upgrade header, defined in RFC7230 + UserAgent, // User-Agent header, defined in RFC7231 + Vary // Vary header, defined in RFC7231 +} + +pub use headers::authorization::Credentials; + +generic_typed_headers_from_request! { + Authorization, // Authorization header, defined in RFC7235 + ProxyAuthorization // Proxy-Authorization header, defined in RFC7235 +} + diff --git a/core/lib/src/request/mod.rs b/core/lib/src/request/mod.rs index 48ac79c7..e184a890 100644 --- a/core/lib/src/request/mod.rs +++ b/core/lib/src/request/mod.rs @@ -3,6 +3,7 @@ mod request; mod from_param; mod from_request; +mod from_request_headers; mod atomic_method; #[cfg(test)] diff --git a/core/lib/tests/typed-headers.rs b/core/lib/tests/typed-headers.rs index e304e021..2bd55425 100644 --- a/core/lib/tests/typed-headers.rs +++ b/core/lib/tests/typed-headers.rs @@ -2,7 +2,9 @@ extern crate rocket; use std::time::{Duration, SystemTime}; +use headers::IfModifiedSince; use rocket::http::Expires; +use rocket_http::{Header, Status}; #[derive(Responder)] struct MyResponse { @@ -10,7 +12,7 @@ struct MyResponse { expires: Expires, } -#[get("/")] +#[get("/expires")] fn index() -> MyResponse { let some_future_time = SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(60 * 60 * 24 * 365 * 100)).unwrap(); @@ -21,10 +23,56 @@ fn index() -> MyResponse { } } +#[get("/data")] +fn get_data_with_opt_header(since: Option) -> String { + if let Some(time) = since { + format!("GET after: {:}", time::OffsetDateTime::from(SystemTime::from(time))) + } else { + format!("Unconditional GET") + } +} + +#[get("/data_since")] +fn get_data_with_header(since: IfModifiedSince) -> String { + format!("GET after: {:}", time::OffsetDateTime::from(SystemTime::from(since))) +} + #[test] -fn typed_header() { - let rocket = rocket::build().mount("/", routes![index]); +fn respond_with_typed_header() { + let rocket = rocket::build().mount( + "/", + routes![index, get_data_with_opt_header, get_data_with_header]); let client = rocket::local::blocking::Client::debug(rocket).unwrap(); - let response = client.get("/").dispatch(); + + let response = client.get("/expires").dispatch(); assert_eq!(response.headers().get_one("Expires").unwrap(), "Sat, 07 Dec 2069 00:00:00 GMT"); } + +#[test] +fn read_typed_header() { + let rocket = rocket::build().mount( + "/", + routes![index, get_data_with_opt_header, get_data_with_header]); + let client = rocket::local::blocking::Client::debug(rocket).unwrap(); + + let response = client.get("/data").dispatch(); + assert_eq!(response.into_string().unwrap(), "Unconditional GET".to_string()); + + let response = client.get("/data") + .header(Header::new("if-modified-since", "Mon, 07 Dec 2020 00:00:00 GMT")).dispatch(); + assert_eq!(response.into_string().unwrap(), + "GET after: 2020-12-07 0:00:00.0 +00:00:00".to_string()); + + let response = client.get("/data_since") + .header(Header::new("if-modified-since", "Tue, 08 Dec 2020 00:00:00 GMT")).dispatch(); + assert_eq!(response.into_string().unwrap(), + "GET after: 2020-12-08 0:00:00.0 +00:00:00".to_string()); + + let response = client.get("/data_since") + .header(Header::new("if-modified-since", "WTF, 07 Dec 2020 00:00:00 GMT")).dispatch(); + assert_eq!(response.status(), Status::BadRequest); + + let response = client.get("/data_since") + .header(Header::new("if-modified-since", "\x0c , 07 Dec 2020 00:00:00 GMT")).dispatch(); + assert_eq!(response.status(), Status::BadRequest); +}