Use AddrIncoming/AddrStream.

This lets us keep support for keep-alive and remote address while doing
other work on async, at the cost of TLS. Abstracting over the connection
type will be done more thoroughly later.
This commit is contained in:
Jeb Rosen 2019-07-10 23:38:58 -07:00
parent e8322dbfb4
commit 0fe3f39304
7 changed files with 29 additions and 102 deletions

View File

@ -22,7 +22,7 @@ private-cookies = ["cookie/secure"]
[dependencies] [dependencies]
smallvec = "0.6" smallvec = "0.6"
percent-encoding = "1" percent-encoding = "1"
hyper = { version = "0.12.31", default-features = false, features = ["tokio"] } hyper = { version = "0.12.31", default-features = false, features = ["runtime"] }
http = "0.1.17" http = "0.1.17"
mime = "0.3.13" mime = "0.3.13"
time = "0.1" time = "0.1"

View File

@ -7,7 +7,8 @@
#[doc(hidden)] pub use hyper::{Body, Request, Response, Server}; #[doc(hidden)] pub use hyper::{Body, Request, Response, Server};
#[doc(hidden)] pub use hyper::body::Payload as Payload; #[doc(hidden)] pub use hyper::body::Payload as Payload;
#[doc(hidden)] pub use hyper::error::Error; #[doc(hidden)] pub use hyper::error::Error;
#[doc(hidden)] pub use hyper::service::{MakeService, Service}; #[doc(hidden)] pub use hyper::service::{make_service_fn, MakeService, Service};
#[doc(hidden)] pub use hyper::server::conn::{AddrIncoming, AddrStream};
#[doc(hidden)] pub use hyper::Chunk; #[doc(hidden)] pub use hyper::Chunk;
#[doc(hidden)] pub use http::header::HeaderMap; #[doc(hidden)] pub use http::header::HeaderMap;

View File

@ -3,3 +3,6 @@ pub use tokio_rustls::rustls;
pub use rustls::internal::pemfile; pub use rustls::internal::pemfile;
pub use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig}; pub use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
// TODO.async: extract from hyper-sync-rustls some convenience
// functions to load certs and keys

View File

@ -6,7 +6,7 @@ use futures::io::{self, AsyncRead, AsyncReadExt as _, AsyncWrite};
use futures::future::Future; use futures::future::Future;
use futures::stream::TryStreamExt; use futures::stream::TryStreamExt;
use super::data_stream::{DataStream, kill_stream}; use super::data_stream::DataStream;
use crate::http::hyper; use crate::http::hyper;
@ -234,9 +234,3 @@ impl std::borrow::Borrow<()> for Data {
&() &()
} }
} }
impl Drop for Data {
fn drop(&mut self) {
kill_stream(&mut self.stream);
}
}

View File

@ -34,28 +34,3 @@ impl AsyncRead for DataStream {
} }
} }
} }
// TODO.async: Either implement this somehow, or remove the
// `Drop` impl and other references to kill_stream
pub fn kill_stream(_stream: &mut dyn AsyncRead) {
// // Only do the expensive reading if we're not sure we're done.
//
// // Take <= 1k from the stream. If there might be more data, force close.
// const FLUSH_LEN: u64 = 1024;
// match io::copy(&mut stream.take(FLUSH_LEN), &mut io::sink()) {
// Ok(FLUSH_LEN) | Err(_) => {
// warn_!("Data left unread. Force closing network stream.");
// let (_, network) = stream.get_mut().get_mut();
// if let Err(e) = network.close(Shutdown::Read) {
// error_!("Failed to close network stream: {:?}", e);
// }
// }
// Ok(n) => debug!("flushed {} unread bytes", n)
// }
}
impl Drop for DataStream {
fn drop(&mut self) {
kill_stream(&mut self.1);
}
}

View File

