From 6dc21e5380ed61b5691a258af8c33312729598c8 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 18 Apr 2017 00:25:13 -0700 Subject: [PATCH] Add support for configurable size limits. --- contrib/src/json.rs | 12 ++- contrib/src/msgpack.rs | 12 ++- examples/config/Rocket.toml | 7 +- lib/src/config/builder.rs | 30 +++++- lib/src/config/config.rs | 100 ++++------------- lib/src/config/custom_values.rs | 165 +++++++++++++++++++++++++++++ lib/src/config/mod.rs | 2 + lib/src/request/form/form_items.rs | 1 + lib/src/request/form/mod.rs | 7 +- lib/src/rocket.rs | 1 + lib/tests/limits.rs | 83 +++++++++++++++ 11 files changed, 323 insertions(+), 97 deletions(-) create mode 100644 lib/src/config/custom_values.rs create mode 100644 lib/tests/limits.rs diff --git a/contrib/src/json.rs b/contrib/src/json.rs index f8ab601b..b31ba43b 100644 --- a/contrib/src/json.rs +++ b/contrib/src/json.rs @@ -1,6 +1,7 @@ use std::ops::{Deref, DerefMut}; use std::io::Read; +use rocket::config; use rocket::outcome::Outcome; use rocket::request::Request; use rocket::data::{self, Data, FromData}; @@ -65,9 +66,8 @@ impl JSON { } } -/// Maximum size of JSON is 1MB. -/// TODO: Determine this size from some configuration parameter. -const MAX_SIZE: u64 = 1048576; +/// Default limit for JSON is 1MB. +const LIMIT: u64 = 1 << 20; impl FromData for JSON { type Error = SerdeError; @@ -78,7 +78,11 @@ impl FromData for JSON { return Outcome::Forward(data); } - let reader = data.open().take(MAX_SIZE); + let size_limit = config::active() + .and_then(|c| c.limits.get("json")) + .unwrap_or(LIMIT); + + let reader = data.open().take(size_limit); match serde_json::from_reader(reader).map(|val| JSON(val)) { Ok(value) => Outcome::Success(value), Err(e) => { diff --git a/contrib/src/msgpack.rs b/contrib/src/msgpack.rs index f33ede0c..05927c96 100644 --- a/contrib/src/msgpack.rs +++ b/contrib/src/msgpack.rs @@ -3,6 +3,7 @@ extern crate rmp_serde; use std::ops::{Deref, DerefMut}; use std::io::{Cursor, Read}; +use rocket::config; use rocket::outcome::Outcome; use rocket::request::Request; use rocket::data::{self, Data, FromData}; @@ -70,9 +71,8 @@ impl MsgPack { } } -/// Maximum size of MessagePack data is 1MB. -/// TODO: Determine this size from some configuration parameter. -const MAX_SIZE: u64 = 1048576; +/// Default limit for MessagePack is 1MB. +const LIMIT: u64 = 1 << 20; /// Accepted content types are: `application/msgpack`, `application/x-msgpack`, /// `bin/msgpack`, and `bin/x-msgpack`. @@ -91,8 +91,12 @@ impl FromData for MsgPack { return Outcome::Forward(data); } + let size_limit = config::active() + .and_then(|c| c.limits.get("msgpack")) + .unwrap_or(LIMIT); + let mut buf = Vec::new(); - if let Err(e) = data.open().take(MAX_SIZE).read_to_end(&mut buf) { + if let Err(e) = data.open().take(size_limit).read_to_end(&mut buf) { let e = MsgPackError::InvalidDataRead(e); error_!("Couldn't read request data: {:?}", e); return Outcome::Failure((Status::BadRequest, e)); diff --git a/examples/config/Rocket.toml b/examples/config/Rocket.toml index 2ae9bd0f..64ebdfc2 100644 --- a/examples/config/Rocket.toml +++ b/examples/config/Rocket.toml @@ -1,6 +1,11 @@ -# Except for the session key, none of these are actually needed; Rocket has sane +# Except for the session key, nothing here is necessary; Rocket has sane # defaults. We show all of them here explicitly for demonstrative purposes. +[global.limits] +forms = 32768 +json = 1048576 # this is an extra used by the json contrib module +msgpack = 1048576 # this is an extra used by the msgpack contrib module + [development] address = "localhost" port = 8000 diff --git a/lib/src/config/builder.rs b/lib/src/config/builder.rs index 251ef98b..5916a9d1 100644 --- a/lib/src/config/builder.rs +++ b/lib/src/config/builder.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; -use config::{Result, Config, Value, Environment}; +use config::{Result, Config, Value, Environment, Limits}; use config::toml_ext::IntoValue; use logger::LoggingLevel; @@ -21,7 +21,9 @@ pub struct ConfigBuilder { /// The session key. pub session_key: Option, /// TLS configuration (path to certificates file, path to private key file). - pub tls_config: Option<(String, String)>, + pub tls: Option<(String, String)>, + /// Size limits. + pub limits: Limits, /// Any extra parameters that aren't part of Rocket's config. pub extras: HashMap, /// The root directory of this config. @@ -65,7 +67,8 @@ impl ConfigBuilder { workers: config.workers, log_level: config.log_level, session_key: None, - tls_config: None, + tls: None, + limits: config.limits, extras: config.extras, root: root_dir, } @@ -165,6 +168,22 @@ impl ConfigBuilder { self } + /// Sets the `limits` in the configuration being built. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::{Config, Environment, Limits}; + /// + /// let mut config = Config::build(Environment::Staging) + /// .limits(Limits::default().add("json", 5 * (1 << 20))) + /// .unwrap(); + /// ``` + pub fn limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + /// Sets the `tls_config` in the configuration being built. /// /// # Example @@ -181,7 +200,7 @@ impl ConfigBuilder { pub fn tls(mut self, certs_path: C, key_path: K) -> Self where C: Into, K: Into { - self.tls_config = Some((certs_path.into(), key_path.into())); + self.tls = Some((certs_path.into(), key_path.into())); self } @@ -282,8 +301,9 @@ impl ConfigBuilder { config.set_log_level(self.log_level); config.set_extras(self.extras); config.set_root(self.root); + config.set_limits(self.limits); - if let Some((certs_path, key_path)) = self.tls_config { + if let Some((certs_path, key_path)) = self.tls { config.set_tls(&certs_path, &key_path)?; } diff --git a/lib/src/config/config.rs b/lib/src/config/config.rs index 8f6b3d51..e51fb90d 100644 --- a/lib/src/config/config.rs +++ b/lib/src/config/config.rs @@ -5,45 +5,13 @@ use std::convert::AsRef; use std::fmt; use std::env; -#[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey}; - +use super::custom_values::*; use {num_cpus, base64}; use config::Environment::*; use config::{Result, Table, Value, ConfigBuilder, Environment, ConfigError}; use logger::LoggingLevel; use http::Key; -pub enum SessionKey { - Generated(Key), - Provided(Key) -} - -impl SessionKey { - #[inline(always)] - pub fn kind(&self) -> &'static str { - match *self { - SessionKey::Generated(_) => "generated", - SessionKey::Provided(_) => "provided", - } - } - - #[inline(always)] - fn inner(&self) -> &Key { - match *self { - SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key - } - } -} - -#[cfg(feature = "tls")] -pub struct TlsConfig { - pub certs: Vec, - pub key: PrivateKey -} - -#[cfg(not(feature = "tls"))] -pub struct TlsConfig; - /// Structure for Rocket application configuration. /// /// A `Config` structure is typically built using the [build](#method.build) @@ -75,6 +43,8 @@ pub struct Config { pub(crate) session_key: SessionKey, /// TLS configuration. pub(crate) tls: Option, + /// Streaming read size limits. + pub limits: Limits, /// Extra parameters that aren't part of Rocket's core config. pub extras: HashMap, /// The path to the configuration file this config belongs to. @@ -94,52 +64,6 @@ macro_rules! config_from_raw { ) } -#[inline(always)] -fn value_as_str<'a>(config: &Config, name: &str, value: &'a Value) -> Result<&'a str> { - value.as_str().ok_or(config.bad_type(name, value.type_str(), "a string")) -} - -#[inline(always)] -fn value_as_u16(config: &Config, name: &str, value: &Value) -> Result { - match value.as_integer() { - Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16), - _ => Err(config.bad_type(name, value.type_str(), "a 16-bit unsigned integer")) - } -} - -#[inline(always)] -fn value_as_log_level(config: &Config, name: &str, value: &Value) -> Result { - value_as_str(config, name, value) - .and_then(|s| s.parse().map_err(|e| config.bad_type(name, value.type_str(), e))) -} - -#[inline(always)] -fn value_as_tls_config<'v>(config: &Config, - name: &str, - value: &'v Value, - ) -> Result<(&'v str, &'v str)> -{ - let (mut certs_path, mut key_path) = (None, None); - let table = value.as_table() - .ok_or_else(|| config.bad_type(name, value.type_str(), "a table"))?; - - let env = config.environment; - for (key, value) in table { - match key.as_str() { - "certs" => certs_path = Some(value_as_str(config, "tls.certs", value)?), - "key" => key_path = Some(value_as_str(config, "tls.key", value)?), - _ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key))) - } - } - - if let (Some(certs), Some(key)) = (certs_path, key_path) { - Ok((certs, key)) - } else { - Err(config.bad_type(name, "a table with missing entries", - "a table with `certs` and `key` entries")) - } -} - impl Config { /// Returns a builder for `Config` structure where the default parameters /// are set to those of `env`. The root configuration directory is set to @@ -219,6 +143,7 @@ impl Config { log_level: LoggingLevel::Normal, session_key: key, tls: None, + limits: Limits::default(), extras: HashMap::new(), config_path: config_path, } @@ -232,6 +157,7 @@ impl Config { log_level: LoggingLevel::Normal, session_key: key, tls: None, + limits: Limits::default(), extras: HashMap::new(), config_path: config_path, } @@ -245,6 +171,7 @@ impl Config { log_level: LoggingLevel::Critical, session_key: key, tls: None, + limits: Limits::default(), extras: HashMap::new(), config_path: config_path, } @@ -255,8 +182,10 @@ impl Config { /// Constructs a `BadType` error given the entry `name`, the invalid `val` /// at that entry, and the `expect`ed type name. #[inline(always)] - fn bad_type(&self, name: &str, actual: &'static str, expect: &'static str) - -> ConfigError { + pub(crate) fn bad_type(&self, + name: &str, + actual: &'static str, + expect: &'static str) -> ConfigError { let id = format!("{}.{}", self.environment, name); ConfigError::BadType(id, expect, actual, self.config_path.clone()) } @@ -284,7 +213,8 @@ impl Config { workers => (u16, set_workers, ok), session_key => (str, set_session_key, id), log => (log_level, set_log_level, ok), - tls => (tls_config, set_raw_tls, id) + tls => (tls_config, set_raw_tls, id), + limits => (limits, set_limits, ok) | _ => { self.extras.insert(name.into(), val.clone()); Ok(()) @@ -442,6 +372,12 @@ impl Config { self.log_level = log_level; } + /// Sets limits. + #[inline] + pub fn set_limits(&mut self, limits: Limits) { + self.limits = limits; + } + #[cfg(feature = "tls")] pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> { use hyper_rustls::util as tls; diff --git a/lib/src/config/custom_values.rs b/lib/src/config/custom_values.rs new file mode 100644 index 00000000..948bdacf --- /dev/null +++ b/lib/src/config/custom_values.rs @@ -0,0 +1,165 @@ +use std::fmt; + +#[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey}; + +use logger::LoggingLevel; +use config::{Result, Config, Value, ConfigError}; +use http::Key; + +pub enum SessionKey { + Generated(Key), + Provided(Key) +} + +impl SessionKey { + #[inline(always)] + pub fn kind(&self) -> &'static str { + match *self { + SessionKey::Generated(_) => "generated", + SessionKey::Provided(_) => "provided", + } + } + + #[inline(always)] + pub(crate) fn inner(&self) -> &Key { + match *self { + SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key + } + } +} + +#[cfg(feature = "tls")] +pub struct TlsConfig { + pub certs: Vec, + pub key: PrivateKey +} + +#[cfg(not(feature = "tls"))] +pub struct TlsConfig; + +// Size limit configuration. We cache those used by Rocket internally but don't +// share that fact in the API. +#[derive(Debug, Clone)] +pub struct Limits { + pub(crate) forms: u64, + extra: Vec<(String, u64)> +} + +impl Default for Limits { + fn default() -> Limits { + Limits { forms: 1024 * 32, extra: Vec::new() } + } +} + +impl Limits { + pub fn add>(mut self, name: S, limit: u64) -> Self { + let name = name.into(); + match name.as_str() { + "forms" => self.forms = limit, + _ => self.extra.push((name, limit)) + } + + self + } + + pub fn get(&self, name: &str) -> Option { + if name == "forms" { + return Some(self.forms); + } + + for &(ref key, val) in &self.extra { + if key == name { + return Some(val); + } + } + + None + } +} + +impl fmt::Display for Limits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt_size(n: u64, f: &mut fmt::Formatter) -> fmt::Result { + if (n & ((1 << 20) - 1)) == 0 { + write!(f, "{}MiB", n >> 20) + } else if (n & ((1 << 10) - 1)) == 0 { + write!(f, "{}KiB", n >> 10) + } else { + write!(f, "{}B", n) + } + } + + write!(f, "forms = ")?; + fmt_size(self.forms, f)?; + for &(ref key, val) in &self.extra { + write!(f, ", {}* = ", key)?; + fmt_size(val, f)?; + } + + Ok(()) + } +} + +pub fn value_as_str<'a>(conf: &Config, name: &str, v: &'a Value) -> Result<&'a str> { + v.as_str().ok_or(conf.bad_type(name, v.type_str(), "a string")) +} + +pub fn value_as_u64(conf: &Config, name: &str, value: &Value) -> Result { + match value.as_integer() { + Some(x) if x >= 0 => Ok(x as u64), + _ => Err(conf.bad_type(name, value.type_str(), "an unsigned integer")) + } +} + +pub fn value_as_u16(conf: &Config, name: &str, value: &Value) -> Result { + match value.as_integer() { + Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16), + _ => Err(conf.bad_type(name, value.type_str(), "a 16-bit unsigned integer")) + } +} + +pub fn value_as_log_level(conf: &Config, + name: &str, + value: &Value + ) -> Result { + value_as_str(conf, name, value) + .and_then(|s| s.parse().map_err(|e| conf.bad_type(name, value.type_str(), e))) +} + +pub fn value_as_tls_config<'v>(conf: &Config, + name: &str, + value: &'v Value, + ) -> Result<(&'v str, &'v str)> { + let (mut certs_path, mut key_path) = (None, None); + let table = value.as_table() + .ok_or_else(|| conf.bad_type(name, value.type_str(), "a table"))?; + + let env = conf.environment; + for (key, value) in table { + match key.as_str() { + "certs" => certs_path = Some(value_as_str(conf, "tls.certs", value)?), + "key" => key_path = Some(value_as_str(conf, "tls.key", value)?), + _ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key))) + } + } + + if let (Some(certs), Some(key)) = (certs_path, key_path) { + Ok((certs, key)) + } else { + Err(conf.bad_type(name, "a table with missing entries", + "a table with `certs` and `key` entries")) + } +} + +pub fn value_as_limits(conf: &Config, name: &str, value: &Value) -> Result { + let table = value.as_table() + .ok_or_else(|| conf.bad_type(name, value.type_str(), "a table"))?; + + let mut limits = Limits::default(); + for (key, val) in table { + let val = value_as_u64(conf, &format!("limits.{}", key), val)?; + limits = limits.add(key.as_str(), val); + } + + Ok(limits) +} diff --git a/lib/src/config/mod.rs b/lib/src/config/mod.rs index 67244900..e03798a2 100644 --- a/lib/src/config/mod.rs +++ b/lib/src/config/mod.rs @@ -179,6 +179,7 @@ mod environment; mod config; mod builder; mod toml_ext; +mod custom_values; use std::sync::{Once, ONCE_INIT}; use std::fs::{self, File}; @@ -190,6 +191,7 @@ use std::env; use toml; +pub use self::custom_values::Limits; pub use toml::{Array, Table, Value}; pub use self::error::{ConfigError, ParsingError}; pub use self::environment::Environment; diff --git a/lib/src/request/form/form_items.rs b/lib/src/request/form/form_items.rs index ab26d0e4..8a0744af 100644 --- a/lib/src/request/form/form_items.rs +++ b/lib/src/request/form/form_items.rs @@ -277,6 +277,7 @@ mod test { &[("user", ""), ("password", "pass")]); check_form!("a=b", &[("a", "b")]); + check_form!("value=Hello+World", &[("value", "Hello+World")]); check_form!("user=", &[("user", "")]); check_form!("user=&", &[("user", "")]); diff --git a/lib/src/request/form/mod.rs b/lib/src/request/form/mod.rs index 16c43a64..0fd9da52 100644 --- a/lib/src/request/form/mod.rs +++ b/lib/src/request/form/mod.rs @@ -28,6 +28,7 @@ use std::marker::PhantomData; use std::fmt::{self, Debug}; use std::io::Read; +use config; use http::Status; use request::Request; use data::{self, Data, FromData}; @@ -255,6 +256,9 @@ impl<'f, T: FromForm<'f> + Debug + 'f> Debug for Form<'f, T> { } } +/// Default limit for forms is 32KiB. +const LIMIT: u64 = 32 * (1 << 10); + /// Parses a `Form` from incoming form data. /// /// If the content type of the request data is not @@ -279,7 +283,8 @@ impl<'f, T: FromForm<'f>> FromData for Form<'f, T> where T::Error: Debug { } let mut form_string = String::with_capacity(4096); - let mut stream = data.open().take(32768); // TODO: Make this configurable? + let limit = config::active().map(|c| c.limits.forms).unwrap_or(LIMIT); + let mut stream = data.open().take(limit); if let Err(e) = stream.read_to_string(&mut form_string) { error_!("IO Error: {:?}", e); Failure((Status::InternalServerError, None)) diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index 5874e638..fefc5e79 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -382,6 +382,7 @@ impl Rocket { info_!("log: {}", White.paint(config.log_level)); info_!("workers: {}", White.paint(config.workers)); info_!("session key: {}", White.paint(config.session_key.kind())); + info_!("limits: {}", White.paint(&config.limits)); let tls_configured = config.tls.is_some(); if tls_configured && cfg!(feature = "tls") { diff --git a/lib/tests/limits.rs b/lib/tests/limits.rs new file mode 100644 index 00000000..9070c42f --- /dev/null +++ b/lib/tests/limits.rs @@ -0,0 +1,83 @@ +#![feature(plugin, custom_derive)] +#![plugin(rocket_codegen)] + +extern crate rocket; + +use rocket::request::Form; + +#[derive(FromForm)] +struct Simple { + value: String +} + +#[post("/", data = "
")] +fn index(form: Form) -> String { + form.into_inner().value +} + +#[cfg(feature = "testing")] +mod tests { + use rocket; + use rocket::config::{Environment, Config, Limits}; + use rocket::testing::MockRequest; + use rocket::http::Method::*; + use rocket::http::{Status, ContentType}; + + fn rocket_with_forms_limit(limit: u64) -> rocket::Rocket { + let config = Config::build(Environment::Development) + .limits(Limits::default().add("forms", limit)) + .unwrap(); + + rocket::custom(config, true).mount("/", routes![super::index]) + } + + // FIXME: Config is global (it's the only global thing). Each of these tests + // will run in different threads in the same process, so the config used by + // all of the tests will be indentical: whichever of these gets executed + // first. As such, only one test will pass; the rest will fail. Make config + // _not_ global so we can actually do these tests. + + // #[test] + // fn large_enough() { + // let rocket = rocket_with_forms_limit(128); + // let mut req = MockRequest::new(Post, "/") + // .body("value=Hello+world") + // .header(ContentType::Form); + + // let mut response = req.dispatch_with(&rocket); + // assert_eq!(response.body_string(), Some("Hello world".into())); + // } + + // #[test] + // fn just_large_enough() { + // let rocket = rocket_with_forms_limit(17); + // let mut req = MockRequest::new(Post, "/") + // .body("value=Hello+world") + // .header(ContentType::Form); + + // let mut response = req.dispatch_with(&rocket); + // assert_eq!(response.body_string(), Some("Hello world".into())); + // } + + // #[test] + // fn much_too_small() { + // let rocket = rocket_with_forms_limit(4); + // let mut req = MockRequest::new(Post, "/") + // .body("value=Hello+world") + // .header(ContentType::Form); + + // let response = req.dispatch_with(&rocket); + // assert_eq!(response.status(), Status::BadRequest); + // } + + #[test] + fn contracted() { + let rocket = rocket_with_forms_limit(10); + let mut req = MockRequest::new(Post, "/") + .body("value=Hello+world") + .header(ContentType::Form); + + let mut response = req.dispatch_with(&rocket); + assert_eq!(response.body_string(), Some("Hell".into())); + } +}