diff --git a/core/http/src/header/header.rs b/core/http/src/header/header.rs index 0571755c..e51be209 100644 --- a/core/http/src/header/header.rs +++ b/core/http/src/header/header.rs @@ -54,6 +54,62 @@ impl<'h> Header<'h> { } } + /// Returns `true` if `name` is a valid header name. + /// + /// This implements a simple (i.e, correct but not particularly performant) + /// header "field-name" checker as defined in RFC 7230. + /// + /// ```text + /// header-field = field-name ":" OWS field-value OWS + /// field-name = token + /// token = 1*tchar + /// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + /// / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + /// / DIGIT / ALPHA + /// ; any VCHAR, except delimiters + /// ``` + /// + /// # Example + /// + /// ```rust + /// # extern crate rocket; + /// use rocket::http::Header; + /// + /// assert!(!Header::is_valid_name("")); + /// assert!(!Header::is_valid_name("some header")); + /// assert!(!Header::is_valid_name("some()")); + /// assert!(!Header::is_valid_name("[SomeHeader]")); + /// assert!(!Header::is_valid_name("<")); + /// assert!(!Header::is_valid_name("")); + /// assert!(!Header::is_valid_name("header,here")); + /// + /// assert!(Header::is_valid_name("Some#Header")); + /// assert!(Header::is_valid_name("Some-Header")); + /// assert!(Header::is_valid_name("This-Is_A~Header")); + /// ``` + #[doc(hidden)] + pub const fn is_valid_name(name: &str) -> bool { + const fn is_tchar(b: &u8) -> bool { + b.is_ascii_alphanumeric() || match *b { + b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | + b'.' | b'^' | b'_' | b'`' | b'|' | b'~' => true, + _ => false + } + } + + let mut i = 0; + let bytes = name.as_bytes(); + while i < bytes.len() { + if !is_tchar(&bytes[i]) { + return false + } + + i += 1; + } + + i > 0 + } + /// Returns `true` if `val` is a valid header value. /// /// If `allow_empty` is `true`, this function returns `true` for empty diff --git a/core/lib/src/config/config.rs b/core/lib/src/config/config.rs index b7a03df0..eaca7f16 100644 --- a/core/lib/src/config/config.rs +++ b/core/lib/src/config/config.rs @@ -8,6 +8,7 @@ use yansi::Paint; use crate::config::{LogLevel, Shutdown, Ident}; use crate::request::{self, Request, FromRequest}; +use crate::http::uncased::Uncased; use crate::data::Limits; #[cfg(feature = "tls")] @@ -78,6 +79,18 @@ pub struct Config { /// How, if at all, to identify the server via the `Server` header. /// **(default: `"Rocket"`)** pub ident: Ident, + /// The name of a header, whose value is typically set by an intermediary + /// server or proxy, which contains the real IP address of the connecting + /// client. Used internally and by [`Request::client_ip()`] and + /// [`Request::real_ip()`]. + /// + /// To disable using any header for this purpose, set this value to `false`. + /// Deserialization semantics are identical to those of [`Ident`] except + /// that the value must syntactically be a valid HTTP header name. + /// + /// **(default: `"X-Real-IP"`)** + #[serde(deserialize_with = "crate::config::ip_header::deserialize")] + pub ip_header: Option>, /// Streaming read size limits. **(default: [`Limits::default()`])** pub limits: Limits, /// Directory to store temporary files in. **(default: @@ -174,6 +187,7 @@ impl Config { workers: num_cpus::get(), max_blocking: 512, ident: Ident::default(), + ip_header: Some(Uncased::from_borrowed("X-Real-IP")), limits: Limits::default(), temp_dir: std::env::temp_dir().into(), keep_alive: 5, @@ -363,6 +377,12 @@ impl Config { launch_meta_!("workers: {}", bold(self.workers)); launch_meta_!("max blocking threads: {}", bold(self.max_blocking)); launch_meta_!("ident: {}", bold(&self.ident)); + + match self.ip_header { + Some(ref name) => launch_meta_!("IP header: {}", bold(name)), + None => launch_meta_!("IP header: {}", bold("disabled")) + } + launch_meta_!("limits: {}", bold(&self.limits)); launch_meta_!("temp dir: {}", bold(&self.temp_dir.relative().display())); launch_meta_!("http/2: {}", bold(cfg!(feature = "http2"))); diff --git a/core/lib/src/config/ip_header.rs b/core/lib/src/config/ip_header.rs new file mode 100644 index 00000000..6a251767 --- /dev/null +++ b/core/lib/src/config/ip_header.rs @@ -0,0 +1,56 @@ +use std::fmt; + +use serde::de; + +use crate::http::Header; +use crate::http::uncased::Uncased; + +pub(crate) fn deserialize<'de, D>(de: D) -> Result>, D::Error> + where D: de::Deserializer<'de> +{ + struct Visitor; + + impl<'de> de::Visitor<'de> for Visitor { + type Value = Option>; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a valid header name or `false`") + } + + fn visit_bool(self, v: bool) -> Result { + if !v { + return Ok(None); + } + + Err(E::invalid_value(de::Unexpected::Bool(v), &self)) + } + + fn visit_some(self, de: D) -> Result + where D: de::Deserializer<'de> + { + de.deserialize_string(self) + } + + fn visit_none(self) -> Result { + Ok(None) + } + + fn visit_unit(self) -> Result { + Ok(None) + } + + fn visit_str(self, v: &str) -> Result { + self.visit_string(v.into()) + } + + fn visit_string(self, v: String) -> Result { + if Header::is_valid_name(&v) { + Ok(Some(Uncased::from_owned(v))) + } else { + Err(E::invalid_value(de::Unexpected::Str(&v), &self)) + } + } + } + + de.deserialize_string(Visitor) +} diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index 8b577e77..00f0e926 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -114,6 +114,7 @@ mod ident; mod config; mod shutdown; +mod ip_header; #[cfg(feature = "tls")] mod tls; diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index c51b8c41..26ec3b41 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -307,9 +307,10 @@ impl<'r> Request<'r> { /// /// Because it is common for proxies to forward connections for clients, the /// remote address may contain information about the proxy instead of the - /// client. For this reason, proxies typically set the "X-Real-IP" header - /// with the client's true IP. To extract this IP from the request, use the - /// [`real_ip()`] or [`client_ip()`] methods. + /// client. For this reason, proxies typically set a "X-Real-IP" header + /// [`ip_header`](rocket::Config::ip_header) with the client's true IP. To + /// extract this IP from the request, use the [`real_ip()`] or + /// [`client_ip()`] methods. /// /// [`real_ip()`]: #method.real_ip /// [`client_ip()`]: #method.client_ip @@ -356,8 +357,9 @@ impl<'r> Request<'r> { self.connection.remote = Some(address); } - /// Returns the IP address in the "X-Real-IP" header of the request if such - /// a header exists and contains a valid IP address. + /// Returns the IP address of the configured + /// [`ip_header`](rocket::Config::ip_header) of the request if such a header + /// is configured, exists and contains a valid IP address. /// /// # Example /// @@ -369,25 +371,40 @@ impl<'r> Request<'r> { /// # let req = c.get("/"); /// assert_eq!(req.real_ip(), None); /// + /// // `ip_header` defaults to `X-Real-IP`. /// let req = req.header(Header::new("X-Real-IP", "127.0.0.1")); /// assert_eq!(req.real_ip(), Some(Ipv4Addr::LOCALHOST.into())); /// ``` pub fn real_ip(&self) -> Option { + let ip_header = self.rocket().config.ip_header.as_ref()?.as_str(); self.headers() - .get_one("X-Real-IP") + .get_one(ip_header) .and_then(|ip| { ip.parse() - .map_err(|_| warn_!("'X-Real-IP' header is malformed: {}", ip)) + .map_err(|_| warn_!("'{}' header is malformed: {}", ip_header, ip)) .ok() }) } /// Attempts to return the client's IP address by first inspecting the - /// "X-Real-IP" header and then using the remote connection's IP address. + /// [`ip_header`](rocket::Config::ip_header) and then using the remote + /// connection's IP address. Note that the built-in `IpAddr` request guard + /// can be used to retrieve the same information in a handler: /// - /// If the "X-Real-IP" header exists and contains a valid IP address, that - /// address is returned. Otherwise, if the address of the remote connection - /// is known, that address is returned. Otherwise, `None` is returned. + /// ```rust + /// # use rocket::get; + /// use std::net::IpAddr; + /// + /// #[get("/")] + /// fn get_ip(client_ip: IpAddr) { /* ... */ } + /// + /// #[get("/")] + /// fn try_get_ip(client_ip: Option) { /* ... */ } + /// ```` + /// + /// If the `ip_header` exists and contains a valid IP address, that address + /// is returned. Otherwise, if the address of the remote connection is + /// known, that address is returned. Otherwise, `None` is returned. /// /// # Example /// @@ -405,7 +422,7 @@ impl<'r> Request<'r> { /// request.set_remote("127.0.0.1:8000".parse().unwrap()); /// assert_eq!(request.client_ip(), Some("127.0.0.1".parse().unwrap())); /// - /// // now with an X-Real-IP header + /// // now with an X-Real-IP header, the default value for `ip_header`. /// request.add_header(Header::new("X-Real-IP", "8.8.8.8")); /// assert_eq!(request.client_ip(), Some("8.8.8.8".parse().unwrap())); /// ``` diff --git a/core/lib/tests/config-real-ip-header.rs b/core/lib/tests/config-real-ip-header.rs new file mode 100644 index 00000000..3711f43d --- /dev/null +++ b/core/lib/tests/config-real-ip-header.rs @@ -0,0 +1,101 @@ +#[macro_use] extern crate rocket; + +#[get("/")] +fn inspect_ip(ip: Option) -> String { + ip.map(|ip| ip.to_string()).unwrap_or("".into()) +} + +mod tests { + use rocket::{Rocket, Build, Route}; + use rocket::local::blocking::Client; + use rocket::figment::Figment; + use rocket::http::Header; + + fn routes() -> Vec { + routes![super::inspect_ip] + } + + fn rocket_with_custom_ip_header(header: Option<&'static str>) -> Rocket { + let mut config = rocket::Config::debug_default(); + config.ip_header = header.map(|h| h.into()); + rocket::custom(config).mount("/", routes()) + } + + #[test] + fn check_real_ip_header_works() { + let client = Client::debug(rocket_with_custom_ip_header(Some("IP"))).unwrap(); + let response = client.get("/") + .header(Header::new("X-Real-IP", "1.2.3.4")) + .header(Header::new("IP", "8.8.8.8")) + .dispatch(); + + assert_eq!(response.into_string(), Some("8.8.8.8".into())); + + let response = client.get("/") + .header(Header::new("IP", "1.1.1.1")) + .dispatch(); + + assert_eq!(response.into_string(), Some("1.1.1.1".into())); + + let response = client.get("/").dispatch(); + assert_eq!(response.into_string(), Some("".into())); + } + + #[test] + fn check_real_ip_header_works_again() { + let client = Client::debug(rocket_with_custom_ip_header(Some("x-forward-ip"))).unwrap(); + let response = client.get("/") + .header(Header::new("X-Forward-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("1.2.3.4".into())); + + let config = Figment::from(rocket::Config::debug_default()) + .merge(("ip_header", "x-forward-ip")); + + let client = Client::debug(rocket::custom(config).mount("/", routes())).unwrap(); + let response = client.get("/") + .header(Header::new("X-Forward-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("1.2.3.4".into())); + } + + #[test] + fn check_default_real_ip_header_works() { + let client = Client::debug_with(routes()).unwrap(); + let response = client.get("/") + .header(Header::new("X-Real-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("1.2.3.4".into())); + } + + #[test] + fn check_no_ip_header_works() { + let client = Client::debug(rocket_with_custom_ip_header(None)).unwrap(); + let response = client.get("/") + .header(Header::new("X-Real-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("".into())); + + let config = Figment::from(rocket::Config::debug_default()) + .merge(("ip_header", false)); + + let client = Client::debug(rocket::custom(config).mount("/", routes())).unwrap(); + let response = client.get("/") + .header(Header::new("X-Real-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("".into())); + + let config = Figment::from(rocket::Config::debug_default()); + let client = Client::debug(rocket::custom(config).mount("/", routes())).unwrap(); + let response = client.get("/") + .header(Header::new("X-Real-IP", "1.2.3.4")) + .dispatch(); + + assert_eq!(response.into_string(), Some("1.2.3.4".into())); + } +} diff --git a/examples/config/Rocket.toml b/examples/config/Rocket.toml index d6409b0a..e6059180 100644 --- a/examples/config/Rocket.toml +++ b/examples/config/Rocket.toml @@ -11,6 +11,7 @@ msgpack = "2 MiB" key = "a default app-key" extra = false ident = "Rocket" +ip_header = "CF-Connecting-IP" [debug] address = "127.0.0.1" diff --git a/site/guide/9-configuration.md b/site/guide/9-configuration.md index e8fdc6d3..b77de10e 100644 --- a/site/guide/9-configuration.md +++ b/site/guide/9-configuration.md @@ -24,6 +24,7 @@ values: | `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count | | `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` | | `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` | +| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` | | `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` | | `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` | | `cli_colors` | `bool` | Whether to use colors and emoji when logging. | `true` | @@ -37,6 +38,8 @@ values: * Note: the `workers`, `max_blocking`, and `shutdown.force` configuration parameters are only read from the [default provider](#default-provider). +[client's real IP]: @api/rocket/request/struct.Request.html#method.real_ip + ### Profiles Configurations can be arbitrarily namespaced by [`Profile`]s. Rocket's @@ -127,6 +130,7 @@ port = 9001 [release] port = 9999 secret_key = "hPRYyVRiMyxpw5sBB1XeCMN1kFsDCqKvBi2QJxBVHQk=" +ip_header = false ``` The following is a `Rocket.toml` file with all configuration options set for @@ -142,6 +146,7 @@ workers = 16 max_blocking = 512 keep_alive = 5 ident = "Rocket" +ip_header = "X-Real-IP" # set to `false` to disable log_level = "normal" temp_dir = "/tmp" cli_colors = true