@ -19,7 +19,7 @@ use crate::router::Route;
#[derive(Debug)] #[derive(Debug)]
pub enum LaunchErrorKind { pub enum LaunchErrorKind {
/// Binding to the provided address/port failed. /// Binding to the provided address/port failed.
Bind(io::Error), Bind(hyper::Error),
/// An I/O error occurred during launch. /// An I/O error occurred during launch.
Io(io::Error), Io(io::Error),
/// Route collisions were detected. /// Route collisions were detected.

View File

@ -5,6 +5,7 @@ use std::io;
use std::mem; use std::mem;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use std::pin::Pin; use std::pin::Pin;
use futures::compat::Compat; use futures::compat::Compat;
@ -13,8 +14,6 @@ use futures::io::AsyncReadExt;
use yansi::Paint; use yansi::Paint;
use state::Container; use state::Container;
use tokio::net::TcpListener;
use tokio::prelude::Stream as _;
#[cfg(feature = "tls")] use crate::http::tls::TlsAcceptor; #[cfg(feature = "tls")] use crate::http::tls::TlsAcceptor;
@ -47,6 +46,7 @@ pub struct Rocket {
#[derive(Clone)] #[derive(Clone)]
struct RocketHyperService { struct RocketHyperService {
rocket: Arc<Rocket>, rocket: Arc<Rocket>,
remote_addr: std::net::SocketAddr,
} }
impl std::ops::Deref for RocketHyperService { impl std::ops::Deref for RocketHyperService {
@ -57,19 +57,6 @@ impl std::ops::Deref for RocketHyperService {
} }
} }
impl<Ctx> hyper::MakeService<Ctx> for RocketHyperService {
type ReqBody = hyper::Body;
type ResBody = hyper::Body;
type Error = io::Error;
type Service = RocketHyperService;
type Future = Compat<futures::future::Ready<Result<Self::Service, Self::MakeError>>>;
type MakeError = Self::Error;
fn make_service(&mut self, _: Ctx) -> Self::Future {
futures::future::ok(RocketHyperService { rocket: self.rocket.clone() }).compat()
}
}
#[doc(hidden)] #[doc(hidden)]
impl hyper::Service for RocketHyperService { impl hyper::Service for RocketHyperService {
type ReqBody = hyper::Body; type ReqBody = hyper::Body;
@ -87,13 +74,12 @@ impl hyper::Service for RocketHyperService {
hyp_req: hyper::Request<Self::ReqBody>, hyp_req: hyper::Request<Self::ReqBody>,
) -> Self::Future { ) -> Self::Future {
let rocket = self.rocket.clone(); let rocket = self.rocket.clone();
let h_addr = self.remote_addr;
async move { async move {
// Get all of the information from Hyper. // Get all of the information from Hyper.
let (h_parts, h_body) = hyp_req.into_parts(); let (h_parts, h_body) = hyp_req.into_parts();
// TODO.async: Get the client address somehow.
let h_addr = "0.0.0.0:0".parse().expect("socket addr");
// Convert the Hyper request into a Rocket request. // Convert the Hyper request into a Rocket request.
let req_res = Request::from_hyp(&rocket, h_parts.method, h_parts.headers, h_parts.uri, h_addr); let req_res = Request::from_hyp(&rocket, h_parts.method, h_parts.headers, h_parts.uri, h_addr);
let mut req = match req_res { let mut req = match req_res {
@ -748,60 +734,22 @@ impl Rocket {
Err(e) => return From::from(io::Error::new(io::ErrorKind::Other, e)), Err(e) => return From::from(io::Error::new(io::ErrorKind::Other, e)),
}; };
let listener = match TcpListener::bind(&addrs[0]) { // TODO.async: support for TLS, unix sockets.
Ok(listener) => listener, // Likely will be implemented with a custom "Incoming" type.
let mut incoming = match hyper::AddrIncoming::bind(&addrs[0]) {
Ok(incoming) => incoming,
Err(e) => return LaunchError::new(LaunchErrorKind::Bind(e)), Err(e) => return LaunchError::new(LaunchErrorKind::Bind(e)),
}; };
// Determine the address and port we actually binded to. // Determine the address and port we actually binded to.
match listener.local_addr() { self.config.port = incoming.local_addr().port();
Ok(server_addr) => self.config.port = server_addr.port(),
Err(e) => return LaunchError::from(e),
}
// TODO.async Move all of this to http crate somewhere let proto = "http://";
// TODO.async Is boxing everything really the best we can do here?
trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send { }
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send { }
let proto; // Set the keep-alive.
let incoming: Box<dyn tokio::prelude::Stream<Item=Box<dyn AsyncReadWrite>, Error=std::io::Error> + Send>; let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64));
incoming.set_keepalive(timeout);
#[cfg(feature = "tls")]
{
use tokio::prelude::Future;
// TODO.async: Can/should we make the clone unnecessary (by reference, or by moving out?)
if let Some(tls) = self.config.tls.clone() {
proto = "https://";
let mut config = tls::rustls::ServerConfig::new(tls::rustls::NoClientAuth::new());
config.set_single_cert(tls.certs, tls.key).expect("invalid key or certificate");
// TODO.async: I once observed an unhandled AlertReceived(UnknownCA) but
// have no idea what happened and cannot reproduce.
let config = TlsAcceptor::from(Arc::new(config));
incoming = Box::new(listener.incoming().and_then(move |stream| {
config.accept(stream)
.map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>)
}));
} else {
proto = "http://";
incoming = Box::new(listener.incoming().map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>));
}
}
// TODO.async: Duplicated code
#[cfg(not(feature = "tls"))]
{
proto = "http://";
incoming = Box::new(listener.incoming().map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>));
}
// TODO.async: Set the keep-alive.
// // Set the keep-alive.
// let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64));
// server.keep_alive(timeout);
// Freeze managed state for synchronization-free accesses later. // Freeze managed state for synchronization-free accesses later.
self.state.freeze(); self.state.freeze();
@ -818,7 +766,13 @@ impl Rocket {
// Restore the log level back to what it originally was. // Restore the log level back to what it originally was.
logger::pop_max_level(); logger::pop_max_level();
let service = RocketHyperService { rocket: Arc::new(self) }; let rocket = Arc::new(self);
let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| {
futures::future::ok::<_, Box<dyn std::error::Error + Send + Sync>>(RocketHyperService {
rocket: rocket.clone(),
remote_addr: socket.remote_addr(),
}).compat()
});
// NB: executor must be passed manually here, see hyperium/hyper#1537 // NB: executor must be passed manually here, see hyperium/hyper#1537
let server = hyper::Server::builder(incoming) let server = hyper::Server::builder(incoming)