diff --git a/core/http/src/hyper.rs b/core/http/src/hyper.rs index 7f1a4f01..143afcdc 100644 --- a/core/http/src/hyper.rs +++ b/core/http/src/hyper.rs @@ -16,6 +16,7 @@ #[doc(hidden)] pub use http::header::HeaderValue as HeaderValue; #[doc(hidden)] pub use http::method::Method; #[doc(hidden)] pub use http::request::Parts as RequestParts; +#[doc(hidden)] pub use http::response::Builder as ResponseBuilder; #[doc(hidden)] pub use http::status::StatusCode; #[doc(hidden)] pub use http::uri::Uri; diff --git a/core/lib/src/ext.rs b/core/lib/src/ext.rs index 6cb1c16b..f7996e4f 100644 --- a/core/lib/src/ext.rs +++ b/core/lib/src/ext.rs @@ -3,8 +3,11 @@ use std::pin::Pin; use futures::io::{AsyncRead, AsyncReadExt as _}; use futures::future::{Future}; +use futures::stream::Stream; use futures::task::{Poll, Context}; +use crate::http::hyper::Chunk; + // Based on std::io::Take, but for AsyncRead instead of Read pub struct Take{ inner: R, @@ -30,11 +33,45 @@ impl AsyncRead for Take where R: AsyncRead + Unpin { } } +pub struct IntoChunkStream { + inner: R, + buf_size: usize, + buffer: Vec, +} + +// TODO.async: Verify correctness of this implementation. +impl Stream for IntoChunkStream + where R: AsyncRead + Unpin +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>{ + assert!(self.buffer.len() == self.buf_size); + + let Self { ref mut inner, ref mut buffer, buf_size } = *self; + + match Pin::new(inner).poll_read(cx, &mut buffer[..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None), + Poll::Ready(Ok(n)) => { + let mut next = std::mem::replace(buffer, vec![0; buf_size]); + next.truncate(n); + Poll::Ready(Some(Ok(Chunk::from(next)))) + } + } + } +} + pub trait AsyncReadExt: AsyncRead { fn take(self, limit: u64) -> Take where Self: Sized { Take { inner: self, limit } } + fn into_chunk_stream(self, buf_size: usize) -> IntoChunkStream where Self: Sized { + IntoChunkStream { inner: self, buf_size, buffer: vec![0; buf_size] } + } + // TODO.async: Verify correctness of this implementation. fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> Pin> + Send + '_>> where Self: Send + Unpin diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 7226eced..cb1e7456 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::convert::From; +use std::convert::{From, TryInto}; use std::cmp::min; use std::io; use std::mem; @@ -8,9 +8,11 @@ use std::sync::Arc; use std::time::Duration; use std::pin::Pin; -use futures::compat::Compat; +use futures::compat::{Compat, Executor01CompatExt, Sink01CompatExt}; use futures::future::{Future, FutureExt, TryFutureExt}; -use futures::io::AsyncReadExt; +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use futures::task::SpawnExt; use yansi::Paint; use state::Container; @@ -27,6 +29,7 @@ use crate::catcher::{self, Catcher}; use crate::outcome::Outcome; use crate::error::{LaunchError, LaunchErrorKind}; use crate::fairing::{Fairing, Fairings}; +use crate::ext::AsyncReadExt; use crate::http::{Method, Status, Header}; use crate::http::hyper::{self, header}; @@ -43,9 +46,9 @@ pub struct Rocket { fairings: Fairings, } -#[derive(Clone)] struct RocketHyperService { rocket: Arc, + spawn: Box, remote_addr: std::net::SocketAddr, } @@ -76,7 +79,13 @@ impl hyper::Service for RocketHyperService { let rocket = self.rocket.clone(); let h_addr = self.remote_addr; - async move { + // This future must return a hyper::Response, but that's not easy + // because the response body might borrow from the request. Instead, + // we do the body writing in another future that will send us + // the response metadata (and a body channel) beforehand. + let (tx, rx) = futures::channel::oneshot::channel(); + + self.spawn.spawn(async move { // Get all of the information from Hyper. let (h_parts, h_body) = hyp_req.into_parts(); @@ -92,7 +101,7 @@ impl hyper::Service for RocketHyperService { // handler) instead of doing this. let dummy = Request::new(&rocket, Method::Get, Origin::dummy()); let r = rocket.handle_error(Status::BadRequest, &dummy).await; - return rocket.issue_response(r).await; + return rocket.issue_response(r, tx).await; } }; @@ -101,7 +110,11 @@ impl hyper::Service for RocketHyperService { // Dispatch the request to get a response, then write that response out. let r = rocket.dispatch(&mut req, data).await; - rocket.issue_response(r).await + rocket.issue_response(r, tx).await; + }).expect("failed to spawn handler"); + + async move { + Ok(rx.await.expect("TODO.async: sender was dropped, error instead")) }.boxed().compat() } } @@ -109,28 +122,31 @@ impl hyper::Service for RocketHyperService { impl Rocket { // TODO.async: Reconsider io::Result #[inline] - fn issue_response<'r>(&self, response: Response<'r>) -> impl Future>> + 'r { - let result = self.write_response(response); - Box::pin(async move { + fn issue_response<'r>( + &self, + response: Response<'r>, + tx: futures::channel::oneshot::Sender>, + ) -> impl Future + 'r { + let result = self.write_response(response, tx); + async move { match result.await { - Ok(r) => { + Ok(()) => { info_!("{}", Paint::green("Response succeeded.")); - Ok(r) } Err(e) => { error_!("Failed to write response: {:?}.", e); - Err(e) } } - }) + } } #[inline] fn write_response<'r>( &self, mut response: Response<'r>, - ) -> impl Future>> + 'r { - Box::pin(async move { + tx: futures::channel::oneshot::Sender>, + ) -> impl Future> + 'r { + async move { let mut hyp_res = hyper::Response::builder(); hyp_res.status(response.status().code); @@ -140,48 +156,58 @@ impl Rocket { hyp_res.header(name, value); } - let body = match response.body() { + let send_response = move |mut hyp_res: hyper::ResponseBuilder, body| -> io::Result<()> { + let response = hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + tx.send(response).expect("channel receiver should not be dropped"); + Ok(()) + }; + + match response.body() { None => { hyp_res.header(header::CONTENT_LENGTH, "0"); - hyper::Body::empty() + send_response(hyp_res, hyper::Body::empty())?; } Some(Body::Sized(body, size)) => { hyp_res.header(header::CONTENT_LENGTH, size.to_string()); + let (sender, hyp_body) = hyper::Body::channel(); + send_response(hyp_res, hyp_body)?; - // TODO.async: Stream the data instead of buffering. - // TODO.async: Possible truncation (u64 -> usize) - let mut buffer = Vec::with_capacity(size as usize); - body.read_to_end(&mut buffer).await?; - hyper::Body::from(buffer) + let mut stream = body.into_chunk_stream(4096); + let mut sink = sender.sink_compat().sink_map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + + while let Some(next) = stream.next().await { + sink.send(next?).await?; + } + + // TODO.async: This should be better, but it creates an + // incomprehensible error messasge instead + // stream.forward(sink).await; } - Some(Body::Chunked(body, _chunk_size)) => { - // // This _might_ happen on a 32-bit machine! - // if chunk_size > (usize::max_value() as u64) { - // let msg = "chunk size exceeds limits of usize type"; - // return Err(io::Error::new(io::ErrorKind::Other, msg)); - // } - // - // // The buffer stores the current chunk being written out. - // let mut buffer = vec![0; chunk_size as usize]; - // let mut stream = hyp_res.start()?; - // loop { - // match body.read_max(&mut buffer)? { - // 0 => break, - // n => stream.write_all(&buffer[..n])?, - // } - // } - // - // stream.end() + Some(Body::Chunked(body, chunk_size)) => { + // TODO.async: This is identical to Body::Sized except for the chunk size - // TODO.async: Stream the data instead of buffering. - let mut buffer = Vec::new(); - body.read_to_end(&mut buffer).await?; - hyper::Body::from(buffer) + let (sender, hyp_body) = hyper::Body::channel(); + send_response(hyp_res, hyp_body)?; + + let mut stream = body.into_chunk_stream(chunk_size.try_into().expect("u64 -> usize overflow")); + let mut sink = sender.sink_compat().sink_map_err(|e| { + io::Error::new(io::ErrorKind::Other, e) + }); + + while let Some(next) = stream.next().await { + sink.send(next?).await?; + } + + // TODO.async: This should be better, but it creates an + // incomprehensible error messasge instead + // stream.forward(sink).await; } }; - Ok(hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?) - }) + Ok(()) + } } } @@ -767,9 +793,11 @@ impl Rocket { logger::pop_max_level(); let rocket = Arc::new(self); + let spawn = Box::new(runtime.executor().compat()); let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| { futures::future::ok::<_, Box>(RocketHyperService { rocket: rocket.clone(), + spawn: spawn.clone(), remote_addr: socket.remote_addr(), }).compat() });