Enable configurable 'ctrl-c' shutdown by default.

This removes the 'ctrl_c_shutdown' feature opting instead for a 'ctrlc'
configuration option. To avoid further merge conflicts with the master
branch, the option is currently read as an extra.

Co-authored-by: Jeb Rosen <jeb@jebrosen.com>
This commit is contained in:
Sergio Benitez 2020-07-08 22:05:54 -07:00
parent 9277ddafdf
commit 824de061c3
4 changed files with 61 additions and 76 deletions

View File

@ -24,7 +24,7 @@ pub trait Listener {
fn local_addr(&self) -> Option<SocketAddr>; fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready /// Try to accept an incoming Connection if ready
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<Self::Connection, io::Error>>; fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>>;
} }
/// A 'Connection' represents an open connection to a client /// A 'Connection' represents an open connection to a client
@ -125,7 +125,10 @@ impl<L: Listener + Unpin> Accept for Incoming<L> {
type Conn = L::Connection; type Conn = L::Connection;
type Error = io::Error; type Error = io::Error;
fn poll_accept(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Self::Conn, Self::Error>>> { fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Option<io::Result<Self::Conn>>> {
self.poll_next(cx).map(Some) self.poll_next(cx).map(Some)
} }
} }
@ -154,10 +157,8 @@ impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
} }
} }
pub fn bind_tcp(address: SocketAddr) -> Pin<Box<dyn Future<Output=Result<TcpListener, io::Error>> + Send>> { pub async fn bind_tcp(address: SocketAddr) -> io::Result<TcpListener> {
Box::pin(async move { Ok(TcpListener::bind(address).await?)
Ok(TcpListener::bind(address).await?)
})
} }
impl Listener for TcpListener { impl Listener for TcpListener {
@ -167,7 +168,7 @@ impl Listener for TcpListener {
self.local_addr().ok() self.local_addr().ok()
} }
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<Self::Connection, io::Error>> { fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>> {
self.poll_accept(cx).map_ok(|(stream, _addr)| stream) self.poll_accept(cx).map_ok(|(stream, _addr)| stream)
} }
} }

View File

