mirror of https://github.com/rwf2/Rocket.git
parent
a285625f80
commit
b3abc760ae
|
@ -1,9 +1,10 @@
|
|||
use crate::tokio::io::AsyncReadExt;
|
||||
use crate::data::data_stream::DataStream;
|
||||
use crate::data::{ByteUnit, StreamReader};
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// The number of bytes to read into the "peek" buffer.
|
||||
pub const PEEK_BYTES: usize = 512;
|
||||
use crate::data::ByteUnit;
|
||||
use crate::data::data_stream::{DataStream, RawReader, RawStream};
|
||||
use crate::data::peekable::Peekable;
|
||||
use crate::data::transform::{Transform, TransformBuf, Inspect, InPlaceMap};
|
||||
|
||||
/// Type representing the body data of a request.
|
||||
///
|
||||
|
@ -38,31 +39,27 @@ pub const PEEK_BYTES: usize = 512;
|
|||
/// body data. This enables partially or fully reading from a `Data` object
|
||||
/// without consuming the `Data` object.
|
||||
pub struct Data<'r> {
|
||||
buffer: Vec<u8>,
|
||||
is_complete: bool,
|
||||
stream: StreamReader<'r>,
|
||||
stream: Peekable<512, RawReader<'r>>,
|
||||
transforms: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
|
||||
}
|
||||
|
||||
// TODO: Before `async`, we had a read timeout of 5s. Such a short read timeout
|
||||
// is likely no longer necessary, but an idle timeout should be implemented.
|
||||
impl<'r> Data<'r> {
|
||||
/// Create a `Data` from a recognized `stream`.
|
||||
pub(crate) fn from<S: Into<StreamReader<'r>>>(stream: S) -> Data<'r> {
|
||||
// TODO.async: This used to also set the read timeout to 5 seconds.
|
||||
// Such a short read timeout is likely no longer necessary, but some
|
||||
// kind of idle timeout should be implemented.
|
||||
#[inline]
|
||||
pub(crate) fn new(stream: Peekable<512, RawReader<'r>>) -> Self {
|
||||
Self { stream, transforms: Vec::new() }
|
||||
}
|
||||
|
||||
let stream = stream.into();
|
||||
let buffer = Vec::with_capacity(PEEK_BYTES / 8);
|
||||
Data { buffer, stream, is_complete: false }
|
||||
#[inline]
|
||||
pub(crate) fn from<S: Into<RawStream<'r>>>(stream: S) -> Data<'r> {
|
||||
Data::new(Peekable::new(RawReader::new(stream.into())))
|
||||
}
|
||||
|
||||
/// This creates a `data` object from a local data source `data`.
|
||||
#[inline]
|
||||
pub(crate) fn local(data: Vec<u8>) -> Data<'r> {
|
||||
Data {
|
||||
buffer: data,
|
||||
stream: StreamReader::empty(),
|
||||
is_complete: true,
|
||||
}
|
||||
Data::new(Peekable::with_buffer(data, true, RawReader::new(RawStream::Empty)))
|
||||
}
|
||||
|
||||
/// Returns the raw data stream, limited to `limit` bytes.
|
||||
|
@ -82,18 +79,31 @@ impl<'r> Data<'r> {
|
|||
/// let stream = data.open(2.mebibytes());
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn open(self, limit: ByteUnit) -> DataStream<'r> {
|
||||
DataStream::new(self.buffer, self.stream, limit.into())
|
||||
DataStream::new(self.transforms, self.stream, limit.into())
|
||||
}
|
||||
|
||||
/// Retrieve at most `num` bytes from the `peek` buffer without consuming
|
||||
/// `self`.
|
||||
/// Fills the peek buffer with body data until it contains at least `num`
|
||||
/// bytes (capped to 512), or the complete body data, whichever is less, and
|
||||
/// returns it. If the buffer already contains either at least `num` bytes
|
||||
/// or all of the body data, no I/O is performed and the buffer is simply
|
||||
/// returned. If `num` is greater than `512`, it is artificially capped to
|
||||
/// `512`.
|
||||
///
|
||||
/// The peek buffer contains at most 512 bytes of the body of the request.
|
||||
/// The actual size of the returned buffer is the `min` of the request's
|
||||
/// body, `num` and `512`. The [`peek_complete`](#method.peek_complete)
|
||||
/// method can be used to determine if this buffer contains _all_ of the
|
||||
/// data in the body of the request.
|
||||
/// No guarantees are made about the actual size of the returned buffer
|
||||
/// except that it will not exceed the length of the body data. It may be:
|
||||
///
|
||||
/// * Less than `num` if `num > 512` or the complete body data is `< 512`
|
||||
/// or an error occurred while reading the body.
|
||||
/// * Equal to `num` if `num` is `<= 512` and exactly `num` bytes of the
|
||||
/// body data were successfully read.
|
||||
/// * Greater than `num` if `> num` bytes of the body data have
|
||||
/// successfully been read, either by this request, a previous request,
|
||||
/// or opportunistically.
|
||||
///
|
||||
/// [`Data::peek_complete()`] can be used to determine if this buffer
|
||||
/// contains the complete body data.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
|
@ -147,30 +157,13 @@ impl<'r> Data<'r> {
|
|||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub async fn peek(&mut self, num: usize) -> &[u8] {
|
||||
let num = std::cmp::min(PEEK_BYTES, num);
|
||||
let mut len = self.buffer.len();
|
||||
if len >= num {
|
||||
return &self.buffer[..num];
|
||||
}
|
||||
|
||||
while len < num {
|
||||
match self.stream.read_buf(&mut self.buffer).await {
|
||||
Ok(0) => { self.is_complete = true; break },
|
||||
Ok(n) => len += n,
|
||||
Err(e) => {
|
||||
error_!("Failed to read into peek buffer: {:?}.", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
&self.buffer[..std::cmp::min(len, num)]
|
||||
self.stream.peek(num).await
|
||||
}
|
||||
|
||||
/// Returns true if the `peek` buffer contains all of the data in the body
|
||||
/// of the request. Returns `false` if it does not or if it is not known if
|
||||
/// it does.
|
||||
/// of the request. Returns `false` if it does not or it is not known.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -185,6 +178,43 @@ impl<'r> Data<'r> {
|
|||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn peek_complete(&self) -> bool {
|
||||
self.is_complete
|
||||
self.stream.complete
|
||||
}
|
||||
|
||||
/// Chains the [`Transform`] `transform` to `self`.
|
||||
///
|
||||
/// Note that transforms do nothing until the data is
|
||||
/// [`open()`ed](Data::open()) and read.
|
||||
#[inline(always)]
|
||||
pub fn chain_transform<T>(&mut self, transform: T) -> &mut Self
|
||||
where T: Transform + Send + Sync + 'static
|
||||
{
|
||||
self.transforms.push(Box::pin(transform));
|
||||
self
|
||||
}
|
||||
|
||||
/// Chain a [`Transform`] that can inspect the data as it streams.
|
||||
pub fn chain_inspect<F>(&mut self, f: F) -> &mut Self
|
||||
where F: FnMut(&[u8]) + Send + Sync + 'static
|
||||
{
|
||||
self.chain_transform(Inspect(Box::new(f)))
|
||||
}
|
||||
|
||||
/// Chain a [`Transform`] that can in-place map the data as it streams.
|
||||
/// Unlike [`Data::chain_try_inplace_map()`], this version assumes the
|
||||
/// mapper is infallible.
|
||||
pub fn chain_inplace_map<F>(&mut self, mut f: F) -> &mut Self
|
||||
where F: FnMut(&mut TransformBuf<'_, '_>) + Send + Sync + 'static
|
||||
{
|
||||
self.chain_transform(InPlaceMap(Box::new(move |buf| Ok(f(buf)))))
|
||||
}
|
||||
|
||||
/// Chain a [`Transform`] that can in-place map the data as it streams.
|
||||
/// Unlike [`Data::chain_inplace_map()`], this version allows the mapper to
|
||||
/// be infallible.
|
||||
pub fn chain_try_inplace_map<F>(&mut self, f: F) -> &mut Self
|
||||
where F: FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static
|
||||
{
|
||||
self.chain_transform(InPlaceMap(Box::new(f)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,17 @@ use std::io::{self, Cursor};
|
|||
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
|
||||
use futures::stream::Stream;
|
||||
use futures::ready;
|
||||
use yansi::Paint;
|
||||
use tokio_util::io::StreamReader;
|
||||
use futures::{ready, stream::Stream};
|
||||
|
||||
use crate::http::hyper;
|
||||
use crate::ext::{PollExt, Chain};
|
||||
use crate::data::{Capped, N};
|
||||
use crate::http::hyper::body::Bytes;
|
||||
use crate::data::transform::Transform;
|
||||
|
||||
use super::peekable::Peekable;
|
||||
use super::transform::TransformBuf;
|
||||
|
||||
/// Raw data stream of a request body.
|
||||
///
|
||||
|
@ -40,47 +44,101 @@ use crate::data::{Capped, N};
|
|||
///
|
||||
/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
|
||||
/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
|
||||
pub struct DataStream<'r> {
|
||||
pub(crate) chain: Take<Chain<Cursor<Vec<u8>>, StreamReader<'r>>>,
|
||||
#[non_exhaustive]
|
||||
pub enum DataStream<'r> {
|
||||
#[doc(hidden)]
|
||||
Base(BaseReader<'r>),
|
||||
#[doc(hidden)]
|
||||
Transform(TransformReader<'r>),
|
||||
}
|
||||
|
||||
/// An adapter: turns a `T: Stream` (in `StreamKind`) into a `tokio::AsyncRead`.
|
||||
pub struct StreamReader<'r> {
|
||||
state: State,
|
||||
inner: StreamKind<'r>,
|
||||
/// A data stream that has a `transformer` applied to it.
|
||||
pub struct TransformReader<'r> {
|
||||
transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
|
||||
stream: Pin<Box<DataStream<'r>>>,
|
||||
inner_done: bool,
|
||||
}
|
||||
|
||||
/// The current state of `StreamReader` `AsyncRead` adapter.
|
||||
enum State {
|
||||
Pending,
|
||||
Partial(Cursor<hyper::body::Bytes>),
|
||||
Done,
|
||||
}
|
||||
/// Limited, pre-buffered reader to the underlying data stream.
|
||||
pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
|
||||
|
||||
/// The kinds of streams we accept as `Data`.
|
||||
enum StreamKind<'r> {
|
||||
/// Direct reader to the underlying data stream. Not limited in any manner.
|
||||
pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
|
||||
|
||||
/// Raw underlying data stream.
|
||||
pub enum RawStream<'r> {
|
||||
Empty,
|
||||
Body(&'r mut hyper::Body),
|
||||
Multipart(multer::Field<'r>)
|
||||
Multipart(multer::Field<'r>),
|
||||
}
|
||||
|
||||
impl<'r> TransformReader<'r> {
|
||||
/// Returns the underlying `BaseReader`.
|
||||
fn base_mut(&mut self) -> &mut BaseReader<'r> {
|
||||
match self.stream.as_mut().get_mut() {
|
||||
DataStream::Base(base) => base,
|
||||
DataStream::Transform(inner) => inner.base_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the underlying `BaseReader`.
|
||||
fn base(&self) -> &BaseReader<'r> {
|
||||
match self.stream.as_ref().get_ref() {
|
||||
DataStream::Base(base) => base,
|
||||
DataStream::Transform(inner) => inner.base(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> DataStream<'r> {
|
||||
pub(crate) fn new(buf: Vec<u8>, stream: StreamReader<'r>, limit: u64) -> Self {
|
||||
let chain = Chain::new(Cursor::new(buf), stream).take(limit).into();
|
||||
Self { chain }
|
||||
pub(crate) fn new(
|
||||
transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
|
||||
Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
|
||||
limit: u64
|
||||
) -> Self {
|
||||
let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
|
||||
for transformer in transformers {
|
||||
stream = DataStream::Transform(TransformReader {
|
||||
transformer,
|
||||
stream: Box::pin(stream),
|
||||
inner_done: false,
|
||||
});
|
||||
}
|
||||
|
||||
stream
|
||||
}
|
||||
|
||||
/// Returns the underlying `BaseReader`.
|
||||
fn base_mut(&mut self) -> &mut BaseReader<'r> {
|
||||
match self {
|
||||
DataStream::Base(base) => base,
|
||||
DataStream::Transform(transform) => transform.base_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the underlying `BaseReader`.
|
||||
fn base(&self) -> &BaseReader<'r> {
|
||||
match self {
|
||||
DataStream::Base(base) => base,
|
||||
DataStream::Transform(transform) => transform.base(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a previous read exhausted the set limit _and then some_.
|
||||
async fn limit_exceeded(&mut self) -> io::Result<bool> {
|
||||
let base = self.base_mut();
|
||||
|
||||
#[cold]
|
||||
async fn _limit_exceeded(stream: &mut DataStream<'_>) -> io::Result<bool> {
|
||||
async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
|
||||
// Read one more byte after reaching limit to see if we cut early.
|
||||
stream.chain.set_limit(1);
|
||||
base.set_limit(1);
|
||||
let mut buf = [0u8; 1];
|
||||
Ok(stream.read(&mut buf).await? != 0)
|
||||
let exceeded = base.read(&mut buf).await? != 0;
|
||||
base.set_limit(0);
|
||||
Ok(exceeded)
|
||||
}
|
||||
|
||||
Ok(self.chain.limit() == 0 && _limit_exceeded(self).await?)
|
||||
Ok(base.limit() == 0 && _limit_exceeded(base).await?)
|
||||
}
|
||||
|
||||
/// Number of bytes a full read from `self` will _definitely_ read.
|
||||
|
@ -95,8 +153,9 @@ impl<'r> DataStream<'r> {
|
|||
/// }
|
||||
/// ```
|
||||
pub fn hint(&self) -> usize {
|
||||
let buf_len = self.chain.get_ref().get_ref().0.get_ref().len();
|
||||
std::cmp::min(buf_len, self.chain.limit() as usize)
|
||||
let base = self.base();
|
||||
let buf_len = base.get_ref().get_ref().0.get_ref().len();
|
||||
std::cmp::min(buf_len, base.limit() as usize)
|
||||
}
|
||||
|
||||
/// A helper method to write the body of the request to any `AsyncWrite`
|
||||
|
@ -227,97 +286,86 @@ impl<'r> DataStream<'r> {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO.async: Consider implementing `AsyncBufRead`.
|
||||
|
||||
impl StreamReader<'_> {
|
||||
pub fn empty() -> Self {
|
||||
Self { inner: StreamKind::Empty, state: State::Done }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> From<&'r mut hyper::Body> for StreamReader<'r> {
|
||||
fn from(body: &'r mut hyper::Body) -> Self {
|
||||
Self { inner: StreamKind::Body(body), state: State::Pending }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> From<multer::Field<'r>> for StreamReader<'r> {
|
||||
fn from(field: multer::Field<'r>) -> Self {
|
||||
Self { inner: StreamKind::Multipart(field), state: State::Pending }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for DataStream<'_> {
|
||||
#[inline(always)]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
|
||||
DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TransformReader<'_> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.chain.limit() == 0 {
|
||||
let stream: &StreamReader<'_> = &self.chain.get_ref().get_ref().1;
|
||||
let kind = match stream.inner {
|
||||
StreamKind::Empty => "an empty stream (vacuous)",
|
||||
StreamKind::Body(_) => "the request body",
|
||||
StreamKind::Multipart(_) => "a multipart form field",
|
||||
};
|
||||
|
||||
warn_!("Data limit reached while reading {}.", kind.primary().bold());
|
||||
let init_fill = buf.filled().len();
|
||||
if !self.inner_done {
|
||||
ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
|
||||
self.inner_done = init_fill == buf.filled().len();
|
||||
}
|
||||
|
||||
Pin::new(&mut self.chain).poll_read(cx, buf)
|
||||
if self.inner_done {
|
||||
return self.transformer.as_mut().poll_finish(cx, buf);
|
||||
}
|
||||
|
||||
let mut tbuf = TransformBuf { buf, cursor: init_fill };
|
||||
self.transformer.as_mut().transform(&mut tbuf)?;
|
||||
if buf.filled().len() == init_fill {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for StreamKind<'_> {
|
||||
type Item = io::Result<hyper::body::Bytes>;
|
||||
impl Stream for RawStream<'_> {
|
||||
type Item = io::Result<Bytes>;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.get_mut() {
|
||||
StreamKind::Body(body) => Pin::new(body).poll_next(cx)
|
||||
RawStream::Body(body) => Pin::new(body).poll_next(cx)
|
||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
||||
StreamKind::Multipart(mp) => Pin::new(mp).poll_next(cx)
|
||||
RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx)
|
||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
||||
StreamKind::Empty => Poll::Ready(None),
|
||||
RawStream::Empty => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
match self {
|
||||
StreamKind::Body(body) => body.size_hint(),
|
||||
StreamKind::Multipart(mp) => mp.size_hint(),
|
||||
StreamKind::Empty => (0, Some(0)),
|
||||
RawStream::Body(body) => body.size_hint(),
|
||||
RawStream::Multipart(mp) => mp.size_hint(),
|
||||
RawStream::Empty => (0, Some(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for StreamReader<'_> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
self.state = match self.state {
|
||||
State::Pending => {
|
||||
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
|
||||
Some(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Some(Ok(bytes)) => State::Partial(Cursor::new(bytes)),
|
||||
None => State::Done,
|
||||
}
|
||||
},
|
||||
State::Partial(ref mut cursor) => {
|
||||
let rem = buf.remaining();
|
||||
match ready!(Pin::new(cursor).poll_read(cx, buf)) {
|
||||
Ok(()) if rem == buf.remaining() => State::Pending,
|
||||
result => return Poll::Ready(result),
|
||||
}
|
||||
}
|
||||
State::Done => return Poll::Ready(Ok(())),
|
||||
impl std::fmt::Display for RawStream<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RawStream::Empty => f.write_str("empty stream"),
|
||||
RawStream::Body(_) => f.write_str("request body"),
|
||||
RawStream::Multipart(_) => f.write_str("multipart form field"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> From<&'r mut hyper::Body> for RawStream<'r> {
|
||||
fn from(value: &'r mut hyper::Body) -> Self {
|
||||
Self::Body(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> From<multer::Field<'r>> for RawStream<'r> {
|
||||
fn from(value: multer::Field<'r>) -> Self {
|
||||
Self::Multipart(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ mod data_stream;
|
|||
mod from_data;
|
||||
mod limits;
|
||||
mod io_stream;
|
||||
mod transform;
|
||||
mod peekable;
|
||||
|
||||
pub use self::data::Data;
|
||||
pub use self::data_stream::DataStream;
|
||||
|
@ -15,5 +17,4 @@ pub use self::limits::Limits;
|
|||
pub use self::capped::{N, Capped};
|
||||
pub use self::io_stream::{IoHandler, IoStream};
|
||||
pub use ubyte::{ByteUnit, ToByteUnit};
|
||||
|
||||
pub(crate) use self::data_stream::StreamReader;
|
||||
pub use self::transform::{Transform, TransformBuf};
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
pub struct Peekable<const N: usize, R> {
|
||||
pub(crate) buffer: Vec<u8>,
|
||||
pub(crate) complete: bool,
|
||||
pub(crate) reader: R,
|
||||
}
|
||||
|
||||
impl<const N: usize, R: AsyncRead + Unpin> Peekable<N, R> {
|
||||
pub fn new(reader: R) -> Self {
|
||||
Self { buffer: Vec::new(), complete: false, reader }
|
||||
}
|
||||
|
||||
pub fn with_buffer(buffer: Vec<u8>, complete: bool, reader: R) -> Self {
|
||||
Self { buffer, complete, reader }
|
||||
}
|
||||
|
||||
pub async fn peek(&mut self, num: usize) -> &[u8] {
|
||||
if self.complete {
|
||||
return self.buffer.as_slice();
|
||||
}
|
||||
|
||||
let to_read = std::cmp::min(N, num);
|
||||
if self.buffer.len() >= to_read {
|
||||
return &self.buffer.as_slice();
|
||||
}
|
||||
|
||||
if self.buffer.capacity() == 0 {
|
||||
self.buffer.reserve(N);
|
||||
}
|
||||
|
||||
while self.buffer.len() < to_read {
|
||||
match self.reader.read_buf::<Vec<u8>>(&mut self.buffer).await {
|
||||
Ok(0) => {
|
||||
self.complete = self.buffer.capacity() > self.buffer.len();
|
||||
break;
|
||||
},
|
||||
Ok(_) => { /* continue */ },
|
||||
Err(e) => {
|
||||
error_!("Failed to read into peek buffer: {:?}.", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.as_slice()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,304 @@
|
|||
use std::io;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Poll, Context};
|
||||
|
||||
use tokio::io::ReadBuf;
|
||||
|
||||
/// Chainable, in-place, streaming data transformer.
|
||||
///
|
||||
/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
|
||||
/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
|
||||
/// transforming readers. The head (most upstream part) of the chain is _always_
|
||||
/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
|
||||
/// composed _only_ of other [`Transform`]s:
|
||||
///
|
||||
/// ```text
|
||||
/// downstream --->
|
||||
/// AsyncRead | Transform | .. | Transform
|
||||
/// <---- upstream
|
||||
/// ```
|
||||
///
|
||||
/// When the upstream source makes data available, the
|
||||
/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
|
||||
/// subset of the filled section added by an upstream data source with
|
||||
/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
|
||||
/// changing the size of the filled section. For example,
|
||||
/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
|
||||
/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
|
||||
///
|
||||
/// Additionally, new data may be added in-place via the traditional approach:
|
||||
/// write to (or overwrite) the initialized section of the buffer and mark it as
|
||||
/// filled. All of the remaining filled data will be passed to downstream
|
||||
/// transforms as "fresh" data. To add data to the end of the (potentially
|
||||
/// rewritten) stream, the [`Transform::poll_finish()`] method can be
|
||||
/// implemented.
|
||||
///
|
||||
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||
pub trait Transform {
|
||||
/// Called when data is read from the upstream source. For any given fresh
|
||||
/// data, this method is called only once. [`TransformBuf::fresh()`] is
|
||||
/// guaranteed to contain at least one byte.
|
||||
///
|
||||
/// While this method is not _async_ (it does not return [`Poll`]), it is
|
||||
/// nevertheless executed in an async context and should respect all such
|
||||
/// restrictions including not blocking.
|
||||
fn transform(
|
||||
self: Pin<&mut Self>,
|
||||
buf: &mut TransformBuf<'_, '_>,
|
||||
) -> io::Result<()>;
|
||||
|
||||
/// Called when the upstream is finished, that is, it has no more data to
|
||||
/// fill. At this point, the transform becomes an async reader. This method
|
||||
/// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
|
||||
/// may never be called if the upstream does not finish.
|
||||
///
|
||||
/// The default implementation returns `Poll::Ready(Ok(()))`.
|
||||
///
|
||||
/// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
|
||||
fn poll_finish(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let (_, _) = (cx, buf);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
/// A buffer of transformable streaming data.
|
||||
///
|
||||
/// # Overview
|
||||
///
|
||||
/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
|
||||
/// data is always a subset of the filled data, filled data is always a subset
|
||||
/// of initialized data, and initialized data is always a subset of the buffer
|
||||
/// itself. Both the filled and initialized data sections are guaranteed to be
|
||||
/// at the start of the buffer, but the fresh subset is likely to begin
|
||||
/// somewhere inside the filled section.
|
||||
///
|
||||
/// To visualize this, the diagram below represents a possible state for the
|
||||
/// byte buffer being tracked. The square `[ ]` brackets represent the complete
|
||||
/// buffer, while the curly `{ }` represent the named subset.
|
||||
///
|
||||
/// ```text
|
||||
/// [ { !! fresh !! } ]
|
||||
/// { +++ filled +++ } unfilled ]
|
||||
/// { ----- initialized ------ } uninitialized ]
|
||||
/// [ capacity ]
|
||||
/// ```
|
||||
///
|
||||
/// The same buffer represented in its true single dimension is below:
|
||||
///
|
||||
/// ```text
|
||||
/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
|
||||
/// ```
|
||||
///
|
||||
/// * `+`: filled (implies initialized)
|
||||
/// * `!`: fresh (implies filled)
|
||||
/// * `-`: unfilled / initialized (implies initialized)
|
||||
/// * `x`: uninitialized (implies unfilled)
|
||||
///
|
||||
/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
|
||||
/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
|
||||
/// portions of the byte buffers is what increases the size of the _filled_
|
||||
/// section. Because a [`ReadBuf`] may already be partially filled when a reader
|
||||
/// adds bytes to it, a mechanism to track where the _newly_ filled portion
|
||||
/// exists is needed. This is exactly what the "fresh" section tracks.
|
||||
///
|
||||
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||
pub struct TransformBuf<'a, 'b> {
|
||||
pub(crate) buf: &'a mut ReadBuf<'b>,
|
||||
pub(crate) cursor: usize,
|
||||
}
|
||||
|
||||
impl TransformBuf<'_, '_> {
|
||||
/// Returns a borrow to the fresh data: data filled by the upstream source.
|
||||
pub fn fresh(&self) -> &[u8] {
|
||||
&self.filled()[self.cursor..]
|
||||
}
|
||||
|
||||
/// Returns a mutable borrow to the fresh data: data filled by the upstream
|
||||
/// source.
|
||||
pub fn fresh_mut(&mut self) -> &mut [u8] {
|
||||
let cursor = self.cursor;
|
||||
&mut self.filled_mut()[cursor..]
|
||||
}
|
||||
|
||||
/// Spoils the fresh data by resetting the filled section to its value
|
||||
/// before any new data was added. As a result, the data will never be seen
|
||||
/// by any downstream consumer unless it is returned via another mechanism.
|
||||
pub fn spoil(&mut self) {
|
||||
let cursor = self.cursor;
|
||||
self.set_filled(cursor);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
|
||||
|
||||
impl Transform for Inspect {
|
||||
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
|
||||
(self.0)(buf.fresh());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InPlaceMap(
|
||||
pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>
|
||||
);
|
||||
|
||||
impl Transform for InPlaceMap {
|
||||
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>,) -> io::Result<()> {
|
||||
(self.0)(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
|
||||
type Target = ReadBuf<'b>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.buf
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.buf
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Test chaining various transform combinations:
|
||||
// * consume | consume
|
||||
// * add | consume
|
||||
// * consume | add
|
||||
// * add | add
|
||||
// Where `add` is a transformer that adds data to the stream, and `consume` is
|
||||
// one that removes data.
|
||||
#[cfg(test)]
|
||||
#[allow(deprecated)]
|
||||
mod tests {
|
||||
use std::hash::SipHasher;
|
||||
use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use ubyte::ToByteUnit;
|
||||
|
||||
use crate::http::Method;
|
||||
use crate::local::blocking::Client;
|
||||
use crate::fairing::AdHoc;
|
||||
use crate::{route, Route, Data, Response, Request};
|
||||
|
||||
mod hash_transform {
|
||||
use std::io::Cursor;
|
||||
use std::hash::Hasher;
|
||||
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
use super::super::*;
|
||||
|
||||
pub struct HashTransform<H: Hasher> {
|
||||
pub(crate) hasher: H,
|
||||
pub(crate) hash: Option<Cursor<[u8; 8]>>
|
||||
}
|
||||
|
||||
impl<H: Hasher + Unpin> Transform for HashTransform<H> {
|
||||
fn transform(
|
||||
mut self: Pin<&mut Self>,
|
||||
buf: &mut TransformBuf<'_, '_>,
|
||||
) -> io::Result<()> {
|
||||
self.hasher.write(buf.fresh());
|
||||
buf.spoil();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_finish(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.hash.is_none() {
|
||||
let hash = self.hasher.finish();
|
||||
self.hash = Some(Cursor::new(hash.to_be_bytes()));
|
||||
}
|
||||
|
||||
let cursor = self.hash.as_mut().unwrap();
|
||||
Pin::new(cursor).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Data<'_> {
|
||||
/// Chain an in-place hash [`Transform`] to `self`.
|
||||
pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
|
||||
where H: Unpin + Send + Sync + 'static
|
||||
{
|
||||
self.chain_transform(HashTransform { hasher, hash: None })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_series() {
|
||||
fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
|
||||
Box::pin(async move {
|
||||
data.open(128.bytes()).stream_to(tokio::io::sink()).await.expect("read ok");
|
||||
route::Outcome::Success(Response::new())
|
||||
})
|
||||
}
|
||||
|
||||
let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
|
||||
let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
|
||||
let rocket = crate::build()
|
||||
.manage(hash.clone())
|
||||
.manage(raw_data.clone())
|
||||
.manage(inspect2.clone())
|
||||
.mount("/", vec![Route::new(Method::Post, "/", handler)])
|
||||
.attach(AdHoc::on_request("transforms", |req, data| Box::pin(async {
|
||||
let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
|
||||
let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
|
||||
let raw_data = req.rocket().state::<Arc<Mutex<Vec<u8>>>>().cloned().unwrap();
|
||||
let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
|
||||
data.chain_inspect(move |bytes| { *raw_data.lock() = bytes.to_vec(); })
|
||||
.chain_hash_transform(SipHasher::new())
|
||||
.chain_inspect(move |bytes| {
|
||||
assert_eq!(bytes.len(), 8);
|
||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||
let value = u64::from_be_bytes(bytes);
|
||||
hash1.store(value, atomic::Ordering::Release);
|
||||
})
|
||||
.chain_inspect(move |bytes| {
|
||||
assert_eq!(bytes.len(), 8);
|
||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||
let value = u64::from_be_bytes(bytes);
|
||||
let prev = hash2.load(atomic::Ordering::Acquire);
|
||||
assert_eq!(prev, value);
|
||||
inspect2.fetch_add(1, atomic::Ordering::Release);
|
||||
});
|
||||
})));
|
||||
|
||||
// Make sure nothing has happened yet.
|
||||
assert!(raw_data.lock().is_empty());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
||||
|
||||
// Check that nothing happens if the data isn't read.
|
||||
let client = Client::debug(rocket).unwrap();
|
||||
client.get("/").body("Hello, world!").dispatch();
|
||||
assert!(raw_data.lock().is_empty());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
||||
|
||||
// Check inspect + hash + inspect + inspect.
|
||||
client.post("/").body("Hello, world!").dispatch();
|
||||
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1);
|
||||
|
||||
// Check inspect + hash + inspect + inspect, round 2.
|
||||
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
|
||||
client.post("/").body(string).dispatch();
|
||||
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
|
||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf);
|
||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2);
|
||||
}
|
||||
}
|
|
@ -59,7 +59,7 @@ enum AdHocKind {
|
|||
Liftoff(Once<dyn for<'a> FnOnce(&'a Rocket<Orbit>) -> BoxFuture<'a, ()> + Send + 'static>),
|
||||
|
||||
/// An ad-hoc **request** fairing. Called when a request is received.
|
||||
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a Data<'_>)
|
||||
Request(Box<dyn for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>)
|
||||
-> BoxFuture<'a, ()> + Send + Sync + 'static>),
|
||||
|
||||
/// An ad-hoc **response** fairing. Called when a response is ready to be
|
||||
|
@ -153,7 +153,7 @@ impl AdHoc {
|
|||
/// });
|
||||
/// ```
|
||||
pub fn on_request<F: Send + Sync + 'static>(name: &'static str, f: F) -> AdHoc
|
||||
where F: for<'a> Fn(&'a mut Request<'_>, &'a Data<'_>) -> BoxFuture<'a, ()>
|
||||
where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()>
|
||||
{
|
||||
AdHoc { name, kind: AdHocKind::Request(Box::new(f)) }
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue