Add support for configurable size limits.

This commit is contained in:
Sergio Benitez 2017-04-18 00:25:13 -07:00
parent e6bbeacb1c
commit 6dc21e5380
11 changed files with 323 additions and 97 deletions

View File

@ -1,6 +1,7 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::io::Read; use std::io::Read;
use rocket::config;
use rocket::outcome::Outcome; use rocket::outcome::Outcome;
use rocket::request::Request; use rocket::request::Request;
use rocket::data::{self, Data, FromData}; use rocket::data::{self, Data, FromData};
@ -65,9 +66,8 @@ impl<T> JSON<T> {
} }
} }
/// Maximum size of JSON is 1MB. /// Default limit for JSON is 1MB.
/// TODO: Determine this size from some configuration parameter. const LIMIT: u64 = 1 << 20;
const MAX_SIZE: u64 = 1048576;
impl<T: Deserialize> FromData for JSON<T> { impl<T: Deserialize> FromData for JSON<T> {
type Error = SerdeError; type Error = SerdeError;
@ -78,7 +78,11 @@ impl<T: Deserialize> FromData for JSON<T> {
return Outcome::Forward(data); return Outcome::Forward(data);
} }
let reader = data.open().take(MAX_SIZE); let size_limit = config::active()
.and_then(|c| c.limits.get("json"))
.unwrap_or(LIMIT);
let reader = data.open().take(size_limit);
match serde_json::from_reader(reader).map(|val| JSON(val)) { match serde_json::from_reader(reader).map(|val| JSON(val)) {
Ok(value) => Outcome::Success(value), Ok(value) => Outcome::Success(value),
Err(e) => { Err(e) => {

View File

@ -3,6 +3,7 @@ extern crate rmp_serde;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
use rocket::config;
use rocket::outcome::Outcome; use rocket::outcome::Outcome;
use rocket::request::Request; use rocket::request::Request;
use rocket::data::{self, Data, FromData}; use rocket::data::{self, Data, FromData};
@ -70,9 +71,8 @@ impl<T> MsgPack<T> {
} }
} }
/// Maximum size of MessagePack data is 1MB. /// Default limit for MessagePack is 1MB.
/// TODO: Determine this size from some configuration parameter. const LIMIT: u64 = 1 << 20;
const MAX_SIZE: u64 = 1048576;
/// Accepted content types are: `application/msgpack`, `application/x-msgpack`, /// Accepted content types are: `application/msgpack`, `application/x-msgpack`,
/// `bin/msgpack`, and `bin/x-msgpack`. /// `bin/msgpack`, and `bin/x-msgpack`.
@ -91,8 +91,12 @@ impl<T: Deserialize> FromData for MsgPack<T> {
return Outcome::Forward(data); return Outcome::Forward(data);
} }
let size_limit = config::active()
.and_then(|c| c.limits.get("msgpack"))
.unwrap_or(LIMIT);
let mut buf = Vec::new(); let mut buf = Vec::new();
if let Err(e) = data.open().take(MAX_SIZE).read_to_end(&mut buf) { if let Err(e) = data.open().take(size_limit).read_to_end(&mut buf) {
let e = MsgPackError::InvalidDataRead(e); let e = MsgPackError::InvalidDataRead(e);
error_!("Couldn't read request data: {:?}", e); error_!("Couldn't read request data: {:?}", e);
return Outcome::Failure((Status::BadRequest, e)); return Outcome::Failure((Status::BadRequest, e));

View File

@ -1,6 +1,11 @@
# Except for the session key, none of these are actually needed; Rocket has sane # Except for the session key, nothing here is necessary; Rocket has sane
# defaults. We show all of them here explicitly for demonstrative purposes. # defaults. We show all of them here explicitly for demonstrative purposes.
[global.limits]
forms = 32768
json = 1048576 # this is an extra used by the json contrib module
msgpack = 1048576 # this is an extra used by the msgpack contrib module
[development] [development]
address = "localhost" address = "localhost"
port = 8000 port = 8000

View File

@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use config::{Result, Config, Value, Environment}; use config::{Result, Config, Value, Environment, Limits};
use config::toml_ext::IntoValue; use config::toml_ext::IntoValue;
use logger::LoggingLevel; use logger::LoggingLevel;
@ -21,7 +21,9 @@ pub struct ConfigBuilder {
/// The session key. /// The session key.
pub session_key: Option<String>, pub session_key: Option<String>,
/// TLS configuration (path to certificates file, path to private key file). /// TLS configuration (path to certificates file, path to private key file).
pub tls_config: Option<(String, String)>, pub tls: Option<(String, String)>,
/// Size limits.
pub limits: Limits,
/// Any extra parameters that aren't part of Rocket's config. /// Any extra parameters that aren't part of Rocket's config.
pub extras: HashMap<String, Value>, pub extras: HashMap<String, Value>,
/// The root directory of this config. /// The root directory of this config.
@ -65,7 +67,8 @@ impl ConfigBuilder {
workers: config.workers, workers: config.workers,
log_level: config.log_level, log_level: config.log_level,
session_key: None, session_key: None,
tls_config: None, tls: None,
limits: config.limits,
extras: config.extras, extras: config.extras,
root: root_dir, root: root_dir,
} }
@ -165,6 +168,22 @@ impl ConfigBuilder {
self self
} }
/// Sets the `limits` in the configuration being built.
///
/// # Example
///
/// ```rust
/// use rocket::config::{Config, Environment, Limits};
///
/// let mut config = Config::build(Environment::Staging)
/// .limits(Limits::default().add("json", 5 * (1 << 20)))
/// .unwrap();
/// ```
pub fn limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
/// Sets the `tls_config` in the configuration being built. /// Sets the `tls_config` in the configuration being built.
/// ///
/// # Example /// # Example
@ -181,7 +200,7 @@ impl ConfigBuilder {
pub fn tls<C, K>(mut self, certs_path: C, key_path: K) -> Self pub fn tls<C, K>(mut self, certs_path: C, key_path: K) -> Self
where C: Into<String>, K: Into<String> where C: Into<String>, K: Into<String>
{ {
self.tls_config = Some((certs_path.into(), key_path.into())); self.tls = Some((certs_path.into(), key_path.into()));
self self
} }
@ -282,8 +301,9 @@ impl ConfigBuilder {
config.set_log_level(self.log_level); config.set_log_level(self.log_level);
config.set_extras(self.extras); config.set_extras(self.extras);
config.set_root(self.root); config.set_root(self.root);
config.set_limits(self.limits);
if let Some((certs_path, key_path)) = self.tls_config { if let Some((certs_path, key_path)) = self.tls {
config.set_tls(&certs_path, &key_path)?; config.set_tls(&certs_path, &key_path)?;
} }

View File

@ -5,45 +5,13 @@ use std::convert::AsRef;
use std::fmt; use std::fmt;
use std::env; use std::env;
#[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey}; use super::custom_values::*;
use {num_cpus, base64}; use {num_cpus, base64};
use config::Environment::*; use config::Environment::*;
use config::{Result, Table, Value, ConfigBuilder, Environment, ConfigError}; use config::{Result, Table, Value, ConfigBuilder, Environment, ConfigError};
use logger::LoggingLevel; use logger::LoggingLevel;
use http::Key; use http::Key;
pub enum SessionKey {
Generated(Key),
Provided(Key)
}
impl SessionKey {
#[inline(always)]
pub fn kind(&self) -> &'static str {
match *self {
SessionKey::Generated(_) => "generated",
SessionKey::Provided(_) => "provided",
}
}
#[inline(always)]
fn inner(&self) -> &Key {
match *self {
SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key
}
}
}
#[cfg(feature = "tls")]
pub struct TlsConfig {
pub certs: Vec<Certificate>,
pub key: PrivateKey
}
#[cfg(not(feature = "tls"))]
pub struct TlsConfig;
/// Structure for Rocket application configuration. /// Structure for Rocket application configuration.
/// ///
/// A `Config` structure is typically built using the [build](#method.build) /// A `Config` structure is typically built using the [build](#method.build)
@ -75,6 +43,8 @@ pub struct Config {
pub(crate) session_key: SessionKey, pub(crate) session_key: SessionKey,
/// TLS configuration. /// TLS configuration.
pub(crate) tls: Option<TlsConfig>, pub(crate) tls: Option<TlsConfig>,
/// Streaming read size limits.
pub limits: Limits,
/// Extra parameters that aren't part of Rocket's core config. /// Extra parameters that aren't part of Rocket's core config.
pub extras: HashMap<String, Value>, pub extras: HashMap<String, Value>,
/// The path to the configuration file this config belongs to. /// The path to the configuration file this config belongs to.
@ -94,52 +64,6 @@ macro_rules! config_from_raw {
) )
} }
#[inline(always)]
fn value_as_str<'a>(config: &Config, name: &str, value: &'a Value) -> Result<&'a str> {
value.as_str().ok_or(config.bad_type(name, value.type_str(), "a string"))
}
#[inline(always)]
fn value_as_u16(config: &Config, name: &str, value: &Value) -> Result<u16> {
match value.as_integer() {
Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16),
_ => Err(config.bad_type(name, value.type_str(), "a 16-bit unsigned integer"))
}
}
#[inline(always)]
fn value_as_log_level(config: &Config, name: &str, value: &Value) -> Result<LoggingLevel> {
value_as_str(config, name, value)
.and_then(|s| s.parse().map_err(|e| config.bad_type(name, value.type_str(), e)))
}
#[inline(always)]
fn value_as_tls_config<'v>(config: &Config,
name: &str,
value: &'v Value,
) -> Result<(&'v str, &'v str)>
{
let (mut certs_path, mut key_path) = (None, None);
let table = value.as_table()
.ok_or_else(|| config.bad_type(name, value.type_str(), "a table"))?;
let env = config.environment;
for (key, value) in table {
match key.as_str() {
"certs" => certs_path = Some(value_as_str(config, "tls.certs", value)?),
"key" => key_path = Some(value_as_str(config, "tls.key", value)?),
_ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key)))
}
}
if let (Some(certs), Some(key)) = (certs_path, key_path) {
Ok((certs, key))
} else {
Err(config.bad_type(name, "a table with missing entries",
"a table with `certs` and `key` entries"))
}
}
impl Config { impl Config {
/// Returns a builder for `Config` structure where the default parameters /// Returns a builder for `Config` structure where the default parameters
/// are set to those of `env`. The root configuration directory is set to /// are set to those of `env`. The root configuration directory is set to
@ -219,6 +143,7 @@ impl Config {
log_level: LoggingLevel::Normal, log_level: LoggingLevel::Normal,
session_key: key, session_key: key,
tls: None, tls: None,
limits: Limits::default(),
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -232,6 +157,7 @@ impl Config {
log_level: LoggingLevel::Normal, log_level: LoggingLevel::Normal,
session_key: key, session_key: key,
tls: None, tls: None,
limits: Limits::default(),
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -245,6 +171,7 @@ impl Config {
log_level: LoggingLevel::Critical, log_level: LoggingLevel::Critical,
session_key: key, session_key: key,
tls: None, tls: None,
limits: Limits::default(),
extras: HashMap::new(), extras: HashMap::new(),
config_path: config_path, config_path: config_path,
} }
@ -255,8 +182,10 @@ impl Config {
/// Constructs a `BadType` error given the entry `name`, the invalid `val` /// Constructs a `BadType` error given the entry `name`, the invalid `val`
/// at that entry, and the `expect`ed type name. /// at that entry, and the `expect`ed type name.
#[inline(always)] #[inline(always)]
fn bad_type(&self, name: &str, actual: &'static str, expect: &'static str) pub(crate) fn bad_type(&self,
-> ConfigError { name: &str,
actual: &'static str,
expect: &'static str) -> ConfigError {
let id = format!("{}.{}", self.environment, name); let id = format!("{}.{}", self.environment, name);
ConfigError::BadType(id, expect, actual, self.config_path.clone()) ConfigError::BadType(id, expect, actual, self.config_path.clone())
} }
@ -284,7 +213,8 @@ impl Config {
workers => (u16, set_workers, ok), workers => (u16, set_workers, ok),
session_key => (str, set_session_key, id), session_key => (str, set_session_key, id),
log => (log_level, set_log_level, ok), log => (log_level, set_log_level, ok),
tls => (tls_config, set_raw_tls, id) tls => (tls_config, set_raw_tls, id),
limits => (limits, set_limits, ok)
| _ => { | _ => {
self.extras.insert(name.into(), val.clone()); self.extras.insert(name.into(), val.clone());
Ok(()) Ok(())
@ -442,6 +372,12 @@ impl Config {
self.log_level = log_level; self.log_level = log_level;
} }
/// Sets limits.
#[inline]
pub fn set_limits(&mut self, limits: Limits) {
self.limits = limits;
}
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> { pub fn set_tls(&mut self, certs_path: &str, key_path: &str) -> Result<()> {
use hyper_rustls::util as tls; use hyper_rustls::util as tls;

View File

@ -0,0 +1,165 @@
use std::fmt;
#[cfg(feature = "tls")] use rustls::{Certificate, PrivateKey};
use logger::LoggingLevel;
use config::{Result, Config, Value, ConfigError};
use http::Key;
pub enum SessionKey {
Generated(Key),
Provided(Key)
}
impl SessionKey {
#[inline(always)]
pub fn kind(&self) -> &'static str {
match *self {
SessionKey::Generated(_) => "generated",
SessionKey::Provided(_) => "provided",
}
}
#[inline(always)]
pub(crate) fn inner(&self) -> &Key {
match *self {
SessionKey::Generated(ref key) | SessionKey::Provided(ref key) => key
}
}
}
#[cfg(feature = "tls")]
pub struct TlsConfig {
pub certs: Vec<Certificate>,
pub key: PrivateKey
}
#[cfg(not(feature = "tls"))]
pub struct TlsConfig;
// Size limit configuration. We cache those used by Rocket internally but don't
// share that fact in the API.
#[derive(Debug, Clone)]
pub struct Limits {
pub(crate) forms: u64,
extra: Vec<(String, u64)>
}
impl Default for Limits {
fn default() -> Limits {
Limits { forms: 1024 * 32, extra: Vec::new() }
}
}
impl Limits {
pub fn add<S: Into<String>>(mut self, name: S, limit: u64) -> Self {
let name = name.into();
match name.as_str() {
"forms" => self.forms = limit,
_ => self.extra.push((name, limit))
}
self
}
pub fn get(&self, name: &str) -> Option<u64> {
if name == "forms" {
return Some(self.forms);
}
for &(ref key, val) in &self.extra {
if key == name {
return Some(val);
}
}
None
}
}
impl fmt::Display for Limits {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fn fmt_size(n: u64, f: &mut fmt::Formatter) -> fmt::Result {
if (n & ((1 << 20) - 1)) == 0 {
write!(f, "{}MiB", n >> 20)
} else if (n & ((1 << 10) - 1)) == 0 {
write!(f, "{}KiB", n >> 10)
} else {
write!(f, "{}B", n)
}
}
write!(f, "forms = ")?;
fmt_size(self.forms, f)?;
for &(ref key, val) in &self.extra {
write!(f, ", {}* = ", key)?;
fmt_size(val, f)?;
}
Ok(())
}
}
pub fn value_as_str<'a>(conf: &Config, name: &str, v: &'a Value) -> Result<&'a str> {
v.as_str().ok_or(conf.bad_type(name, v.type_str(), "a string"))
}
pub fn value_as_u64(conf: &Config, name: &str, value: &Value) -> Result<u64> {
match value.as_integer() {
Some(x) if x >= 0 => Ok(x as u64),
_ => Err(conf.bad_type(name, value.type_str(), "an unsigned integer"))
}
}
pub fn value_as_u16(conf: &Config, name: &str, value: &Value) -> Result<u16> {
match value.as_integer() {
Some(x) if x >= 0 && x <= (u16::max_value() as i64) => Ok(x as u16),
_ => Err(conf.bad_type(name, value.type_str(), "a 16-bit unsigned integer"))
}
}
pub fn value_as_log_level(conf: &Config,
name: &str,
value: &Value
) -> Result<LoggingLevel> {
value_as_str(conf, name, value)
.and_then(|s| s.parse().map_err(|e| conf.bad_type(name, value.type_str(), e)))
}
pub fn value_as_tls_config<'v>(conf: &Config,
name: &str,
value: &'v Value,
) -> Result<(&'v str, &'v str)> {
let (mut certs_path, mut key_path) = (None, None);
let table = value.as_table()
.ok_or_else(|| conf.bad_type(name, value.type_str(), "a table"))?;
let env = conf.environment;
for (key, value) in table {
match key.as_str() {
"certs" => certs_path = Some(value_as_str(conf, "tls.certs", value)?),
"key" => key_path = Some(value_as_str(conf, "tls.key", value)?),
_ => return Err(ConfigError::UnknownKey(format!("{}.tls.{}", env, key)))
}
}
if let (Some(certs), Some(key)) = (certs_path, key_path) {
Ok((certs, key))
} else {
Err(conf.bad_type(name, "a table with missing entries",
"a table with `certs` and `key` entries"))
}
}
pub fn value_as_limits(conf: &Config, name: &str, value: &Value) -> Result<Limits> {
let table = value.as_table()
.ok_or_else(|| conf.bad_type(name, value.type_str(), "a table"))?;
let mut limits = Limits::default();
for (key, val) in table {
let val = value_as_u64(conf, &format!("limits.{}", key), val)?;
limits = limits.add(key.as_str(), val);
}
Ok(limits)
}

View File

@ -179,6 +179,7 @@ mod environment;
mod config; mod config;
mod builder; mod builder;
mod toml_ext; mod toml_ext;
mod custom_values;
use std::sync::{Once, ONCE_INIT}; use std::sync::{Once, ONCE_INIT};
use std::fs::{self, File}; use std::fs::{self, File};
@ -190,6 +191,7 @@ use std::env;
use toml; use toml;
pub use self::custom_values::Limits;
pub use toml::{Array, Table, Value}; pub use toml::{Array, Table, Value};
pub use self::error::{ConfigError, ParsingError}; pub use self::error::{ConfigError, ParsingError};
pub use self::environment::Environment; pub use self::environment::Environment;

View File

@ -277,6 +277,7 @@ mod test {
&[("user", ""), ("password", "pass")]); &[("user", ""), ("password", "pass")]);
check_form!("a=b", &[("a", "b")]); check_form!("a=b", &[("a", "b")]);
check_form!("value=Hello+World", &[("value", "Hello+World")]);
check_form!("user=", &[("user", "")]); check_form!("user=", &[("user", "")]);
check_form!("user=&", &[("user", "")]); check_form!("user=&", &[("user", "")]);

View File

@ -28,6 +28,7 @@ use std::marker::PhantomData;
use std::fmt::{self, Debug}; use std::fmt::{self, Debug};
use std::io::Read; use std::io::Read;
use config;
use http::Status; use http::Status;
use request::Request; use request::Request;
use data::{self, Data, FromData}; use data::{self, Data, FromData};
@ -255,6 +256,9 @@ impl<'f, T: FromForm<'f> + Debug + 'f> Debug for Form<'f, T> {
} }
} }
/// Default limit for forms is 32KiB.
const LIMIT: u64 = 32 * (1 << 10);
/// Parses a `Form` from incoming form data. /// Parses a `Form` from incoming form data.
/// ///
/// If the content type of the request data is not /// If the content type of the request data is not
@ -279,7 +283,8 @@ impl<'f, T: FromForm<'f>> FromData for Form<'f, T> where T::Error: Debug {
} }
let mut form_string = String::with_capacity(4096); let mut form_string = String::with_capacity(4096);
let mut stream = data.open().take(32768); // TODO: Make this configurable? let limit = config::active().map(|c| c.limits.forms).unwrap_or(LIMIT);
let mut stream = data.open().take(limit);
if let Err(e) = stream.read_to_string(&mut form_string) { if let Err(e) = stream.read_to_string(&mut form_string) {
error_!("IO Error: {:?}", e); error_!("IO Error: {:?}", e);
Failure((Status::InternalServerError, None)) Failure((Status::InternalServerError, None))

View File

@ -382,6 +382,7 @@ impl Rocket {
info_!("log: {}", White.paint(config.log_level)); info_!("log: {}", White.paint(config.log_level));
info_!("workers: {}", White.paint(config.workers)); info_!("workers: {}", White.paint(config.workers));
info_!("session key: {}", White.paint(config.session_key.kind())); info_!("session key: {}", White.paint(config.session_key.kind()));
info_!("limits: {}", White.paint(&config.limits));
let tls_configured = config.tls.is_some(); let tls_configured = config.tls.is_some();
if tls_configured && cfg!(feature = "tls") { if tls_configured && cfg!(feature = "tls") {

83
lib/tests/limits.rs Normal file
View File

@ -0,0 +1,83 @@
#![feature(plugin, custom_derive)]
#![plugin(rocket_codegen)]
extern crate rocket;
use rocket::request::Form;
#[derive(FromForm)]
struct Simple {
value: String
}
#[post("/", data = "<form>")]
fn index(form: Form<Simple>) -> String {
form.into_inner().value
}
#[cfg(feature = "testing")]
mod tests {
use rocket;
use rocket::config::{Environment, Config, Limits};
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::{Status, ContentType};
fn rocket_with_forms_limit(limit: u64) -> rocket::Rocket {
let config = Config::build(Environment::Development)
.limits(Limits::default().add("forms", limit))
.unwrap();
rocket::custom(config, true).mount("/", routes![super::index])
}
// FIXME: Config is global (it's the only global thing). Each of these tests
// will run in different threads in the same process, so the config used by
// all of the tests will be indentical: whichever of these gets executed
// first. As such, only one test will pass; the rest will fail. Make config
// _not_ global so we can actually do these tests.
// #[test]
// fn large_enough() {
// let rocket = rocket_with_forms_limit(128);
// let mut req = MockRequest::new(Post, "/")
// .body("value=Hello+world")
// .header(ContentType::Form);
// let mut response = req.dispatch_with(&rocket);
// assert_eq!(response.body_string(), Some("Hello world".into()));
// }
// #[test]
// fn just_large_enough() {
// let rocket = rocket_with_forms_limit(17);
// let mut req = MockRequest::new(Post, "/")
// .body("value=Hello+world")
// .header(ContentType::Form);
// let mut response = req.dispatch_with(&rocket);
// assert_eq!(response.body_string(), Some("Hello world".into()));
// }
// #[test]
// fn much_too_small() {
// let rocket = rocket_with_forms_limit(4);
// let mut req = MockRequest::new(Post, "/")
// .body("value=Hello+world")
// .header(ContentType::Form);
// let response = req.dispatch_with(&rocket);
// assert_eq!(response.status(), Status::BadRequest);
// }
#[test]
fn contracted() {
let rocket = rocket_with_forms_limit(10);
let mut req = MockRequest::new(Post, "/")
.body("value=Hello+world")
.header(ContentType::Form);
let mut response = req.dispatch_with(&rocket);
assert_eq!(response.body_string(), Some("Hell".into()));
}
}