@ -113,28 +113,24 @@ impl Listener for TlsListener {
} }
} }
pub fn bind_tls(address: SocketAddr, cert_chain: Vec<Certificate>, key: PrivateKey) pub async fn bind_tls(
-> Pin<Box<dyn Future<Output=Result<TlsListener, io::Error>> + Send>> address: SocketAddr,
{ cert_chain: Vec<Certificate>,
Box::pin(async move { key: PrivateKey
let listener = TcpListener::bind(address).await?; ) -> io::Result<TlsListener> {
let listener = TcpListener::bind(address).await?;
let client_auth = rustls::NoClientAuth::new(); let client_auth = rustls::NoClientAuth::new();
let mut tls_config = ServerConfig::new(client_auth); let mut tls_config = ServerConfig::new(client_auth);
let cache = rustls::ServerSessionMemoryCache::new(1024); let cache = rustls::ServerSessionMemoryCache::new(1024);
tls_config.set_persistence(cache); tls_config.set_persistence(cache);
tls_config.ticketer = rustls::Ticketer::new(); tls_config.ticketer = rustls::Ticketer::new();
tls_config.set_single_cert(cert_chain, key).expect("invalid key"); tls_config.set_single_cert(cert_chain, key).expect("invalid key");
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let state = TlsListenerState::Listening;
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); Ok(TlsListener { listener, acceptor, state })
Ok(TlsListener {
listener,
acceptor,
state: TlsListenerState::Listening,
})
})
} }
impl Connection for TlsStream<TcpStream> { impl Connection for TlsStream<TcpStream> {

View File

@ -19,16 +19,14 @@ edition = "2018"
all-features = true all-features = true
[features] [features]
default = ["private-cookies", "ctrl_c_shutdown"] default = ["private-cookies"]
tls = ["rocket_http/tls"] tls = ["rocket_http/tls"]
private-cookies = ["rocket_http/private-cookies"] private-cookies = ["rocket_http/private-cookies"]
ctrl_c_shutdown = ["tokio/signal"]
[dependencies] [dependencies]
rocket_codegen = { version = "0.5.0-dev", path = "../codegen" } rocket_codegen = { version = "0.5.0-dev", path = "../codegen" }
rocket_http = { version = "0.5.0-dev", path = "../http" } rocket_http = { version = "0.5.0-dev", path = "../http" }
futures = "0.3.0" futures = "0.3.0"
tokio = { version = "0.2.9", features = ["fs", "io-std", "io-util", "rt-threaded", "sync"] }
yansi = "0.5" yansi = "0.5"
log = { version = "0.4", features = ["std"] } log = { version = "0.4", features = ["std"] }
toml = "0.4.7" toml = "0.4.7"
@ -42,6 +40,10 @@ atty = "0.2"
async-trait = "0.1" async-trait = "0.1"
ref-cast = "1.0" ref-cast = "1.0"
[dependencies.tokio]
version = "0.2.9"
features = ["fs", "io-std", "io-util", "rt-threaded", "sync", "signal"]
[build-dependencies] [build-dependencies]
yansi = "0.5" yansi = "0.5"
version_check = "0.9.1" version_check = "0.9.1"

View File

@ -523,7 +523,9 @@ impl Rocket {
#[derive(Clone)] #[derive(Clone)]
struct TokioExecutor; struct TokioExecutor;
impl<Fut> hyper::Executor<Fut> for TokioExecutor where Fut: Future + Send + 'static, Fut::Output: Send { impl<Fut> hyper::Executor<Fut> for TokioExecutor
where Fut: Future + Send + 'static, Fut::Output: Send
{
fn execute(&self, fut: Fut) { fn execute(&self, fut: Fut) {
tokio::spawn(fut); tokio::spawn(fut);
} }
@ -945,9 +947,9 @@ impl Rocket {
/// Returns a `Future` that drives the server, listening for and dispatching /// Returns a `Future` that drives the server, listening for and dispatching
/// requests to mounted routes and catchers. The `Future` completes when the /// requests to mounted routes and catchers. The `Future` completes when the
/// server is shut down, via a [`ShutdownHandle`], or encounters a fatal /// server is shut down via a [`ShutdownHandle`], encounters a fatal error,
/// error. If the `ctrl_c_shutdown` feature is enabled, the server will /// or if the the `ctrlc` configuration option is set, when `Ctrl+C` is
/// also shut down once `Ctrl-C` is pressed. /// pressed.
/// ///
/// # Error /// # Error
/// ///
@ -969,6 +971,7 @@ impl Rocket {
/// ``` /// ```
pub async fn launch(mut self) -> Result<(), crate::error::Error> { pub async fn launch(mut self) -> Result<(), crate::error::Error> {
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use futures::future::Either;
use crate::error::Error::Launch; use crate::error::Error::Launch;
self.prelaunch_check().await.map_err(crate::error::Error::Launch)?; self.prelaunch_check().await.map_err(crate::error::Error::Launch)?;
@ -979,70 +982,53 @@ impl Rocket {
Err(e) => return Err(Launch(e.into())), Err(e) => return Err(Launch(e.into())),
}; };
#[cfg(feature = "ctrl_c_shutdown")] // FIXME: Make `ctrlc` a known `Rocket` config option.
let ( // If `ctrl-c` shutdown is enabled, we `select` on `the ctrl-c` signal
shutdown_handle, // and server. Otherwise, we only wait on the `server`, hence `pending`.
(cancel_ctrl_c_listener_sender, cancel_ctrl_c_listener_receiver) let shutdown_handle = self.shutdown_handle.clone();
) = ( let shutdown_signal = match self.config.get_bool("ctrlc") {
self.shutdown_handle.clone(), Ok(false) => futures::future::pending().boxed(),
oneshot::channel(), _ => tokio::signal::ctrl_c().boxed(),
); };
let server = { let server = {
macro_rules! listen_on { macro_rules! listen_on {
($expr:expr) => {{ ($expr:expr) => {{
let listener = match $expr { let listener = match $expr {
Ok(ok) => ok, Ok(ok) => ok,
Err(err) => return Err(Launch(LaunchError::new(LaunchErrorKind::Bind(err)))), Err(err) => return Err(Launch(LaunchError::new(LaunchErrorKind::Bind(err))))
}; };
self.listen_on(listener) self.listen_on(listener)
}}; }};
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")] {
{
if let Some(tls) = self.config.tls.clone() { if let Some(tls) = self.config.tls.clone() {
listen_on!(crate::http::tls::bind_tls(addr, tls.certs, tls.key).await).boxed() listen_on!(crate::http::tls::bind_tls(addr, tls.certs, tls.key).await).boxed()
} else { } else {
listen_on!(crate::http::private::bind_tcp(addr).await).boxed() listen_on!(crate::http::private::bind_tcp(addr).await).boxed()
} }
} }
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))] {
{ listen_on!(crate::http::private::bind_tcp(addr).await).boxed()
listen_on!(crate::http::private::bind_tcp(addr).await)
} }
}; };
#[cfg(feature = "ctrl_c_shutdown")] match futures::future::select(shutdown_signal, server).await {
let server = server.inspect(|_| { Either::Left((Ok(()), server)) => {
let _ = cancel_ctrl_c_listener_sender.send(()); // Ctrl-was pressed. Signal shutdown, wait for the server.
}); shutdown_handle.shutdown();
server.await
#[cfg(feature = "ctrl_c_shutdown")] }
{ Either::Left((Err(err), server)) => {
tokio::spawn(async move { // Error setting up ctrl-c signal. Let the user know.
use futures::future::{select, Either}; warn!("Failed to enable `ctrl+c` graceful signal shutdown.");
info_!("Error: {}", err);
let either = select( server.await
tokio::signal::ctrl_c().boxed(), }
cancel_ctrl_c_listener_receiver, // Server shut down before Ctrl-C; return the result.
).await; Either::Right((result, _)) => result,
match either {
Either::Left((Ok(()), _)) | Either::Right((_, _)) => shutdown_handle.shutdown(),
Either::Left((Err(err), _)) => {
// Signal handling isn't strictly necessary, so we can skip it
// if necessary. It's a good idea to let the user know we're
// doing so in case they are expecting certain behavior.
let message = "Not listening for shutdown keybinding.";
warn!("{}", Paint::yellow(message));
info_!("Error: {}", err);
}
}
});
} }
server.await
} }
} }