Introduce chainable data transformers.

Resolves #775.
This commit is contained in:
Sergio Benitez 2023-12-12 13:08:49 -08:00
parent a285625f80
commit b3abc760ae
6 changed files with 579 additions and 148 deletions

View File

@ -1,9 +1,10 @@
use crate::tokio::io::AsyncReadExt;
use crate::data::data_stream::DataStream;
use crate::data::{ByteUnit, StreamReader};
use std::io;
use std::pin::Pin;
/// The number of bytes to read into the "peek" buffer.
pub const PEEK_BYTES: usize = 512;
use crate::data::ByteUnit;
use crate::data::data_stream::{DataStream, RawReader, RawStream};
use crate::data::peekable::Peekable;
use crate::data::transform::{Transform, TransformBuf, Inspect, InPlaceMap};
/// Type representing the body data of a request.
///
@ -38,31 +39,27 @@ pub const PEEK_BYTES: usize = 512;
/// body data. This enables partially or fully reading from a `Data` object
/// without consuming the `Data` object.
pub struct Data<'r> {
buffer: Vec<u8>,
is_complete: bool,
stream: StreamReader<'r>,
stream: Peekable<512, RawReader<'r>>,
transforms: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
}
// TODO: Before `async`, we had a read timeout of 5s. Such a short read timeout
// is likely no longer necessary, but an idle timeout should be implemented.
impl<'r> Data<'r> {
/// Create a `Data` from a recognized `stream`.
pub(crate) fn from<S: Into<StreamReader<'r>>>(stream: S) -> Data<'r> {
// TODO.async: This used to also set the read timeout to 5 seconds.
// Such a short read timeout is likely no longer necessary, but some
// kind of idle timeout should be implemented.
#[inline]
pub(crate) fn new(stream: Peekable<512, RawReader<'r>>) -> Self {
Self { stream, transforms: Vec::new() }
}
let stream = stream.into();
let buffer = Vec::with_capacity(PEEK_BYTES / 8);
Data { buffer, stream, is_complete: false }
#[inline]
pub(crate) fn from<S: Into<RawStream<'r>>>(stream: S) -> Data<'r> {
Data::new(Peekable::new(RawReader::new(stream.into())))
}
/// This creates a `data` object from a local data source `data`.
#[inline]
pub(crate) fn local(data: Vec<u8>) -> Data<'r> {
Data {
buffer: data,
stream: StreamReader::empty(),
is_complete: true,
}
Data::new(Peekable::with_buffer(data, true, RawReader::new(RawStream::Empty)))
}
/// Returns the raw data stream, limited to `limit` bytes.
@ -82,18 +79,31 @@ impl<'r> Data<'r> {
/// let stream = data.open(2.mebibytes());
/// }
/// ```
#[inline(always)]
pub fn open(self, limit: ByteUnit) -> DataStream<'r> {
DataStream::new(self.buffer, self.stream, limit.into())
DataStream::new(self.transforms, self.stream, limit.into())
}
/// Retrieve at most `num` bytes from the `peek` buffer without consuming
/// `self`.
/// Fills the peek buffer with body data until it contains at least `num`
/// bytes (capped to 512), or the complete body data, whichever is less, and
/// returns it. If the buffer already contains either at least `num` bytes
/// or all of the body data, no I/O is performed and the buffer is simply
/// returned. If `num` is greater than `512`, it is artificially capped to
/// `512`.
///
/// The peek buffer contains at most 512 bytes of the body of the request.
/// The actual size of the returned buffer is the `min` of the request's
/// body, `num` and `512`. The [`peek_complete`](#method.peek_complete)
/// method can be used to determine if this buffer contains _all_ of the
/// data in the body of the request.
/// No guarantees are made about the actual size of the returned buffer
/// except that it will not exceed the length of the body data. It may be:
///
/// * Less than `num` if `num > 512` or the complete body data is `< 512`
/// or an error occurred while reading the body.
/// * Equal to `num` if `num` is `<= 512` and exactly `num` bytes of the
/// body data were successfully read.
/// * Greater than `num` if `> num` bytes of the body data have
/// successfully been read, either by this request, a previous request,
/// or opportunistically.
///
/// [`Data::peek_complete()`] can be used to determine if this buffer
/// contains the complete body data.
///
/// # Examples
///
@ -147,30 +157,13 @@ impl<'r> Data<'r> {
/// }
/// }
/// ```
#[inline(always)]
pub async fn peek(&mut self, num: usize) -> &[u8] {
let num = std::cmp::min(PEEK_BYTES, num);
let mut len = self.buffer.len();
if len >= num {
return &self.buffer[..num];
}
while len < num {
match self.stream.read_buf(&mut self.buffer).await {
Ok(0) => { self.is_complete = true; break },
Ok(n) => len += n,
Err(e) => {
error_!("Failed to read into peek buffer: {:?}.", e);
break;
}
}
}
&self.buffer[..std::cmp::min(len, num)]
self.stream.peek(num).await
}
/// Returns true if the `peek` buffer contains all of the data in the body
/// of the request. Returns `false` if it does not or if it is not known if
/// it does.
/// of the request. Returns `false` if it does not or it is not known.
///
/// # Example
///
@ -185,6 +178,43 @@ impl<'r> Data<'r> {
/// ```
#[inline(always)]
pub fn peek_complete(&self) -> bool {
self.is_complete
self.stream.complete
}
/// Chains the [`Transform`] `transform` to `self`.
///
/// Note that transforms do nothing until the data is
/// [`open()`ed](Data::open()) and read.
#[inline(always)]
pub fn chain_transform<T>(&mut self, transform: T) -> &mut Self
where T: Transform + Send + Sync + 'static
{
self.transforms.push(Box::pin(transform));
self
}
/// Chain a [`Transform`] that can inspect the data as it streams.
pub fn chain_inspect<F>(&mut self, f: F) -> &mut Self
where F: FnMut(&[u8]) + Send + Sync + 'static
{
self.chain_transform(Inspect(Box::new(f)))
}
/// Chain a [`Transform`] that can in-place map the data as it streams.
/// Unlike [`Data::chain_try_inplace_map()`], this version assumes the
/// mapper is infallible.
pub fn chain_inplace_map<F>(&mut self, mut f: F) -> &mut Self
where F: FnMut(&mut TransformBuf<'_, '_>) + Send + Sync + 'static
{
self.chain_transform(InPlaceMap(Box::new(move |buf| Ok(f(buf)))))
}
/// Chain a [`Transform`] that can in-place map the data as it streams.
/// Unlike [`Data::chain_inplace_map()`], this version allows the mapper to
/// be infallible.
pub fn chain_try_inplace_map<F>(&mut self, f: F) -> &mut Self
where F: FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static
{
self.chain_transform(InPlaceMap(Box::new(f)))
}
}

View File

@ -5,13 +5,17 @@ use std::io::{self, Cursor};
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
use futures::stream::Stream;
use futures::ready;
use yansi::Paint;
use tokio_util::io::StreamReader;
use futures::{ready, stream::Stream};
use crate::http::hyper;
use crate::ext::{PollExt, Chain};
use crate::data::{Capped, N};
use crate::http::hyper::body::Bytes;
use crate::data::transform::Transform;
use super::peekable::Peekable;
use super::transform::TransformBuf;
/// Raw data stream of a request body.
///
@ -40,47 +44,101 @@ use crate::data::{Capped, N};
///
/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
pub struct DataStream<'r> {
pub(crate) chain: Take<Chain<Cursor<Vec<u8>>, StreamReader<'r>>>,
#[non_exhaustive]
pub enum DataStream<'r> {
#[doc(hidden)]
Base(BaseReader<'r>),
#[doc(hidden)]
Transform(TransformReader<'r>),
}
/// An adapter: turns a `T: Stream` (in `StreamKind`) into a `tokio::AsyncRead`.
pub struct StreamReader<'r> {
state: State,
inner: StreamKind<'r>,
/// A data stream that has a `transformer` applied to it.
pub struct TransformReader<'r> {
transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
stream: Pin<Box<DataStream<'r>>>,
inner_done: bool,
}
/// The current state of `StreamReader` `AsyncRead` adapter.
enum State {
Pending,
Partial(Cursor<hyper::body::Bytes>),
Done,
}
/// Limited, pre-buffered reader to the underlying data stream.
pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
/// The kinds of streams we accept as `Data`.
enum StreamKind<'r> {
/// Direct reader to the underlying data stream. Not limited in any manner.
pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
/// Raw underlying data stream.
pub enum RawStream<'r> {
Empty,
Body(&'r mut hyper::Body),
Multipart(multer::Field<'r>)
Multipart(multer::Field<'r>),
}
impl<'r> TransformReader<'r> {
/// Returns the underlying `BaseReader`.
fn base_mut(&mut self) -> &mut BaseReader<'r> {
match self.stream.as_mut().get_mut() {
DataStream::Base(base) => base,
DataStream::Transform(inner) => inner.base_mut(),
}
}
/// Returns the underlying `BaseReader`.
fn base(&self) -> &BaseReader<'r> {
match self.stream.as_ref().get_ref() {
DataStream::Base(base) => base,
DataStream::Transform(inner) => inner.base(),
}
}
}
impl<'r> DataStream<'r> {
pub(crate) fn new(buf: Vec<u8>, stream: StreamReader<'r>, limit: u64) -> Self {
let chain = Chain::new(Cursor::new(buf), stream).take(limit).into();
Self { chain }
pub(crate) fn new(
transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
limit: u64
) -> Self {
let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
for transformer in transformers {
stream = DataStream::Transform(TransformReader {
transformer,
stream: Box::pin(stream),
inner_done: false,
});
}
stream
}
/// Returns the underlying `BaseReader`.
fn base_mut(&mut self) -> &mut BaseReader<'r> {
match self {
DataStream::Base(base) => base,
DataStream::Transform(transform) => transform.base_mut(),
}
}
/// Returns the underlying `BaseReader`.
fn base(&self) -> &BaseReader<'r> {
match self {
DataStream::Base(base) => base,
DataStream::Transform(transform) => transform.base(),
}
}
/// Whether a previous read exhausted the set limit _and then some_.
async fn limit_exceeded(&mut self) -> io::Result<bool> {
let base = self.base_mut();
#[cold]
async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result<bool> {
async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
// Read one more byte after reaching limit to see if we cut early.
stream.chain.set_limit(1);
base.set_limit(1);
let mut buf = [0u8; 1];
Ok(stream.read(&mut buf).await? != 0)
let exceeded = base.read(&mut buf).await? != 0;
base.set_limit(0);
Ok(exceeded)
}
Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?)
Ok(base.limit() == 0 && _limit_exceeded(base).await?)
}
/// Number of bytes a full read from `self` will _definitely_ read.
@ -95,8 +153,9 @@ impl<'r> DataStream<'r> {
/// }
/// ```
pub fn hint(&self) -> usize {
let buf_len = self.chain.get_ref().get_ref().0.get_ref().len();
std::cmp::min(buf_len, self.chain.limit() as usize)
let base = self.base();
let buf_len = base.get_ref().get_ref().0.get_ref().len();
std::cmp::min(buf_len, base.limit() as usize)
}
/// A helper method to write the body of the request to any `AsyncWrite`
@ -227,97 +286,86 @@ impl<'r> DataStream<'r> {
}
}
// TODO.async: Consider implementing `AsyncBufRead`.
impl StreamReader<'_> {
pub fn empty() -> Self {
Self { inner: StreamKind::Empty, state: State::Done }
}
}
impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> {
fn from(body: &'r mut hyper::Body) -> Self {
Self { inner: StreamKind::Body(body), state: State::Pending }
}
}
impl<'r> From<multer::Field<'r>> for StreamReader<'r> {
fn from(field: multer::Field<'r>) -> Self {
Self { inner: StreamKind::Multipart(field), state: State::Pending }
}
}
impl AsyncRead for DataStream<'_> {
#[inline(always)]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
}
impl AsyncRead for TransformReader<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.chain.limit() == 0 {
let stream: &StreamReader<'_> = &self.chain.get_ref().get_ref().1;
let kind = match stream.inner {
StreamKind::Empty => "an empty stream (vacuous)",
StreamKind::Body(_) => "the request body",
StreamKind::Multipart(_) => "a multipart form field",
};
warn_!("Data limit reached while reading {}.", kind.primary().bold());
let init_fill = buf.filled().len();
if !self.inner_done {
ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
self.inner_done = init_fill == buf.filled().len();
}
Pin::new(&mut self.chain).poll_read(cx, buf)
if self.inner_done {
return self.transformer.as_mut().poll_finish(cx, buf);
}
let mut tbuf = TransformBuf { buf, cursor: init_fill };
self.transformer.as_mut().transform(&mut tbuf)?;
if buf.filled().len() == init_fill {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Poll::Ready(Ok(()))
}
}
impl Stream for StreamKind<'_> {
type Item = io::Result<hyper::body::Bytes>;
impl Stream for RawStream<'_> {
type Item = io::Result<Bytes>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
StreamKind::Body(body) => Pin::new(body).poll_next(cx)
RawStream::Body(body) => Pin::new(body).poll_next(cx)
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx)
RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx)
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
StreamKind::Empty => Poll::Ready(None),
RawStream::Empty => Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
StreamKind::Body(body) => body.size_hint(),
StreamKind::Multipart(mp) => mp.size_hint(),
StreamKind::Empty => (0, Some(0)),
RawStream::Body(body) => body.size_hint(),
RawStream::Multipart(mp) => mp.size_hint(),
RawStream::Empty => (0, Some(0)),
}
}
}
impl AsyncRead for StreamReader<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
self.state = match self.state {
State::Pending => {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Err(e)) => return Poll::Ready(Err(e)),
Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)),
None => State::Done,
}
},
State::Partial(ref mut cursor) => {
let rem = buf.remaining();
match ready!(Pin::new(cursor).poll_read(cx, buf)) {
Ok(()) if rem == buf.remaining() => State::Pending,
result => return Poll::Ready(result),
}
}
State::Done => return Poll::Ready(Ok(())),
impl std::fmt::Display for RawStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RawStream::Empty => f.write_str("empty stream"),
RawStream::Body(_) => f.write_str("request body"),
RawStream::Multipart(_) => f.write_str("multipart form field"),
}
}
}
impl<'r> From<&'r mut hyper::Body> for RawStream<'r> {
fn from(value: &'r mut hyper::Body) -> Self {
Self::Body(value)
}
}
impl<'r> From<multer::Field<'r>> for RawStream<'r> {
fn from(value: multer::Field<'r>) -> Self {
Self::Multipart(value)
}
}

