From fa77435187aaedda73b0caad0ab5bfdb24a8b476 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 5 Nov 2020 20:43:31 -0800 Subject: [PATCH] Bust cache on 'Request::{add,replace}_header()'. Also changes 'Header::name()' to return '&UncasedStr'. Resolves #518. --- core/http/src/header.rs | 21 ++++++----- core/lib/src/request/request.rs | 41 +++++++++++++--------- core/lib/tests/replace-content-type-518.rs | 37 +++++++++++++++++++ 3 files changed, 71 insertions(+), 28 deletions(-) create mode 100644 core/lib/tests/replace-content-type-518.rs diff --git a/core/http/src/header.rs b/core/http/src/header.rs index dbb5b23f..b0c5a162 100644 --- a/core/http/src/header.rs +++ b/core/http/src/header.rs @@ -54,8 +54,7 @@ impl<'h> Header<'h> { } } - /// Returns the name of this header with casing preserved. To do a - /// case-insensitive equality check, use `.name` directly. + /// Returns the name of this header. /// /// # Example /// @@ -67,23 +66,23 @@ impl<'h> Header<'h> { /// /// let value = format!("{} value", "custom"); /// let header = Header::new("X-Custom-Header", value); - /// assert_eq!(header.name(), "X-Custom-Header"); - /// assert!(header.name() != "X-CUSTOM-HEADER"); + /// assert_eq!(header.name().as_str(), "X-Custom-Header"); + /// assert_ne!(header.name().as_str(), "X-CUSTOM-HEADER"); /// ``` /// - /// A case-insensitive equality check via `.name`: + /// A case-insensitive equality check: /// /// ```rust /// # extern crate rocket; /// use rocket::http::Header; /// /// let header = Header::new("X-Custom-Header", "custom value"); - /// assert_eq!(header.name, "X-Custom-Header"); - /// assert_eq!(header.name, "X-CUSTOM-HEADER"); + /// assert_eq!(header.name(), "X-Custom-Header"); + /// assert_eq!(header.name(), "X-CUSTOM-HEADER"); /// ``` #[inline(always)] - pub fn name(&self) -> &str { - self.name.as_str() + pub fn name(&self) -> &UncasedStr { + &self.name } /// Returns the value of this header. @@ -553,7 +552,7 @@ impl<'h> HeaderMap<'h> { /// /// // Actually iterate through them. /// for header in map.iter() { - /// match header.name() { + /// match header.name().as_str() { /// "X-Custom" => assert_eq!(header.value(), "value_1"), /// "X-Other" => assert_eq!(header.value(), "other"), /// "X-Third" => assert_eq!(header.value(), "third"), @@ -598,7 +597,7 @@ impl<'h> HeaderMap<'h> { /// /// // Actually iterate through them. /// for header in map.into_iter() { - /// match header.name() { + /// match header.name().as_str() { /// "X-Custom" => assert_eq!(header.value(), "value_1"), /// "X-Other" => assert_eq!(header.value(), "other"), /// "X-Third" => assert_eq!(header.value(), "third"), diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 6b1e4da1..1327620b 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -14,7 +14,7 @@ use crate::request::{FromFormValue, FormItems, FormItem}; use crate::{Rocket, Config, Shutdown, Route}; use crate::http::{hyper, uri::{Origin, Segments}}; -use crate::http::{Method, Header, HeaderMap}; +use crate::http::{Method, Header, HeaderMap, uncased::UncasedStr}; use crate::http::{RawStr, ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::private::{Indexed, SmallVec}; use crate::data::Limits; @@ -350,7 +350,9 @@ impl<'r> Request<'r> { /// ``` #[inline(always)] pub fn add_header<'h: 'r, H: Into>>(&mut self, header: H) { - self.headers.add(header.into()); + let header = header.into(); + self.bust_header_cache(header.name(), false); + self.headers.add(header); } /// Replaces the value of the header with name `header.name` with @@ -369,20 +371,22 @@ impl<'r> Request<'r> { /// /// request.add_header(ContentType::Any); /// assert_eq!(request.headers().get_one("Content-Type"), Some("*/*")); + /// assert_eq!(request.content_type(), Some(&ContentType::Any)); /// /// request.replace_header(ContentType::PNG); /// assert_eq!(request.headers().get_one("Content-Type"), Some("image/png")); + /// assert_eq!(request.content_type(), Some(&ContentType::PNG)); /// # }); /// ``` #[inline(always)] pub fn replace_header<'h: 'r, H: Into>>(&mut self, header: H) { - self.headers.replace(header.into()); + let header = header.into(); + self.bust_header_cache(header.name(), true); + self.headers.replace(header); } /// Returns the Content-Type header of `self`. If the header is not present, - /// returns `None`. The Content-Type header is cached after the first call - /// to this function. As a result, subsequent calls will always return the - /// same value. + /// returns `None`. /// /// # Example /// @@ -394,10 +398,6 @@ impl<'r> Request<'r> { /// # Request::example(Method::Get, "/uri", |mut request| { /// request.add_header(ContentType::JSON); /// assert_eq!(request.content_type(), Some(&ContentType::JSON)); - /// - /// // The header is cached; it cannot be replaced after first access. - /// request.replace_header(ContentType::HTML); - /// assert_eq!(request.content_type(), Some(&ContentType::JSON)); /// # }); /// ``` #[inline(always)] @@ -408,9 +408,7 @@ impl<'r> Request<'r> { } /// Returns the Accept header of `self`. If the header is not present, - /// returns `None`. The Accept header is cached after the first call to this - /// function. As a result, subsequent calls will always return the same - /// value. + /// returns `None`. /// /// # Example /// @@ -422,10 +420,6 @@ impl<'r> Request<'r> { /// # Request::example(Method::Get, "/uri", |mut request| { /// request.add_header(Accept::JSON); /// assert_eq!(request.accept(), Some(&Accept::JSON)); - /// - /// // The header is cached; it cannot be replaced after first access. - /// request.replace_header(Accept::HTML); - /// assert_eq!(request.accept(), Some(&Accept::JSON)); /// # }); /// ``` #[inline(always)] @@ -747,6 +741,19 @@ impl<'r> Request<'r> { // They _are not_ part of the stable API. Please, don't use these. #[doc(hidden)] impl<'r> Request<'r> { + /// Resets the cached value (if any) for the header with name `name`. + fn bust_header_cache(&mut self, name: &UncasedStr, replace: bool) { + if name == "Content-Type" { + if self.content_type().is_none() || replace { + self.state.content_type = Storage::new(); + } + } else if name == "Accept" { + if self.accept().is_none() || replace { + self.state.accept = Storage::new(); + } + } + } + // Only used by doc-tests! Needs to be `pub` because doc-test are external. pub fn example)>(method: Method, uri: &str, f: F) { let rocket = Rocket::custom(Config::default()); diff --git a/core/lib/tests/replace-content-type-518.rs b/core/lib/tests/replace-content-type-518.rs new file mode 100644 index 00000000..3ed4b9a7 --- /dev/null +++ b/core/lib/tests/replace-content-type-518.rs @@ -0,0 +1,37 @@ +use rocket::{fairing::AdHoc, http::ContentType, local::blocking::Client}; + +#[rocket::post("/", data = "<_data>", format = "json")] +fn index(_data: rocket::Data) -> &'static str { "json" } + +#[rocket::post("/", data = "<_data>", rank = 2)] +fn other_index(_data: rocket::Data) -> &'static str { "other" } + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .mount("/", rocket::routes![index, other_index]) + .attach(AdHoc::on_request("Change CT", |req, _| Box::pin(async move { + let need_ct = req.content_type().is_none(); + if req.uri().path().starts_with("/add") { + req.set_uri(rocket::uri!(index)); + if need_ct { req.add_header(ContentType::JSON); } + } else if need_ct { + req.replace_header(ContentType::JSON); + } + }))) +} + +#[test] +fn check_fairing_changes_content_type() { + let client = Client::untracked(rocket()).unwrap(); + let response = client.post("/").header(ContentType::PNG).dispatch(); + assert_eq!(response.into_string().unwrap(), "other"); + + let response = client.post("/").dispatch(); + assert_eq!(response.into_string().unwrap(), "json"); + + let response = client.post("/add").dispatch(); + assert_eq!(response.into_string().unwrap(), "json"); + + let response = client.post("/add").header(ContentType::HTML).dispatch(); + assert_eq!(response.into_string().unwrap(), "other"); +}