Require data limits on 'Data::open()'.

Closes #1325.
This commit is contained in:
Sergio Benitez 2020-07-30 04:17:38 -07:00
parent 45b4436ed3
commit 549c9241c4
29 changed files with 525 additions and 529 deletions

View File

@ -18,13 +18,12 @@ use std::ops::{Deref, DerefMut};
use std::io; use std::io;
use std::iter::FromIterator; use std::iter::FromIterator;
use tokio::io::AsyncReadExt;
use rocket::request::Request; use rocket::request::Request;
use rocket::outcome::Outcome::*; use rocket::outcome::Outcome::*;
use rocket::data::{Transform::*, Transformed, Data, FromTransformedData, TransformFuture, FromDataFuture}; use rocket::data::{Data, ByteUnit, Transform::*, Transformed};
use rocket::response::{self, Responder, content}; use rocket::data::{FromTransformedData, TransformFuture, FromDataFuture};
use rocket::http::Status; use rocket::http::Status;
use rocket::response::{self, Responder, content};
use serde::{Serialize, Serializer}; use serde::{Serialize, Serializer};
use serde::de::{Deserialize, Deserializer}; use serde::de::{Deserialize, Deserializer};
@ -111,9 +110,6 @@ impl<T> Json<T> {
} }
} }
/// Default limit for JSON is 1MB.
const LIMIT: u64 = 1 << 20;
/// An error returned by the [`Json`] data guard when incoming data fails to /// An error returned by the [`Json`] data guard when incoming data fails to
/// serialize as JSON. /// serialize as JSON.
#[derive(Debug)] #[derive(Debug)]
@ -128,6 +124,8 @@ pub enum JsonError<'a> {
Parse(&'a str, serde_json::error::Error), Parse(&'a str, serde_json::error::Error),
} }
const DEFAULT_LIMIT: ByteUnit = ByteUnit::Mebibyte(1);
impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for Json<T> { impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for Json<T> {
type Error = JsonError<'a>; type Error = JsonError<'a>;
type Owned = String; type Owned = String;
@ -135,11 +133,9 @@ impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for Json<T> {
fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> {
Box::pin(async move { Box::pin(async move {
let size_limit = r.limits().get("json").unwrap_or(LIMIT); let size_limit = r.limits().get("json").unwrap_or(DEFAULT_LIMIT);
let mut s = String::with_capacity(512); match d.open(size_limit).stream_to_string().await {
let mut reader = d.open().take(size_limit); Ok(s) => Borrowed(Success(s)),
match reader.read_to_string(&mut s).await {
Ok(_) => Borrowed(Success(s)),
Err(e) => Borrowed(Failure((Status::BadRequest, JsonError::Io(e)))) Err(e) => Borrowed(Failure((Status::BadRequest, JsonError::Io(e))))
} }
}) })

View File

@ -20,9 +20,10 @@ use tokio::io::AsyncReadExt;
use rocket::request::Request; use rocket::request::Request;
use rocket::outcome::Outcome::*; use rocket::outcome::Outcome::*;
use rocket::data::{Data, FromTransformedData, FromDataFuture, Transform::*, TransformFuture, Transformed}; use rocket::data::{Data, ByteUnit, Transform::*, TransformFuture, Transformed};
use rocket::http::Status; use rocket::data::{FromTransformedData, FromDataFuture};
use rocket::response::{self, content, Responder}; use rocket::response::{self, content, Responder};
use rocket::http::Status;
use serde::Serialize; use serde::Serialize;
use serde::de::Deserialize; use serde::de::Deserialize;
@ -110,8 +111,7 @@ impl<T> MsgPack<T> {
} }
} }
/// Default limit for MessagePack is 1MB. const DEFAULT_LIMIT: ByteUnit = ByteUnit::Mebibyte(1);
const LIMIT: u64 = 1 << 20;
impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for MsgPack<T> { impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for MsgPack<T> {
type Error = Error; type Error = Error;
@ -120,9 +120,9 @@ impl<'a, T: Deserialize<'a>> FromTransformedData<'a> for MsgPack<T> {
fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { fn transform<'r>(r: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> {
Box::pin(async move { Box::pin(async move {
let size_limit = r.limits().get("msgpack").unwrap_or(LIMIT); let size_limit = r.limits().get("msgpack").unwrap_or(DEFAULT_LIMIT);
let mut buf = Vec::new(); let mut buf = Vec::new();
let mut reader = d.open().take(size_limit); let mut reader = d.open(size_limit);
match reader.read_to_end(&mut buf).await { match reader.read_to_end(&mut buf).await {
Ok(_) => Borrowed(Success(buf)), Ok(_) => Borrowed(Success(buf)),
Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))), Err(e) => Borrowed(Failure((Status::BadRequest, Error::InvalidDataRead(e)))),

View File

@ -3,9 +3,8 @@
use rocket::{Request, Data}; use rocket::{Request, Data};
use rocket::local::blocking::Client; use rocket::local::blocking::Client;
use rocket::request::Form; use rocket::request::Form;
use rocket::data::{self, FromData}; use rocket::data::{self, FromData, ToByteUnit};
use rocket::http::{RawStr, ContentType, Status}; use rocket::http::{RawStr, ContentType, Status};
use rocket::tokio::io::AsyncReadExt;
// Test that the data parameters works as expected. // Test that the data parameters works as expected.
@ -21,13 +20,10 @@ impl FromData for Simple {
type Error = (); type Error = ();
async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome<Self, ()> { async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome<Self, ()> {
let mut string = String::new(); match data.open(64.bytes()).stream_to_string().await {
let mut stream = data.open().take(64); Ok(string) => data::Outcome::Success(Simple(string)),
if let Err(_) = stream.read_to_string(&mut string).await { Err(_) => data::Outcome::Failure((Status::InternalServerError, ())),
return data::Outcome::Failure((Status::InternalServerError, ()));
} }
data::Outcome::Success(Simple(string))
} }
} }

View File

@ -9,10 +9,9 @@ use std::path::PathBuf;
use rocket::http::ext::Normalize; use rocket::http::ext::Normalize;
use rocket::local::blocking::Client; use rocket::local::blocking::Client;
use rocket::data::{self, Data, FromData}; use rocket::data::{self, Data, FromData, ToByteUnit};
use rocket::request::{Request, Form}; use rocket::request::{Request, Form};
use rocket::http::{Status, RawStr, ContentType}; use rocket::http::{Status, RawStr, ContentType};
use rocket::tokio::io::AsyncReadExt;
// Use all of the code generation available at once. // Use all of the code generation available at once.
@ -28,9 +27,7 @@ impl FromData for Simple {
type Error = (); type Error = ();
async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome<Self, ()> { async fn from_data(_: &Request<'_>, data: Data) -> data::Outcome<Self, ()> {
let mut string = String::new(); let string = data.open(64.bytes()).stream_to_string().await.unwrap();
let mut stream = data.open().take(64);
stream.read_to_string(&mut string).await.unwrap();
data::Outcome::Success(Simple(string)) data::Outcome::Success(Simple(string))
} }
} }

View File

@ -39,6 +39,7 @@ atty = "0.2"
async-trait = "0.1" async-trait = "0.1"
ref-cast = "1.0" ref-cast = "1.0"
atomic = "0.4" atomic = "0.4"
ubyte = "0.9.1"
[dependencies.pear] [dependencies.pear]
git = "https://github.com/SergioBenitez/Pear.git" git = "https://github.com/SergioBenitez/Pear.git"

View File

@ -1,7 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use crate::config::{Result, Config, Value, Environment, Limits, LoggingLevel}; use crate::config::{Result, Config, Value, Environment, LoggingLevel};
use crate::data::Limits;
/// Structure following the builder pattern for building `Config` structures. /// Structure following the builder pattern for building `Config` structures.
#[derive(Clone)] #[derive(Clone)]
@ -189,10 +190,11 @@ impl ConfigBuilder {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::{Config, Environment, Limits}; /// use rocket::config::{Config, Environment};
/// use rocket::data::{Limits, ToByteUnit};
/// ///
/// let mut config = Config::build(Environment::Staging) /// let mut config = Config::build(Environment::Staging)
/// .limits(Limits::new().limit("json", 5 * (1 << 20))) /// .limits(Limits::new().limit("json", 5.mebibytes()))
/// .unwrap(); /// .unwrap();
/// ``` /// ```
pub fn limits(mut self, limits: Limits) -> Self { pub fn limits(mut self, limits: Limits) -> Self {

View File

@ -7,6 +7,7 @@ use std::fmt;
use crate::config::Environment::*; use crate::config::Environment::*;
use crate::config::{Result, ConfigBuilder, Environment, ConfigError, LoggingLevel}; use crate::config::{Result, ConfigBuilder, Environment, ConfigError, LoggingLevel};
use crate::config::{FullConfig, Table, Value, Array, Datetime}; use crate::config::{FullConfig, Table, Value, Array, Datetime};
use crate::data::Limits;
use crate::http::private::Key; use crate::http::private::Key;
use super::custom_values::*; use super::custom_values::*;
@ -51,7 +52,7 @@ pub struct Config {
pub(crate) secret_key: SecretKey, pub(crate) secret_key: SecretKey,
/// TLS configuration. /// TLS configuration.
pub(crate) tls: Option<TlsConfig>, pub(crate) tls: Option<TlsConfig>,
/// Streaming read size limits. /// Streaming data limits.
pub limits: Limits, pub limits: Limits,
/// Extra parameters that aren't part of Rocket's core config. /// Extra parameters that aren't part of Rocket's core config.
pub extras: HashMap<String, Value>, pub extras: HashMap<String, Value>,
@ -96,10 +97,8 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::Config;
///
/// # if false { /// # if false {
/// let config = Config::read().unwrap(); /// let config = rocket::Config::read().unwrap();
/// # } /// # }
/// ``` /// ```
pub fn read() -> Result<Config> { pub fn read() -> Result<Config> {
@ -115,10 +114,8 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::Config;
///
/// # if false { /// # if false {
/// let config = Config::read_from("/var/my-config.toml").unwrap(); /// let config = rocket::Config::read_from("/var/my-config.toml").unwrap();
/// # } /// # }
/// ``` /// ```
pub fn read_from<P: AsRef<Path>>(path: P) -> Result<Config> { pub fn read_from<P: AsRef<Path>>(path: P) -> Result<Config> {
@ -185,9 +182,7 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::Config; /// let mut my_config = rocket::Config::active().unwrap();
///
/// let mut my_config = Config::active().unwrap();
/// my_config.set_port(1001); /// my_config.set_port(1001);
/// ``` /// ```
pub fn active() -> Result<Config> { pub fn active() -> Result<Config> {
@ -453,9 +448,7 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::Config; /// let mut config = rocket::Config::development();
///
/// let mut config = Config::development();
/// ///
/// // Set keep-alive timeout to 10 seconds. /// // Set keep-alive timeout to 10 seconds.
/// config.set_keep_alive(10); /// config.set_keep_alive(10);
@ -537,10 +530,10 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::{Config, Limits}; /// use rocket::data::{Limits, ToByteUnit};
/// ///
/// let mut config = Config::development(); /// let mut config = rocket::Config::development();
/// config.set_limits(Limits::default().limit("json", 4 * (1 << 20))); /// config.set_limits(Limits::default().limit("json", 4.mebibytes()));
/// ``` /// ```
#[inline] #[inline]
pub fn set_limits(&mut self, limits: Limits) { pub fn set_limits(&mut self, limits: Limits) {
@ -563,11 +556,9 @@ impl Config {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use rocket::config::Config;
///
/// # use rocket::config::ConfigError; /// # use rocket::config::ConfigError;
/// # fn config_test() -> Result<(), ConfigError> { /// # fn config_test() -> Result<(), ConfigError> {
/// let mut config = Config::development(); /// let mut config = rocket::Config::development();
/// config.set_tls("/etc/ssl/my_certs.pem", "/etc/ssl/priv.key")?; /// config.set_tls("/etc/ssl/my_certs.pem", "/etc/ssl/priv.key")?;
/// # Ok(()) /// # Ok(())
/// # } /// # }

View File

@ -4,6 +4,7 @@ use std::fmt;
use crate::http::private::Key; use crate::http::private::Key;
use crate::config::{Result, Config, Value, ConfigError, LoggingLevel}; use crate::config::{Result, Config, Value, ConfigError, LoggingLevel};
use crate::data::Limits;
#[derive(Clone)] #[derive(Clone)]
pub enum SecretKey { pub enum SecretKey {
@ -53,154 +54,6 @@ pub struct TlsConfig {
#[derive(Clone)] #[derive(Clone)]
pub struct TlsConfig; pub struct TlsConfig;
/// Mapping from data type to size limits.
///
/// A `Limits` structure contains a mapping from a given data type ("forms",
/// "json", and so on) to the maximum size in bytes that should be accepted by a
/// Rocket application for that data type. For instance, if the limit for
/// "forms" is set to `256`, only 256 bytes from an incoming form request will
/// be read.
///
/// # Defaults
///
/// As documented in [`config`](crate::config), the default limits are as follows:
///
/// * **forms**: 32KiB
///
/// # Usage
///
/// A `Limits` structure is created following the builder pattern:
///
/// ```rust
/// use rocket::config::Limits;
///
/// // Set a limit of 64KiB for forms and 3MiB for JSON.
/// let limits = Limits::new()
/// .limit("forms", 64 * 1024)
/// .limit("json", 3 * 1024 * 1024);
/// ```
#[derive(Debug, Clone)]
pub struct Limits {
// We cache this internally but don't share that fact in the API.
pub(crate) forms: u64,
extra: Vec<(String, u64)>
}
impl Default for Limits {
fn default() -> Limits {
// Default limit for forms is 32KiB.
Limits { forms: 32 * 1024, extra: Vec::new() }
}
}
impl Limits {
/// Construct a new `Limits` structure with the default limits set.
///
/// # Example
///
/// ```rust
/// use rocket::config::Limits;
///
/// let limits = Limits::new();
/// assert_eq!(limits.get("forms"), Some(32 * 1024));
/// ```
#[inline]
pub fn new() -> Self {
Limits::default()
}
/// Adds or replaces a limit in `self`, consuming `self` and returning a new
/// `Limits` structure with the added or replaced limit.
///
/// # Example
///
/// ```rust
/// use rocket::config::Limits;
///
/// let limits = Limits::new()
/// .limit("json", 1 * 1024 * 1024);
///
/// assert_eq!(limits.get("forms"), Some(32 * 1024));
/// assert_eq!(limits.get("json"), Some(1 * 1024 * 1024));
///
/// let new_limits = limits.limit("json", 64 * 1024 * 1024);
/// assert_eq!(new_limits.get("json"), Some(64 * 1024 * 1024));
/// ```
pub fn limit<S: Into<String>>(mut self, name: S, limit: u64) -> Self {
let name = name.into();
match name.as_str() {
"forms" => self.forms = limit,
_ => {
let mut found = false;
for tuple in &mut self.extra {
if tuple.0 == name {
tuple.1 = limit;
found = true;
break;
}
}
if !found {
self.extra.push((name, limit))
}
}
}
self
}
/// Retrieve the set limit, if any, for the data type with name `name`.
///
/// # Example
///
/// ```rust
/// use rocket::config::Limits;
///
/// let limits = Limits::new()
/// .limit("json", 64 * 1024 * 1024);
///
/// assert_eq!(limits.get("forms"), Some(32 * 1024));
/// assert_eq!(limits.get("json"), Some(64 * 1024 * 1024));
/// assert!(limits.get("msgpack").is_none());
/// ```
pub fn get(&self, name: &str) -> Option<u64> {
if name == "forms" {
return Some(self.forms);
}
for &(ref key, val) in &self.extra {
if key == name {
return Some(val);
}
}
None
}
}
impl fmt::Display for Limits {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt_size(n: u64, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if (n & ((1 << 20) - 1)) == 0 {
write!(f, "{}MiB", n >> 20)
} else if (n & ((1 << 10) - 1)) == 0 {
write!(f, "{}KiB", n >> 10)
} else {
write!(f, "{}B", n)
}
}
write!(f, "forms = ")?;
fmt_size(self.forms, f)?;
for &(ref key, val) in &self.extra {
write!(f, ", {}* = ", key)?;
fmt_size(val, f)?;
}
Ok(())
}
}
pub fn str<'a>(conf: &Config, name: &str, v: &'a Value) -> Result<&'a str> { pub fn str<'a>(conf: &Config, name: &str, v: &'a Value) -> Result<&'a str> {
v.as_str().ok_or_else(|| conf.bad_type(name, v.type_str(), "a string")) v.as_str().ok_or_else(|| conf.bad_type(name, v.type_str(), "a string"))
} }
@ -266,7 +119,7 @@ pub fn limits(conf: &Config, name: &str, value: &Value) -> Result<Limits> {
let mut limits = Limits::default(); let mut limits = Limits::default();
for (key, val) in table { for (key, val) in table {
let val = u64(conf, &format!("limits.{}", key), val)?; let val = u64(conf, &format!("limits.{}", key), val)?;
limits = limits.limit(key.as_str(), val); limits = limits.limit(key.as_str(), val.into());
} }
Ok(limits) Ok(limits)

View File

@ -198,7 +198,6 @@ use std::path::{Path, PathBuf};
use toml; use toml;
pub use self::custom_values::Limits;
pub use toml::value::{Array, Map, Table, Value, Datetime}; pub use toml::value::{Array, Map, Table, Value, Datetime};
pub use self::error::ConfigError; pub use self::error::ConfigError;
pub use self::environment::Environment; pub use self::environment::Environment;

View File

@ -1,23 +1,20 @@
use std::future::Future; use std::io::Cursor;
use std::io;
use std::path::Path;
use tokio::io::AsyncWrite;
use super::data_stream::DataStream;
use crate::http::hyper; use crate::http::hyper;
use crate::ext::{AsyncReadExt, AsyncReadBody}; use crate::ext::AsyncReadBody;
use crate::tokio::io::AsyncReadExt;
use crate::data::data_stream::DataStream;
use crate::data::ByteUnit;
/// The number of bytes to read into the "peek" buffer. /// The number of bytes to read into the "peek" buffer.
const PEEK_BYTES: usize = 512; pub const PEEK_BYTES: usize = 512;
/// Type representing the data in the body of an incoming request. /// Type representing the data in the body of an incoming request.
/// ///
/// This type is the only means by which the body of a request can be retrieved. /// This type is the only means by which the body of a request can be retrieved.
/// This type is not usually used directly. Instead, types that implement /// This type is not usually used directly. Instead, types that implement
/// [`FromTransformedData`](crate::data::FromTransformedData) are used via code generation by /// [`FromTransformedData`](crate::data::FromTransformedData) are used via code
/// specifying the `data = "<var>"` route parameter as follows: /// generation by specifying the `data = "<var>"` route parameter as follows:
/// ///
/// ```rust /// ```rust
/// # #[macro_use] extern crate rocket; /// # #[macro_use] extern crate rocket;
@ -27,8 +24,9 @@ const PEEK_BYTES: usize = 512;
/// # fn main() { } /// # fn main() { }
/// ``` /// ```
/// ///
/// Above, `DataGuard` can be any type that implements `FromTransformedData`. Note that /// Above, `DataGuard` can be any type that implements `FromTransformedData` (or
/// `Data` itself implements `FromTransformedData`. /// equivalently, `FromData`). Note that `Data` itself implements
/// `FromTransformedData`.
/// ///
/// # Reading Data /// # Reading Data
/// ///
@ -50,162 +48,14 @@ pub struct Data {
} }
impl Data { impl Data {
/// Returns the raw data stream. pub(crate) async fn from_hyp(body: hyper::Body) -> Data {
///
/// The stream contains all of the data in the body of the request,
/// including that in the `peek` buffer. The method consumes the `Data`
/// instance. This ensures that a `Data` type _always_ represents _all_ of
/// the data in a request.
///
/// # Example
///
/// ```rust
/// use rocket::Data;
///
/// fn handler(data: Data) {
/// let stream = data.open();
/// }
/// ```
pub fn open(mut self) -> DataStream {
let buffer = std::mem::replace(&mut self.buffer, vec![]);
let stream = std::mem::replace(&mut self.stream, AsyncReadBody::empty());
DataStream(buffer, stream)
}
pub(crate) fn from_hyp(body: hyper::Body) -> impl Future<Output = Data> {
// TODO.async: This used to also set the read timeout to 5 seconds. // TODO.async: This used to also set the read timeout to 5 seconds.
// Such a short read timeout is likely no longer necessary, but some // Such a short read timeout is likely no longer necessary, but some
// kind of idle timeout should be implemented. // kind of idle timeout should be implemented.
Data::new(body) let stream = AsyncReadBody::from(body);
} let buffer = Vec::with_capacity(PEEK_BYTES / 8);
Data { buffer, stream, is_complete: false }
/// Retrieve the `peek` buffer.
///
/// The peek buffer contains at most 512 bytes of the body of the request.
/// The actual size of the returned buffer varies by web request. The
/// [`peek_complete`](#method.peek_complete) method can be used to determine
/// if this buffer contains _all_ of the data in the body of the request.
///
/// # Example
///
/// ```rust
/// use rocket::Data;
///
/// fn handler(data: Data) {
/// let peek = data.peek();
/// }
/// ```
#[inline(always)]
pub fn peek(&self) -> &[u8] {
if self.buffer.len() > PEEK_BYTES {
&self.buffer[..PEEK_BYTES]
} else {
&self.buffer
}
}
/// Returns true if the `peek` buffer contains all of the data in the body
/// of the request. Returns `false` if it does not or if it is not known if
/// it does.
///
/// # Example
///
/// ```rust
/// use rocket::Data;
///
/// fn handler(data: Data) {
/// if data.peek_complete() {
/// println!("All of the data: {:?}", data.peek());
/// }
/// }
/// ```
#[inline(always)]
pub fn peek_complete(&self) -> bool {
self.is_complete
}
/// A helper method to write the body of the request to any `AsyncWrite` type.
///
/// This method is identical to `tokio::io::copy(&mut data.open(), &mut writer)`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::Data;
///
/// async fn handler(mut data: Data) -> io::Result<String> {
/// // write all of the data to stdout
/// let written = data.stream_to(tokio::io::stdout()).await?;
/// Ok(format!("Wrote {} bytes.", written))
/// }
/// ```
#[inline(always)]
pub async fn stream_to<W: AsyncWrite + Unpin>(self, mut writer: W) -> io::Result<u64> {
let mut stream = self.open();
tokio::io::copy(&mut stream, &mut writer).await
}
/// A helper method to write the body of the request to a file at the path
/// determined by `path`.
///
/// This method is identical to
/// `tokio::io::copy(&mut self.open(), &mut File::create(path).await?)`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::Data;
///
/// async fn handler(mut data: Data) -> io::Result<String> {
/// let written = data.stream_to_file("/static/file").await?;
/// Ok(format!("Wrote {} bytes to /static/file", written))
/// }
/// ```
#[inline(always)]
pub async fn stream_to_file<P: AsRef<Path>>(self, path: P) -> io::Result<u64> {
let mut file = tokio::fs::File::create(path).await?;
self.stream_to(&mut file).await
}
// Creates a new data object with an internal buffer `buf`, where the cursor
// in the buffer is at `pos` and the buffer has `cap` valid bytes. Thus, the
// bytes `vec[pos..cap]` are buffered and unread. The remainder of the data
// bytes can be read from `stream`.
#[inline(always)]
pub(crate) async fn new(body: hyper::Body) -> Data {
trace_!("Data::new({:?})", body);
let mut stream = AsyncReadBody::from(body);
let mut peek_buf = vec![0; PEEK_BYTES];
let eof = match stream.read_max(&mut peek_buf[..]).await {
Ok(n) => {
trace_!("Filled peek buf with {} bytes.", n);
// TODO.async: This has not gone away, and I don't entirely
// understand what's happening here
// We can use `set_len` here instead of `truncate`, but we'll
// take the performance hit to avoid `unsafe`. All of this code
// should go away when we migrate away from hyper 0.10.x.
peek_buf.truncate(n);
n < PEEK_BYTES
}
Err(e) => {
error_!("Failed to read into peek buffer: {:?}.", e);
// Likewise here as above.
peek_buf.truncate(0);
false
}
};
trace_!("Peek bytes: {}/{} bytes.", peek_buf.len(), PEEK_BYTES);
Data { buffer: peek_buf, stream, is_complete: eof }
} }
/// This creates a `data` object from a local data source `data`. /// This creates a `data` object from a local data source `data`.
@ -217,10 +67,101 @@ impl Data {
is_complete: true, is_complete: true,
} }
} }
}
impl std::borrow::Borrow<()> for Data { /// Returns the raw data stream, limited to `limit` bytes.
fn borrow(&self) -> &() { ///
&() /// The stream contains all of the data in the body of the request,
/// including that in the `peek` buffer. The method consumes the `Data`
/// instance. This ensures that a `Data` type _always_ represents _all_ of
/// the data in a request.
///
/// # Example
///
/// ```rust
/// use rocket::data::{Data, ToByteUnit};
///
/// # const SIZE_LIMIT: u64 = 2 << 20; // 2MiB
/// fn handler(data: Data) {
/// let stream = data.open(2.mebibytes());
/// }
/// ```
pub fn open(self, limit: ByteUnit) -> DataStream {
let buffer_limit = std::cmp::min(self.buffer.len().into(), limit);
let stream_limit = limit - buffer_limit;
let buffer = Cursor::new(self.buffer).take(buffer_limit.into());
let stream = self.stream.take(stream_limit.into());
DataStream { buffer, stream }
}
/// Retrieve at most `num` bytes from the `peek` buffer without consuming
/// `self`.
///
/// The peek buffer contains at most 512 bytes of the body of the request.
/// The actual size of the returned buffer is the `max` of the request's
/// body, `num` and `512`. The [`peek_complete`](#method.peek_complete)
/// method can be used to determine if this buffer contains _all_ of the
/// data in the body of the request.
///
/// # Example
///
/// ```rust
/// use rocket::request::Request;
/// use rocket::data::{self, Data, FromData};
/// # struct MyType;
/// # type MyError = String;
///
/// #[rocket::async_trait]
/// impl FromData for MyType {
/// type Error = MyError;
///
/// async fn from_data(req: &Request<'_>, mut data: Data) -> data::Outcome<Self, MyError> {
/// if data.peek(10).await != b"hi" {
/// return data::Outcome::Forward(data)
/// }
///
/// /* .. */
/// # unimplemented!()
/// }
/// }
/// ```
pub async fn peek(&mut self, num: usize) -> &[u8] {
let num = std::cmp::min(PEEK_BYTES, num);
let mut len = self.buffer.len();
if len >= num {
return &self.buffer[..num];
}
while len < num {
match self.stream.read_buf(&mut self.buffer).await {
Ok(0) => { self.is_complete = true; break },
Ok(n) => len += n,
Err(e) => {
error_!("Failed to read into peek buffer: {:?}.", e);
break;
}
}
}
&self.buffer[..std::cmp::min(len, num)]
}
/// Returns true if the `peek` buffer contains all of the data in the body
/// of the request. Returns `false` if it does not or if it is not known if
/// it does.
///
/// # Example
///
/// ```rust
/// use rocket::data::Data;
///
/// async fn handler(mut data: Data) {
/// if data.peek_complete() {
/// println!("All of the data: {:?}", data.peek(512).await);
/// }
/// }
/// ```
#[inline(always)]
pub fn peek_complete(&self) -> bool {
self.is_complete
} }
} }

View File

@ -1,34 +1,132 @@
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::path::Path;
use std::io::{self, Cursor};
use tokio::io::AsyncRead; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, Take};
use crate::ext::AsyncReadBody; use crate::ext::AsyncReadBody;
/// Raw data stream of a request body. /// Raw data stream of a request body.
/// ///
/// This stream can only be obtained by calling /// This stream can only be obtained by calling
/// [`Data::open()`](crate::data::Data::open()). The stream contains all of the data /// [`Data::open()`](crate::data::Data::open()). The stream contains all of the
/// in the body of the request. It exposes no methods directly. Instead, it must /// data in the body of the request. It exposes no methods directly. Instead, it
/// be used as an opaque [`AsyncRead`] structure. /// must be used as an opaque [`AsyncRead`] structure.
pub struct DataStream(pub(crate) Vec<u8>, pub(crate) AsyncReadBody); pub struct DataStream {
pub(crate) buffer: Take<Cursor<Vec<u8>>>,
pub(crate) stream: Take<AsyncReadBody>
}
// TODO.async: Consider implementing `AsyncBufRead` impl DataStream {
/// A helper method to write the body of the request to any `AsyncWrite`
/// type.
///
/// This method is identical to `tokio::io::copy(&mut self, &mut writer)`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::data::{Data, ToByteUnit};
///
/// async fn handler(mut data: Data) -> io::Result<String> {
/// // write all of the data to stdout
/// let written = data.open(512.kibibytes()).stream_to(tokio::io::stdout()).await?;
/// Ok(format!("Wrote {} bytes.", written))
/// }
/// ```
#[inline(always)]
pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<u64>
where W: AsyncWrite + Unpin
{
tokio::io::copy(&mut self, &mut writer).await
}
/// A helper method to write the body of the request to a file at the path
/// determined by `path`.
///
/// This method is identical to `self.stream_to(&mut
/// File::create(path).await?)`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::data::{Data, ToByteUnit};
///
/// async fn handler(mut data: Data) -> io::Result<String> {
/// let written = data.open(1.megabytes()).stream_to_file("/static/file").await?;
/// Ok(format!("Wrote {} bytes to /static/file", written))
/// }
/// ```
#[inline(always)]
pub async fn stream_to_file<P: AsRef<Path>>(self, path: P) -> io::Result<u64> {
let mut file = tokio::fs::File::create(path).await?;
self.stream_to(&mut file).await
}
/// A helper method to write the body of the request to a `String`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::data::{Data, ToByteUnit};
///
/// async fn handler(data: Data) -> io::Result<String> {
/// data.open(10.bytes()).stream_to_string().await
/// }
/// ```
pub async fn stream_to_string(mut self) -> io::Result<String> {
let buf_len = self.buffer.get_ref().get_ref().len();
let max_from_buf = std::cmp::min(buf_len, self.buffer.limit() as usize);
let capacity = std::cmp::min(max_from_buf, 1024);
let mut string = String::with_capacity(capacity);
self.read_to_string(&mut string).await?;
Ok(string)
}
/// A helper method to write the body of the request to a `Vec<u8>`.
///
/// # Example
///
/// ```rust
/// use std::io;
/// use rocket::data::{Data, ToByteUnit};
///
/// async fn handler(data: Data) -> io::Result<Vec<u8>> {
/// data.open(4.kibibytes()).stream_to_vec().await
/// }
/// ```
pub async fn stream_to_vec(mut self) -> io::Result<Vec<u8>> {
let buf_len = self.buffer.get_ref().get_ref().len();
let max_from_buf = std::cmp::min(buf_len, self.buffer.limit() as usize);
let capacity = std::cmp::min(max_from_buf, 1024);
let mut vec = Vec::with_capacity(capacity);
self.read_to_end(&mut vec).await?;
Ok(vec)
}
}
// TODO.async: Consider implementing `AsyncBufRead`.
impl AsyncRead for DataStream { impl AsyncRead for DataStream {
#[inline(always)] #[inline(always)]
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize, std::io::Error>> { fn poll_read(
trace_!("DataStream::poll_read()"); mut self: Pin<&mut Self>,
if self.0.len() > 0 { cx: &mut Context<'_>,
let count = std::cmp::min(buf.len(), self.0.len()); buf: &mut [u8]
trace_!("Reading peeked {} into dest {} = {} bytes", self.0.len(), buf.len(), count); ) -> Poll<io::Result<usize>> {
let next = self.0.split_off(count); if self.buffer.limit() > 0 {
(&mut buf[..count]).copy_from_slice(&self.0[..]); trace_!("DataStream::buffer_read()");
self.0 = next; match Pin::new(&mut self.buffer).poll_read(cx, buf) {
Poll::Ready(Ok(count)) Poll::Ready(Ok(0)) => { /* fall through */ },
} else { poll => return poll,
trace_!("Delegating to remaining stream"); }
Pin::new(&mut self.1).poll_read(cx, buf)
} }
trace_!("DataStream::stream_read()");
Pin::new(&mut self.stream).poll_read(cx, buf)
} }
} }

View File

@ -7,7 +7,7 @@ use crate::outcome::{self, IntoOutcome};
use crate::outcome::Outcome::*; use crate::outcome::Outcome::*;
use crate::http::Status; use crate::http::Status;
use crate::request::Request; use crate::request::Request;
use crate::data::Data; use crate::data::{Data, ByteUnit};
/// Type alias for the `Outcome` of a `FromTransformedData` conversion. /// Type alias for the `Outcome` of a `FromTransformedData` conversion.
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Data>; pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Data>;
@ -197,13 +197,12 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome<T, E>>;
/// # struct Name<'a> { first: &'a str, last: &'a str, } /// # struct Name<'a> { first: &'a str, last: &'a str, }
/// use std::io::{self, Read}; /// use std::io::{self, Read};
/// ///
/// use tokio::io::AsyncReadExt; /// use rocket::Request;
/// /// use rocket::data::{Data, Outcome, FromDataFuture, ByteUnit};
/// use rocket::{Request, Data}; /// use rocket::data::{FromTransformedData, Transform, Transformed, TransformFuture};
/// use rocket::data::{FromTransformedData, Outcome, Transform, Transformed, TransformFuture, FromDataFuture};
/// use rocket::http::Status; /// use rocket::http::Status;
/// ///
/// const NAME_LIMIT: u64 = 256; /// const NAME_LIMIT: ByteUnit = ByteUnit::Byte(256);
/// ///
/// enum NameError { /// enum NameError {
/// Io(io::Error), /// Io(io::Error),
@ -217,10 +216,8 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome<T, E>>;
/// ///
/// fn transform<'r>(_: &'r Request, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { /// fn transform<'r>(_: &'r Request, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> {
/// Box::pin(async move { /// Box::pin(async move {
/// let mut stream = data.open().take(NAME_LIMIT); /// let outcome = match data.open(NAME_LIMIT).stream_to_string().await {
/// let mut string = String::with_capacity((NAME_LIMIT / 2) as usize); /// Ok(string) => Outcome::Success(string),
/// let outcome = match stream.read_to_string(&mut string).await {
/// Ok(_) => Outcome::Success(string),
/// Err(e) => Outcome::Failure((Status::InternalServerError, NameError::Io(e))) /// Err(e) => Outcome::Failure((Status::InternalServerError, NameError::Io(e)))
/// }; /// };
/// ///
@ -231,9 +228,9 @@ pub type FromDataFuture<'fut, T, E> = BoxFuture<'fut, Outcome<T, E>>;
/// ///
/// fn from_data(_: &'a Request, outcome: Transformed<'a, Self>) -> FromDataFuture<'a, Self, Self::Error> { /// fn from_data(_: &'a Request, outcome: Transformed<'a, Self>) -> FromDataFuture<'a, Self, Self::Error> {
/// Box::pin(async move { /// Box::pin(async move {
/// // Retrieve a borrow to the now transformed `String` (an &str). This /// // Retrieve a borrow to the now transformed `String` (an &str).
/// // is only correct because we know we _always_ return a `Borrowed` from /// // This is only correct because we know we _always_ return a
/// // `transform` above. /// // `Borrowed` from `transform` above.
/// let string = try_outcome!(outcome.borrowed()); /// let string = try_outcome!(outcome.borrowed());
/// ///
/// // Perform a crude, inefficient parse. /// // Perform a crude, inefficient parse.
@ -407,7 +404,7 @@ pub trait FromTransformedData<'a>: Sized {
impl<'a> FromTransformedData<'a> for Data { impl<'a> FromTransformedData<'a> for Data {
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
type Owned = Data; type Owned = Data;
type Borrowed = (); type Borrowed = Data;
#[inline(always)] #[inline(always)]
fn transform<'r>(_: &'r Request<'_>, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { fn transform<'r>(_: &'r Request<'_>, data: Data) -> TransformFuture<'r, Self::Owned, Self::Error> {
@ -486,12 +483,12 @@ impl<'a> FromTransformedData<'a> for Data {
/// use std::io::Read; /// use std::io::Read;
/// ///
/// use rocket::{Request, Data}; /// use rocket::{Request, Data};
/// use rocket::data::{self, Outcome, FromData, FromDataFuture}; /// use rocket::data::{self, Outcome, FromData, FromDataFuture, ByteUnit};
/// use rocket::http::{Status, ContentType}; /// use rocket::http::{Status, ContentType};
/// use rocket::tokio::io::AsyncReadExt; /// use rocket::tokio::io::AsyncReadExt;
/// ///
/// // Always use a limit to prevent DoS attacks. /// // Always use a limit to prevent DoS attacks.
/// const LIMIT: u64 = 256; /// const LIMIT: ByteUnit = ByteUnit::Byte(256);
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl FromData for Person { /// impl FromData for Person {
@ -505,11 +502,10 @@ impl<'a> FromTransformedData<'a> for Data {
/// } /// }
/// ///
/// // Read the data into a String. /// // Read the data into a String.
/// let mut string = String::new(); /// let string = match data.open(LIMIT).stream_to_string().await {
/// let mut reader = data.open().take(LIMIT); /// Ok(string) => string,
/// if let Err(e) = reader.read_to_string(&mut string).await { /// Err(e) => return Outcome::Failure((Status::InternalServerError, format!("{}", e)))
/// return Outcome::Failure((Status::InternalServerError, format!("{:?}", e))); /// };
/// }
/// ///
/// // Split the string into two pieces at ':'. /// // Split the string into two pieces at ':'.
/// let (name, age) = match string.find(':') { /// let (name, age) = match string.find(':') {
@ -550,7 +546,7 @@ pub trait FromData: Sized {
impl<'a, T: FromData + 'a> FromTransformedData<'a> for T { impl<'a, T: FromData + 'a> FromTransformedData<'a> for T {
type Error = T::Error; type Error = T::Error;
type Owned = Data; type Owned = Data;
type Borrowed = (); type Borrowed = Data;
#[inline(always)] #[inline(always)]
fn transform<'r>(_: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> { fn transform<'r>(_: &'r Request<'_>, d: Data) -> TransformFuture<'r, Self::Owned, Self::Error> {
@ -612,12 +608,8 @@ impl FromData for String {
#[inline(always)] #[inline(always)]
async fn from_data(_: &Request<'_>, data: Data) -> Outcome<Self, Self::Error> { async fn from_data(_: &Request<'_>, data: Data) -> Outcome<Self, Self::Error> {
use tokio::io::AsyncReadExt; match data.open(ByteUnit::max_value()).stream_to_string().await {
Ok(string) => Success(string),
let mut string = String::new();
let mut reader = data.open();
match reader.read_to_string(&mut string).await {
Ok(_) => Success(string),
Err(e) => Failure((Status::BadRequest, e)), Err(e) => Failure((Status::BadRequest, e)),
} }
} }
@ -632,7 +624,7 @@ impl FromData for Vec<u8> {
async fn from_data(_: &Request<'_>, data: Data) -> Outcome<Self, Self::Error> { async fn from_data(_: &Request<'_>, data: Data) -> Outcome<Self, Self::Error> {
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
let mut stream = data.open(); let mut stream = data.open(ByteUnit::max_value());
let mut buf = Vec::new(); let mut buf = Vec::new();
match stream.read_to_end(&mut buf).await { match stream.read_to_end(&mut buf).await {
Ok(_) => Success(buf), Ok(_) => Success(buf),

138
core/lib/src/data/limits.rs Normal file
View File

@ -0,0 +1,138 @@
use std::fmt;
use crate::data::{ByteUnit, ToByteUnit};
/// Mapping from data type to size limits.
///
/// A `Limits` structure contains a mapping from a given data type ("forms",
/// "json", and so on) to the maximum size in bytes that should be accepted by a
/// Rocket application for that data type. For instance, if the limit for
/// "forms" is set to `256`, only 256 bytes from an incoming form request will
/// be read.
///
/// # Defaults
///
/// As documented in [`config`](crate::config), the default limits are as follows:
///
/// * **forms**: 32KiB
///
/// # Usage
///
/// A `Limits` structure is created following the builder pattern:
///
/// ```rust
/// use rocket::data::{Limits, ToByteUnit};
///
/// // Set a limit of 64KiB for forms and 3MiB for JSON.
/// let limits = Limits::new()
/// .limit("forms", 64.kibibytes())
/// .limit("json", 3.mebibytes());
/// ```
#[derive(Debug, Clone)]
pub struct Limits {
// We cache this internally but don't share that fact in the API.
pub(crate) forms: ByteUnit,
extra: Vec<(String, ByteUnit)>
}
impl Default for Limits {
fn default() -> Limits {
// Default limit for forms is 32KiB.
Limits { forms: 32.kibibytes(), extra: Vec::new() }
}
}
impl Limits {
/// Construct a new `Limits` structure with the default limits set.
///
/// # Example
///
/// ```rust
/// use rocket::data::{Limits, ToByteUnit};
///
/// let limits = Limits::new();
/// assert_eq!(limits.get("forms"), Some(32.kibibytes()));
/// ```
#[inline]
pub fn new() -> Self {
Limits::default()
}
/// Adds or replaces a limit in `self`, consuming `self` and returning a new
/// `Limits` structure with the added or replaced limit.
///
/// # Example
///
/// ```rust
/// use rocket::data::{Limits, ToByteUnit};
///
/// let limits = Limits::new().limit("json", 1.mebibytes());
///
/// assert_eq!(limits.get("forms"), Some(32.kibibytes()));
/// assert_eq!(limits.get("json"), Some(1.mebibytes()));
///
/// let new_limits = limits.limit("json", 64.mebibytes());
/// assert_eq!(new_limits.get("json"), Some(64.mebibytes()));
/// ```
pub fn limit<S: Into<String>>(mut self, name: S, limit: ByteUnit) -> Self {
let name = name.into();
match name.as_str() {
"forms" => self.forms = limit,
_ => {
let mut found = false;
for tuple in &mut self.extra {
if tuple.0 == name {
tuple.1 = limit;
found = true;
break;
}
}
if !found {
self.extra.push((name, limit))
}
}
}
self
}
/// Retrieve the set limit, if any, for the data type with name `name`.
///
/// # Example
///
/// ```rust
/// use rocket::data::{Limits, ToByteUnit};
///
/// let limits = Limits::new().limit("json", 64.mebibytes());
///
/// assert_eq!(limits.get("forms"), Some(32.kibibytes()));
/// assert_eq!(limits.get("json"), Some(64.mebibytes()));
/// assert!(limits.get("msgpack").is_none());
/// ```
pub fn get(&self, name: &str) -> Option<ByteUnit> {
if name == "forms" {
return Some(self.forms);
}
for &(ref key, val) in &self.extra {
if key == name {
return Some(val);
}
}
None
}
}
impl fmt::Display for Limits {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "forms = {}", self.forms)?;
for (key, val) in &self.extra {
write!(f, ", {}* = {}", key, val)?;
}
Ok(())
}
}

View File

@ -3,7 +3,11 @@
mod data; mod data;
mod data_stream; mod data_stream;
mod from_data; mod from_data;
mod limits;
pub use self::data::Data; pub use self::data::Data;
pub use self::data_stream::DataStream; pub use self::data_stream::DataStream;
pub use self::from_data::{FromTransformedData, FromDataFuture, FromData, Outcome, Transform, Transformed, TransformFuture}; pub use self::from_data::{FromData, Outcome, FromTransformedData, FromDataFuture};
pub use self::from_data::{Transform, Transformed, TransformFuture};
pub use self::limits::Limits;
pub use ubyte::{ByteUnit, ToByteUnit};

View File

@ -2,11 +2,10 @@ use std::io::{self, Cursor};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Poll, Context}; use std::task::{Poll, Context};
use futures::{ready, future::BoxFuture, stream::Stream}; use futures::{ready, stream::Stream};
use tokio::io::{AsyncRead, AsyncReadExt as _}; use tokio::io::AsyncRead;
use crate::http::hyper; use crate::http::hyper::{self, Bytes, HttpBody};
use hyper::{Bytes, HttpBody};
pub struct IntoBytesStream<R> { pub struct IntoBytesStream<R> {
inner: R, inner: R,
@ -29,6 +28,7 @@ impl<R> Stream for IntoBytesStream<R>
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None), Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None),
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
// FIXME(perf).
let mut next = std::mem::replace(buffer, vec![0; buf_size]); let mut next = std::mem::replace(buffer, vec![0; buf_size]);
next.truncate(n); next.truncate(n);
Poll::Ready(Some(Ok(Bytes::from(next)))) Poll::Ready(Some(Ok(Bytes::from(next))))
@ -37,38 +37,20 @@ impl<R> Stream for IntoBytesStream<R>
} }
} }
pub trait AsyncReadExt: AsyncRead { pub trait AsyncReadExt: AsyncRead + Sized {
fn into_bytes_stream(self, buf_size: usize) -> IntoBytesStream<Self> where Self: Sized { fn into_bytes_stream(self, buf_size: usize) -> IntoBytesStream<Self> {
IntoBytesStream { inner: self, buf_size, buffer: vec![0; buf_size] } IntoBytesStream { inner: self, buf_size, buffer: vec![0; buf_size] }
} }
fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> BoxFuture<'_, io::Result<usize>>
where Self: Send + Unpin
{
Box::pin(async move {
let start_len = buf.len();
while !buf.is_empty() {
match self.read(buf).await {
Ok(0) => break,
Ok(n) => buf = &mut buf[n..],
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(start_len - buf.len())
})
}
} }
impl<T: AsyncRead> AsyncReadExt for T { } impl<T: AsyncRead> AsyncReadExt for T { }
pub struct AsyncReadBody { pub struct AsyncReadBody {
inner: hyper::Body, inner: hyper::Body,
state: AsyncReadBodyState, state: State,
} }
enum AsyncReadBodyState { enum State {
Pending, Pending,
Partial(Cursor<Bytes>), Partial(Cursor<Bytes>),
Done, Done,
@ -76,37 +58,43 @@ enum AsyncReadBodyState {
impl AsyncReadBody { impl AsyncReadBody {
pub fn empty() -> Self { pub fn empty() -> Self {
Self { inner: hyper::Body::empty(), state: AsyncReadBodyState::Done } Self { inner: hyper::Body::empty(), state: State::Done }
} }
} }
impl From<hyper::Body> for AsyncReadBody { impl From<hyper::Body> for AsyncReadBody {
fn from(body: hyper::Body) -> Self { fn from(body: hyper::Body) -> Self {
Self { inner: body, state: AsyncReadBodyState::Pending } Self { inner: body, state: State::Pending }
} }
} }
impl AsyncRead for AsyncReadBody { impl AsyncRead for AsyncReadBody {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8]
) -> Poll<io::Result<usize>> {
loop { loop {
match self.state { match self.state {
AsyncReadBodyState::Pending => { State::Pending => {
match ready!(Pin::new(&mut self.inner).poll_data(cx)) { match ready!(Pin::new(&mut self.inner).poll_data(cx)) {
Some(Ok(bytes)) => self.state = AsyncReadBodyState::Partial(Cursor::new(bytes)), Some(Ok(bytes)) => {
Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), self.state = State::Partial(Cursor::new(bytes));
None => self.state = AsyncReadBodyState::Done, }
Some(Err(e)) => {
let error = io::Error::new(io::ErrorKind::Other, e);
return Poll::Ready(Err(error));
}
None => self.state = State::Done,
} }
}, },
AsyncReadBodyState::Partial(ref mut cursor) => { State::Partial(ref mut cursor) => {
match ready!(Pin::new(cursor).poll_read(cx, buf)) { match ready!(Pin::new(cursor).poll_read(cx, buf)) {
Ok(n) if n == 0 => { Ok(n) if n == 0 => self.state = State::Pending,
self.state = AsyncReadBodyState::Pending; result => return Poll::Ready(result),
}
Ok(n) => return Poll::Ready(Ok(n)),
Err(e) => return Poll::Ready(Err(e)),
} }
} }
AsyncReadBodyState::Done => return Poll::Ready(Ok(0)), State::Done => return Poll::Ready(Ok(0)),
} }
} }
} }

View File

@ -142,7 +142,7 @@ pub fn ignite() -> Rocket {
/// Alias to [`Rocket::custom()`]. Creates a new instance of `Rocket` with a /// Alias to [`Rocket::custom()`]. Creates a new instance of `Rocket` with a
/// custom configuration. /// custom configuration.
pub fn custom(config: config::Config) -> Rocket { pub fn custom(config: Config) -> Rocket {
Rocket::custom(config) Rocket::custom(config)
} }

View File

@ -84,8 +84,8 @@ impl<'c> LocalRequest<'c> {
} }
// Actually dispatch the request. // Actually dispatch the request.
let data = Data::local(self.data); let mut data = Data::local(self.data);
let token = rocket.preprocess_request(&mut self.request, &data).await; let token = rocket.preprocess_request(&mut self.request, &mut data).await;
let response = LocalResponse::new(self.request, move |request| { let response = LocalResponse::new(self.request, move |request| {
rocket.dispatch(token, request, data) rocket.dispatch(token, request, data)
}).await; }).await;

View File

@ -1,7 +1,5 @@
use std::ops::Deref; use std::ops::Deref;
use tokio::io::AsyncReadExt;
use crate::outcome::Outcome::*; use crate::outcome::Outcome::*;
use crate::request::{Request, form::{FromForm, FormItems, FormDataError}}; use crate::request::{Request, form::{FromForm, FormItems, FormDataError}};
use crate::data::{Outcome, Transform, Transformed, Data, FromTransformedData, TransformFuture, FromDataFuture}; use crate::data::{Outcome, Transform, Transformed, Data, FromTransformedData, TransformFuture, FromDataFuture};
@ -193,21 +191,18 @@ impl<'f, T: FromForm<'f> + Send + 'f> FromTransformedData<'f> for Form<T> {
data: Data data: Data
) -> TransformFuture<'r, Self::Owned, Self::Error> { ) -> TransformFuture<'r, Self::Owned, Self::Error> {
Box::pin(async move { Box::pin(async move {
use std::cmp::min;
if !request.content_type().map_or(false, |ct| ct.is_form()) { if !request.content_type().map_or(false, |ct| ct.is_form()) {
warn_!("Form data does not have form content type."); warn_!("Form data does not have form content type.");
return Transform::Borrowed(Forward(data)); return Transform::Borrowed(Forward(data));
} }
let limit = request.limits().forms; match data.open(request.limits().forms).stream_to_string().await {
let mut stream = data.open().take(limit); Ok(form_string) => Transform::Borrowed(Success(form_string)),
let mut form_string = String::with_capacity(min(4096, limit) as usize); Err(e) => {
if let Err(e) = stream.read_to_string(&mut form_string).await { let err = (Status::InternalServerError, FormDataError::Io(e));
return Transform::Borrowed(Failure((Status::InternalServerError, FormDataError::Io(e)))); Transform::Borrowed(Failure(err))
}
} }
Transform::Borrowed(Success(form_string))
}) })
} }

View File

@ -12,14 +12,12 @@ use atomic::Atomic;
use crate::request::{FromParam, FromSegments, FromRequest, Outcome}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
use crate::request::{FromFormValue, FormItems, FormItem}; use crate::request::{FromFormValue, FormItems, FormItem};
use crate::rocket::Rocket; use crate::{Rocket, Config, Shutdown, Route};
use crate::shutdown::Shutdown;
use crate::router::Route;
use crate::config::{Config, Limits};
use crate::http::{hyper, uri::{Origin, Segments}}; use crate::http::{hyper, uri::{Origin, Segments}};
use crate::http::{Method, Header, HeaderMap, Cookies}; use crate::http::{Method, Header, HeaderMap, Cookies};
use crate::http::{RawStr, ContentType, Accept, MediaType}; use crate::http::{RawStr, ContentType, Accept, MediaType};
use crate::http::private::{Indexed, SmallVec, CookieJar}; use crate::http::private::{Indexed, SmallVec, CookieJar};
use crate::data::Limits;
type Indices = (usize, usize); type Indices = (usize, usize);
@ -513,7 +511,7 @@ impl<'r> Request<'r> {
} }
} }
/// Returns the configured application receive limits. /// Returns the configured application data limits.
/// ///
/// # Example /// # Example
/// ///

View File

@ -19,16 +19,13 @@ use yansi::Paint;
/// ```rust /// ```rust
/// use std::io; /// use std::io;
/// ///
/// use tokio::io::AsyncReadExt;
///
/// # use rocket::post; /// # use rocket::post;
/// use rocket::Data; /// use rocket::data::{Data, ToByteUnit};
/// use rocket::response::Debug; /// use rocket::response::Debug;
/// ///
/// #[post("/", format = "plain", data = "<data>")] /// #[post("/", format = "plain", data = "<data>")]
/// async fn post(data: Data) -> Result<String, Debug<io::Error>> { /// async fn post(data: Data) -> Result<String, Debug<io::Error>> {
/// let mut name = String::with_capacity(32); /// let name = data.open(32.bytes()).stream_to_string().await?;
/// data.open().take(32).read_to_string(&mut name).await?;
/// Ok(name) /// Ok(name)
/// } /// }
/// ``` /// ```

View File

@ -1,5 +1,4 @@
use std::{io, mem}; use std::{io, mem};
use std::cmp::min;
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashMap; use std::collections::HashMap;
@ -211,10 +210,10 @@ async fn hyper_service_fn(
}; };
// Retrieve the data from the hyper body. // Retrieve the data from the hyper body.
let data = Data::from_hyp(h_body).await; let mut data = Data::from_hyp(h_body).await;
// Dispatch the request to get a response, then write that response out. // Dispatch the request to get a response, then write that response out.
let token = rocket.preprocess_request(&mut req, &data).await; let token = rocket.preprocess_request(&mut req, &mut data).await;
let r = rocket.dispatch(token, &mut req, data).await; let r = rocket.dispatch(token, &mut req, data).await;
rocket.issue_response(r, tx).await; rocket.issue_response(r, tx).await;
}); });
@ -303,16 +302,16 @@ impl Rocket {
pub(crate) async fn preprocess_request( pub(crate) async fn preprocess_request(
&self, &self,
req: &mut Request<'_>, req: &mut Request<'_>,
data: &Data data: &mut Data
) -> Token { ) -> Token {
// Check if this is a form and if the form contains the special _method // Check if this is a form and if the form contains the special _method
// field which we use to reinterpret the request's method. // field which we use to reinterpret the request's method.
let data_len = data.peek().len();
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len()); let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
let peek_buffer = data.peek(max_len).await;
let is_form = req.content_type().map_or(false, |ct| ct.is_form()); let is_form = req.content_type().map_or(false, |ct| ct.is_form());
if is_form && req.method() == Method::Post && data_len >= min_len { if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
if let Ok(form) = std::str::from_utf8(&data.peek()[..min(data_len, max_len)]) { if let Ok(form) = std::str::from_utf8(peek_buffer) {
let method: Option<Result<Method, _>> = FormItems::from(form) let method: Option<Result<Method, _>> = FormItems::from(form)
.filter(|item| item.key.as_str() == "_method") .filter(|item| item.key.as_str() == "_method")
.map(|item| item.value.parse()) .map(|item| item.value.parse())

View File

@ -14,13 +14,14 @@ fn index(form: Form<Simple>) -> String {
mod limits_tests { mod limits_tests {
use rocket; use rocket;
use rocket::config::{Environment, Config, Limits}; use rocket::config::{Environment, Config};
use rocket::local::blocking::Client; use rocket::local::blocking::Client;
use rocket::http::{Status, ContentType}; use rocket::http::{Status, ContentType};
use rocket::data::Limits;
fn rocket_with_forms_limit(limit: u64) -> rocket::Rocket { fn rocket_with_forms_limit(limit: u64) -> rocket::Rocket {
let config = Config::build(Environment::Development) let config = Config::build(Environment::Development)
.limits(Limits::default().limit("forms", limit)) .limits(Limits::default().limit("forms", limit.into()))
.unwrap(); .unwrap();
rocket::custom(config).mount("/", routes![super::index]) rocket::custom(config).mount("/", routes![super::index])

View File

@ -4,8 +4,8 @@
use std::io; use std::io;
use rocket::tokio::io::AsyncReadExt; use rocket::request::Request;
use rocket::{Request, data::Data}; use rocket::data::{Data, ToByteUnit};
use rocket::response::{Debug, content::{Json, Html}}; use rocket::response::{Debug, content::{Json, Html}};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
@ -27,7 +27,7 @@ struct Person {
#[get("/<name>/<age>", format = "json")] #[get("/<name>/<age>", format = "json")]
fn get_hello(name: String, age: u8) -> Json<String> { fn get_hello(name: String, age: u8) -> Json<String> {
// NOTE: In a real application, we'd use `rocket_contrib::json::Json`. // NOTE: In a real application, we'd use `rocket_contrib::json::Json`.
let person = Person { name: name, age: age, }; let person = Person { name, age };
Json(serde_json::to_string(&person).unwrap()) Json(serde_json::to_string(&person).unwrap())
} }
@ -39,10 +39,8 @@ fn get_hello(name: String, age: u8) -> Json<String> {
// use `contrib::Json` to automatically serialize a type into JSON. // use `contrib::Json` to automatically serialize a type into JSON.
#[post("/<age>", format = "plain", data = "<name_data>")] #[post("/<age>", format = "plain", data = "<name_data>")]
async fn post_hello(age: u8, name_data: Data) -> Result<Json<String>, Debug<io::Error>> { async fn post_hello(age: u8, name_data: Data) -> Result<Json<String>, Debug<io::Error>> {
let mut name = String::with_capacity(32); let name = name_data.open(64.bytes()).stream_to_string().await?;
let mut stream = name_data.open().take(32); let person = Person { name, age };
stream.read_to_string(&mut name).await?;
let person = Person { name: name, age: age, };
// NOTE: In a real application, we'd use `rocket_contrib::json::Json`. // NOTE: In a real application, we'd use `rocket_contrib::json::Json`.
Ok(Json(serde_json::to_string(&person).expect("valid JSON"))) Ok(Json(serde_json::to_string(&person).expect("valid JSON")))
} }

View File

@ -3,7 +3,8 @@ mod tests;
use std::env; use std::env;
use rocket::{Request, Route, Data}; use rocket::{Request, Route};
use rocket::data::{Data, ToByteUnit};
use rocket::http::{Status, RawStr, Method::*}; use rocket::http::{Status, RawStr, Method::*};
use rocket::response::{Responder, status::Custom}; use rocket::response::{Responder, status::Custom};
use rocket::handler::{Handler, Outcome, HandlerFuture}; use rocket::handler::{Handler, Outcome, HandlerFuture};
@ -47,7 +48,7 @@ fn upload<'r>(req: &'r Request, data: Data) -> HandlerFuture<'r> {
let file = File::create(env::temp_dir().join("upload.txt")).await; let file = File::create(env::temp_dir().join("upload.txt")).await;
if let Ok(file) = file { if let Ok(file) = file {
if let Ok(n) = data.stream_to(file).await { if let Ok(n) = data.open(2.mebibytes()).stream_to(file).await {
return Outcome::from(req, format!("OK: {} bytes uploaded.", n)); return Outcome::from(req, format!("OK: {} bytes uploaded.", n));
} }

View File

@ -5,7 +5,7 @@ mod paste_id;
use std::io; use std::io;
use rocket::Data; use rocket::data::{Data, ToByteUnit};
use rocket::response::{content::Plain, Debug}; use rocket::response::{content::Plain, Debug};
use rocket::tokio::fs::File; use rocket::tokio::fs::File;
@ -20,7 +20,7 @@ async fn upload(paste: Data) -> Result<String, Debug<io::Error>> {
let filename = format!("upload/{id}", id = id); let filename = format!("upload/{id}", id = id);
let url = format!("{host}/{id}\n", host = HOST, id = id); let url = format!("{host}/{id}\n", host = HOST, id = id);
paste.stream_to_file(filename).await?; paste.open(128.kibibytes()).stream_to_file(filename).await?;
Ok(url) Ok(url)
} }

View File

@ -3,11 +3,13 @@
#[cfg(test)] mod tests; #[cfg(test)] mod tests;
use std::{io, env}; use std::{io, env};
use rocket::{Data, response::Debug}; use rocket::data::{Data, ToByteUnit};
use rocket::response::Debug;
#[post("/upload", format = "plain", data = "<data>")] #[post("/upload", format = "plain", data = "<data>")]
async fn upload(data: Data) -> Result<String, Debug<io::Error>> { async fn upload(data: Data) -> Result<String, Debug<io::Error>> {
Ok(data.stream_to_file(env::temp_dir().join("upload.txt")).await?.to_string()) let path = env::temp_dir().join("upload.txt");
Ok(data.open(128.kibibytes()).stream_to_file(path).await?.to_string())
} }
#[get("/")] #[get("/")]

View File

@ -70,7 +70,9 @@ echo ":: Checking for trailing whitespace..."
ensure_trailing_whitespace_free ensure_trailing_whitespace_free
echo ":: Updating dependencies..." echo ":: Updating dependencies..."
$CARGO update if ! $CARGO update ; then
echo " WARNING: Update failed! Proceeding with possibly outdated deps..."
fi
if [ "$1" = "--contrib" ]; then if [ "$1" = "--contrib" ]; then
FEATURES=( FEATURES=(

View File

@ -260,8 +260,8 @@ Here's our version (in `src/main.rs`):
# fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } # fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Ok(()) }
# } # }
use rocket::Data;
use rocket::response::Debug; use rocket::response::Debug;
use rocket::data::{Data, ToByteUnit};
#[post("/", data = "<paste>")] #[post("/", data = "<paste>")]
async fn upload(paste: Data) -> Result<String, Debug<std::io::Error>> { async fn upload(paste: Data) -> Result<String, Debug<std::io::Error>> {
@ -269,13 +269,14 @@ async fn upload(paste: Data) -> Result<String, Debug<std::io::Error>> {
let filename = format!("upload/{id}", id = id); let filename = format!("upload/{id}", id = id);
let url = format!("{host}/{id}\n", host = "http://localhost:8000", id = id); let url = format!("{host}/{id}\n", host = "http://localhost:8000", id = id);
// Write the paste out to the file and return the URL. // Write the paste out, limited to 128KiB, and return the URL.
paste.stream_to_file(filename).await?; paste.open(128.kibibytes()).stream_to_file(filename).await?;
Ok(url) Ok(url)
} }
``` ```
Ensure that the route is mounted at the root path: Note the [`kibibytes()`] method call: this method comes from the [`ToByteUnit`]
extension trait. Ensure that the route is mounted at the root path:
```rust ```rust
# #[macro_use] extern crate rocket; # #[macro_use] extern crate rocket;
@ -310,6 +311,9 @@ cat upload/* # ensure that contents are correct
Note that since we haven't created a `GET /<id>` route, visiting the returned URL Note that since we haven't created a `GET /<id>` route, visiting the returned URL
will result in a **404**. We'll fix that now. will result in a **404**. We'll fix that now.
[`kibibytes()`]: @api/rocket/data/trait.ToByteUnit.html#tymethod.kibibytes
[`ToByteUnit`]: @api/rocket/data/trait.ToByteUnit.html
## Retrieving Pastes ## Retrieving Pastes
The final step is to create the `retrieve` route which, given an `<id>`, will The final step is to create the `retrieve` route which, given an `<id>`, will
@ -443,9 +447,9 @@ through some of them to get a better feel for Rocket. Here are some ideas:
the two `POST /` routes should be called. the two `POST /` routes should be called.
* Support **deletion** of pastes by adding a new `DELETE /<id>` route. Use * Support **deletion** of pastes by adding a new `DELETE /<id>` route. Use
`PasteId` to validate `<id>`. `PasteId` to validate `<id>`.
* **Limit the upload** to a maximum size. If the upload exceeds that size, * Indicate **partial uploads** with a **206** partial status code. If the user
return a **206** partial status code. Otherwise, return a **201** created uploads a paste that meets or exceeds the allowed limit, return a **206**
status code. partial status code. Otherwise, return a **201** created status code.
* Set the `Content-Type` of the return value in `upload` and `retrieve` to * Set the `Content-Type` of the return value in `upload` and `retrieve` to
`text/plain`. `text/plain`.
* **Return a unique "key"** after each upload and require that the key is * **Return a unique "key"** after each upload and require that the key is

View File

@ -1037,36 +1037,39 @@ The only condition is that the generic type in `Json` implements the
Sometimes you just want to handle incoming data directly. For example, you might Sometimes you just want to handle incoming data directly. For example, you might
want to stream the incoming data out to a file. Rocket makes this as simple as want to stream the incoming data out to a file. Rocket makes this as simple as
possible via the [`Data`](@api/rocket/data/struct.Data.html) possible via the [`Data`](@api/rocket/data/struct.Data.html) type:
type:
```rust ```rust
# #[macro_use] extern crate rocket; # #[macro_use] extern crate rocket;
# fn main() {} # fn main() {}
use rocket::Data; use rocket::data::{Data, ToByteUnit};
use rocket::response::Debug; use rocket::response::Debug;
#[post("/upload", format = "plain", data = "<data>")] #[post("/upload", format = "plain", data = "<data>")]
async fn upload(data: Data) -> Result<String, Debug<std::io::Error>> { async fn upload(data: Data) -> Result<String, Debug<std::io::Error>> {
Ok(data.stream_to_file("/tmp/upload.txt").await.map(|n| n.to_string())?) let bytes_written = data.open(128.kibibytes())
.stream_to_file("/tmp/upload.txt")
.await?;
Ok(bytes_written.to_string())
} }
``` ```
The route above accepts any `POST` request to the `/upload` path with The route above accepts any `POST` request to the `/upload` path with
`Content-Type: text/plain` The incoming data is streamed out to `Content-Type: text/plain` At most 128KiB (`128 << 10` bytes) of the incoming
`tmp/upload.txt`, and the number of bytes written is returned as a plain text data are streamed out to `tmp/upload.txt`, and the number of bytes written is
response if the upload succeeds. If the upload fails, an error response is returned as a plain text response if the upload succeeds. If the upload fails,
returned. The handler above is complete. It really is that simple! See the an error response is returned. The handler above is complete. It really is that
[GitHub example code](@example/raw_upload) for the full crate. simple! See the [GitHub example code](@example/raw_upload) for the full crate.
! warning: You should _always_ set limits when reading incoming data. ! note: Rocket requires setting limits when reading incoming data.
To prevent DoS attacks, you should limit the amount of data you're willing to To aid in preventing DoS attacks, Rocket requires you to specify, as a
accept. The [`take()`] reader adapter makes doing this easy: [`ByteUnit`](@api/rocket/data/struct.ByteUnit.html), the amount of data you're
`data.open().take(LIMIT)`. willing to accept from the client when `open`ing a data stream. The
[`ToByteUnit`](@api/rocket/data/trait.ToByteUnit.html) trait makes specifying
[`take()`]: https://doc.rust-lang.org/std/io/trait.Read.html#method.take such a value as idiomatic as `128.kibibytes()`.
## Async Routes ## Async Routes