View File

@ -7,6 +7,8 @@ mod data_stream;
mod from_data;
mod limits;
mod io_stream;
mod transform;
mod peekable;
pub use self::data::Data;
pub use self::data_stream::DataStream;
@ -15,5 +17,4 @@ pub use self::limits::Limits;
pub use self::capped::{N, Capped};
pub use self::io_stream::{IoHandler, IoStream};
pub use ubyte::{ByteUnit, ToByteUnit};
pub(crate) use self::data_stream::StreamReader;
pub use self::transform::{Transform, TransformBuf};

View File

@ -0,0 +1,48 @@
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct Peekable<const N: usize, R> {
pub(crate) buffer: Vec<u8>,
pub(crate) complete: bool,
pub(crate) reader: R,
}
impl<const N: usize, R: AsyncRead + Unpin> Peekable<N, R> {
pub fn new(reader: R) -> Self {
Self { buffer: Vec::new(), complete: false, reader }
}
pub fn with_buffer(buffer: Vec<u8>, complete: bool, reader: R) -> Self {
Self { buffer, complete, reader }
}
pub async fn peek(&mut self, num: usize) -> &[u8] {
if self.complete {
return self.buffer.as_slice();
}
let to_read = std::cmp::min(N, num);
if self.buffer.len() >= to_read {
return &self.buffer.as_slice();
}
if self.buffer.capacity() == 0 {
self.buffer.reserve(N);
}
while self.buffer.len() < to_read {
match self.reader.read_buf::<Vec<u8>>(&mut self.buffer).await {
Ok(0) => {
self.complete = self.buffer.capacity() > self.buffer.len();
break;
},
Ok(_) => { /* continue */ },
Err(e) => {
error_!("Failed to read into peek buffer: {:?}.", e);
break;
}
}
}
self.buffer.as_slice()
}
}

