mirror of https://github.com/rwf2/Rocket.git
Stream body data instead of buffering it.
This requires some awkward channel and spawning work because Body might contain borrowed data.
This commit is contained in:
parent
0fe3f39304
commit
f83caf2d08
|
@ -16,6 +16,7 @@
|
|||
#[doc(hidden)] pub use http::header::HeaderValue as HeaderValue;
|
||||
#[doc(hidden)] pub use http::method::Method;
|
||||
#[doc(hidden)] pub use http::request::Parts as RequestParts;
|
||||
#[doc(hidden)] pub use http::response::Builder as ResponseBuilder;
|
||||
#[doc(hidden)] pub use http::status::StatusCode;
|
||||
#[doc(hidden)] pub use http::uri::Uri;
|
||||
|
||||
|
|
|
@ -3,8 +3,11 @@ use std::pin::Pin;
|
|||
|
||||
use futures::io::{AsyncRead, AsyncReadExt as _};
|
||||
use futures::future::{Future};
|
||||
use futures::stream::Stream;
|
||||
use futures::task::{Poll, Context};
|
||||
|
||||
use crate::http::hyper::Chunk;
|
||||
|
||||
// Based on std::io::Take, but for AsyncRead instead of Read
|
||||
pub struct Take<R>{
|
||||
inner: R,
|
||||
|
@ -30,11 +33,45 @@ impl<R> AsyncRead for Take<R> where R: AsyncRead + Unpin {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct IntoChunkStream<R> {
|
||||
inner: R,
|
||||
buf_size: usize,
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
// TODO.async: Verify correctness of this implementation.
|
||||
impl<R> Stream for IntoChunkStream<R>
|
||||
where R: AsyncRead + Unpin
|
||||
{
|
||||
type Item = Result<Chunk, io::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>{
|
||||
assert!(self.buffer.len() == self.buf_size);
|
||||
|
||||
let Self { ref mut inner, ref mut buffer, buf_size } = *self;
|
||||
|
||||
match Pin::new(inner).poll_read(cx, &mut buffer[..]) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Ok(n)) if n == 0 => Poll::Ready(None),
|
||||
Poll::Ready(Ok(n)) => {
|
||||
let mut next = std::mem::replace(buffer, vec![0; buf_size]);
|
||||
next.truncate(n);
|
||||
Poll::Ready(Some(Ok(Chunk::from(next))))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AsyncReadExt: AsyncRead {
|
||||
fn take(self, limit: u64) -> Take<Self> where Self: Sized {
|
||||
Take { inner: self, limit }
|
||||
}
|
||||
|
||||
fn into_chunk_stream(self, buf_size: usize) -> IntoChunkStream<Self> where Self: Sized {
|
||||
IntoChunkStream { inner: self, buf_size, buffer: vec![0; buf_size] }
|
||||
}
|
||||
|
||||
// TODO.async: Verify correctness of this implementation.
|
||||
fn read_max<'a>(&'a mut self, mut buf: &'a mut [u8]) -> Pin<Box<dyn Future<Output=io::Result<usize>> + Send + '_>>
|
||||
where Self: Send + Unpin
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
use std::convert::From;
|
||||
use std::convert::{From, TryInto};
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use std::mem;
|
||||
|
@ -8,9 +8,11 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use std::pin::Pin;
|
||||
|
||||
use futures::compat::Compat;
|
||||
use futures::compat::{Compat, Executor01CompatExt, Sink01CompatExt};
|
||||
use futures::future::{Future, FutureExt, TryFutureExt};
|
||||
use futures::io::AsyncReadExt;
|
||||
use futures::sink::SinkExt;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::task::SpawnExt;
|
||||
|
||||
use yansi::Paint;
|
||||
use state::Container;
|
||||
|
@ -27,6 +29,7 @@ use crate::catcher::{self, Catcher};
|
|||
use crate::outcome::Outcome;
|
||||
use crate::error::{LaunchError, LaunchErrorKind};
|
||||
use crate::fairing::{Fairing, Fairings};
|
||||
use crate::ext::AsyncReadExt;
|
||||
|
||||
use crate::http::{Method, Status, Header};
|
||||
use crate::http::hyper::{self, header};
|
||||
|
@ -43,9 +46,9 @@ pub struct Rocket {
|
|||
fairings: Fairings,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RocketHyperService {
|
||||
rocket: Arc<Rocket>,
|
||||
spawn: Box<dyn futures::task::Spawn + Send>,
|
||||
remote_addr: std::net::SocketAddr,
|
||||
}
|
||||
|
||||
|
@ -76,7 +79,13 @@ impl hyper::Service for RocketHyperService {
|
|||
let rocket = self.rocket.clone();
|
||||
let h_addr = self.remote_addr;
|
||||
|
||||
async move {
|
||||
// This future must return a hyper::Response, but that's not easy
|
||||
// because the response body might borrow from the request. Instead,
|
||||
// we do the body writing in another future that will send us
|
||||
// the response metadata (and a body channel) beforehand.
|
||||
let (tx, rx) = futures::channel::oneshot::channel();
|
||||
|
||||
self.spawn.spawn(async move {
|
||||
// Get all of the information from Hyper.
|
||||
let (h_parts, h_body) = hyp_req.into_parts();
|
||||
|
||||
|
@ -92,7 +101,7 @@ impl hyper::Service for RocketHyperService {
|
|||
// handler) instead of doing this.
|
||||
let dummy = Request::new(&rocket, Method::Get, Origin::dummy());
|
||||
let r = rocket.handle_error(Status::BadRequest, &dummy).await;
|
||||
return rocket.issue_response(r).await;
|
||||
return rocket.issue_response(r, tx).await;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -101,7 +110,11 @@ impl hyper::Service for RocketHyperService {
|
|||
|
||||
// Dispatch the request to get a response, then write that response out.
|
||||
let r = rocket.dispatch(&mut req, data).await;
|
||||
rocket.issue_response(r).await
|
||||
rocket.issue_response(r, tx).await;
|
||||
}).expect("failed to spawn handler");
|
||||
|
||||
async move {
|
||||
Ok(rx.await.expect("TODO.async: sender was dropped, error instead"))
|
||||
}.boxed().compat()
|
||||
}
|
||||
}
|
||||
|
@ -109,28 +122,31 @@ impl hyper::Service for RocketHyperService {
|
|||
impl Rocket {
|
||||
// TODO.async: Reconsider io::Result
|
||||
#[inline]
|
||||
fn issue_response<'r>(&self, response: Response<'r>) -> impl Future<Output = io::Result<hyper::Response<hyper::Body>>> + 'r {
|
||||
let result = self.write_response(response);
|
||||
Box::pin(async move {
|
||||
fn issue_response<'r>(
|
||||
&self,
|
||||
response: Response<'r>,
|
||||
tx: futures::channel::oneshot::Sender<hyper::Response<hyper::Body>>,
|
||||
) -> impl Future<Output = ()> + 'r {
|
||||
let result = self.write_response(response, tx);
|
||||
async move {
|
||||
match result.await {
|
||||
Ok(r) => {
|
||||
Ok(()) => {
|
||||
info_!("{}", Paint::green("Response succeeded."));
|
||||
Ok(r)
|
||||
}
|
||||
Err(e) => {
|
||||
error_!("Failed to write response: {:?}.", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_response<'r>(
|
||||
&self,
|
||||
mut response: Response<'r>,
|
||||
) -> impl Future<Output = io::Result<hyper::Response<hyper::Body>>> + 'r {
|
||||
Box::pin(async move {
|
||||
tx: futures::channel::oneshot::Sender<hyper::Response<hyper::Body>>,
|
||||
) -> impl Future<Output = io::Result<()>> + 'r {
|
||||
async move {
|
||||
let mut hyp_res = hyper::Response::builder();
|
||||
hyp_res.status(response.status().code);
|
||||
|
||||
|
@ -140,48 +156,58 @@ impl Rocket {
|
|||
hyp_res.header(name, value);
|
||||
}
|
||||
|
||||
let body = match response.body() {
|
||||
let send_response = move |mut hyp_res: hyper::ResponseBuilder, body| -> io::Result<()> {
|
||||
let response = hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
tx.send(response).expect("channel receiver should not be dropped");
|
||||
Ok(())
|
||||
};
|
||||
|
||||
match response.body() {
|
||||
None => {
|
||||
hyp_res.header(header::CONTENT_LENGTH, "0");
|
||||
hyper::Body::empty()
|
||||
send_response(hyp_res, hyper::Body::empty())?;
|
||||
}
|
||||
Some(Body::Sized(body, size)) => {
|
||||
hyp_res.header(header::CONTENT_LENGTH, size.to_string());
|
||||
let (sender, hyp_body) = hyper::Body::channel();
|
||||
send_response(hyp_res, hyp_body)?;
|
||||
|
||||
// TODO.async: Stream the data instead of buffering.
|
||||
// TODO.async: Possible truncation (u64 -> usize)
|
||||
let mut buffer = Vec::with_capacity(size as usize);
|
||||
body.read_to_end(&mut buffer).await?;
|
||||
hyper::Body::from(buffer)
|
||||
let mut stream = body.into_chunk_stream(4096);
|
||||
let mut sink = sender.sink_compat().sink_map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
});
|
||||
|
||||
while let Some(next) = stream.next().await {
|
||||
sink.send(next?).await?;
|
||||
}
|
||||
|
||||
// TODO.async: This should be better, but it creates an
|
||||
// incomprehensible error messasge instead
|
||||
// stream.forward(sink).await;
|
||||
}
|
||||
Some(Body::Chunked(body, _chunk_size)) => {
|
||||
// // This _might_ happen on a 32-bit machine!
|
||||
// if chunk_size > (usize::max_value() as u64) {
|
||||
// let msg = "chunk size exceeds limits of usize type";
|
||||
// return Err(io::Error::new(io::ErrorKind::Other, msg));
|
||||
// }
|
||||
//
|
||||
// // The buffer stores the current chunk being written out.
|
||||
// let mut buffer = vec![0; chunk_size as usize];
|
||||
// let mut stream = hyp_res.start()?;
|
||||
// loop {
|
||||
// match body.read_max(&mut buffer)? {
|
||||
// 0 => break,
|
||||
// n => stream.write_all(&buffer[..n])?,
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// stream.end()
|
||||
Some(Body::Chunked(body, chunk_size)) => {
|
||||
// TODO.async: This is identical to Body::Sized except for the chunk size
|
||||
|
||||
// TODO.async: Stream the data instead of buffering.
|
||||
let mut buffer = Vec::new();
|
||||
body.read_to_end(&mut buffer).await?;
|
||||
hyper::Body::from(buffer)
|
||||
let (sender, hyp_body) = hyper::Body::channel();
|
||||
send_response(hyp_res, hyp_body)?;
|
||||
|
||||
let mut stream = body.into_chunk_stream(chunk_size.try_into().expect("u64 -> usize overflow"));
|
||||
let mut sink = sender.sink_compat().sink_map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
});
|
||||
|
||||
while let Some(next) = stream.next().await {
|
||||
sink.send(next?).await?;
|
||||
}
|
||||
|
||||
// TODO.async: This should be better, but it creates an
|
||||
// incomprehensible error messasge instead
|
||||
// stream.forward(sink).await;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(hyp_res.body(body).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?)
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -767,9 +793,11 @@ impl Rocket {
|
|||
logger::pop_max_level();
|
||||
|
||||
let rocket = Arc::new(self);
|
||||
let spawn = Box::new(runtime.executor().compat());
|
||||
let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| {
|
||||
futures::future::ok::<_, Box<dyn std::error::Error + Send + Sync>>(RocketHyperService {
|
||||
rocket: rocket.clone(),
|
||||
spawn: spawn.clone(),
|
||||
remote_addr: socket.remote_addr(),
|
||||
}).compat()
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue