Stream body data instead of buffering it.

This requires some awkward channel and spawning work because
Body might contain borrowed data.
This commit is contained in:
Jeb Rosen 2019-07-11 18:54:48 -07:00
parent 0fe3f39304
commit f83caf2d08
3 changed files with 113 additions and 47 deletions

View File

@ -16,6 +16,7 @@
#[doc(hidden)] pub use http::header::HeaderValue as HeaderValue; #[doc(hidden)] pub use http::header::HeaderValue as HeaderValue;
#[doc(hidden)] pub use http::method::Method; #[doc(hidden)] pub use http::method::Method;
#[doc(hidden)] pub use http::request::Parts as RequestParts; #[doc(hidden)] pub use http::request::Parts as RequestParts;
#[doc(hidden)] pub use http::response::Builder as ResponseBuilder;
#[doc(hidden)] pub use http::status::StatusCode; #[doc(hidden)] pub use http::status::StatusCode;
#[doc(hidden)] pub use http::uri::Uri; #[doc(hidden)] pub use http::uri::Uri;

View File

@ -3,8 +3,11 @@ use std::pin::Pin;
use futures::io::{AsyncRead, AsyncReadExt as _}; use futures::io::{AsyncRead, AsyncReadExt as _};
use futures::future::{Future}; use futures::future::{Future};
use futures::stream::Stream;
use futures::task::{Poll, Context}; use futures::task::{Poll, Context};
use crate::http::hyper::Chunk;
// Based on std::io::Take, but for AsyncRead instead of Read // Based on std::io::Take, but for AsyncRead instead of Read
pub struct Take<R>{ pub struct Take<R>{
inner: R, inner: R,
@ -30,11 +33,45 @@ impl<R> AsyncRead for Take<R> where R: AsyncRead + Unpin {
} }
} }
pub struct IntoChunkStream<R> {
inner: R,
buf_size: usize,
buffer: Vec<u8>,
}
// TODO.async: Verify correctness of this implementation.
impl<R> Stream for IntoChunkStream<R>
where R: AsyncRead + Unpin
{
type Item = Result<Chunk, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>{
assert!(self.buffer.len() == self.buf_size);
let Self { ref mut inner, ref mut buffer, buf_size } = *self;
match Pin::new(inner).poll_read(cx, &mut buffer[..]) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None),
Poll::Ready(Ok(n)) => {
let mut next = std::mem::replace(buffer, vec![0; buf_size]);
next.truncate(n);
Poll::Ready(Some(Ok(Chunk::from(next))))
}
}
}
}
pub trait AsyncReadExt: AsyncRead { pub trait AsyncReadExt: AsyncRead {
fn take(self, limit: u64) -> Take<Self> where Self: Sized { fn take(self, limit: u64) -> Take<Self> where Self: Sized {
Take { inner: self, limit } Take { inner: self, limit }
} }
fn into_chunk_stream(self, buf_size: usize) -> IntoChunkStream<Self> where Self: Sized {
IntoChunkStream { inner: self, buf_size, buffer: vec![0; buf_size] }
}
// TODO.async: Verify correctness of this implementation. // TODO.async: Verify correctness of this implementation.
fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> Pin<Box<dyn Future<Output=io::Result<usize>> + Send + '_>> fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> Pin<Box<dyn Future<Output=io::Result<usize>> + Send + '_>>
where Self: Send + Unpin where Self: Send + Unpin

View File

@ -1,5 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::From; use std::convert::{From, TryInto};
use std::cmp::min; use std::cmp::min;
use std::io; use std::io;
use std::mem; use std::mem;
@ -8,9 +8,11 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::pin::Pin; use std::pin::Pin;
use futures::compat::Compat; use futures::compat::{Compat, Executor01CompatExt, Sink01CompatExt};
use futures::future::{Future, FutureExt, TryFutureExt}; use futures::future::{Future, FutureExt, TryFutureExt};
use futures::io::AsyncReadExt; use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::task::SpawnExt;
use yansi::Paint; use yansi::Paint;
use state::Container; use state::Container;
@ -27,6 +29,7 @@ use crate::catcher::{self, Catcher};
use crate::outcome::Outcome; use crate::outcome::Outcome;
use crate::error::{LaunchError, LaunchErrorKind}; use crate::error::{LaunchError, LaunchErrorKind};
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::ext::AsyncReadExt;
use crate::http::{Method, Status, Header}; use crate::http::{Method, Status, Header};
use crate::http::hyper::{self, header}; use crate::http::hyper::{self, header};
@ -43,9 +46,9 @@ pub struct Rocket {
fairings: Fairings, fairings: Fairings,
} }
#[derive(Clone)]
struct RocketHyperService { struct RocketHyperService {
rocket: Arc<Rocket>, rocket: Arc<Rocket>,
spawn: Box<dyn futures::task::Spawn + Send>,
remote_addr: std::net::SocketAddr, remote_addr: std::net::SocketAddr,
} }
@ -76,7 +79,13 @@ impl hyper::Service for RocketHyperService {
let rocket = self.rocket.clone(); let rocket = self.rocket.clone();
let h_addr = self.remote_addr; let h_addr = self.remote_addr;
async move { // This future must return a hyper::Response, but that's not easy
// because the response body might borrow from the request. Instead,
// we do the body writing in another future that will send us
// the response metadata (and a body channel) beforehand.
let (tx, rx) = futures::channel::oneshot::channel();
self.spawn.spawn(async move {
// Get all of the information from Hyper. // Get all of the information from Hyper.
let (h_parts, h_body) = hyp_req.into_parts(); let (h_parts, h_body) = hyp_req.into_parts();
@ -92,7 +101,7 @@ impl hyper::Service for RocketHyperService {
// handler) instead of doing this. // handler) instead of doing this.
let dummy = Request::new(&rocket, Method::Get, Origin::dummy()); let dummy = Request::new(&rocket, Method::Get, Origin::dummy());
let r = rocket.handle_error(Status::BadRequest, &dummy).await; let r = rocket.handle_error(Status::BadRequest, &dummy).await;
return rocket.issue_response(r).await; return rocket.issue_response(r, tx).await;
} }
}; };
@ -101,7 +110,11 @@ impl hyper::Service for RocketHyperService {
// Dispatch the request to get a response, then write that response out. // Dispatch the request to get a response, then write that response out.
let r = rocket.dispatch(&mut req, data).await; let r = rocket.dispatch(&mut req, data).await;
rocket.issue_response(r).await rocket.issue_response(r, tx).await;
}).expect("failed to spawn handler");
async move {
Ok(rx.await.expect("TODO.async: sender was dropped, error instead"))
}.boxed().compat() }.boxed().compat()
} }
} }
@ -109,28 +122,31 @@ impl hyper::Service for RocketHyperService {
impl Rocket { impl Rocket {
// TODO.async: Reconsider io::Result // TODO.async: Reconsider io::Result
#[inline] #[inline]
fn issue_response<'r>(&self, response: Response<'r>) -> impl Future<Output = io::Result<hyper::Response<hyper::Body>>> + 'r { fn issue_response<'r>(
let result = self.write_response(response); &self,
Box::pin(async move { response: Response<'r>,
tx: futures::channel::oneshot::Sender<hyper::Response<hyper::Body>>,
) -> impl Future<Output = ()> + 'r {
let result = self.write_response(response, tx);
async move {
match result.await { match result.await {
Ok(r) => { Ok(()) => {
info_!("{}", Paint::green("Response succeeded.")); info_!("{}", Paint::green("Response succeeded."));
Ok(r)
} }
Err(e) => { Err(e) => {
error_!("Failed to write response: {:?}.", e); error_!("Failed to write response: {:?}.", e);
Err(e)
} }
} }
}) }
} }
#[inline] #[inline]
fn write_response<'r>( fn write_response<'r>(
&self, &self,
mut response: Response<'r>, mut response: Response<'r>,
) -> impl Future<Output = io::Result<hyper::Response<hyper::Body>>> + 'r { tx: futures::channel::oneshot::Sender<hyper::Response<hyper::Body>>,
Box::pin(async move { ) -> impl Future<Output = io::Result<()>> + 'r {
async move {
let mut hyp_res = hyper::Response::builder(); let mut hyp_res = hyper::Response::builder();
hyp_res.status(response.status().code); hyp_res.status(response.status().code);
@ -140,48 +156,58 @@ impl Rocket {
hyp_res.header(name, value); hyp_res.header(name, value);
} }
let body = match response.body() { let send_response = move |mut hyp_res: hyper::ResponseBuilder, body| -> io::Result<()> {
let response = hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
tx.send(response).expect("channel receiver should not be dropped");
Ok(())
};
match response.body() {
None => { None => {
hyp_res.header(header::CONTENT_LENGTH, "0"); hyp_res.header(header::CONTENT_LENGTH, "0");
hyper::Body::empty() send_response(hyp_res, hyper::Body::empty())?;
} }
Some(Body::Sized(body, size)) => { Some(Body::Sized(body, size)) => {
hyp_res.header(header::CONTENT_LENGTH, size.to_string()); hyp_res.header(header::CONTENT_LENGTH, size.to_string());
let (sender, hyp_body) = hyper::Body::channel();
send_response(hyp_res, hyp_body)?;
// TODO.async: Stream the data instead of buffering. let mut stream = body.into_chunk_stream(4096);
// TODO.async: Possible truncation (u64 -> usize) let mut sink = sender.sink_compat().sink_map_err(|e| {
let mut buffer = Vec::with_capacity(size as usize); io::Error::new(io::ErrorKind::Other, e)
body.read_to_end(&mut buffer).await?; });
hyper::Body::from(buffer)
while let Some(next) = stream.next().await {
sink.send(next?).await?;
}
// TODO.async: This should be better, but it creates an
// incomprehensible error messasge instead
// stream.forward(sink).await;
} }
Some(Body::Chunked(body, _chunk_size)) => { Some(Body::Chunked(body, chunk_size)) => {
// // This _might_ happen on a 32-bit machine! // TODO.async: This is identical to Body::Sized except for the chunk size
// if chunk_size > (usize::max_value() as u64) {
// let msg = "chunk size exceeds limits of usize type";
// return Err(io::Error::new(io::ErrorKind::Other, msg));
// }
//
// // The buffer stores the current chunk being written out.
// let mut buffer = vec![0; chunk_size as usize];
// let mut stream = hyp_res.start()?;
// loop {
// match body.read_max(&mut buffer)? {
// 0 => break,
// n => stream.write_all(&buffer[..n])?,
// }
// }
//
// stream.end()
// TODO.async: Stream the data instead of buffering. let (sender, hyp_body) = hyper::Body::channel();
let mut buffer = Vec::new(); send_response(hyp_res, hyp_body)?;
body.read_to_end(&mut buffer).await?;
hyper::Body::from(buffer) let mut stream = body.into_chunk_stream(chunk_size.try_into().expect("u64 -> usize overflow"));
let mut sink = sender.sink_compat().sink_map_err(|e| {
io::Error::new(io::ErrorKind::Other, e)
});
while let Some(next) = stream.next().await {
sink.send(next?).await?;
}
// TODO.async: This should be better, but it creates an
// incomprehensible error messasge instead
// stream.forward(sink).await;
} }
}; };
Ok(hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?) Ok(())
}) }
} }
} }
@ -767,9 +793,11 @@ impl Rocket {
logger::pop_max_level(); logger::pop_max_level();
let rocket = Arc::new(self); let rocket = Arc::new(self);
let spawn = Box::new(runtime.executor().compat());
let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| { let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| {
futures::future::ok::<_, Box<dyn std::error::Error + Send + Sync>>(RocketHyperService { futures::future::ok::<_, Box<dyn std::error::Error + Send + Sync>>(RocketHyperService {
rocket: rocket.clone(), rocket: rocket.clone(),
spawn: spawn.clone(),
remote_addr: socket.remote_addr(), remote_addr: socket.remote_addr(),
}).compat() }).compat()
}); });