View File

@ -0,0 +1,304 @@
use std::io;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Poll, Context};
use tokio::io::ReadBuf;
/// Chainable, in-place, streaming data transformer.
///
/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
/// transforming readers. The head (most upstream part) of the chain is _always_
/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
/// composed _only_ of other [`Transform`]s:
///
/// ```text
/// downstream --->
/// AsyncRead | Transform | .. | Transform
/// <---- upstream
/// ```
///
/// When the upstream source makes data available, the
/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
/// subset of the filled section added by an upstream data source with
/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
/// changing the size of the filled section. For example,
/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
///
/// Additionally, new data may be added in-place via the traditional approach:
/// write to (or overwrite) the initialized section of the buffer and mark it as
/// filled. All of the remaining filled data will be passed to downstream
/// transforms as "fresh" data. To add data to the end of the (potentially
/// rewritten) stream, the [`Transform::poll_finish()`] method can be
/// implemented.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
pub trait Transform {
/// Called when data is read from the upstream source. For any given fresh
/// data, this method is called only once. [`TransformBuf::fresh()`] is
/// guaranteed to contain at least one byte.
///
/// While this method is not _async_ (it does not return [`Poll`]), it is
/// nevertheless executed in an async context and should respect all such
/// restrictions including not blocking.
fn transform(
self: Pin<&mut Self>,
buf: &mut TransformBuf<'_, '_>,
) -> io::Result<()>;
/// Called when the upstream is finished, that is, it has no more data to
/// fill. At this point, the transform becomes an async reader. This method
/// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
/// may never be called if the upstream does not finish.
///
/// The default implementation returns `Poll::Ready(Ok(()))`.
///
/// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
fn poll_finish(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let (_, _) = (cx, buf);
Poll::Ready(Ok(()))
}
}
/// A buffer of transformable streaming data.
///
/// # Overview
///
/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
/// data is always a subset of the filled data, filled data is always a subset
/// of initialized data, and initialized data is always a subset of the buffer
/// itself. Both the filled and initialized data sections are guaranteed to be
/// at the start of the buffer, but the fresh subset is likely to begin
/// somewhere inside the filled section.
///
/// To visualize this, the diagram below represents a possible state for the
/// byte buffer being tracked. The square `[ ]` brackets represent the complete
/// buffer, while the curly `{ }` represent the named subset.
///
/// ```text
/// [ { !! fresh !! } ]
/// { +++ filled +++ } unfilled ]
/// { ----- initialized ------ } uninitialized ]
/// [ capacity ]
/// ```
///
/// The same buffer represented in its true single dimension is below:
///
/// ```text
/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
/// ```
///
/// * `+`: filled (implies initialized)
/// * `!`: fresh (implies filled)
/// * `-`: unfilled / initialized (implies initialized)
/// * `x`: uninitialized (implies unfilled)
///
/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
/// portions of the byte buffers is what increases the size of the _filled_
/// section. Because a [`ReadBuf`] may already be partially filled when a reader
/// adds bytes to it, a mechanism to track where the _newly_ filled portion
/// exists is needed. This is exactly what the "fresh" section tracks.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
pub struct TransformBuf<'a, 'b> {
pub(crate) buf: &'a mut ReadBuf<'b>,
pub(crate) cursor: usize,
}
impl TransformBuf<'_, '_> {
/// Returns a borrow to the fresh data: data filled by the upstream source.
pub fn fresh(&self) -> &[u8] {
&self.filled()[self.cursor..]
}
/// Returns a mutable borrow to the fresh data: data filled by the upstream
/// source.
pub fn fresh_mut(&mut self) -> &mut [u8] {
let cursor = self.cursor;
&mut self.filled_mut()[cursor..]
}
/// Spoils the fresh data by resetting the filled section to its value
/// before any new data was added. As a result, the data will never be seen
/// by any downstream consumer unless it is returned via another mechanism.
pub fn spoil(&mut self) {
let cursor = self.cursor;
self.set_filled(cursor);
}
}
pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
impl Transform for Inspect {
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
(self.0)(buf.fresh());
Ok(())
}
}
pub struct InPlaceMap(
pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>
);
impl Transform for InPlaceMap {
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>,) -> io::Result<()> {
(self.0)(buf)
}
}
impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
type Target = ReadBuf<'b>;
fn deref(&self) -> &Self::Target {
&self.buf
}
}
impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buf
}
}
// TODO: Test chaining various transform combinations:
// * consume | consume
// * add | consume
// * consume | add
// * add | add
// Where `add` is a transformer that adds data to the stream, and `consume` is
// one that removes data.
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::hash::SipHasher;
use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}};
use parking_lot::Mutex;
use ubyte::ToByteUnit;
use crate::http::Method;
use crate::local::blocking::Client;
use crate::fairing::AdHoc;
use crate::{route, Route, Data, Response, Request};
mod hash_transform {
use std::io::Cursor;
use std::hash::Hasher;
use tokio::io::AsyncRead;
use super::super::*;
pub struct HashTransform<H: Hasher> {
pub(crate) hasher: H,
pub(crate) hash: Option<Cursor<[u8; 8]>>
}
impl<H: Hasher + Unpin> Transform for HashTransform<H> {
fn transform(
mut self: Pin<&mut Self>,
buf: &mut TransformBuf<'_, '_>,
) -> io::Result<()> {
self.hasher.write(buf.fresh());
buf.spoil();
Ok(())
}
fn poll_finish(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.hash.is_none() {
let hash = self.hasher.finish();
self.hash = Some(Cursor::new(hash.to_be_bytes()));
}
let cursor = self.hash.as_mut().unwrap();
Pin::new(cursor).poll_read(cx, buf)
}
}
impl crate::Data<'_> {
/// Chain an in-place hash [`Transform`] to `self`.
pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
where H: Unpin + Send + Sync + 'static
{
self.chain_transform(HashTransform { hasher, hash: None })
}
}
}
#[test]
fn test_transform_series() {
fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
Box::pin(async move {
data.open(128.bytes()).stream_to(tokio::io::sink()).await.expect("read ok");
route::Outcome::Success(Response::new())
})
}
let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
let rocket = crate::build()
.manage(hash.clone())
.manage(raw_data.clone())
.manage(inspect2.clone())
.mount("/", vec![Route::new(Method::Post, "/", handler)])
.attach(AdHoc::on_request("transforms", |req, data| Box::pin(async {
let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
let raw_data = req.rocket().state::<Arc<Mutex<Vec<u8>>>>().cloned().unwrap();
let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
data.chain_inspect(move |bytes| { *raw_data.lock() = bytes.to_vec(); })
.chain_hash_transform(SipHasher::new())
.chain_inspect(move |bytes| {
assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes);
hash1.store(value, atomic::Ordering::Release);
})
.chain_inspect(move |bytes| {
assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes);
let prev = hash2.load(atomic::Ordering::Acquire);
assert_eq!(prev, value);
inspect2.fetch_add(1, atomic::Ordering::Release);
});
})));
// Make sure nothing has happened yet.
assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
// Check that nothing happens if the data isn't read.
let client = Client::debug(rocket).unwrap();
client.get("/").body("Hello, world!").dispatch();
assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
// Check inspect + hash + inspect + inspect.
client.post("/").body("Hello, world!").dispatch();
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1);
// Check inspect + hash + inspect + inspect, round 2.
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
client.post("/").body(string).dispatch();
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf);
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2);
}
}

View File

@ -59,7 +59,7 @@ enum AdHocKind {
Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
/// An ad-hoc **request** fairing. Called when a request is received.
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a Data<'_>)
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>)
-> BoxFuture<'a, ()> + Send + Sync + 'static>),
/// An ad-hoc **response** fairing. Called when a response is ready to be
@ -153,7 +153,7 @@ impl AdHoc {
/// });
/// ```
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data<'_>) -> BoxFuture<'a, ()>
where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>
{
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
}