diff --git a/core/lib/src/data/data.rs b/core/lib/src/data/data.rs index bbd2bb75..c54fa00d 100644 --- a/core/lib/src/data/data.rs +++ b/core/lib/src/data/data.rs @@ -1,9 +1,10 @@ -use crate::tokio::io::AsyncReadExt; -use crate::data::data_stream::DataStream; -use crate::data::{ByteUnit, StreamReader}; +use std::io; +use std::pin::Pin; -/// The number of bytes to read into the "peek" buffer. -pub const PEEK_BYTES: usize = 512; +use crate::data::ByteUnit; +use crate::data::data_stream::{DataStream, RawReader, RawStream}; +use crate::data::peekable::Peekable; +use crate::data::transform::{Transform, TransformBuf, Inspect, InPlaceMap}; /// Type representing the body data of a request. /// @@ -38,31 +39,27 @@ pub const PEEK_BYTES: usize = 512; /// body data. This enables partially or fully reading from a `Data` object /// without consuming the `Data` object. pub struct Data<'r> { - buffer: Vec, - is_complete: bool, - stream: StreamReader<'r>, + stream: Peekable<512, RawReader<'r>>, + transforms: Vec>>, } +// TODO: Before `async`, we had a read timeout of 5s. Such a short read timeout +// is likely no longer necessary, but an idle timeout should be implemented. impl<'r> Data<'r> { - /// Create a `Data` from a recognized `stream`. - pub(crate) fn from>>(stream: S) -> Data<'r> { - // TODO.async: This used to also set the read timeout to 5 seconds. - // Such a short read timeout is likely no longer necessary, but some - // kind of idle timeout should be implemented. + #[inline] + pub(crate) fn new(stream: Peekable<512, RawReader<'r>>) -> Self { + Self { stream, transforms: Vec::new() } + } - let stream = stream.into(); - let buffer = Vec::with_capacity(PEEK_BYTES / 8); - Data { buffer, stream, is_complete: false } + #[inline] + pub(crate) fn from>>(stream: S) -> Data<'r> { + Data::new(Peekable::new(RawReader::new(stream.into()))) } /// This creates a `data` object from a local data source `data`. #[inline] pub(crate) fn local(data: Vec) -> Data<'r> { - Data { - buffer: data, - stream: StreamReader::empty(), - is_complete: true, - } + Data::new(Peekable::with_buffer(data, true, RawReader::new(RawStream::Empty))) } /// Returns the raw data stream, limited to `limit` bytes. @@ -82,18 +79,31 @@ impl<'r> Data<'r> { /// let stream = data.open(2.mebibytes()); /// } /// ``` + #[inline(always)] pub fn open(self, limit: ByteUnit) -> DataStream<'r> { - DataStream::new(self.buffer, self.stream, limit.into()) + DataStream::new(self.transforms, self.stream, limit.into()) } - /// Retrieve at most `num` bytes from the `peek` buffer without consuming - /// `self`. + /// Fills the peek buffer with body data until it contains at least `num` + /// bytes (capped to 512), or the complete body data, whichever is less, and + /// returns it. If the buffer already contains either at least `num` bytes + /// or all of the body data, no I/O is performed and the buffer is simply + /// returned. If `num` is greater than `512`, it is artificially capped to + /// `512`. /// - /// The peek buffer contains at most 512 bytes of the body of the request. - /// The actual size of the returned buffer is the `min` of the request's - /// body, `num` and `512`. The [`peek_complete`](#method.peek_complete) - /// method can be used to determine if this buffer contains _all_ of the - /// data in the body of the request. + /// No guarantees are made about the actual size of the returned buffer + /// except that it will not exceed the length of the body data. It may be: + /// + /// * Less than `num` if `num > 512` or the complete body data is `< 512` + /// or an error occurred while reading the body. + /// * Equal to `num` if `num` is `<= 512` and exactly `num` bytes of the + /// body data were successfully read. + /// * Greater than `num` if `> num` bytes of the body data have + /// successfully been read, either by this request, a previous request, + /// or opportunistically. + /// + /// [`Data::peek_complete()`] can be used to determine if this buffer + /// contains the complete body data. /// /// # Examples /// @@ -147,30 +157,13 @@ impl<'r> Data<'r> { /// } /// } /// ``` + #[inline(always)] pub async fn peek(&mut self, num: usize) -> &[u8] { - let num = std::cmp::min(PEEK_BYTES, num); - let mut len = self.buffer.len(); - if len >= num { - return &self.buffer[..num]; - } - - while len < num { - match self.stream.read_buf(&mut self.buffer).await { - Ok(0) => { self.is_complete = true; break }, - Ok(n) => len += n, - Err(e) => { - error_!("Failed to read into peek buffer: {:?}.", e); - break; - } - } - } - - &self.buffer[..std::cmp::min(len, num)] + self.stream.peek(num).await } /// Returns true if the `peek` buffer contains all of the data in the body - /// of the request. Returns `false` if it does not or if it is not known if - /// it does. + /// of the request. Returns `false` if it does not or it is not known. /// /// # Example /// @@ -185,6 +178,43 @@ impl<'r> Data<'r> { /// ``` #[inline(always)] pub fn peek_complete(&self) -> bool { - self.is_complete + self.stream.complete + } + + /// Chains the [`Transform`] `transform` to `self`. + /// + /// Note that transforms do nothing until the data is + /// [`open()`ed](Data::open()) and read. + #[inline(always)] + pub fn chain_transform(&mut self, transform: T) -> &mut Self + where T: Transform + Send + Sync + 'static + { + self.transforms.push(Box::pin(transform)); + self + } + + /// Chain a [`Transform`] that can inspect the data as it streams. + pub fn chain_inspect(&mut self, f: F) -> &mut Self + where F: FnMut(&[u8]) + Send + Sync + 'static + { + self.chain_transform(Inspect(Box::new(f))) + } + + /// Chain a [`Transform`] that can in-place map the data as it streams. + /// Unlike [`Data::chain_try_inplace_map()`], this version assumes the + /// mapper is infallible. + pub fn chain_inplace_map(&mut self, mut f: F) -> &mut Self + where F: FnMut(&mut TransformBuf<'_, '_>) + Send + Sync + 'static + { + self.chain_transform(InPlaceMap(Box::new(move |buf| Ok(f(buf))))) + } + + /// Chain a [`Transform`] that can in-place map the data as it streams. + /// Unlike [`Data::chain_inplace_map()`], this version allows the mapper to + /// be infallible. + pub fn chain_try_inplace_map(&mut self, f: F) -> &mut Self + where F: FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static + { + self.chain_transform(InPlaceMap(Box::new(f))) } } diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index a9a9e07a..f30f046e 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -5,13 +5,17 @@ use std::io::{self, Cursor}; use tokio::fs::File; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take}; -use futures::stream::Stream; -use futures::ready; -use yansi::Paint; +use tokio_util::io::StreamReader; +use futures::{ready, stream::Stream}; use crate::http::hyper; use crate::ext::{PollExt, Chain}; use crate::data::{Capped, N}; +use crate::http::hyper::body::Bytes; +use crate::data::transform::Transform; + +use super::peekable::Peekable; +use super::transform::TransformBuf; /// Raw data stream of a request body. /// @@ -40,47 +44,101 @@ use crate::data::{Capped, N}; /// /// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to() /// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to() -pub struct DataStream<'r> { - pub(crate) chain: Take>, StreamReader<'r>>>, +#[non_exhaustive] +pub enum DataStream<'r> { + #[doc(hidden)] + Base(BaseReader<'r>), + #[doc(hidden)] + Transform(TransformReader<'r>), } -/// An adapter: turns a `T: Stream` (in `StreamKind`) into a `tokio::AsyncRead`. -pub struct StreamReader<'r> { - state: State, - inner: StreamKind<'r>, +/// A data stream that has a `transformer` applied to it. +pub struct TransformReader<'r> { + transformer: Pin>, + stream: Pin>>, + inner_done: bool, } -/// The current state of `StreamReader` `AsyncRead` adapter. -enum State { - Pending, - Partial(Cursor), - Done, -} +/// Limited, pre-buffered reader to the underlying data stream. +pub type BaseReader<'r> = Take>, RawReader<'r>>>; -/// The kinds of streams we accept as `Data`. -enum StreamKind<'r> { +/// Direct reader to the underlying data stream. Not limited in any manner. +pub type RawReader<'r> = StreamReader, Bytes>; + +/// Raw underlying data stream. +pub enum RawStream<'r> { Empty, Body(&'r mut hyper::Body), - Multipart(multer::Field<'r>) + Multipart(multer::Field<'r>), +} + +impl<'r> TransformReader<'r> { + /// Returns the underlying `BaseReader`. + fn base_mut(&mut self) -> &mut BaseReader<'r> { + match self.stream.as_mut().get_mut() { + DataStream::Base(base) => base, + DataStream::Transform(inner) => inner.base_mut(), + } + } + + /// Returns the underlying `BaseReader`. + fn base(&self) -> &BaseReader<'r> { + match self.stream.as_ref().get_ref() { + DataStream::Base(base) => base, + DataStream::Transform(inner) => inner.base(), + } + } } impl<'r> DataStream<'r> { - pub(crate) fn new(buf: Vec, stream: StreamReader<'r>, limit: u64) -> Self { - let chain = Chain::new(Cursor::new(buf), stream).take(limit).into(); - Self { chain } + pub(crate) fn new( + transformers: Vec>>, + Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>, + limit: u64 + ) -> Self { + let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit)); + for transformer in transformers { + stream = DataStream::Transform(TransformReader { + transformer, + stream: Box::pin(stream), + inner_done: false, + }); + } + + stream + } + + /// Returns the underlying `BaseReader`. + fn base_mut(&mut self) -> &mut BaseReader<'r> { + match self { + DataStream::Base(base) => base, + DataStream::Transform(transform) => transform.base_mut(), + } + } + + /// Returns the underlying `BaseReader`. + fn base(&self) -> &BaseReader<'r> { + match self { + DataStream::Base(base) => base, + DataStream::Transform(transform) => transform.base(), + } } /// Whether a previous read exhausted the set limit _and then some_. async fn limit_exceeded(&mut self) -> io::Result { + let base = self.base_mut(); + #[cold] - async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result { + async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result { // Read one more byte after reaching limit to see if we cut early. - stream.chain.set_limit(1); + base.set_limit(1); let mut buf = [0u8; 1]; - Ok(stream.read(&mut buf).await? != 0) + let exceeded = base.read(&mut buf).await? != 0; + base.set_limit(0); + Ok(exceeded) } - Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?) + Ok(base.limit() == 0 && _limit_exceeded(base).await?) } /// Number of bytes a full read from `self` will _definitely_ read. @@ -95,8 +153,9 @@ impl<'r> DataStream<'r> { /// } /// ``` pub fn hint(&self) -> usize { - let buf_len = self.chain.get_ref().get_ref().0.get_ref().len(); - std::cmp::min(buf_len, self.chain.limit() as usize) + let base = self.base(); + let buf_len = base.get_ref().get_ref().0.get_ref().len(); + std::cmp::min(buf_len, base.limit() as usize) } /// A helper method to write the body of the request to any `AsyncWrite` @@ -227,97 +286,86 @@ impl<'r> DataStream<'r> { } } -// TODO.async: Consider implementing `AsyncBufRead`. - -impl StreamReader<'_> { - pub fn empty() -> Self { - Self { inner: StreamKind::Empty, state: State::Done } - } -} - -impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> { - fn from(body: &'r mut hyper::Body) -> Self { - Self { inner: StreamKind::Body(body), state: State::Pending } - } -} - -impl<'r> From> for StreamReader<'r> { - fn from(field: multer::Field<'r>) -> Self { - Self { inner: StreamKind::Multipart(field), state: State::Pending } - } -} - impl AsyncRead for DataStream<'_> { - #[inline(always)] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf), + DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf), + } + } +} + +impl AsyncRead for TransformReader<'_> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - if self.chain.limit() == 0 { - let stream: &StreamReader<'_> = &self.chain.get_ref().get_ref().1; - let kind = match stream.inner { - StreamKind::Empty => "an empty stream (vacuous)", - StreamKind::Body(_) => "the request body", - StreamKind::Multipart(_) => "a multipart form field", - }; - - warn_!("Data limit reached while reading {}.", kind.primary().bold()); + let init_fill = buf.filled().len(); + if !self.inner_done { + ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?; + self.inner_done = init_fill == buf.filled().len(); } - Pin::new(&mut self.chain).poll_read(cx, buf) + if self.inner_done { + return self.transformer.as_mut().poll_finish(cx, buf); + } + + let mut tbuf = TransformBuf { buf, cursor: init_fill }; + self.transformer.as_mut().transform(&mut tbuf)?; + if buf.filled().len() == init_fill { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + Poll::Ready(Ok(())) } } -impl Stream for StreamKind<'_> { - type Item = io::Result; +impl Stream for RawStream<'_> { + type Item = io::Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - StreamKind::Body(body) => Pin::new(body).poll_next(cx) + RawStream::Body(body) => Pin::new(body).poll_next(cx) .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), - StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx) + RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx) .map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)), - StreamKind::Empty => Poll::Ready(None), + RawStream::Empty => Poll::Ready(None), } } fn size_hint(&self) -> (usize, Option) { match self { - StreamKind::Body(body) => body.size_hint(), - StreamKind::Multipart(mp) => mp.size_hint(), - StreamKind::Empty => (0, Some(0)), + RawStream::Body(body) => body.size_hint(), + RawStream::Multipart(mp) => mp.size_hint(), + RawStream::Empty => (0, Some(0)), } } } -impl AsyncRead for StreamReader<'_> { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - loop { - self.state = match self.state { - State::Pending => { - match ready!(Pin::new(&mut self.inner).poll_next(cx)) { - Some(Err(e)) => return Poll::Ready(Err(e)), - Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)), - None => State::Done, - } - }, - State::Partial(ref mut cursor) => { - let rem = buf.remaining(); - match ready!(Pin::new(cursor).poll_read(cx, buf)) { - Ok(()) if rem == buf.remaining() => State::Pending, - result => return Poll::Ready(result), - } - } - State::Done => return Poll::Ready(Ok(())), - } +impl std::fmt::Display for RawStream<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RawStream::Empty => f.write_str("empty stream"), + RawStream::Body(_) => f.write_str("request body"), + RawStream::Multipart(_) => f.write_str("multipart form field"), } } } + +impl<'r> From<&'r mut hyper::Body> for RawStream<'r> { + fn from(value: &'r mut hyper::Body) -> Self { + Self::Body(value) + } +} + +impl<'r> From> for RawStream<'r> { + fn from(value: multer::Field<'r>) -> Self { + Self::Multipart(value) + } +} diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index 9c7a3314..e3eebdd2 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -7,6 +7,8 @@ mod data_stream; mod from_data; mod limits; mod io_stream; +mod transform; +mod peekable; pub use self::data::Data; pub use self::data_stream::DataStream; @@ -15,5 +17,4 @@ pub use self::limits::Limits; pub use self::capped::{N, Capped}; pub use self::io_stream::{IoHandler, IoStream}; pub use ubyte::{ByteUnit, ToByteUnit}; - -pub(crate) use self::data_stream::StreamReader; +pub use self::transform::{Transform, TransformBuf}; diff --git a/core/lib/src/data/peekable.rs b/core/lib/src/data/peekable.rs new file mode 100644 index 00000000..58daac70 --- /dev/null +++ b/core/lib/src/data/peekable.rs @@ -0,0 +1,48 @@ +use tokio::io::{AsyncRead, AsyncReadExt}; + +pub struct Peekable { + pub(crate) buffer: Vec, + pub(crate) complete: bool, + pub(crate) reader: R, +} + +impl Peekable { + pub fn new(reader: R) -> Self { + Self { buffer: Vec::new(), complete: false, reader } + } + + pub fn with_buffer(buffer: Vec, complete: bool, reader: R) -> Self { + Self { buffer, complete, reader } + } + + pub async fn peek(&mut self, num: usize) -> &[u8] { + if self.complete { + return self.buffer.as_slice(); + } + + let to_read = std::cmp::min(N, num); + if self.buffer.len() >= to_read { + return &self.buffer.as_slice(); + } + + if self.buffer.capacity() == 0 { + self.buffer.reserve(N); + } + + while self.buffer.len() < to_read { + match self.reader.read_buf::>(&mut self.buffer).await { + Ok(0) => { + self.complete = self.buffer.capacity() > self.buffer.len(); + break; + }, + Ok(_) => { /* continue */ }, + Err(e) => { + error_!("Failed to read into peek buffer: {:?}.", e); + break; + } + } + } + + self.buffer.as_slice() + } +} diff --git a/core/lib/src/data/transform.rs b/core/lib/src/data/transform.rs new file mode 100644 index 00000000..e3be992c --- /dev/null +++ b/core/lib/src/data/transform.rs @@ -0,0 +1,304 @@ +use std::io; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::task::{Poll, Context}; + +use tokio::io::ReadBuf; + +/// Chainable, in-place, streaming data transformer. +/// +/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`] +/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of +/// transforming readers. The head (most upstream part) of the chain is _always_ +/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is +/// composed _only_ of other [`Transform`]s: +/// +/// ```text +/// downstream ---> +/// AsyncRead | Transform | .. | Transform +/// <---- upstream +/// ``` +/// +/// When the upstream source makes data available, the +/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the +/// subset of the filled section added by an upstream data source with +/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially +/// changing the size of the filled section. For example, +/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and +/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place. +/// +/// Additionally, new data may be added in-place via the traditional approach: +/// write to (or overwrite) the initialized section of the buffer and mark it as +/// filled. All of the remaining filled data will be passed to downstream +/// transforms as "fresh" data. To add data to the end of the (potentially +/// rewritten) stream, the [`Transform::poll_finish()`] method can be +/// implemented. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +pub trait Transform { + /// Called when data is read from the upstream source. For any given fresh + /// data, this method is called only once. [`TransformBuf::fresh()`] is + /// guaranteed to contain at least one byte. + /// + /// While this method is not _async_ (it does not return [`Poll`]), it is + /// nevertheless executed in an async context and should respect all such + /// restrictions including not blocking. + fn transform( + self: Pin<&mut Self>, + buf: &mut TransformBuf<'_, '_>, + ) -> io::Result<()>; + + /// Called when the upstream is finished, that is, it has no more data to + /// fill. At this point, the transform becomes an async reader. This method + /// thus has identical semantics to [`AsyncRead::poll_read()`]. This method + /// may never be called if the upstream does not finish. + /// + /// The default implementation returns `Poll::Ready(Ok(()))`. + /// + /// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read() + fn poll_finish( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let (_, _) = (cx, buf); + Poll::Ready(Ok(())) + } +} + +/// A buffer of transformable streaming data. +/// +/// # Overview +/// +/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh +/// data is always a subset of the filled data, filled data is always a subset +/// of initialized data, and initialized data is always a subset of the buffer +/// itself. Both the filled and initialized data sections are guaranteed to be +/// at the start of the buffer, but the fresh subset is likely to begin +/// somewhere inside the filled section. +/// +/// To visualize this, the diagram below represents a possible state for the +/// byte buffer being tracked. The square `[ ]` brackets represent the complete +/// buffer, while the curly `{ }` represent the named subset. +/// +/// ```text +/// [ { !! fresh !! } ] +/// { +++ filled +++ } unfilled ] +/// { ----- initialized ------ } uninitialized ] +/// [ capacity ] +/// ``` +/// +/// The same buffer represented in its true single dimension is below: +/// +/// ```text +/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx] +/// ``` +/// +/// * `+`: filled (implies initialized) +/// * `!`: fresh (implies filled) +/// * `-`: unfilled / initialized (implies initialized) +/// * `x`: uninitialized (implies unfilled) +/// +/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a +/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized +/// portions of the byte buffers is what increases the size of the _filled_ +/// section. Because a [`ReadBuf`] may already be partially filled when a reader +/// adds bytes to it, a mechanism to track where the _newly_ filled portion +/// exists is needed. This is exactly what the "fresh" section tracks. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +pub struct TransformBuf<'a, 'b> { + pub(crate) buf: &'a mut ReadBuf<'b>, + pub(crate) cursor: usize, +} + +impl TransformBuf<'_, '_> { + /// Returns a borrow to the fresh data: data filled by the upstream source. + pub fn fresh(&self) -> &[u8] { + &self.filled()[self.cursor..] + } + + /// Returns a mutable borrow to the fresh data: data filled by the upstream + /// source. + pub fn fresh_mut(&mut self) -> &mut [u8] { + let cursor = self.cursor; + &mut self.filled_mut()[cursor..] + } + + /// Spoils the fresh data by resetting the filled section to its value + /// before any new data was added. As a result, the data will never be seen + /// by any downstream consumer unless it is returned via another mechanism. + pub fn spoil(&mut self) { + let cursor = self.cursor; + self.set_filled(cursor); + } +} + +pub struct Inspect(pub(crate) Box); + +impl Transform for Inspect { + fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> { + (self.0)(buf.fresh()); + Ok(()) + } +} + +pub struct InPlaceMap( + pub(crate) Box) -> io::Result<()> + Send + Sync + 'static> +); + +impl Transform for InPlaceMap { + fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>,) -> io::Result<()> { + (self.0)(buf) + } +} + +impl<'a, 'b> Deref for TransformBuf<'a, 'b> { + type Target = ReadBuf<'b>; + + fn deref(&self) -> &Self::Target { + &self.buf + } +} + +impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buf + } +} + +// TODO: Test chaining various transform combinations: +// * consume | consume +// * add | consume +// * consume | add +// * add | add +// Where `add` is a transformer that adds data to the stream, and `consume` is +// one that removes data. +#[cfg(test)] +#[allow(deprecated)] +mod tests { + use std::hash::SipHasher; + use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}}; + + use parking_lot::Mutex; + use ubyte::ToByteUnit; + + use crate::http::Method; + use crate::local::blocking::Client; + use crate::fairing::AdHoc; + use crate::{route, Route, Data, Response, Request}; + + mod hash_transform { + use std::io::Cursor; + use std::hash::Hasher; + + use tokio::io::AsyncRead; + + use super::super::*; + + pub struct HashTransform { + pub(crate) hasher: H, + pub(crate) hash: Option> + } + + impl Transform for HashTransform { + fn transform( + mut self: Pin<&mut Self>, + buf: &mut TransformBuf<'_, '_>, + ) -> io::Result<()> { + self.hasher.write(buf.fresh()); + buf.spoil(); + Ok(()) + } + + fn poll_finish( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.hash.is_none() { + let hash = self.hasher.finish(); + self.hash = Some(Cursor::new(hash.to_be_bytes())); + } + + let cursor = self.hash.as_mut().unwrap(); + Pin::new(cursor).poll_read(cx, buf) + } + } + + impl crate::Data<'_> { + /// Chain an in-place hash [`Transform`] to `self`. + pub fn chain_hash_transform(&mut self, hasher: H) -> &mut Self + where H: Unpin + Send + Sync + 'static + { + self.chain_transform(HashTransform { hasher, hash: None }) + } + } + } + + #[test] + fn test_transform_series() { + fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> { + Box::pin(async move { + data.open(128.bytes()).stream_to(tokio::io::sink()).await.expect("read ok"); + route::Outcome::Success(Response::new()) + }) + } + + let inspect2: Arc = Arc::new(AtomicU8::new(0)); + let raw_data: Arc>> = Arc::new(Mutex::new(Vec::new())); + let hash: Arc = Arc::new(AtomicU64::new(0)); + let rocket = crate::build() + .manage(hash.clone()) + .manage(raw_data.clone()) + .manage(inspect2.clone()) + .mount("/", vec![Route::new(Method::Post, "/", handler)]) + .attach(AdHoc::on_request("transforms", |req, data| Box::pin(async { + let hash1 = req.rocket().state::>().cloned().unwrap(); + let hash2 = req.rocket().state::>().cloned().unwrap(); + let raw_data = req.rocket().state::>>>().cloned().unwrap(); + let inspect2 = req.rocket().state::>().cloned().unwrap(); + data.chain_inspect(move |bytes| { *raw_data.lock() = bytes.to_vec(); }) + .chain_hash_transform(SipHasher::new()) + .chain_inspect(move |bytes| { + assert_eq!(bytes.len(), 8); + let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); + let value = u64::from_be_bytes(bytes); + hash1.store(value, atomic::Ordering::Release); + }) + .chain_inspect(move |bytes| { + assert_eq!(bytes.len(), 8); + let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]"); + let value = u64::from_be_bytes(bytes); + let prev = hash2.load(atomic::Ordering::Acquire); + assert_eq!(prev, value); + inspect2.fetch_add(1, atomic::Ordering::Release); + }); + }))); + + // Make sure nothing has happened yet. + assert!(raw_data.lock().is_empty()); + assert_eq!(hash.load(atomic::Ordering::Acquire), 0); + assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); + + // Check that nothing happens if the data isn't read. + let client = Client::debug(rocket).unwrap(); + client.get("/").body("Hello, world!").dispatch(); + assert!(raw_data.lock().is_empty()); + assert_eq!(hash.load(atomic::Ordering::Acquire), 0); + assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0); + + // Check inspect + hash + inspect + inspect. + client.post("/").body("Hello, world!").dispatch(); + assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes()); + assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f); + assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1); + + // Check inspect + hash + inspect + inspect, round 2. + let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!"; + client.post("/").body(string).dispatch(); + assert_eq!(raw_data.lock().as_slice(), string.as_bytes()); + assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf); + assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2); + } +} diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index e689cdfa..a83d6e58 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -59,7 +59,7 @@ enum AdHocKind { Liftoff(Once FnOnce(&'a Rocket) -> BoxFuture<'a, ()> + Send + 'static>), /// An ad-hoc **request** fairing. Called when a request is received. - Request(Box Fn(&'a mut Request<'_>, &'a Data<'_>) + Request(Box Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), /// An ad-hoc **response** fairing. Called when a response is ready to be @@ -153,7 +153,7 @@ impl AdHoc { /// }); /// ``` pub fn on_request(name: &'static str, f: F) -> AdHoc - where F: for<'a> Fn(&'a mut Request<'_>, &'a Data<'_>) -> BoxFuture<'a, ()> + where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } }