From 549c9241c41320fc5af76b53c2ffc3bd8db88f8c Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 30 Jul 2020 04:17:38 -0700 Subject: [PATCH] Require data limits on 'Data::open()'. Closes #1325. --- contrib/lib/src/json.rs | 20 +- contrib/lib/src/msgpack.rs | 12 +- core/codegen/tests/route-data.rs | 12 +- core/codegen/tests/route.rs | 7 +- core/lib/Cargo.toml | 1 + core/lib/src/config/builder.rs | 8 +- core/lib/src/config/config.rs | 29 +-- core/lib/src/config/custom_values.rs | 151 +---------- core/lib/src/config/mod.rs | 1 - core/lib/src/data/data.rs | 279 ++++++++------------- core/lib/src/data/data_stream.rs | 134 ++++++++-- core/lib/src/data/from_data.rs | 50 ++-- core/lib/src/data/limits.rs | 138 ++++++++++ core/lib/src/data/mod.rs | 6 +- core/lib/src/ext.rs | 68 +++-- core/lib/src/lib.rs | 2 +- core/lib/src/local/asynchronous/request.rs | 4 +- core/lib/src/request/form/form.rs | 17 +- core/lib/src/request/request.rs | 8 +- core/lib/src/response/debug.rs | 7 +- core/lib/src/rocket.rs | 13 +- core/lib/tests/limits.rs | 5 +- examples/content_types/src/main.rs | 12 +- examples/manual_routes/src/main.rs | 5 +- examples/pastebin/src/main.rs | 4 +- examples/raw_upload/src/main.rs | 6 +- scripts/test.sh | 4 +- site/guide/10-pastebin.md | 18 +- site/guide/4-requests.md | 33 +-- 29 files changed, 525 insertions(+), 529 deletions(-) create mode 100644 core/lib/src/data/limits.rs diff --git a/contrib/lib/src/json.rs b/contrib/lib/src/json.rs index 32640e7f..2ac5867c 100644 --- a/contrib/lib/src/json.rs +++ b/contrib/lib/src/json.rs @@ -18,13 +18,12 @@ use std::ops::{Deref, DerefMut}; use std::io; use std::iter::FromIterator; -use tokio::io::AsyncReadExt; - use rocket::request::Request; use rocket::outcome::Outcome::*; -use rocket::data::{Transform::*, Transformed, Data, FromTransformedData, TransformFuture, FromDataFuture}; -use rocket::response::{self, Responder, content}; +use rocket::data::{Data, ByteUnit, Transform::*, Transformed}; +use rocket::data::{FromTransformedData, TransformFuture, FromDataFuture}; use rocket::http::Status; +use rocket::response::{self, Responder, content}; use serde::{Serialize, Serializer}; use serde::de::{Deserialize, Deserializer}; @@ -111,9 +110,6 @@ impl Json { } } -/// Default limit for JSON is 1MB. -const LIMIT: u64 = 1 << 20; - /// An error returned by the [`Json`] data guard when incoming data fails to /// serialize as JSON. #[derive(Debug)] @@ -128,6 +124,8 @@ pub enum JsonError<'a> { Parse(&'a str, serde_json::error::Error), } +const DEFAULT_LIMIT: ByteUnit = ByteUnit::Mebibyte(1); + impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for Json { type Error = JsonError<'a>; type Owned = String; @@ -135,11 +133,9 @@ impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for Json { fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { Box::pin(async move { - let size_limit = r.limits().get("json").unwrap_or(LIMIT); - let mut s = String::with_capacity(512); - let mut reader = d.open().take(size_limit); - match reader.read_to_string(&mut s).await { - Ok(_) => Borrowed(Success(s)), + let size_limit = r.limits().get("json").unwrap_or(DEFAULT_LIMIT); + match d.open(size_limit).stream_to_string().await { + Ok(s) => Borrowed(Success(s)), Err(e) => Borrowed(Failure((Status::BadRequest, JsonError::Io(e)))) } }) diff --git a/contrib/lib/src/msgpack.rs b/contrib/lib/src/msgpack.rs index f2434e55..0f476d4d 100644 --- a/contrib/lib/src/msgpack.rs +++ b/contrib/lib/src/msgpack.rs @@ -20,9 +20,10 @@ use tokio::io::AsyncReadExt; use rocket::request::Request; use rocket::outcome::Outcome::*; -use rocket::data::{Data, FromTransformedData, FromDataFuture, Transform::*, TransformFuture, Transformed}; -use rocket::http::Status; +use rocket::data::{Data, ByteUnit, Transform::*, TransformFuture, Transformed}; +use rocket::data::{FromTransformedData, FromDataFuture}; use rocket::response::{self, content, Responder}; +use rocket::http::Status; use serde::Serialize; use serde::de::Deserialize; @@ -110,8 +111,7 @@ impl MsgPack { } } -/// Default limit for MessagePack is 1MB. -const LIMIT: u64 = 1 << 20; +const DEFAULT_LIMIT: ByteUnit = ByteUnit::Mebibyte(1); impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for MsgPack { type Error = Error; @@ -120,9 +120,9 @@ impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for MsgPack { fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { Box::pin(async move { - let size_limit = r.limits().get("msgpack").unwrap_or(LIMIT); + let size_limit = r.limits().get("msgpack").unwrap_or(DEFAULT_LIMIT); let mut buf = Vec::new(); - let mut reader = d.open().take(size_limit); + let mut reader = d.open(size_limit); match reader.read_to_end(&mut buf).await { Ok(_) => Borrowed(Success(buf)), Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))), diff --git a/core/codegen/tests/route-data.rs b/core/codegen/tests/route-data.rs index c8503343..f036d3c5 100644 --- a/core/codegen/tests/route-data.rs +++ b/core/codegen/tests/route-data.rs @@ -3,9 +3,8 @@ use rocket::{Request, Data}; use rocket::local::blocking::Client; use rocket::request::Form; -use rocket::data::{self, FromData}; +use rocket::data::{self, FromData, ToByteUnit}; use rocket::http::{RawStr, ContentType, Status}; -use rocket::tokio::io::AsyncReadExt; // Test that the data parameters works as expected. @@ -21,13 +20,10 @@ impl FromData for Simple { type Error = (); async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome { - let mut string = String::new(); - let mut stream = data.open().take(64); - if let Err(_) = stream.read_to_string(&mut string).await { - return data::Outcome::Failure((Status::InternalServerError, ())); + match data.open(64.bytes()).stream_to_string().await { + Ok(string) => data::Outcome::Success(Simple(string)), + Err(_) => data::Outcome::Failure((Status::InternalServerError, ())), } - - data::Outcome::Success(Simple(string)) } } diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index cfedce77..0afe1682 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -9,10 +9,9 @@ use std::path::PathBuf; use rocket::http::ext::Normalize; use rocket::local::blocking::Client; -use rocket::data::{self, Data, FromData}; +use rocket::data::{self, Data, FromData, ToByteUnit}; use rocket::request::{Request, Form}; use rocket::http::{Status, RawStr, ContentType}; -use rocket::tokio::io::AsyncReadExt; // Use all of the code generation available at once. @@ -28,9 +27,7 @@ impl FromData for Simple { type Error = (); async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome { - let mut string = String::new(); - let mut stream = data.open().take(64); - stream.read_to_string(&mut string).await.unwrap(); + let string = data.open(64.bytes()).stream_to_string().await.unwrap(); data::Outcome::Success(Simple(string)) } } diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 8e95c0a3..c467e642 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -39,6 +39,7 @@ atty = "0.2" async-trait = "0.1" ref-cast = "1.0" atomic = "0.4" +ubyte = "0.9.1" [dependencies.pear] git = "https://github.com/SergioBenitez/Pear.git" diff --git a/core/lib/src/config/builder.rs b/core/lib/src/config/builder.rs index 6d60fa22..c29d849d 100644 --- a/core/lib/src/config/builder.rs +++ b/core/lib/src/config/builder.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; -use crate::config::{Result, Config, Value, Environment, Limits, LoggingLevel}; +use crate::config::{Result, Config, Value, Environment, LoggingLevel}; +use crate::data::Limits; /// Structure following the builder pattern for building `Config` structures. #[derive(Clone)] @@ -189,10 +190,11 @@ impl ConfigBuilder { /// # Example /// /// ```rust - /// use rocket::config::{Config, Environment, Limits}; + /// use rocket::config::{Config, Environment}; + /// use rocket::data::{Limits, ToByteUnit}; /// /// let mut config = Config::build(Environment::Staging) - /// .limits(Limits::new().limit("json", 5 * (1 << 20))) + /// .limits(Limits::new().limit("json", 5.mebibytes())) /// .unwrap(); /// ``` pub fn limits(mut self, limits: Limits) -> Self { diff --git a/core/lib/src/config/config.rs b/core/lib/src/config/config.rs index 8efed903..01e4350a 100644 --- a/core/lib/src/config/config.rs +++ b/core/lib/src/config/config.rs @@ -7,6 +7,7 @@ use std::fmt; use crate::config::Environment::*; use crate::config::{Result, ConfigBuilder, Environment, ConfigError, LoggingLevel}; use crate::config::{FullConfig, Table, Value, Array, Datetime}; +use crate::data::Limits; use crate::http::private::Key; use super::custom_values::*; @@ -51,7 +52,7 @@ pub struct Config { pub(crate) secret_key: SecretKey, /// TLS configuration. pub(crate) tls: Option, - /// Streaming read size limits. + /// Streaming data limits. pub limits: Limits, /// Extra parameters that aren't part of Rocket's core config. pub extras: HashMap, @@ -96,10 +97,8 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::Config; - /// /// # if false { - /// let config = Config::read().unwrap(); + /// let config = rocket::Config::read().unwrap(); /// # } /// ``` pub fn read() -> Result { @@ -115,10 +114,8 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::Config; - /// /// # if false { - /// let config = Config::read_from("/var/my-config.toml").unwrap(); + /// let config = rocket::Config::read_from("/var/my-config.toml").unwrap(); /// # } /// ``` pub fn read_from>(path: P) -> Result { @@ -185,9 +182,7 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::Config; - /// - /// let mut my_config = Config::active().unwrap(); + /// let mut my_config = rocket::Config::active().unwrap(); /// my_config.set_port(1001); /// ``` pub fn active() -> Result { @@ -453,9 +448,7 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::Config; - /// - /// let mut config = Config::development(); + /// let mut config = rocket::Config::development(); /// /// // Set keep-alive timeout to 10 seconds. /// config.set_keep_alive(10); @@ -537,10 +530,10 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::{Config, Limits}; + /// use rocket::data::{Limits, ToByteUnit}; /// - /// let mut config = Config::development(); - /// config.set_limits(Limits::default().limit("json", 4 * (1 << 20))); + /// let mut config = rocket::Config::development(); + /// config.set_limits(Limits::default().limit("json", 4.mebibytes())); /// ``` #[inline] pub fn set_limits(&mut self, limits: Limits) { @@ -563,11 +556,9 @@ impl Config { /// # Example /// /// ```rust - /// use rocket::config::Config; - /// /// # use rocket::config::ConfigError; /// # fn config_test() -> Result<(), ConfigError> { - /// let mut config = Config::development(); + /// let mut config = rocket::Config::development(); /// config.set_tls("/etc/ssl/my_certs.pem", "/etc/ssl/priv.key")?; /// # Ok(()) /// # } diff --git a/core/lib/src/config/custom_values.rs b/core/lib/src/config/custom_values.rs index e4772813..641f2cf6 100644 --- a/core/lib/src/config/custom_values.rs +++ b/core/lib/src/config/custom_values.rs @@ -4,6 +4,7 @@ use std::fmt; use crate::http::private::Key; use crate::config::{Result, Config, Value, ConfigError, LoggingLevel}; +use crate::data::Limits; #[derive(Clone)] pub enum SecretKey { @@ -53,154 +54,6 @@ pub struct TlsConfig { #[derive(Clone)] pub struct TlsConfig; -/// Mapping from data type to size limits. -/// -/// A `Limits` structure contains a mapping from a given data type ("forms", -/// "json", and so on) to the maximum size in bytes that should be accepted by a -/// Rocket application for that data type. For instance, if the limit for -/// "forms" is set to `256`, only 256 bytes from an incoming form request will -/// be read. -/// -/// # Defaults -/// -/// As documented in [`config`](crate::config), the default limits are as follows: -/// -/// * **forms**: 32KiB -/// -/// # Usage -/// -/// A `Limits` structure is created following the builder pattern: -/// -/// ```rust -/// use rocket::config::Limits; -/// -/// // Set a limit of 64KiB for forms and 3MiB for JSON. -/// let limits = Limits::new() -/// .limit("forms", 64 * 1024) -/// .limit("json", 3 * 1024 * 1024); -/// ``` -#[derive(Debug, Clone)] -pub struct Limits { - // We cache this internally but don't share that fact in the API. - pub(crate) forms: u64, - extra: Vec<(String, u64)> -} - -impl Default for Limits { - fn default() -> Limits { - // Default limit for forms is 32KiB. - Limits { forms: 32 * 1024, extra: Vec::new() } - } -} - -impl Limits { - /// Construct a new `Limits` structure with the default limits set. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::Limits; - /// - /// let limits = Limits::new(); - /// assert_eq!(limits.get("forms"), Some(32 * 1024)); - /// ``` - #[inline] - pub fn new() -> Self { - Limits::default() - } - - /// Adds or replaces a limit in `self`, consuming `self` and returning a new - /// `Limits` structure with the added or replaced limit. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::Limits; - /// - /// let limits = Limits::new() - /// .limit("json", 1 * 1024 * 1024); - /// - /// assert_eq!(limits.get("forms"), Some(32 * 1024)); - /// assert_eq!(limits.get("json"), Some(1 * 1024 * 1024)); - /// - /// let new_limits = limits.limit("json", 64 * 1024 * 1024); - /// assert_eq!(new_limits.get("json"), Some(64 * 1024 * 1024)); - /// ``` - pub fn limit>(mut self, name: S, limit: u64) -> Self { - let name = name.into(); - match name.as_str() { - "forms" => self.forms = limit, - _ => { - let mut found = false; - for tuple in &mut self.extra { - if tuple.0 == name { - tuple.1 = limit; - found = true; - break; - } - } - - if !found { - self.extra.push((name, limit)) - } - } - } - - self - } - - /// Retrieve the set limit, if any, for the data type with name `name`. - /// - /// # Example - /// - /// ```rust - /// use rocket::config::Limits; - /// - /// let limits = Limits::new() - /// .limit("json", 64 * 1024 * 1024); - /// - /// assert_eq!(limits.get("forms"), Some(32 * 1024)); - /// assert_eq!(limits.get("json"), Some(64 * 1024 * 1024)); - /// assert!(limits.get("msgpack").is_none()); - /// ``` - pub fn get(&self, name: &str) -> Option { - if name == "forms" { - return Some(self.forms); - } - - for &(ref key, val) in &self.extra { - if key == name { - return Some(val); - } - } - - None - } -} - -impl fmt::Display for Limits { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fn fmt_size(n: u64, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if (n & ((1 << 20) - 1)) == 0 { - write!(f, "{}MiB", n >> 20) - } else if (n & ((1 << 10) - 1)) == 0 { - write!(f, "{}KiB", n >> 10) - } else { - write!(f, "{}B", n) - } - } - - write!(f, "forms = ")?; - fmt_size(self.forms, f)?; - for &(ref key, val) in &self.extra { - write!(f, ", {}* = ", key)?; - fmt_size(val, f)?; - } - - Ok(()) - } -} - pub fn str<'a>(conf: &Config, name: &str, v: &'a Value) -> Result<&'a str> { v.as_str().ok_or_else(|| conf.bad_type(name, v.type_str(), "a string")) } @@ -266,7 +119,7 @@ pub fn limits(conf: &Config, name: &str, value: &Value) -> Result { let mut limits = Limits::default(); for (key, val) in table { let val = u64(conf, &format!("limits.{}", key), val)?; - limits = limits.limit(key.as_str(), val); + limits = limits.limit(key.as_str(), val.into()); } Ok(limits) diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index b85b6be4..bb6268de 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -198,7 +198,6 @@ use std::path::{Path, PathBuf}; use toml; -pub use self::custom_values::Limits; pub use toml::value::{Array, Map, Table, Value, Datetime}; pub use self::error::ConfigError; pub use self::environment::Environment; diff --git a/core/lib/src/data/data.rs b/core/lib/src/data/data.rs index b4126889..d54a93de 100644 --- a/core/lib/src/data/data.rs +++ b/core/lib/src/data/data.rs @@ -1,23 +1,20 @@ -use std::future::Future; -use std::io; -use std::path::Path; - -use tokio::io::AsyncWrite; - -use super::data_stream::DataStream; +use std::io::Cursor; use crate::http::hyper; -use crate::ext::{AsyncReadExt, AsyncReadBody}; +use crate::ext::AsyncReadBody; +use crate::tokio::io::AsyncReadExt; +use crate::data::data_stream::DataStream; +use crate::data::ByteUnit; /// The number of bytes to read into the "peek" buffer. -const PEEK_BYTES: usize = 512; +pub const PEEK_BYTES: usize = 512; /// Type representing the data in the body of an incoming request. /// /// This type is the only means by which the body of a request can be retrieved. /// This type is not usually used directly. Instead, types that implement -/// [`FromTransformedData`](crate::data::FromTransformedData) are used via code generation by -/// specifying the `data = ""` route parameter as follows: +/// [`FromTransformedData`](crate::data::FromTransformedData) are used via code +/// generation by specifying the `data = ""` route parameter as follows: /// /// ```rust /// # #[macro_use] extern crate rocket; @@ -27,8 +24,9 @@ const PEEK_BYTES: usize = 512; /// # fn main() { } /// ``` /// -/// Above, `DataGuard` can be any type that implements `FromTransformedData`. Note that -/// `Data` itself implements `FromTransformedData`. +/// Above, `DataGuard` can be any type that implements `FromTransformedData` (or +/// equivalently, `FromData`). Note that `Data` itself implements +/// `FromTransformedData`. /// /// # Reading Data /// @@ -50,162 +48,14 @@ pub struct Data { } impl Data { - /// Returns the raw data stream. - /// - /// The stream contains all of the data in the body of the request, - /// including that in the `peek` buffer. The method consumes the `Data` - /// instance. This ensures that a `Data` type _always_ represents _all_ of - /// the data in a request. - /// - /// # Example - /// - /// ```rust - /// use rocket::Data; - /// - /// fn handler(data: Data) { - /// let stream = data.open(); - /// } - /// ``` - pub fn open(mut self) -> DataStream { - let buffer = std::mem::replace(&mut self.buffer, vec![]); - let stream = std::mem::replace(&mut self.stream, AsyncReadBody::empty()); - DataStream(buffer, stream) - } - - pub(crate) fn from_hyp(body: hyper::Body) -> impl Future { + pub(crate) async fn from_hyp(body: hyper::Body) -> Data { // TODO.async: This used to also set the read timeout to 5 seconds. // Such a short read timeout is likely no longer necessary, but some // kind of idle timeout should be implemented. - Data::new(body) - } - - /// Retrieve the `peek` buffer. - /// - /// The peek buffer contains at most 512 bytes of the body of the request. - /// The actual size of the returned buffer varies by web request. The - /// [`peek_complete`](#method.peek_complete) method can be used to determine - /// if this buffer contains _all_ of the data in the body of the request. - /// - /// # Example - /// - /// ```rust - /// use rocket::Data; - /// - /// fn handler(data: Data) { - /// let peek = data.peek(); - /// } - /// ``` - #[inline(always)] - pub fn peek(&self) -> &[u8] { - if self.buffer.len() > PEEK_BYTES { - &self.buffer[..PEEK_BYTES] - } else { - &self.buffer - } - } - - /// Returns true if the `peek` buffer contains all of the data in the body - /// of the request. Returns `false` if it does not or if it is not known if - /// it does. - /// - /// # Example - /// - /// ```rust - /// use rocket::Data; - /// - /// fn handler(data: Data) { - /// if data.peek_complete() { - /// println!("All of the data: {:?}", data.peek()); - /// } - /// } - /// ``` - #[inline(always)] - pub fn peek_complete(&self) -> bool { - self.is_complete - } - - /// A helper method to write the body of the request to any `AsyncWrite` type. - /// - /// This method is identical to `tokio::io::copy(&mut data.open(), &mut writer)`. - /// - /// # Example - /// - /// ```rust - /// use std::io; - /// use rocket::Data; - /// - /// async fn handler(mut data: Data) -> io::Result { - /// // write all of the data to stdout - /// let written = data.stream_to(tokio::io::stdout()).await?; - /// Ok(format!("Wrote {} bytes.", written)) - /// } - /// ``` - #[inline(always)] - pub async fn stream_to(self, mut writer: W) -> io::Result { - let mut stream = self.open(); - tokio::io::copy(&mut stream, &mut writer).await - } - - /// A helper method to write the body of the request to a file at the path - /// determined by `path`. - /// - /// This method is identical to - /// `tokio::io::copy(&mut self.open(), &mut File::create(path).await?)`. - /// - /// # Example - /// - /// ```rust - /// use std::io; - /// use rocket::Data; - /// - /// async fn handler(mut data: Data) -> io::Result { - /// let written = data.stream_to_file("/static/file").await?; - /// Ok(format!("Wrote {} bytes to /static/file", written)) - /// } - /// ``` - #[inline(always)] - pub async fn stream_to_file>(self, path: P) -> io::Result { - let mut file = tokio::fs::File::create(path).await?; - self.stream_to(&mut file).await - } - - // Creates a new data object with an internal buffer `buf`, where the cursor - // in the buffer is at `pos` and the buffer has `cap` valid bytes. Thus, the - // bytes `vec[pos..cap]` are buffered and unread. The remainder of the data - // bytes can be read from `stream`. - #[inline(always)] - pub(crate) async fn new(body: hyper::Body) -> Data { - trace_!("Data::new({:?})", body); - - let mut stream = AsyncReadBody::from(body); - - let mut peek_buf = vec![0; PEEK_BYTES]; - - let eof = match stream.read_max(&mut peek_buf[..]).await { - Ok(n) => { - trace_!("Filled peek buf with {} bytes.", n); - - // TODO.async: This has not gone away, and I don't entirely - // understand what's happening here - - // We can use `set_len` here instead of `truncate`, but we'll - // take the performance hit to avoid `unsafe`. All of this code - // should go away when we migrate away from hyper 0.10.x. - - peek_buf.truncate(n); - n < PEEK_BYTES - } - Err(e) => { - error_!("Failed to read into peek buffer: {:?}.", e); - // Likewise here as above. - peek_buf.truncate(0); - false - } - }; - - trace_!("Peek bytes: {}/{} bytes.", peek_buf.len(), PEEK_BYTES); - Data { buffer: peek_buf, stream, is_complete: eof } + let stream = AsyncReadBody::from(body); + let buffer = Vec::with_capacity(PEEK_BYTES / 8); + Data { buffer, stream, is_complete: false } } /// This creates a `data` object from a local data source `data`. @@ -217,10 +67,101 @@ impl Data { is_complete: true, } } -} -impl std::borrow::Borrow<()> for Data { - fn borrow(&self) -> &() { - &() + /// Returns the raw data stream, limited to `limit` bytes. + /// + /// The stream contains all of the data in the body of the request, + /// including that in the `peek` buffer. The method consumes the `Data` + /// instance. This ensures that a `Data` type _always_ represents _all_ of + /// the data in a request. + /// + /// # Example + /// + /// ```rust + /// use rocket::data::{Data, ToByteUnit}; + /// + /// # const SIZE_LIMIT: u64 = 2 << 20; // 2MiB + /// fn handler(data: Data) { + /// let stream = data.open(2.mebibytes()); + /// } + /// ``` + pub fn open(self, limit: ByteUnit) -> DataStream { + let buffer_limit = std::cmp::min(self.buffer.len().into(), limit); + let stream_limit = limit - buffer_limit; + let buffer = Cursor::new(self.buffer).take(buffer_limit.into()); + let stream = self.stream.take(stream_limit.into()); + DataStream { buffer, stream } + } + + /// Retrieve at most `num` bytes from the `peek` buffer without consuming + /// `self`. + /// + /// The peek buffer contains at most 512 bytes of the body of the request. + /// The actual size of the returned buffer is the `max` of the request's + /// body, `num` and `512`. The [`peek_complete`](#method.peek_complete) + /// method can be used to determine if this buffer contains _all_ of the + /// data in the body of the request. + /// + /// # Example + /// + /// ```rust + /// use rocket::request::Request; + /// use rocket::data::{self, Data, FromData}; + /// # struct MyType; + /// # type MyError = String; + /// + /// #[rocket::async_trait] + /// impl FromData for MyType { + /// type Error = MyError; + /// + /// async fn from_data(req: &Request<'_>, mut data: Data) -> data::Outcome { + /// if data.peek(10).await != b"hi" { + /// return data::Outcome::Forward(data) + /// } + /// + /// /* .. */ + /// # unimplemented!() + /// } + /// } + /// ``` + pub async fn peek(&mut self, num: usize) -> &[u8] { + let num = std::cmp::min(PEEK_BYTES, num); + let mut len = self.buffer.len(); + if len >= num { + return &self.buffer[..num]; + } + + while len < num { + match self.stream.read_buf(&mut self.buffer).await { + Ok(0) => { self.is_complete = true; break }, + Ok(n) => len += n, + Err(e) => { + error_!("Failed to read into peek buffer: {:?}.", e); + break; + } + } + } + + &self.buffer[..std::cmp::min(len, num)] + } + + /// Returns true if the `peek` buffer contains all of the data in the body + /// of the request. Returns `false` if it does not or if it is not known if + /// it does. + /// + /// # Example + /// + /// ```rust + /// use rocket::data::Data; + /// + /// async fn handler(mut data: Data) { + /// if data.peek_complete() { + /// println!("All of the data: {:?}", data.peek(512).await); + /// } + /// } + /// ``` + #[inline(always)] + pub fn peek_complete(&self) -> bool { + self.is_complete } } diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index 0ebd04fb..b1e6b776 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -1,34 +1,132 @@ use std::pin::Pin; use std::task::{Context, Poll}; +use std::path::Path; +use std::io::{self, Cursor}; -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, Take}; use crate::ext::AsyncReadBody; /// Raw data stream of a request body. /// /// This stream can only be obtained by calling -/// [`Data::open()`](crate::data::Data::open()). The stream contains all of the data -/// in the body of the request. It exposes no methods directly. Instead, it must -/// be used as an opaque [`AsyncRead`] structure. -pub struct DataStream(pub(crate) Vec, pub(crate) AsyncReadBody); +/// [`Data::open()`](crate::data::Data::open()). The stream contains all of the +/// data in the body of the request. It exposes no methods directly. Instead, it +/// must be used as an opaque [`AsyncRead`] structure. +pub struct DataStream { + pub(crate) buffer: Take>>, + pub(crate) stream: Take +} -// TODO.async: Consider implementing `AsyncBufRead` +impl DataStream { + /// A helper method to write the body of the request to any `AsyncWrite` + /// type. + /// + /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`. + /// + /// # Example + /// + /// ```rust + /// use std::io; + /// use rocket::data::{Data, ToByteUnit}; + /// + /// async fn handler(mut data: Data) -> io::Result { + /// // write all of the data to stdout + /// let written = data.open(512.kibibytes()).stream_to(tokio::io::stdout()).await?; + /// Ok(format!("Wrote {} bytes.", written)) + /// } + /// ``` + #[inline(always)] + pub async fn stream_to(mut self, mut writer: W) -> io::Result + where W: AsyncWrite + Unpin + { + tokio::io::copy(&mut self, &mut writer).await + } + + /// A helper method to write the body of the request to a file at the path + /// determined by `path`. + /// + /// This method is identical to `self.stream_to(&mut + /// File::create(path).await?)`. + /// + /// # Example + /// + /// ```rust + /// use std::io; + /// use rocket::data::{Data, ToByteUnit}; + /// + /// async fn handler(mut data: Data) -> io::Result { + /// let written = data.open(1.megabytes()).stream_to_file("/static/file").await?; + /// Ok(format!("Wrote {} bytes to /static/file", written)) + /// } + /// ``` + #[inline(always)] + pub async fn stream_to_file>(self, path: P) -> io::Result { + let mut file = tokio::fs::File::create(path).await?; + self.stream_to(&mut file).await + } + + /// A helper method to write the body of the request to a `String`. + /// + /// # Example + /// + /// ```rust + /// use std::io; + /// use rocket::data::{Data, ToByteUnit}; + /// + /// async fn handler(data: Data) -> io::Result { + /// data.open(10.bytes()).stream_to_string().await + /// } + /// ``` + pub async fn stream_to_string(mut self) -> io::Result { + let buf_len = self.buffer.get_ref().get_ref().len(); + let max_from_buf = std::cmp::min(buf_len, self.buffer.limit() as usize); + let capacity = std::cmp::min(max_from_buf, 1024); + let mut string = String::with_capacity(capacity); + self.read_to_string(&mut string).await?; + Ok(string) + } + + /// A helper method to write the body of the request to a `Vec`. + /// + /// # Example + /// + /// ```rust + /// use std::io; + /// use rocket::data::{Data, ToByteUnit}; + /// + /// async fn handler(data: Data) -> io::Result> { + /// data.open(4.kibibytes()).stream_to_vec().await + /// } + /// ``` + pub async fn stream_to_vec(mut self) -> io::Result> { + let buf_len = self.buffer.get_ref().get_ref().len(); + let max_from_buf = std::cmp::min(buf_len, self.buffer.limit() as usize); + let capacity = std::cmp::min(max_from_buf, 1024); + let mut vec = Vec::with_capacity(capacity); + self.read_to_end(&mut vec).await?; + Ok(vec) + } +} + +// TODO.async: Consider implementing `AsyncBufRead`. impl AsyncRead for DataStream { #[inline(always)] - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { - trace_!("DataStream::poll_read()"); - if self.0.len() > 0 { - let count = std::cmp::min(buf.len(), self.0.len()); - trace_!("Reading peeked {} into dest {} = {} bytes", self.0.len(), buf.len(), count); - let next = self.0.split_off(count); - (&mut buf[..count]).copy_from_slice(&self.0[..]); - self.0 = next; - Poll::Ready(Ok(count)) - } else { - trace_!("Delegating to remaining stream"); - Pin::new(&mut self.1).poll_read(cx, buf) + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8] + ) -> Poll> { + if self.buffer.limit() > 0 { + trace_!("DataStream::buffer_read()"); + match Pin::new(&mut self.buffer).poll_read(cx, buf) { + Poll::Ready(Ok(0)) => { /* fall through */ }, + poll => return poll, + } } + + trace_!("DataStream::stream_read()"); + Pin::new(&mut self.stream).poll_read(cx, buf) } } diff --git a/core/lib/src/data/from_data.rs b/core/lib/src/data/from_data.rs index b6774fed..24ac7c6a 100644 --- a/core/lib/src/data/from_data.rs +++ b/core/lib/src/data/from_data.rs @@ -7,7 +7,7 @@ use crate::outcome::{self, IntoOutcome}; use crate::outcome::Outcome::*; use crate::http::Status; use crate::request::Request; -use crate::data::Data; +use crate::data::{Data, ByteUnit}; /// Type alias for the `Outcome` of a `FromTransformedData` conversion. pub type Outcome = outcome::Outcome; @@ -197,13 +197,12 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome>; /// # struct Name<'a> { first: &'a str, last: &'a str, } /// use std::io::{self, Read}; /// -/// use tokio::io::AsyncReadExt; -/// -/// use rocket::{Request, Data}; -/// use rocket::data::{FromTransformedData, Outcome, Transform, Transformed, TransformFuture, FromDataFuture}; +/// use rocket::Request; +/// use rocket::data::{Data, Outcome, FromDataFuture, ByteUnit}; +/// use rocket::data::{FromTransformedData, Transform, Transformed, TransformFuture}; /// use rocket::http::Status; /// -/// const NAME_LIMIT: u64 = 256; +/// const NAME_LIMIT: ByteUnit = ByteUnit::Byte(256); /// /// enum NameError { /// Io(io::Error), @@ -217,10 +216,8 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome>; /// /// fn transform<'r>(_: &'r Request, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { /// Box::pin(async move { -/// let mut stream = data.open().take(NAME_LIMIT); -/// let mut string = String::with_capacity((NAME_LIMIT / 2) as usize); -/// let outcome = match stream.read_to_string(&mut string).await { -/// Ok(_) => Outcome::Success(string), +/// let outcome = match data.open(NAME_LIMIT).stream_to_string().await { +/// Ok(string) => Outcome::Success(string), /// Err(e) => Outcome::Failure((Status::InternalServerError, NameError::Io(e))) /// }; /// @@ -231,9 +228,9 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome>; /// /// fn from_data(_: &'a Request, outcome: Transformed<'a, Self>) -> FromDataFuture<'a, Self, Self::Error> { /// Box::pin(async move { -/// // Retrieve a borrow to the now transformed `String` (an &str). This -/// // is only correct because we know we _always_ return a `Borrowed` from -/// // `transform` above. +/// // Retrieve a borrow to the now transformed `String` (an &str). +/// // This is only correct because we know we _always_ return a +/// // `Borrowed` from `transform` above. /// let string = try_outcome!(outcome.borrowed()); /// /// // Perform a crude, inefficient parse. @@ -407,7 +404,7 @@ pub trait FromTransformedData<'a>: Sized { impl<'a> FromTransformedData<'a> for Data { type Error = std::convert::Infallible; type Owned = Data; - type Borrowed = (); + type Borrowed = Data; #[inline(always)] fn transform<'r>(_: &'r Request<'_>, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { @@ -486,12 +483,12 @@ impl<'a> FromTransformedData<'a> for Data { /// use std::io::Read; /// /// use rocket::{Request, Data}; -/// use rocket::data::{self, Outcome, FromData, FromDataFuture}; +/// use rocket::data::{self, Outcome, FromData, FromDataFuture, ByteUnit}; /// use rocket::http::{Status, ContentType}; /// use rocket::tokio::io::AsyncReadExt; /// /// // Always use a limit to prevent DoS attacks. -/// const LIMIT: u64 = 256; +/// const LIMIT: ByteUnit = ByteUnit::Byte(256); /// /// #[rocket::async_trait] /// impl FromData for Person { @@ -505,11 +502,10 @@ impl<'a> FromTransformedData<'a> for Data { /// } /// /// // Read the data into a String. -/// let mut string = String::new(); -/// let mut reader = data.open().take(LIMIT); -/// if let Err(e) = reader.read_to_string(&mut string).await { -/// return Outcome::Failure((Status::InternalServerError, format!("{:?}", e))); -/// } +/// let string = match data.open(LIMIT).stream_to_string().await { +/// Ok(string) => string, +/// Err(e) => return Outcome::Failure((Status::InternalServerError, format!("{}", e))) +/// }; /// /// // Split the string into two pieces at ':'. /// let (name, age) = match string.find(':') { @@ -550,7 +546,7 @@ pub trait FromData: Sized { impl<'a, T: FromData + 'a> FromTransformedData<'a> for T { type Error = T::Error; type Owned = Data; - type Borrowed = (); + type Borrowed = Data; #[inline(always)] fn transform<'r>(_: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { @@ -612,12 +608,8 @@ impl FromData for String { #[inline(always)] async fn from_data(_: &Request<'_>, data: Data) -> Outcome { - use tokio::io::AsyncReadExt; - - let mut string = String::new(); - let mut reader = data.open(); - match reader.read_to_string(&mut string).await { - Ok(_) => Success(string), + match data.open(ByteUnit::max_value()).stream_to_string().await { + Ok(string) => Success(string), Err(e) => Failure((Status::BadRequest, e)), } } @@ -632,7 +624,7 @@ impl FromData for Vec { async fn from_data(_: &Request<'_>, data: Data) -> Outcome { use tokio::io::AsyncReadExt; - let mut stream = data.open(); + let mut stream = data.open(ByteUnit::max_value()); let mut buf = Vec::new(); match stream.read_to_end(&mut buf).await { Ok(_) => Success(buf), diff --git a/core/lib/src/data/limits.rs b/core/lib/src/data/limits.rs new file mode 100644 index 00000000..7bdb10ba --- /dev/null +++ b/core/lib/src/data/limits.rs @@ -0,0 +1,138 @@ +use std::fmt; + +use crate::data::{ByteUnit, ToByteUnit}; + +/// Mapping from data type to size limits. +/// +/// A `Limits` structure contains a mapping from a given data type ("forms", +/// "json", and so on) to the maximum size in bytes that should be accepted by a +/// Rocket application for that data type. For instance, if the limit for +/// "forms" is set to `256`, only 256 bytes from an incoming form request will +/// be read. +/// +/// # Defaults +/// +/// As documented in [`config`](crate::config), the default limits are as follows: +/// +/// * **forms**: 32KiB +/// +/// # Usage +/// +/// A `Limits` structure is created following the builder pattern: +/// +/// ```rust +/// use rocket::data::{Limits, ToByteUnit}; +/// +/// // Set a limit of 64KiB for forms and 3MiB for JSON. +/// let limits = Limits::new() +/// .limit("forms", 64.kibibytes()) +/// .limit("json", 3.mebibytes()); +/// ``` +#[derive(Debug, Clone)] +pub struct Limits { + // We cache this internally but don't share that fact in the API. + pub(crate) forms: ByteUnit, + extra: Vec<(String, ByteUnit)> +} + +impl Default for Limits { + fn default() -> Limits { + // Default limit for forms is 32KiB. + Limits { forms: 32.kibibytes(), extra: Vec::new() } + } +} + +impl Limits { + /// Construct a new `Limits` structure with the default limits set. + /// + /// # Example + /// + /// ```rust + /// use rocket::data::{Limits, ToByteUnit}; + /// + /// let limits = Limits::new(); + /// assert_eq!(limits.get("forms"), Some(32.kibibytes())); + /// ``` + #[inline] + pub fn new() -> Self { + Limits::default() + } + + /// Adds or replaces a limit in `self`, consuming `self` and returning a new + /// `Limits` structure with the added or replaced limit. + /// + /// # Example + /// + /// ```rust + /// use rocket::data::{Limits, ToByteUnit}; + /// + /// let limits = Limits::new().limit("json", 1.mebibytes()); + /// + /// assert_eq!(limits.get("forms"), Some(32.kibibytes())); + /// assert_eq!(limits.get("json"), Some(1.mebibytes())); + /// + /// let new_limits = limits.limit("json", 64.mebibytes()); + /// assert_eq!(new_limits.get("json"), Some(64.mebibytes())); + /// ``` + pub fn limit>(mut self, name: S, limit: ByteUnit) -> Self { + let name = name.into(); + match name.as_str() { + "forms" => self.forms = limit, + _ => { + let mut found = false; + for tuple in &mut self.extra { + if tuple.0 == name { + tuple.1 = limit; + found = true; + break; + } + } + + if !found { + self.extra.push((name, limit)) + } + } + } + + self + } + + /// Retrieve the set limit, if any, for the data type with name `name`. + /// + /// # Example + /// + /// ```rust + /// use rocket::data::{Limits, ToByteUnit}; + /// + /// let limits = Limits::new().limit("json", 64.mebibytes()); + /// + /// assert_eq!(limits.get("forms"), Some(32.kibibytes())); + /// assert_eq!(limits.get("json"), Some(64.mebibytes())); + /// assert!(limits.get("msgpack").is_none()); + /// ``` + pub fn get(&self, name: &str) -> Option { + if name == "forms" { + return Some(self.forms); + } + + for &(ref key, val) in &self.extra { + if key == name { + return Some(val); + } + } + + None + } +} + +impl fmt::Display for Limits { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "forms = {}", self.forms)?; + for (key, val) in &self.extra { + write!(f, ", {}* = {}", key, val)?; + } + + Ok(()) + } +} + diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index f127efe9..e93d4337 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -3,7 +3,11 @@ mod data; mod data_stream; mod from_data; +mod limits; pub use self::data::Data; pub use self::data_stream::DataStream; -pub use self::from_data::{FromTransformedData, FromDataFuture, FromData, Outcome, Transform, Transformed, TransformFuture}; +pub use self::from_data::{FromData, Outcome, FromTransformedData, FromDataFuture}; +pub use self::from_data::{Transform, Transformed, TransformFuture}; +pub use self::limits::Limits; +pub use ubyte::{ByteUnit, ToByteUnit}; diff --git a/core/lib/src/ext.rs b/core/lib/src/ext.rs index 81236216..97bfa760 100644 --- a/core/lib/src/ext.rs +++ b/core/lib/src/ext.rs @@ -2,11 +2,10 @@ use std::io::{self, Cursor}; use std::pin::Pin; use std::task::{Poll, Context}; -use futures::{ready, future::BoxFuture, stream::Stream}; -use tokio::io::{AsyncRead, AsyncReadExt as _}; +use futures::{ready, stream::Stream}; +use tokio::io::AsyncRead; -use crate::http::hyper; -use hyper::{Bytes, HttpBody}; +use crate::http::hyper::{self, Bytes, HttpBody}; pub struct IntoBytesStream { inner: R, @@ -29,6 +28,7 @@ impl Stream for IntoBytesStream Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None), Poll::Ready(Ok(n)) => { + // FIXME(perf). let mut next = std::mem::replace(buffer, vec![0; buf_size]); next.truncate(n); Poll::Ready(Some(Ok(Bytes::from(next)))) @@ -37,38 +37,20 @@ impl Stream for IntoBytesStream } } -pub trait AsyncReadExt: AsyncRead { - fn into_bytes_stream(self, buf_size: usize) -> IntoBytesStream where Self: Sized { +pub trait AsyncReadExt: AsyncRead + Sized { + fn into_bytes_stream(self, buf_size: usize) -> IntoBytesStream { IntoBytesStream { inner: self, buf_size, buffer: vec![0; buf_size] } } - - fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> BoxFuture<'_, io::Result> - where Self: Send + Unpin - { - Box::pin(async move { - let start_len = buf.len(); - while !buf.is_empty() { - match self.read(buf).await { - Ok(0) => break, - Ok(n) => buf = &mut buf[n..], - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(e) => return Err(e), - } - } - - Ok(start_len - buf.len()) - }) - } } impl AsyncReadExt for T { } pub struct AsyncReadBody { inner: hyper::Body, - state: AsyncReadBodyState, + state: State, } -enum AsyncReadBodyState { +enum State { Pending, Partial(Cursor), Done, @@ -76,37 +58,43 @@ enum AsyncReadBodyState { impl AsyncReadBody { pub fn empty() -> Self { - Self { inner: hyper::Body::empty(), state: AsyncReadBodyState::Done } + Self { inner: hyper::Body::empty(), state: State::Done } } } impl From for AsyncReadBody { fn from(body: hyper::Body) -> Self { - Self { inner: body, state: AsyncReadBodyState::Pending } + Self { inner: body, state: State::Pending } } } impl AsyncRead for AsyncReadBody { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8] + ) -> Poll> { loop { match self.state { - AsyncReadBodyState::Pending => { + State::Pending => { match ready!(Pin::new(&mut self.inner).poll_data(cx)) { - Some(Ok(bytes)) => self.state = AsyncReadBodyState::Partial(Cursor::new(bytes)), - Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), - None => self.state = AsyncReadBodyState::Done, + Some(Ok(bytes)) => { + self.state = State::Partial(Cursor::new(bytes)); + } + Some(Err(e)) => { + let error = io::Error::new(io::ErrorKind::Other, e); + return Poll::Ready(Err(error)); + } + None => self.state = State::Done, } }, - AsyncReadBodyState::Partial(ref mut cursor) => { + State::Partial(ref mut cursor) => { match ready!(Pin::new(cursor).poll_read(cx, buf)) { - Ok(n) if n == 0 => { - self.state = AsyncReadBodyState::Pending; - } - Ok(n) => return Poll::Ready(Ok(n)), - Err(e) => return Poll::Ready(Err(e)), + Ok(n) if n == 0 => self.state = State::Pending, + result => return Poll::Ready(result), } } - AsyncReadBodyState::Done => return Poll::Ready(Ok(0)), + State::Done => return Poll::Ready(Ok(0)), } } } diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index ed8eccde..e2d52f31 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -142,7 +142,7 @@ pub fn ignite() -> Rocket { /// Alias to [`Rocket::custom()`]. Creates a new instance of `Rocket` with a /// custom configuration. -pub fn custom(config: config::Config) -> Rocket { +pub fn custom(config: Config) -> Rocket { Rocket::custom(config) } diff --git a/core/lib/src/local/asynchronous/request.rs b/core/lib/src/local/asynchronous/request.rs index bfa43ba8..0ef11df5 100644 --- a/core/lib/src/local/asynchronous/request.rs +++ b/core/lib/src/local/asynchronous/request.rs @@ -84,8 +84,8 @@ impl<'c> LocalRequest<'c> { } // Actually dispatch the request. - let data = Data::local(self.data); - let token = rocket.preprocess_request(&mut self.request, &data).await; + let mut data = Data::local(self.data); + let token = rocket.preprocess_request(&mut self.request, &mut data).await; let response = LocalResponse::new(self.request, move |request| { rocket.dispatch(token, request, data) }).await; diff --git a/core/lib/src/request/form/form.rs b/core/lib/src/request/form/form.rs index a8351d38..6e0f4832 100644 --- a/core/lib/src/request/form/form.rs +++ b/core/lib/src/request/form/form.rs @@ -1,7 +1,5 @@ use std::ops::Deref; -use tokio::io::AsyncReadExt; - use crate::outcome::Outcome::*; use crate::request::{Request, form::{FromForm, FormItems, FormDataError}}; use crate::data::{Outcome, Transform, Transformed, Data, FromTransformedData, TransformFuture, FromDataFuture}; @@ -193,21 +191,18 @@ impl<'f, T: FromForm<'f> + Send + 'f> FromTransformedData<'f> for Form { data: Data ) -> TransformFuture<'r, Self::Owned, Self::Error> { Box::pin(async move { - use std::cmp::min; - if !request.content_type().map_or(false, |ct| ct.is_form()) { warn_!("Form data does not have form content type."); return Transform::Borrowed(Forward(data)); } - let limit = request.limits().forms; - let mut stream = data.open().take(limit); - let mut form_string = String::with_capacity(min(4096, limit) as usize); - if let Err(e) = stream.read_to_string(&mut form_string).await { - return Transform::Borrowed(Failure((Status::InternalServerError, FormDataError::Io(e)))); + match data.open(request.limits().forms).stream_to_string().await { + Ok(form_string) => Transform::Borrowed(Success(form_string)), + Err(e) => { + let err = (Status::InternalServerError, FormDataError::Io(e)); + Transform::Borrowed(Failure(err)) + } } - - Transform::Borrowed(Success(form_string)) }) } diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 10d51c93..f7600a37 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -12,14 +12,12 @@ use atomic::Atomic; use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::request::{FromFormValue, FormItems, FormItem}; -use crate::rocket::Rocket; -use crate::shutdown::Shutdown; -use crate::router::Route; -use crate::config::{Config, Limits}; +use crate::{Rocket, Config, Shutdown, Route}; use crate::http::{hyper, uri::{Origin, Segments}}; use crate::http::{Method, Header, HeaderMap, Cookies}; use crate::http::{RawStr, ContentType, Accept, MediaType}; use crate::http::private::{Indexed, SmallVec, CookieJar}; +use crate::data::Limits; type Indices = (usize, usize); @@ -513,7 +511,7 @@ impl<'r> Request<'r> { } } - /// Returns the configured application receive limits. + /// Returns the configured application data limits. /// /// # Example /// diff --git a/core/lib/src/response/debug.rs b/core/lib/src/response/debug.rs index 354d3e1f..8ff46e36 100644 --- a/core/lib/src/response/debug.rs +++ b/core/lib/src/response/debug.rs @@ -19,16 +19,13 @@ use yansi::Paint; /// ```rust /// use std::io; /// -/// use tokio::io::AsyncReadExt; -/// /// # use rocket::post; -/// use rocket::Data; +/// use rocket::data::{Data, ToByteUnit}; /// use rocket::response::Debug; /// /// #[post("/", format = "plain", data = "")] /// async fn post(data: Data) -> Result> { -/// let mut name = String::with_capacity(32); -/// data.open().take(32).read_to_string(&mut name).await?; +/// let name = data.open(32.bytes()).stream_to_string().await?; /// Ok(name) /// } /// ``` diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 6c144a48..4b69bc88 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,5 +1,4 @@ use std::{io, mem}; -use std::cmp::min; use std::sync::Arc; use std::collections::HashMap; @@ -211,10 +210,10 @@ async fn hyper_service_fn( }; // Retrieve the data from the hyper body. - let data = Data::from_hyp(h_body).await; + let mut data = Data::from_hyp(h_body).await; // Dispatch the request to get a response, then write that response out. - let token = rocket.preprocess_request(&mut req, &data).await; + let token = rocket.preprocess_request(&mut req, &mut data).await; let r = rocket.dispatch(token, &mut req, data).await; rocket.issue_response(r, tx).await; }); @@ -303,16 +302,16 @@ impl Rocket { pub(crate) async fn preprocess_request( &self, req: &mut Request<'_>, - data: &Data + data: &mut Data ) -> Token { // Check if this is a form and if the form contains the special _method // field which we use to reinterpret the request's method. - let data_len = data.peek().len(); let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); + let peek_buffer = data.peek(max_len).await; let is_form = req.content_type().map_or(false, |ct| ct.is_form()); - if is_form && req.method() == Method::Post && data_len >= min_len { - if let Ok(form) = std::str::from_utf8(&data.peek()[..min(data_len, max_len)]) { + if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len { + if let Ok(form) = std::str::from_utf8(peek_buffer) { let method: Option> = FormItems::from(form) .filter(|item| item.key.as_str() == "_method") .map(|item| item.value.parse()) diff --git a/core/lib/tests/limits.rs b/core/lib/tests/limits.rs index 067d1230..5fe7804c 100644 --- a/core/lib/tests/limits.rs +++ b/core/lib/tests/limits.rs @@ -14,13 +14,14 @@ fn index(form: Form) -> String { mod limits_tests { use rocket; - use rocket::config::{Environment, Config, Limits}; + use rocket::config::{Environment, Config}; use rocket::local::blocking::Client; use rocket::http::{Status, ContentType}; + use rocket::data::Limits; fn rocket_with_forms_limit(limit: u64) -> rocket::Rocket { let config = Config::build(Environment::Development) - .limits(Limits::default().limit("forms", limit)) + .limits(Limits::default().limit("forms", limit.into())) .unwrap(); rocket::custom(config).mount("/", routes![super::index]) diff --git a/examples/content_types/src/main.rs b/examples/content_types/src/main.rs index 0ebf2236..ba972174 100644 --- a/examples/content_types/src/main.rs +++ b/examples/content_types/src/main.rs @@ -4,8 +4,8 @@ use std::io; -use rocket::tokio::io::AsyncReadExt; -use rocket::{Request, data::Data}; +use rocket::request::Request; +use rocket::data::{Data, ToByteUnit}; use rocket::response::{Debug, content::{Json, Html}}; use serde::{Serialize, Deserialize}; @@ -27,7 +27,7 @@ struct Person { #[get("//", format = "json")] fn get_hello(name: String, age: u8) -> Json { // NOTE: In a real application, we'd use `rocket_contrib::json::Json`. - let person = Person { name: name, age: age, }; + let person = Person { name, age }; Json(serde_json::to_string(&person).unwrap()) } @@ -39,10 +39,8 @@ fn get_hello(name: String, age: u8) -> Json { // use `contrib::Json` to automatically serialize a type into JSON. #[post("/", format = "plain", data = "")] async fn post_hello(age: u8, name_data: Data) -> Result, Debug> { - let mut name = String::with_capacity(32); - let mut stream = name_data.open().take(32); - stream.read_to_string(&mut name).await?; - let person = Person { name: name, age: age, }; + let name = name_data.open(64.bytes()).stream_to_string().await?; + let person = Person { name, age }; // NOTE: In a real application, we'd use `rocket_contrib::json::Json`. Ok(Json(serde_json::to_string(&person).expect("valid JSON"))) } diff --git a/examples/manual_routes/src/main.rs b/examples/manual_routes/src/main.rs index 7d5c34cf..aa16a3bb 100644 --- a/examples/manual_routes/src/main.rs +++ b/examples/manual_routes/src/main.rs @@ -3,7 +3,8 @@ mod tests; use std::env; -use rocket::{Request, Route, Data}; +use rocket::{Request, Route}; +use rocket::data::{Data, ToByteUnit}; use rocket::http::{Status, RawStr, Method::*}; use rocket::response::{Responder, status::Custom}; use rocket::handler::{Handler, Outcome, HandlerFuture}; @@ -47,7 +48,7 @@ fn upload<'r>(req: &'r Request, data: Data) -> HandlerFuture<'r> { let file = File::create(env::temp_dir().join("upload.txt")).await; if let Ok(file) = file { - if let Ok(n) = data.stream_to(file).await { + if let Ok(n) = data.open(2.mebibytes()).stream_to(file).await { return Outcome::from(req, format!("OK: {} bytes uploaded.", n)); } diff --git a/examples/pastebin/src/main.rs b/examples/pastebin/src/main.rs index 90006d45..f3b13c1a 100644 --- a/examples/pastebin/src/main.rs +++ b/examples/pastebin/src/main.rs @@ -5,7 +5,7 @@ mod paste_id; use std::io; -use rocket::Data; +use rocket::data::{Data, ToByteUnit}; use rocket::response::{content::Plain, Debug}; use rocket::tokio::fs::File; @@ -20,7 +20,7 @@ async fn upload(paste: Data) -> Result> { let filename = format!("upload/{id}", id = id); let url = format!("{host}/{id}\n", host = HOST, id = id); - paste.stream_to_file(filename).await?; + paste.open(128.kibibytes()).stream_to_file(filename).await?; Ok(url) } diff --git a/examples/raw_upload/src/main.rs b/examples/raw_upload/src/main.rs index 37e88e8d..22f971a6 100644 --- a/examples/raw_upload/src/main.rs +++ b/examples/raw_upload/src/main.rs @@ -3,11 +3,13 @@ #[cfg(test)] mod tests; use std::{io, env}; -use rocket::{Data, response::Debug}; +use rocket::data::{Data, ToByteUnit}; +use rocket::response::Debug; #[post("/upload", format = "plain", data = "")] async fn upload(data: Data) -> Result> { - Ok(data.stream_to_file(env::temp_dir().join("upload.txt")).await?.to_string()) + let path = env::temp_dir().join("upload.txt"); + Ok(data.open(128.kibibytes()).stream_to_file(path).await?.to_string()) } #[get("/")] diff --git a/scripts/test.sh b/scripts/test.sh index e9459b06..032ba022 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -70,7 +70,9 @@ echo ":: Checking for trailing whitespace..." ensure_trailing_whitespace_free echo ":: Updating dependencies..." -$CARGO update +if ! $CARGO update ; then + echo " WARNING: Update failed! Proceeding with possibly outdated deps..." +fi if [ "$1" = "--contrib" ]; then FEATURES=( diff --git a/site/guide/10-pastebin.md b/site/guide/10-pastebin.md index 17c464ac..a104a252 100644 --- a/site/guide/10-pastebin.md +++ b/site/guide/10-pastebin.md @@ -260,8 +260,8 @@ Here's our version (in `src/main.rs`): # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } # } -use rocket::Data; use rocket::response::Debug; +use rocket::data::{Data, ToByteUnit}; #[post("/", data = "")] async fn upload(paste: Data) -> Result> { @@ -269,13 +269,14 @@ async fn upload(paste: Data) -> Result> { let filename = format!("upload/{id}", id = id); let url = format!("{host}/{id}\n", host = "http://localhost:8000", id = id); - // Write the paste out to the file and return the URL. - paste.stream_to_file(filename).await?; + // Write the paste out, limited to 128KiB, and return the URL. + paste.open(128.kibibytes()).stream_to_file(filename).await?; Ok(url) } ``` -Ensure that the route is mounted at the root path: +Note the [`kibibytes()`] method call: this method comes from the [`ToByteUnit`] +extension trait. Ensure that the route is mounted at the root path: ```rust # #[macro_use] extern crate rocket; @@ -310,6 +311,9 @@ cat upload/* # ensure that contents are correct Note that since we haven't created a `GET /` route, visiting the returned URL will result in a **404**. We'll fix that now. +[`kibibytes()`]: @api/rocket/data/trait.ToByteUnit.html#tymethod.kibibytes +[`ToByteUnit`]: @api/rocket/data/trait.ToByteUnit.html + ## Retrieving Pastes The final step is to create the `retrieve` route which, given an ``, will @@ -443,9 +447,9 @@ through some of them to get a better feel for Rocket. Here are some ideas: the two `POST /` routes should be called. * Support **deletion** of pastes by adding a new `DELETE /` route. Use `PasteId` to validate ``. - * **Limit the upload** to a maximum size. If the upload exceeds that size, - return a **206** partial status code. Otherwise, return a **201** created - status code. + * Indicate **partial uploads** with a **206** partial status code. If the user + uploads a paste that meets or exceeds the allowed limit, return a **206** + partial status code. Otherwise, return a **201** created status code. * Set the `Content-Type` of the return value in `upload` and `retrieve` to `text/plain`. * **Return a unique "key"** after each upload and require that the key is diff --git a/site/guide/4-requests.md b/site/guide/4-requests.md index 17086260..f23629ac 100644 --- a/site/guide/4-requests.md +++ b/site/guide/4-requests.md @@ -1037,36 +1037,39 @@ The only condition is that the generic type in `Json` implements the Sometimes you just want to handle incoming data directly. For example, you might want to stream the incoming data out to a file. Rocket makes this as simple as -possible via the [`Data`](@api/rocket/data/struct.Data.html) -type: +possible via the [`Data`](@api/rocket/data/struct.Data.html) type: ```rust # #[macro_use] extern crate rocket; # fn main() {} -use rocket::Data; +use rocket::data::{Data, ToByteUnit}; use rocket::response::Debug; #[post("/upload", format = "plain", data = "")] async fn upload(data: Data) -> Result> { - Ok(data.stream_to_file("/tmp/upload.txt").await.map(|n| n.to_string())?) + let bytes_written = data.open(128.kibibytes()) + .stream_to_file("/tmp/upload.txt") + .await?; + + Ok(bytes_written.to_string()) } ``` The route above accepts any `POST` request to the `/upload` path with -`Content-Type: text/plain` The incoming data is streamed out to -`tmp/upload.txt`, and the number of bytes written is returned as a plain text -response if the upload succeeds. If the upload fails, an error response is -returned. The handler above is complete. It really is that simple! See the -[GitHub example code](@example/raw_upload) for the full crate. +`Content-Type: text/plain` At most 128KiB (`128 << 10` bytes) of the incoming +data are streamed out to `tmp/upload.txt`, and the number of bytes written is +returned as a plain text response if the upload succeeds. If the upload fails, +an error response is returned. The handler above is complete. It really is that +simple! See the [GitHub example code](@example/raw_upload) for the full crate. -! warning: You should _always_ set limits when reading incoming data. +! note: Rocket requires setting limits when reading incoming data. - To prevent DoS attacks, you should limit the amount of data you're willing to - accept. The [`take()`] reader adapter makes doing this easy: - `data.open().take(LIMIT)`. - - [`take()`]: https://doc.rust-lang.org/std/io/trait.Read.html#method.take + To aid in preventing DoS attacks, Rocket requires you to specify, as a + [`ByteUnit`](@api/rocket/data/struct.ByteUnit.html), the amount of data you're + willing to accept from the client when `open`ing a data stream. The + [`ToByteUnit`](@api/rocket/data/trait.ToByteUnit.html) trait makes specifying + such a value as idiomatic as `128.kibibytes()`. ## Async Routes