diff --git a/lib/src/data/data.rs b/lib/src/data/data.rs index 74d2fe79..df51eef9 100644 --- a/lib/src/data/data.rs +++ b/lib/src/data/data.rs @@ -1,22 +1,24 @@ -use std::io::{self, Read, Write, Cursor, BufReader}; +use std::io::{self, Read, Write, Cursor, BufReader, Chain, Take}; use std::path::Path; use std::fs::File; use std::time::Duration; -use std::mem::transmute; #[cfg(feature = "tls")] use hyper_rustls::WrappedStream; -use super::data_stream::{DataStream, StreamReader, kill_stream}; +use super::data_stream::DataStream; use super::net_stream::NetStream; use ext::ReadExt; +use http::hyper; use http::hyper::h1::HttpReader; -use http::hyper::buffer; use http::hyper::h1::HttpReader::*; use http::hyper::net::{HttpStream, NetworkStream}; -pub type BodyReader<'a, 'b> = - self::HttpReader<&'a mut self::buffer::BufReader<&'b mut NetworkStream>>; +pub type HyperBodyReader<'a, 'b> = + self::HttpReader<&'a mut hyper::buffer::BufReader<&'b mut NetworkStream>>; + +// |---- from hyper ----| +pub type BodyReader = HttpReader>>, BufReader>>; /// The number of bytes to read into the "peek" buffer. const PEEK_BYTES: usize = 4096; @@ -51,12 +53,8 @@ const PEEK_BYTES: usize = 4096; /// without consuming the `Data` object. pub struct Data { buffer: Vec, - is_done: bool, - // TODO: This sucks as it depends on a TCPStream. Oh, hyper. - stream: StreamReader, - // Ideally we wouldn't have these, but Hyper forces us to. - position: usize, - capacity: usize, + is_complete: bool, + stream: BodyReader, } impl Data { @@ -67,39 +65,28 @@ impl Data { /// instance. This ensures that a `Data` type _always_ represents _all_ of /// the data in a request. pub fn open(mut self) -> DataStream { - // Swap out the buffer and stream for empty ones so we can move. - let mut buffer = vec![]; - let mut stream = EmptyReader(self.stream.get_ref().clone()); - ::std::mem::swap(&mut buffer, &mut self.buffer); - ::std::mem::swap(&mut stream, &mut self.stream); + let buffer = ::std::mem::replace(&mut self.buffer, vec![]); + let empty_stream = Cursor::new(vec![]).take(0) + .chain(BufReader::new(NetStream::Local(Cursor::new(vec![])))); - // Setup the underlying reader at the correct pointers. - let mut cursor = Cursor::new(buffer); - cursor.set_position(self.position as u64); - - // Get a reference to the underlying network stream. - let network = stream.get_ref().clone(); - - // The first part of the stream is the buffer. Then the real steam. - let buf = cursor.take((self.capacity - self.position) as u64); - let stream = buf.chain(BufReader::new(stream)); - - DataStream::new(stream, network) + let empty_http_stream = HttpReader::SizedReader(empty_stream, 0); + let stream = ::std::mem::replace(&mut self.stream, empty_http_stream); + DataStream(Cursor::new(buffer).chain(stream)) } // FIXME: This is absolutely terrible (downcasting!), thanks to Hyper. - pub(crate) fn from_hyp(mut h_body: BodyReader) -> Result { - // Create the Data object from hyper's buffer. - let (vec, pos, cap) = h_body.get_mut().take_buf(); - let net_stream = h_body.get_ref().get_ref(); + pub(crate) fn from_hyp(mut body: HyperBodyReader) -> Result { + // Steal the internal, undecoded data buffer and net stream from Hyper. + let (hyper_buf, pos, cap) = body.get_mut().take_buf(); + let hyper_net_stream = body.get_ref().get_ref(); #[cfg(feature = "tls")] fn concrete_stream(stream: &&mut NetworkStream) -> Option { - stream.downcast_ref::() - .map(|s| NetStream::Http(s.clone())) + stream.downcast_ref::() + .map(|s| NetStream::Https(s.clone())) .or_else(|| { - stream.downcast_ref::() - .map(|s| NetStream::Https(s.clone())) + stream.downcast_ref::() + .map(|s| NetStream::Http(s.clone())) }) } @@ -109,25 +96,32 @@ impl Data { .map(|s| NetStream::Http(s.clone())) } - // Retrieve the underlying HTTPStream from Hyper. - let stream = match concrete_stream(net_stream) { - Some(stream) => stream, + // Retrieve the underlying Http(s)Stream from Hyper. + let net_stream = match concrete_stream(hyper_net_stream) { + Some(net_stream) => net_stream, None => return Err("Stream is not an HTTP(s) stream!") }; // Set the read timeout to 5 seconds. - stream.set_read_timeout(Some(Duration::from_secs(5))).expect("timeout set"); + net_stream.set_read_timeout(Some(Duration::from_secs(5))).expect("timeout set"); - // Create a reader from the stream. Don't read what's already buffered. - let buffered = (cap - pos) as u64; - let reader = match h_body { - SizedReader(_, n) => SizedReader(stream, n - buffered), - EofReader(_) => EofReader(stream), - EmptyReader(_) => EmptyReader(stream), - ChunkedReader(_, n) => ChunkedReader(stream, n.map(|k| k - buffered)), + // TODO: Explain this. + trace_!("Hyper buffer: [{}..{}] ({} bytes).", pos, cap, cap - pos); + let (start, remaining) = (pos as u64, (cap - pos) as u64); + let mut cursor = Cursor::new(hyper_buf); + cursor.set_position(start); + let inner_data = cursor.take(remaining) + .chain(BufReader::new(net_stream.clone())); + + // Create an HTTP reader from the stream. + let http_stream = match body { + SizedReader(_, n) => SizedReader(inner_data, n), + EofReader(_) => EofReader(inner_data), + EmptyReader(_) => EmptyReader(inner_data), + ChunkedReader(_, n) => ChunkedReader(inner_data, n) }; - Ok(Data::new(vec, pos, cap, reader)) + Ok(Data::new(http_stream)) } /// Retrieve the `peek` buffer. @@ -138,7 +132,7 @@ impl Data { /// buffer contains _all_ of the data in the body of the request. #[inline(always)] pub fn peek(&self) -> &[u8] { - &self.buffer[self.position..self.capacity] + &self.buffer } /// Returns true if the `peek` buffer contains all of the data in the body @@ -146,7 +140,7 @@ impl Data { /// it does. #[inline(always)] pub fn peek_complete(&self) -> bool { - self.is_done + self.is_complete } /// A helper method to write the body of the request to any `Write` type. @@ -171,40 +165,32 @@ impl Data { // in the buffer is at `pos` and the buffer has `cap` valid bytes. Thus, the // bytes `vec[pos..cap]` are buffered and unread. The remainder of the data // bytes can be read from `stream`. - pub(crate) fn new(mut buf: Vec, - pos: usize, - mut cap: usize, - mut stream: StreamReader - ) -> Data { - // Make sure the buffer is large enough for the bytes we want to peek. - if buf.len() < PEEK_BYTES { - trace_!("Resizing peek buffer from {} to {}.", buf.len(), PEEK_BYTES); - buf.resize(PEEK_BYTES, 0); - } + pub(crate) fn new(mut stream: BodyReader) -> Data { + trace_!("Date::new({:?})", stream); + let mut peek_buf = vec![0; PEEK_BYTES]; // Fill the buffer with as many bytes as possible. If we read less than // that buffer's length, we know we reached the EOF. Otherwise, it's // unclear, so we just say we didn't reach EOF. - trace!("Init buffer cap: {}", cap); - let eof = match stream.read_max(&mut buf[cap..]) { + let eof = match stream.read_max(&mut peek_buf[..]) { Ok(n) => { trace_!("Filled peek buf with {} bytes.", n); - cap += n; - cap < buf.len() + // TODO: Explain this. + unsafe { peek_buf.set_len(n); } + n < PEEK_BYTES } Err(e) => { error_!("Failed to read into peek buffer: {:?}.", e); + unsafe { peek_buf.set_len(0); } false }, }; - trace_!("Peek buffer size: {}, remaining: {}", buf.len(), buf.len() - cap); + trace_!("Peek bytes: {}/{} bytes.", peek_buf.len(), PEEK_BYTES); Data { - buffer: buf, + buffer: peek_buf, stream: stream, - is_done: eof, - position: pos, - capacity: cap, + is_complete: eof, } } @@ -218,24 +204,23 @@ impl Data { (data, rest) }; - let (buf_len, stream_len) = (buf.len(), rest.len() as u64); - let stream = NetStream::Local(Cursor::new(rest)); + let stream_len = rest.len() as u64; + let stream = Cursor::new(vec![]).take(0) + .chain(BufReader::new(NetStream::Local(Cursor::new(rest)))); + Data { buffer: buf, stream: HttpReader::SizedReader(stream, stream_len), - is_done: stream_len == 0, - position: 0, - capacity: buf_len, + is_complete: stream_len == 0, } } } -impl Drop for Data { - fn drop(&mut self) { - // This is okay since the network stream expects to be shared mutably. - unsafe { - let stream: &mut StreamReader = transmute(self.stream.by_ref()); - kill_stream(stream, self.stream.get_mut()); - } - } -} +// impl Drop for Data { +// fn drop(&mut self) { +// // FIXME: Do a read; if > 1024, kill the stream. Need access to the +// // internals of `Chain` to do this efficiently/without crazy baggage. +// // https://github.com/rust-lang/rust/pull/41463 +// let _ = io::copy(&mut self.stream, &mut io::sink()); +// } +// } diff --git a/lib/src/data/data_stream.rs b/lib/src/data/data_stream.rs index c98b8293..58432f14 100644 --- a/lib/src/data/data_stream.rs +++ b/lib/src/data/data_stream.rs @@ -1,13 +1,14 @@ -use std::io::{self, BufRead, Read, Cursor, BufReader, Chain, Take}; -use std::net::Shutdown; +use std::io::{self, BufRead, Read, Cursor, BufReader, Chain}; -use super::net_stream::NetStream; +use super::data::BodyReader; -use http::hyper::net::NetworkStream; -use http::hyper::h1::HttpReader; - -pub type StreamReader = HttpReader; -pub type InnerStream = Chain>>, BufReader>; +// It's very unfortunate that we have to wrap `BodyReader` in a `BufReader` +// since it already contains another `BufReader`. The issue is that Hyper's +// `HttpReader` doesn't implement `BufRead`. Unfortunately, this will likely +// stay "double buffered" until we switch HTTP libraries. +// |-- peek buf --| +// pub type InnerStream = Chain>, BufReader>; +pub type InnerStream = Chain>, BodyReader>; /// Raw data stream of a request body. /// @@ -15,55 +16,33 @@ pub type InnerStream = Chain>>, BufReader>; /// [Data::open](/rocket/data/struct.Data.html#method.open). The stream contains /// all of the data in the body of the request. It exposes no methods directly. /// Instead, it must be used as an opaque `Read` or `BufRead` structure. -pub struct DataStream { - stream: InnerStream, - network: NetStream, -} - -impl DataStream { - #[inline(always)] - pub(crate) fn new(stream: InnerStream, network: NetStream) -> DataStream { - DataStream { stream, network } - } -} +pub struct DataStream(pub(crate) InnerStream); impl Read for DataStream { #[inline(always)] fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.stream.read(buf) + trace_!("DataStream::read()"); + self.0.read(buf) } } -impl BufRead for DataStream { - #[inline(always)] - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.stream.fill_buf() - } +// impl BufRead for DataStream { +// #[inline(always)] +// fn fill_buf(&mut self) -> io::Result<&[u8]> { +// self.0.fill_buf() +// } - #[inline(always)] - fn consume(&mut self, amt: usize) { - self.stream.consume(amt) - } -} +// #[inline(always)] +// fn consume(&mut self, amt: usize) { +// self.0.consume(amt) +// } +// } - -pub fn kill_stream(stream: &mut S, network: &mut N) { - // Take <= 1k from the stream. If there might be more data, force close. - const FLUSH_LEN: u64 = 1024; - match io::copy(&mut stream.take(FLUSH_LEN), &mut io::sink()) { - Ok(FLUSH_LEN) | Err(_) => { - warn_!("Data left unread. Force closing network stream."); - if let Err(e) = network.close(Shutdown::Both) { - error_!("Failed to close network stream: {:?}", e); - } - } - Ok(n) => debug!("flushed {} unread bytes", n) - } -} - -impl Drop for DataStream { - // Be a bad citizen and close the TCP stream if there's unread data. - fn drop(&mut self) { - kill_stream(&mut self.stream, &mut self.network); - } -} +// impl Drop for DataStream { +// fn drop(&mut self) { +// // FIXME: Do a read; if > 1024, kill the stream. Need access to the +// // internals of `Chain` to do this efficiently/without crazy baggage. +// // https://github.com/rust-lang/rust/pull/41463 +// let _ = io::copy(&mut self.0, &mut io::sink()); +// } +// } diff --git a/lib/src/data/net_stream.rs b/lib/src/data/net_stream.rs index 28c3d766..908eb43a 100644 --- a/lib/src/data/net_stream.rs +++ b/lib/src/data/net_stream.rs @@ -20,17 +20,21 @@ pub enum NetStream { impl io::Read for NetStream { #[inline(always)] fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { + trace_!("NetStream::read()"); + let res = match *self { Http(ref mut stream) => stream.read(buf), Local(ref mut stream) => stream.read(buf), #[cfg(feature = "tls")] Https(ref mut stream) => stream.read(buf) - } + }; + trace_!("NetStream::read() -- complete"); + res } } impl io::Write for NetStream { #[inline(always)] fn write(&mut self, buf: &[u8]) -> io::Result { + trace_!("NetStream::write()"); match *self { Http(ref mut stream) => stream.write(buf), Local(ref mut stream) => stream.write(buf), @@ -85,3 +89,20 @@ impl NetworkStream for NetStream { } } } + +// impl Drop for NetStream { +// fn drop(&mut self) { +// // Take <= 1k from the stream. If there might be more data, force close. +// trace_!("Dropping the network stream..."); +// // const FLUSH_LEN: u64 = 1024; +// // match io::copy(&mut self.take(FLUSH_LEN), &mut io::sink()) { +// // Ok(FLUSH_LEN) | Err(_) => { +// // warn_!("Data left unread. Force closing network stream."); +// // if let Err(e) = self.close(Shutdown::Both) { +// // error_!("Failed to close network stream: {:?}", e); +// // } +// // } +// // Ok(n) => debug!("flushed {} unread bytes", n) +// // } +// } +// }