Bust cache on 'Request::{add,replace}_header()'.

Also changes 'Header::name()' to return '&UncasedStr'.

Resolves #518.
This commit is contained in:
Sergio Benitez 2020-11-05 20:43:31 -08:00
parent c6298b9e11
commit fa77435187
3 changed files with 71 additions and 28 deletions

View File

@ -54,8 +54,7 @@ impl<'h> Header<'h> {
} }
} }
/// Returns the name of this header with casing preserved. To do a /// Returns the name of this header.
/// case-insensitive equality check, use `.name` directly.
/// ///
/// # Example /// # Example
/// ///
@ -67,23 +66,23 @@ impl<'h> Header<'h> {
/// ///
/// let value = format!("{} value", "custom"); /// let value = format!("{} value", "custom");
/// let header = Header::new("X-Custom-Header", value); /// let header = Header::new("X-Custom-Header", value);
/// assert_eq!(header.name(), "X-Custom-Header"); /// assert_eq!(header.name().as_str(), "X-Custom-Header");
/// assert!(header.name() != "X-CUSTOM-HEADER"); /// assert_ne!(header.name().as_str(), "X-CUSTOM-HEADER");
/// ``` /// ```
/// ///
/// A case-insensitive equality check via `.name`: /// A case-insensitive equality check:
/// ///
/// ```rust /// ```rust
/// # extern crate rocket; /// # extern crate rocket;
/// use rocket::http::Header; /// use rocket::http::Header;
/// ///
/// let header = Header::new("X-Custom-Header", "custom value"); /// let header = Header::new("X-Custom-Header", "custom value");
/// assert_eq!(header.name, "X-Custom-Header"); /// assert_eq!(header.name(), "X-Custom-Header");
/// assert_eq!(header.name, "X-CUSTOM-HEADER"); /// assert_eq!(header.name(), "X-CUSTOM-HEADER");
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn name(&self) -> &str { pub fn name(&self) -> &UncasedStr {
self.name.as_str() &self.name
} }
/// Returns the value of this header. /// Returns the value of this header.
@ -553,7 +552,7 @@ impl<'h> HeaderMap<'h> {
/// ///
/// // Actually iterate through them. /// // Actually iterate through them.
/// for header in map.iter() { /// for header in map.iter() {
/// match header.name() { /// match header.name().as_str() {
/// "X-Custom" => assert_eq!(header.value(), "value_1"), /// "X-Custom" => assert_eq!(header.value(), "value_1"),
/// "X-Other" => assert_eq!(header.value(), "other"), /// "X-Other" => assert_eq!(header.value(), "other"),
/// "X-Third" => assert_eq!(header.value(), "third"), /// "X-Third" => assert_eq!(header.value(), "third"),
@ -598,7 +597,7 @@ impl<'h> HeaderMap<'h> {
/// ///
/// // Actually iterate through them. /// // Actually iterate through them.
/// for header in map.into_iter() { /// for header in map.into_iter() {
/// match header.name() { /// match header.name().as_str() {
/// "X-Custom" => assert_eq!(header.value(), "value_1"), /// "X-Custom" => assert_eq!(header.value(), "value_1"),
/// "X-Other" => assert_eq!(header.value(), "other"), /// "X-Other" => assert_eq!(header.value(), "other"),
/// "X-Third" => assert_eq!(header.value(), "third"), /// "X-Third" => assert_eq!(header.value(), "third"),

View File

@ -14,7 +14,7 @@ use crate::request::{FromFormValue, FormItems, FormItem};
use crate::{Rocket, Config, Shutdown, Route}; use crate::{Rocket, Config, Shutdown, Route};
use crate::http::{hyper, uri::{Origin, Segments}}; use crate::http::{hyper, uri::{Origin, Segments}};
use crate::http::{Method, Header, HeaderMap}; use crate::http::{Method, Header, HeaderMap, uncased::UncasedStr};
use crate::http::{RawStr, ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::{RawStr, ContentType, Accept, MediaType, CookieJar, Cookie};
use crate::http::private::{Indexed, SmallVec}; use crate::http::private::{Indexed, SmallVec};
use crate::data::Limits; use crate::data::Limits;
@ -350,7 +350,9 @@ impl<'r> Request<'r> {
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn add_header<'h: 'r, H: Into<Header<'h>>>(&mut self, header: H) { pub fn add_header<'h: 'r, H: Into<Header<'h>>>(&mut self, header: H) {
self.headers.add(header.into()); let header = header.into();
self.bust_header_cache(header.name(), false);
self.headers.add(header);
} }
/// Replaces the value of the header with name `header.name` with /// Replaces the value of the header with name `header.name` with
@ -369,20 +371,22 @@ impl<'r> Request<'r> {
/// ///
/// request.add_header(ContentType::Any); /// request.add_header(ContentType::Any);
/// assert_eq!(request.headers().get_one("Content-Type"), Some("*/*")); /// assert_eq!(request.headers().get_one("Content-Type"), Some("*/*"));
/// assert_eq!(request.content_type(), Some(&ContentType::Any));
/// ///
/// request.replace_header(ContentType::PNG); /// request.replace_header(ContentType::PNG);
/// assert_eq!(request.headers().get_one("Content-Type"), Some("image/png")); /// assert_eq!(request.headers().get_one("Content-Type"), Some("image/png"));
/// assert_eq!(request.content_type(), Some(&ContentType::PNG));
/// # }); /// # });
/// ``` /// ```
#[inline(always)] #[inline(always)]
pub fn replace_header<'h: 'r, H: Into<Header<'h>>>(&mut self, header: H) { pub fn replace_header<'h: 'r, H: Into<Header<'h>>>(&mut self, header: H) {
self.headers.replace(header.into()); let header = header.into();
self.bust_header_cache(header.name(), true);
self.headers.replace(header);
} }
/// Returns the Content-Type header of `self`. If the header is not present, /// Returns the Content-Type header of `self`. If the header is not present,
/// returns `None`. The Content-Type header is cached after the first call /// returns `None`.
/// to this function. As a result, subsequent calls will always return the
/// same value.
/// ///
/// # Example /// # Example
/// ///
@ -394,10 +398,6 @@ impl<'r> Request<'r> {
/// # Request::example(Method::Get, "/uri", |mut request| { /// # Request::example(Method::Get, "/uri", |mut request| {
/// request.add_header(ContentType::JSON); /// request.add_header(ContentType::JSON);
/// assert_eq!(request.content_type(), Some(&ContentType::JSON)); /// assert_eq!(request.content_type(), Some(&ContentType::JSON));
///
/// // The header is cached; it cannot be replaced after first access.
/// request.replace_header(ContentType::HTML);
/// assert_eq!(request.content_type(), Some(&ContentType::JSON));
/// # }); /// # });
/// ``` /// ```
#[inline(always)] #[inline(always)]
@ -408,9 +408,7 @@ impl<'r> Request<'r> {
} }
/// Returns the Accept header of `self`. If the header is not present, /// Returns the Accept header of `self`. If the header is not present,
/// returns `None`. The Accept header is cached after the first call to this /// returns `None`.
/// function. As a result, subsequent calls will always return the same
/// value.
/// ///
/// # Example /// # Example
/// ///
@ -422,10 +420,6 @@ impl<'r> Request<'r> {
/// # Request::example(Method::Get, "/uri", |mut request| { /// # Request::example(Method::Get, "/uri", |mut request| {
/// request.add_header(Accept::JSON); /// request.add_header(Accept::JSON);
/// assert_eq!(request.accept(), Some(&Accept::JSON)); /// assert_eq!(request.accept(), Some(&Accept::JSON));
///
/// // The header is cached; it cannot be replaced after first access.
/// request.replace_header(Accept::HTML);
/// assert_eq!(request.accept(), Some(&Accept::JSON));
/// # }); /// # });
/// ``` /// ```
#[inline(always)] #[inline(always)]
@ -747,6 +741,19 @@ impl<'r> Request<'r> {
// They _are not_ part of the stable API. Please, don't use these. // They _are not_ part of the stable API. Please, don't use these.
#[doc(hidden)] #[doc(hidden)]
impl<'r> Request<'r> { impl<'r> Request<'r> {
/// Resets the cached value (if any) for the header with name `name`.
fn bust_header_cache(&mut self, name: &UncasedStr, replace: bool) {
if name == "Content-Type" {
if self.content_type().is_none() || replace {
self.state.content_type = Storage::new();
}
} else if name == "Accept" {
if self.accept().is_none() || replace {
self.state.accept = Storage::new();
}
}
}
// Only used by doc-tests! Needs to be `pub` because doc-test are external. // Only used by doc-tests! Needs to be `pub` because doc-test are external.
pub fn example<F: Fn(&mut Request<'_>)>(method: Method, uri: &str, f: F) { pub fn example<F: Fn(&mut Request<'_>)>(method: Method, uri: &str, f: F) {
let rocket = Rocket::custom(Config::default()); let rocket = Rocket::custom(Config::default());

View File

@ -0,0 +1,37 @@
use rocket::{fairing::AdHoc, http::ContentType, local::blocking::Client};
#[rocket::post("/", data = "<_data>", format = "json")]
fn index(_data: rocket::Data) -> &'static str { "json" }
#[rocket::post("/", data = "<_data>", rank = 2)]
fn other_index(_data: rocket::Data) -> &'static str { "other" }
fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", rocket::routes![index, other_index])
.attach(AdHoc::on_request("Change CT", |req, _| Box::pin(async move {
let need_ct = req.content_type().is_none();
if req.uri().path().starts_with("/add") {
req.set_uri(rocket::uri!(index));
if need_ct { req.add_header(ContentType::JSON); }
} else if need_ct {
req.replace_header(ContentType::JSON);
}
})))
}
#[test]
fn check_fairing_changes_content_type() {
let client = Client::untracked(rocket()).unwrap();
let response = client.post("/").header(ContentType::PNG).dispatch();
assert_eq!(response.into_string().unwrap(), "other");
let response = client.post("/").dispatch();
assert_eq!(response.into_string().unwrap(), "json");
let response = client.post("/add").dispatch();
assert_eq!(response.into_string().unwrap(), "json");
let response = client.post("/add").header(ContentType::HTML).dispatch();
assert_eq!(response.into_string().unwrap(), "other");
}