mirror of https://github.com/rwf2/Rocket.git
Use AddrIncoming/AddrStream.
This lets us keep support for keep-alive and remote address while doing other work on async, at the cost of TLS. Abstracting over the connection type will be done more thoroughly later.
This commit is contained in:
parent
e8322dbfb4
commit
0fe3f39304
|
@ -22,7 +22,7 @@ private-cookies = ["cookie/secure"]
|
|||
[dependencies]
|
||||
smallvec = "0.6"
|
||||
percent-encoding = "1"
|
||||
hyper = { version = "0.12.31", default-features = false, features = ["tokio"] }
|
||||
hyper = { version = "0.12.31", default-features = false, features = ["runtime"] }
|
||||
http = "0.1.17"
|
||||
mime = "0.3.13"
|
||||
time = "0.1"
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
#[doc(hidden)] pub use hyper::{Body, Request, Response, Server};
|
||||
#[doc(hidden)] pub use hyper::body::Payload as Payload;
|
||||
#[doc(hidden)] pub use hyper::error::Error;
|
||||
#[doc(hidden)] pub use hyper::service::{MakeService, Service};
|
||||
#[doc(hidden)] pub use hyper::service::{make_service_fn, MakeService, Service};
|
||||
#[doc(hidden)] pub use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
|
||||
#[doc(hidden)] pub use hyper::Chunk;
|
||||
#[doc(hidden)] pub use http::header::HeaderMap;
|
||||
|
|
|
@ -3,3 +3,6 @@ pub use tokio_rustls::rustls;
|
|||
|
||||
pub use rustls::internal::pemfile;
|
||||
pub use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
|
||||
|
||||
// TODO.async: extract from hyper-sync-rustls some convenience
|
||||
// functions to load certs and keys
|
||||
|
|
|
@ -6,7 +6,7 @@ use futures::io::{self, AsyncRead, AsyncReadExt as _, AsyncWrite};
|
|||
use futures::future::Future;
|
||||
use futures::stream::TryStreamExt;
|
||||
|
||||
use super::data_stream::{DataStream, kill_stream};
|
||||
use super::data_stream::DataStream;
|
||||
|
||||
use crate::http::hyper;
|
||||
|
||||
|
@ -234,9 +234,3 @@ impl std::borrow::Borrow<()> for Data {
|
|||
&()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Data {
|
||||
fn drop(&mut self) {
|
||||
kill_stream(&mut self.stream);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,28 +34,3 @@ impl AsyncRead for DataStream {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO.async: Either implement this somehow, or remove the
|
||||
// `Drop` impl and other references to kill_stream
|
||||
pub fn kill_stream(_stream: &mut dyn AsyncRead) {
|
||||
// // Only do the expensive reading if we're not sure we're done.
|
||||
//
|
||||
// // Take <= 1k from the stream. If there might be more data, force close.
|
||||
// const FLUSH_LEN: u64 = 1024;
|
||||
// match io::copy(&mut stream.take(FLUSH_LEN), &mut io::sink()) {
|
||||
// Ok(FLUSH_LEN) | Err(_) => {
|
||||
// warn_!("Data left unread. Force closing network stream.");
|
||||
// let (_, network) = stream.get_mut().get_mut();
|
||||
// if let Err(e) = network.close(Shutdown::Read) {
|
||||
// error_!("Failed to close network stream: {:?}", e);
|
||||
// }
|
||||
// }
|
||||
// Ok(n) => debug!("flushed {} unread bytes", n)
|
||||
// }
|
||||
}
|
||||
|
||||
impl Drop for DataStream {
|
||||
fn drop(&mut self) {
|
||||
kill_stream(&mut self.1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ use crate::router::Route;
|
|||
#[derive(Debug)]
|
||||
pub enum LaunchErrorKind {
|
||||
/// Binding to the provided address/port failed.
|
||||
Bind(io::Error),
|
||||
Bind(hyper::Error),
|
||||
/// An I/O error occurred during launch.
|
||||
Io(io::Error),
|
||||
/// Route collisions were detected.
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::io;
|
|||
use std::mem;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::pin::Pin;
|
||||
|
||||
use futures::compat::Compat;
|
||||
|
@ -13,8 +14,6 @@ use futures::io::AsyncReadExt;
|
|||
|
||||
use yansi::Paint;
|
||||
use state::Container;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::prelude::Stream as _;
|
||||
|
||||
#[cfg(feature = "tls")] use crate::http::tls::TlsAcceptor;
|
||||
|
||||
|
@ -47,6 +46,7 @@ pub struct Rocket {
|
|||
#[derive(Clone)]
|
||||
struct RocketHyperService {
|
||||
rocket: Arc<Rocket>,
|
||||
remote_addr: std::net::SocketAddr,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for RocketHyperService {
|
||||
|
@ -57,19 +57,6 @@ impl std::ops::Deref for RocketHyperService {
|
|||
}
|
||||
}
|
||||
|
||||
impl<Ctx> hyper::MakeService<Ctx> for RocketHyperService {
|
||||
type ReqBody = hyper::Body;
|
||||
type ResBody = hyper::Body;
|
||||
type Error = io::Error;
|
||||
type Service = RocketHyperService;
|
||||
type Future = Compat<futures::future::Ready<Result<Self::Service, Self::MakeError>>>;
|
||||
type MakeError = Self::Error;
|
||||
|
||||
fn make_service(&mut self, _: Ctx) -> Self::Future {
|
||||
futures::future::ok(RocketHyperService { rocket: self.rocket.clone() }).compat()
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl hyper::Service for RocketHyperService {
|
||||
type ReqBody = hyper::Body;
|
||||
|
@ -87,13 +74,12 @@ impl hyper::Service for RocketHyperService {
|
|||
hyp_req: hyper::Request<Self::ReqBody>,
|
||||
) -> Self::Future {
|
||||
let rocket = self.rocket.clone();
|
||||
let h_addr = self.remote_addr;
|
||||
|
||||
async move {
|
||||
// Get all of the information from Hyper.
|
||||
let (h_parts, h_body) = hyp_req.into_parts();
|
||||
|
||||
// TODO.async: Get the client address somehow.
|
||||
let h_addr = "0.0.0.0:0".parse().expect("socket addr");
|
||||
|
||||
// Convert the Hyper request into a Rocket request.
|
||||
let req_res = Request::from_hyp(&rocket, h_parts.method, h_parts.headers, h_parts.uri, h_addr);
|
||||
let mut req = match req_res {
|
||||
|
@ -748,60 +734,22 @@ impl Rocket {
|
|||
Err(e) => return From::from(io::Error::new(io::ErrorKind::Other, e)),
|
||||
};
|
||||
|
||||
let listener = match TcpListener::bind(&addrs[0]) {
|
||||
Ok(listener) => listener,
|
||||
// TODO.async: support for TLS, unix sockets.
|
||||
// Likely will be implemented with a custom "Incoming" type.
|
||||
|
||||
let mut incoming = match hyper::AddrIncoming::bind(&addrs[0]) {
|
||||
Ok(incoming) => incoming,
|
||||
Err(e) => return LaunchError::new(LaunchErrorKind::Bind(e)),
|
||||
};
|
||||
|
||||
// Determine the address and port we actually binded to.
|
||||
match listener.local_addr() {
|
||||
Ok(server_addr) => self.config.port = server_addr.port(),
|
||||
Err(e) => return LaunchError::from(e),
|
||||
}
|
||||
self.config.port = incoming.local_addr().port();
|
||||
|
||||
// TODO.async Move all of this to http crate somewhere
|
||||
// TODO.async Is boxing everything really the best we can do here?
|
||||
trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send { }
|
||||
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send { }
|
||||
let proto = "http://";
|
||||
|
||||
let proto;
|
||||
let incoming: Box<dyn tokio::prelude::Stream<Item=Box<dyn AsyncReadWrite>, Error=std::io::Error> + Send>;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
use tokio::prelude::Future;
|
||||
|
||||
// TODO.async: Can/should we make the clone unnecessary (by reference, or by moving out?)
|
||||
if let Some(tls) = self.config.tls.clone() {
|
||||
proto = "https://";
|
||||
let mut config = tls::rustls::ServerConfig::new(tls::rustls::NoClientAuth::new());
|
||||
config.set_single_cert(tls.certs, tls.key).expect("invalid key or certificate");
|
||||
|
||||
// TODO.async: I once observed an unhandled AlertReceived(UnknownCA) but
|
||||
// have no idea what happened and cannot reproduce.
|
||||
let config = TlsAcceptor::from(Arc::new(config));
|
||||
|
||||
incoming = Box::new(listener.incoming().and_then(move |stream| {
|
||||
config.accept(stream)
|
||||
.map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>)
|
||||
}));
|
||||
} else {
|
||||
proto = "http://";
|
||||
incoming = Box::new(listener.incoming().map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO.async: Duplicated code
|
||||
#[cfg(not(feature = "tls"))]
|
||||
{
|
||||
proto = "http://";
|
||||
incoming = Box::new(listener.incoming().map(|stream| Box::new(stream) as Box<dyn AsyncReadWrite>));
|
||||
}
|
||||
|
||||
// TODO.async: Set the keep-alive.
|
||||
// // Set the keep-alive.
|
||||
// let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64));
|
||||
// server.keep_alive(timeout);
|
||||
// Set the keep-alive.
|
||||
let timeout = self.config.keep_alive.map(|s| Duration::from_secs(s as u64));
|
||||
incoming.set_keepalive(timeout);
|
||||
|
||||
// Freeze managed state for synchronization-free accesses later.
|
||||
self.state.freeze();
|
||||
|
@ -818,7 +766,13 @@ impl Rocket {
|
|||
// Restore the log level back to what it originally was.
|
||||
logger::pop_max_level();
|
||||
|
||||
let service = RocketHyperService { rocket: Arc::new(self) };
|
||||
let rocket = Arc::new(self);
|
||||
let service = hyper::make_service_fn(move |socket: &hyper::AddrStream| {
|
||||
futures::future::ok::<_, Box<dyn std::error::Error + Send + Sync>>(RocketHyperService {
|
||||
rocket: rocket.clone(),
|
||||
remote_addr: socket.remote_addr(),
|
||||
}).compat()
|
||||
});
|
||||
|
||||
// NB: executor must be passed manually here, see hyperium/hyper#1537
|
||||
let server = hyper::Server::builder(incoming)
|
||||
|
|
Loading…
Reference in New Issue