Implement graceful shutdown.

The crux of the implementation is as follows:

  * Configurable ctrl-c, signals that trigger a graceful shutdown.
  * Configurable grace period before forced I/O termination.
  * Programatic triggering via an application-wide method.
  * A future (`Shutdown`) that resolves only when shutdown is requested.

Resolves #180.
This commit is contained in:
Sergio Benitez 2021-04-28 02:01:35 -07:00
parent 63e6845386
commit a72e8da735
12 changed files with 774 additions and 107 deletions

View File

@ -24,7 +24,10 @@ pub trait Listener {
fn local_addr(&self) -> Option<SocketAddr>; fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready /// Try to accept an incoming Connection if ready
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>>; fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>>;
} }
/// A 'Connection' represents an open connection to a client /// A 'Connection' represents an open connection to a client
@ -40,16 +43,17 @@ pin_project_lite::pin_project! {
/// Accept). This type is internal to Rocket. /// Accept). This type is internal to Rocket.
#[must_use = "streams do nothing unless polled"] #[must_use = "streams do nothing unless polled"]
pub struct Incoming<L> { pub struct Incoming<L> {
listener: L,
sleep_on_errors: Option<Duration>, sleep_on_errors: Option<Duration>,
#[pin] #[pin]
pending_error_delay: Option<Sleep>, pending_error_delay: Option<Sleep>,
#[pin]
listener: L,
} }
} }
impl<L: Listener> Incoming<L> { impl<L: Listener> Incoming<L> {
/// Construct an `Incoming` from an existing `Listener`. /// Construct an `Incoming` from an existing `Listener`.
pub fn from_listener(listener: L) -> Self { pub fn new(listener: L) -> Self {
Self { Self {
listener, listener,
sleep_on_errors: Some(Duration::from_millis(250)), sleep_on_errors: Some(Duration::from_millis(250)),
@ -96,7 +100,7 @@ impl<L: Listener> Incoming<L> {
me.pending_error_delay.set(None); me.pending_error_delay.set(None);
match me.listener.poll_accept(cx) { match me.listener.as_mut().poll_accept(cx) {
Poll::Ready(Ok(stream)) => { Poll::Ready(Ok(stream)) => {
return Poll::Ready(Ok(stream)); return Poll::Ready(Ok(stream));
}, },
@ -123,7 +127,7 @@ impl<L: Listener> Incoming<L> {
} }
} }
impl<L: Listener + Unpin> Accept for Incoming<L> { impl<L: Listener> Accept for Incoming<L> {
type Conn = L::Connection; type Conn = L::Connection;
type Error = io::Error; type Error = io::Error;
@ -171,7 +175,10 @@ impl Listener for TcpListener {
self.local_addr().ok() self.local_addr().ok()
} }
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>> { fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream) (*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
} }
} }

View File

@ -64,7 +64,10 @@ impl Listener for TlsListener {
self.listener.local_addr().ok() self.listener.local_addr().ok()
} }
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>> { fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
loop { loop {
match self.state { match self.state {
TlsListenerState::Listening => { TlsListenerState::Listening => {

View File

@ -7,7 +7,7 @@ use figment::value::{Map, Dict};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use yansi::Paint; use yansi::Paint;
use crate::config::{TlsConfig, LogLevel}; use crate::config::{TlsConfig, LogLevel, Shutdown};
use crate::request::{self, Request, FromRequest}; use crate::request::{self, Request, FromRequest};
use crate::data::Limits; use crate::data::Limits;
@ -82,17 +82,16 @@ pub struct Config {
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))] #[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
#[serde(serialize_with = "SecretKey::serialize_zero")] #[serde(serialize_with = "SecretKey::serialize_zero")]
pub secret_key: SecretKey, pub secret_key: SecretKey,
/// The directory to store temporary files in. **(default: /// Directory to store temporary files in. **(default:
/// [`std::env::temp_dir`]). /// [`std::env::temp_dir()`])**
pub temp_dir: PathBuf, pub temp_dir: PathBuf,
/// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)** /// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)**
pub log_level: LogLevel, pub log_level: LogLevel,
/// Graceful shutdown configuration. **(default: [`Shutdown::default()`])**
pub shutdown: Shutdown,
/// Whether to use colors and emoji when logging. **(default: `true`)** /// Whether to use colors and emoji when logging. **(default: `true`)**
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")] #[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub cli_colors: bool, pub cli_colors: bool,
/// Whether `ctrl-c` initiates a server shutdown. **(default: `true`)**
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub ctrlc: bool,
} }
impl Default for Config { impl Default for Config {
@ -152,7 +151,7 @@ impl Config {
temp_dir: std::env::temp_dir(), temp_dir: std::env::temp_dir(),
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
cli_colors: true, cli_colors: true,
ctrlc: true, shutdown: Shutdown::default(),
} }
} }
@ -318,6 +317,7 @@ impl Config {
launch_info_!("temp dir: {}", Paint::default(&self.temp_dir.display()).bold()); launch_info_!("temp dir: {}", Paint::default(&self.temp_dir.display()).bold());
launch_info_!("log level: {}", Paint::default(self.log_level).bold()); launch_info_!("log level: {}", Paint::default(self.log_level).bold());
launch_info_!("cli colors: {}", Paint::default(&self.cli_colors).bold()); launch_info_!("cli colors: {}", Paint::default(&self.cli_colors).bold());
launch_info_!("shutdown: {}", Paint::default(&self.shutdown).bold());
// Check for now depreacted config values. // Check for now depreacted config values.
for (key, replacement) in Self::DEPRECATED_KEYS { for (key, replacement) in Self::DEPRECATED_KEYS {
@ -398,8 +398,8 @@ impl Config {
/// The stringy parameter name for setting/extracting [`Config::log_level`]. /// The stringy parameter name for setting/extracting [`Config::log_level`].
pub const LOG_LEVEL: &'static str = "log_level"; pub const LOG_LEVEL: &'static str = "log_level";
/// The stringy parameter name for setting/extracting [`Config::ctrlc`]. /// The stringy parameter name for setting/extracting [`Config::shutdown`].
pub const CTRLC: &'static str = "ctrlc"; pub const SHUTDOWN: &'static str = "shutdown";
} }
impl Provider for Config { impl Provider for Config {

View File

@ -113,6 +113,7 @@
mod config; mod config;
mod tls; mod tls;
mod shutdown;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
mod secret_key; mod secret_key;
@ -121,19 +122,24 @@ mod secret_key;
pub use config::Config; pub use config::Config;
pub use crate::log::LogLevel; pub use crate::log::LogLevel;
pub use shutdown::Shutdown;
pub use tls::TlsConfig; pub use tls::TlsConfig;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))] #[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
pub use secret_key::SecretKey; pub use secret_key::SecretKey;
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub use shutdown::Sig;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use figment::{Figment, Profile}; use figment::{Figment, Profile};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use crate::config::{Config, TlsConfig}; use crate::config::{Config, TlsConfig, Shutdown};
use crate::log::LogLevel; use crate::log::LogLevel;
use crate::data::{Limits, ToByteUnit}; use crate::data::{Limits, ToByteUnit};
@ -217,7 +223,7 @@ mod tests {
jail.create_file("Rocket.toml", r#" jail.create_file("Rocket.toml", r#"
[global] [global]
ctrlc = 0 shutdown.ctrlc = 0
[global.tls] [global.tls]
certs = "/ssl/cert.pem" certs = "/ssl/cert.pem"
@ -231,7 +237,7 @@ mod tests {
let config = Config::from(Config::figment()); let config = Config::from(Config::figment());
assert_eq!(config, Config { assert_eq!(config, Config {
ctrlc: false, shutdown: Shutdown { ctrlc: false, ..Default::default() },
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")), tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
limits: Limits::default() limits: Limits::default()
.limit("forms", 1.mebibytes()) .limit("forms", 1.mebibytes())

View File

@ -0,0 +1,273 @@
use std::fmt;
use std::future::Future;
#[cfg(unix)]
use std::collections::HashSet;
use futures::future::{Either, pending};
use serde::{Deserialize, Serialize};
/// A Unix signal for triggering graceful shutdown.
///
/// Each variant corresponds to a Unix process signal which can be used to
/// trigger a graceful shutdown. See [`Shutdown`] for details.
///
/// ## (De)serialization
///
/// A `Sig` variant serializes and deserializes as a lowercase string equal to
/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for
/// [`Sig::Chld`], and so on.
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Sig {
/// The `SIGALRM` Unix signal.
Alrm,
/// The `SIGCHLD` Unix signal.
Chld,
/// The `SIGHUP` Unix signal.
Hup,
/// The `SIGINT` Unix signal.
Int,
/// The `SIGIO` Unix signal.
Io,
/// The `SIGPIPE` Unix signal.
Pipe,
/// The `SIGQUIT` Unix signal.
Quit,
/// The `SIGTERM` Unix signal.
Term,
/// The `SIGUSR1` Unix signal.
Usr1,
/// The `SIGUSR2` Unix signal.
Usr2
}
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
impl fmt::Display for Sig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Sig::Alrm => "SIGALRM",
Sig::Chld => "SIGCHLD",
Sig::Hup => "SIGHUP",
Sig::Int => "SIGINT",
Sig::Io => "SIGIO",
Sig::Pipe => "SIGPIPE",
Sig::Quit => "SIGQUIT",
Sig::Term => "SIGTERM",
Sig::Usr1 => "SIGUSR1",
Sig::Usr2 => "SIGUSR2",
};
s.fmt(f)
}
}
/// Graceful shutdown configuration.
///
/// This structure configures when and how graceful shutdown occurs. The `ctrlc`
/// and `signals` properties control _when_ and the `grace` property controls
/// _how_.
///
/// # Triggers
///
/// _All_ graceful shutdowns are initiated via
/// [`Shutdown::notify()`](crate::Shutdown::notify()). Rocket can be configured
/// to trigger shutdown automatically on certain conditions, specified via the
/// `ctrlc` and `signals` properties of this structure. More specifically, if
/// `ctrlc` is `true` (the default), `ctrl-c` (`SIGINT`) initiates a server
/// shutdown, and on Unix, `signals` specifies a list of IPC signals that
/// trigger a shutdown (`["term"]` by default).
///
/// # Grace Period
///
/// Once a shutdown is triggered, Rocket stops accepting new connections and
/// waits at most `grace` seconds before force-closing all outstanding I/O.
/// Applications can `await` the [`Shutdown`](crate::Shutdown) future to detect
/// a shutdown and cancel any server-initiated I/O, such, as from [infinite
/// responders](crate::response::stream#graceful-shutdown), to avoid abrupt I/O
/// cancellation.
///
/// # Example
///
/// As with all Rocket configuration options, when using the default
/// [`Config::figment()`](crate::Config::figment()), `Shutdown` can be
/// configured via a `Rocket.toml` file. As always, defaults are provided
/// (documented below), and thus configuration only needs to provided to change
/// defaults.
///
/// ```rust
/// # use rocket::figment::{Figment, providers::{Format, Toml}};
/// use rocket::{Rocket, Config};
///
/// // If these are the contents of `Rocket.toml`...
/// # let toml = Toml::string(r#"
/// [default.shutdown]
/// ctrlc = false
/// signals = ["term", "hup"]
/// grace = 10
/// # "#).nested();
///
/// // The config parses as follows:
/// # let config = Config::from(Figment::from(Config::debug_default()).merge(toml));
/// assert_eq!(config.shutdown.ctrlc, false);
/// assert_eq!(config.shutdown.grace, 10);
///
/// # #[cfg(unix)] {
/// use rocket::config::Sig;
///
/// assert_eq!(config.shutdown.signals.len(), 2);
/// assert!(config.shutdown.signals.contains(&Sig::Term));
/// assert!(config.shutdown.signals.contains(&Sig::Hup));
/// # }
/// ```
///
/// Or, as with all configuration options, programatically:
///
/// ```rust
/// # use rocket::figment::{Figment, providers::{Format, Toml}};
/// use rocket::{Rocket, Config};
/// use rocket::config::Shutdown;
///
/// #[cfg(unix)]
/// use rocket::config::Sig;
///
/// let config = Config {
/// shutdown: Shutdown {
/// ctrlc: false,
/// #[cfg(unix)]
/// signals: {
/// let mut set = std::collections::HashSet::new();
/// set.insert(Sig::Term);
/// set.insert(Sig::Hup);
/// set
/// },
/// grace: 10
/// },
/// ..Config::default()
/// };
///
/// assert_eq!(config.shutdown.ctrlc, false);
/// assert_eq!(config.shutdown.grace, 10);
///
/// #[cfg(unix)] {
/// assert_eq!(config.shutdown.signals.len(), 2);
/// assert!(config.shutdown.signals.contains(&Sig::Term));
/// assert!(config.shutdown.signals.contains(&Sig::Hup));
/// }
/// ```
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Shutdown {
/// Whether `ctrl-c` (`SIGINT`) initiates a server shutdown.
///
/// **default: `true`**
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub ctrlc: bool,
/// On Unix, a set of signal which trigger a shutdown. On non-Unix, this
/// option is unavailable and silently ignored.
///
/// **default: { [`Sig::Term`] }**
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub signals: HashSet<Sig>,
/// The shutdown grace period: number of seconds to continue to try to
/// finish outstanding I/O for before forcibly terminating it.
///
/// **default: `5`**
pub grace: u32,
}
impl fmt::Display for Shutdown {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ctrlc = {}, ", self.ctrlc)?;
#[cfg(unix)] {
write!(f, "signals = [")?;
for (i, sig) in self.signals.iter().enumerate() {
if i != 0 { write!(f, ", ")?; }
write!(f, "{}", sig)?;
}
write!(f, "], ")?;
}
write!(f, "grace = {}s", self.grace)?;
Ok(())
}
}
impl Default for Shutdown {
fn default() -> Self {
Shutdown {
ctrlc: true,
#[cfg(unix)]
signals: { let mut set = HashSet::new(); set.insert(Sig::Term); set },
grace: 5,
}
}
}
impl Shutdown {
#[cfg(unix)]
pub(crate) fn collective_signal(&self) -> impl Future<Output = ()> {
use futures::future::{FutureExt, select_all};
use tokio::signal::unix::{signal, SignalKind};
if !self.ctrlc && self.signals.is_empty() {
return Either::Right(pending());
}
let mut signals = self.signals.clone();
if self.ctrlc {
signals.insert(Sig::Int);
}
let mut sigfuts = vec![];
for sig in signals {
let sigkind = match sig {
Sig::Alrm => SignalKind::alarm(),
Sig::Chld => SignalKind::child(),
Sig::Hup => SignalKind::hangup(),
Sig::Int => SignalKind::interrupt(),
Sig::Io => SignalKind::io(),
Sig::Pipe => SignalKind::pipe(),
Sig::Quit => SignalKind::quit(),
Sig::Term => SignalKind::terminate(),
Sig::Usr1 => SignalKind::user_defined1(),
Sig::Usr2 => SignalKind::user_defined2()
};
let sigfut = match signal(sigkind) {
Ok(mut signal) => Box::pin(async move {
signal.recv().await;
warn!("Received {} signal. Requesting shutdown.", sig);
}),
Err(e) => {
warn!("Failed to enable `{}` shutdown signal.", sig);
info_!("Error: {}", e);
continue
}
};
sigfuts.push(sigfut);
}
Either::Left(select_all(sigfuts).map(|_| ()))
}
#[cfg(not(unix))]
pub(crate) fn collective_signal(&self) -> impl Future<Output = ()> {
use futures::future::FutureExt;
match self.ctrlc {
true => Either::Left(tokio::signal::ctrl_c().map(|result| {
if let Err(e) = result {
warn!("Failed to enable `ctrl-c` shutdown signal.");
info_!("Error: {}", e);
}
})),
false => Either::Right(pending()),
}
}
}

View File

@ -1,11 +1,14 @@
use std::io; use std::{io, time::Duration};
use std::pin::Pin;
use std::task::{Poll, Context}; use std::task::{Poll, Context};
use std::pin::Pin;
use bytes::BytesMut; use bytes::BytesMut;
use tokio::io::{AsyncRead, ReadBuf};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use futures::{ready, stream::Stream}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{sleep, Sleep};
use futures::stream::Stream;
use futures::future::{Future, Fuse, FutureExt};
use crate::http::hyper::Bytes; use crate::http::hyper::Bytes;
@ -115,7 +118,7 @@ impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
if !*me.done_first { if !*me.done_first {
let init_rem = buf.remaining(); let init_rem = buf.remaining();
ready!(me.first.poll_read(cx, buf))?; futures::ready!(me.first.poll_read(cx, buf))?;
if buf.remaining() == init_rem { if buf.remaining() == init_rem {
*me.done_first = true; *me.done_first = true;
} else { } else {
@ -125,3 +128,138 @@ impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
me.second.poll_read(cx, buf) me.second.poll_read(cx, buf)
} }
} }
pin_project! {
/// I/O that can be cancelled when a future `F` resolves.
#[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> {
#[pin]
io: I,
#[pin]
trigger: Fuse<F>,
sleep: Option<Pin<Box<Sleep>>>,
grace: Duration,
}
}
impl<F: Future, I> CancellableIo<F, I> {
pub fn new(trigger: F, io: I, grace: Duration) -> Self {
CancellableIo {
trigger: trigger.fuse(),
sleep: None,
io, grace,
}
}
fn poll_trigger(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> io::Result<()> {
let me = self.project();
if me.trigger.poll(cx).is_ready() {
*me.sleep = Some(Box::pin(sleep(*me.grace)));
}
if let Some(sleep) = me.sleep {
if sleep.as_mut().poll(cx).is_ready() {
return Err(io::Error::new(io::ErrorKind::TimedOut, "..."));
}
}
Ok(())
}
}
impl<F: Future, I: AsyncRead> AsyncRead for CancellableIo<F, I> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger(cx)?;
self.as_mut().project().io.poll_read(cx, buf)
}
}
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.as_mut().poll_trigger(cx)?;
self.as_mut().project().io.poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<()>> {
self.as_mut().poll_trigger(cx)?;
self.as_mut().project().io.poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), io::Error>> {
self.as_mut().poll_trigger(cx)?;
self.as_mut().project().io.poll_shutdown(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.as_mut().poll_trigger(cx)?;
self.as_mut().project().io.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored()
}
}
use crate::http::private::{Listener, Connection};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn remote_addr(&self) -> Option<std::net::SocketAddr> {
self.io.remote_addr()
}
}
pin_project! {
pub struct CancellableListener<F, L> {
pub trigger: F,
#[pin]
pub listener: L,
pub grace: Duration,
}
}
impl<F, L> CancellableListener<F, L> {
pub fn new(trigger: F, listener: L, grace: u64) -> Self {
CancellableListener { trigger, listener, grace: Duration::from_secs(grace) }
}
}
impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
type Connection = CancellableIo<F, L::Connection>;
fn local_addr(&self) -> Option<std::net::SocketAddr> {
self.listener.local_addr()
}
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
self.as_mut().project().listener
.poll_accept(cx)
.map(|res| res.map(|conn| {
CancellableIo::new(self.trigger.clone(), conn, self.grace)
}))
}
}

View File

@ -104,8 +104,8 @@
/// These are public dependencies! Update docs if these are changed, especially /// These are public dependencies! Update docs if these are changed, especially
/// figment's version number in docs. /// figment's version number in docs.
#[doc(hidden)] #[doc(hidden)] pub use yansi;
pub use yansi; #[doc(hidden)] pub use async_stream;
pub use futures; pub use futures;
pub use tokio; pub use tokio;
pub use figment; pub use figment;
@ -139,6 +139,8 @@ pub mod http {
pub use crate::cookies::*; pub use crate::cookies::*;
} }
/// TODO: We need a futures mod or something.
mod trip_wire;
mod shutdown; mod shutdown;
mod server; mod server;
mod ext; mod ext;
@ -183,7 +185,7 @@ pub use async_trait::async_trait;
/// WARNING: This is unstable! Do not use this method outside of Rocket! /// WARNING: This is unstable! Do not use this method outside of Rocket!
#[doc(hidden)] #[doc(hidden)]
pub fn async_test<R>(fut: impl std::future::Future<Output = R> + Send) -> R { pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.thread_name("rocket-test-worker-thread") .thread_name("rocket-test-worker-thread")
.worker_threads(1) .worker_threads(1)

View File

@ -1,10 +1,7 @@
use std::sync::Arc;
use state::Container; use state::Container;
use figment::Figment; use figment::Figment;
use tokio::sync::Notify;
use crate::{Route, Catcher, Config, Rocket}; use crate::{Catcher, Config, Rocket, Route, Shutdown};
use crate::router::Router; use crate::router::Router;
use crate::fairing::Fairings; use crate::fairing::Fairings;
@ -100,7 +97,7 @@ phases! {
pub(crate) figment: Figment, pub(crate) figment: Figment,
pub(crate) config: Config, pub(crate) config: Config,
pub(crate) state: Container![Send + Sync], pub(crate) state: Container![Send + Sync],
pub(crate) shutdown: Arc<Notify>, pub(crate) shutdown: Shutdown,
} }
/// The final launch [`Phase`]. /// The final launch [`Phase`].
@ -113,6 +110,6 @@ phases! {
pub(crate) figment: Figment, pub(crate) figment: Figment,
pub(crate) config: Config, pub(crate) config: Config,
pub(crate) state: Container![Send + Sync], pub(crate) state: Container![Send + Sync],
pub(crate) shutdown: Arc<Notify>, pub(crate) shutdown: Shutdown,
} }
} }

View File

@ -1,15 +1,14 @@
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::convert::TryInto; use std::convert::TryInto;
use std::sync::Arc;
use yansi::Paint; use yansi::Paint;
use either::Either; use either::Either;
use tokio::sync::Notify;
use figment::{Figment, Provider}; use figment::{Figment, Provider};
use crate::{Route, Catcher, Config, Shutdown, sentinel}; use crate::{Catcher, Config, Route, Shutdown, sentinel};
use crate::router::Router; use crate::router::Router;
use crate::trip_wire::TripWire;
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
use crate::phase::{Stateful, StateRef, State}; use crate::phase::{Stateful, StateRef, State};
@ -505,7 +504,7 @@ impl Rocket<Build> {
// Ignite the rocket. // Ignite the rocket.
let rocket: Rocket<Ignite> = Rocket(Igniting { let rocket: Rocket<Ignite> = Rocket(Igniting {
router, config, router, config,
shutdown: Arc::new(Notify::new()), shutdown: Shutdown(TripWire::new()),
figment: self.0.figment, figment: self.0.figment,
fairings: self.0.fairings, fairings: self.0.fairings,
state: self.0.state, state: self.0.state,
@ -553,18 +552,14 @@ impl Rocket<Ignite> {
&self.config &self.config
} }
/// Returns a handle which can be used to notify this instance of Rocket to /// Returns a handle which can be used to trigger a shutdown and detect a
/// stop serving connections, resolving the future returned by /// triggered shutdown.
/// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ the
/// instance is launched, it will be immediately shutdown after liftoff.
/// ///
/// # Caveats /// A completed graceful shutdown resolves the future returned by
/// /// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ an
/// Due to [bugs](https://github.com/hyperium/hyper/issues/1885) in Rocket's /// instance is launched, it will be immediately shutdown after liftoff. See
/// upstream HTTP library, graceful shutdown currently works by stopping new /// [`Shutdown`] and [`config::Shutdown`](crate::config::Shutdown) for
/// connections from arriving without stopping in-process connections from /// details on graceful shutdown.
/// sending or receiving. As a result, shutdown will stall if a response is
/// infinite or if a client stalls a connection.
/// ///
/// # Example /// # Example
/// ///
@ -590,7 +585,7 @@ impl Rocket<Ignite> {
/// } /// }
/// ``` /// ```
pub fn shutdown(&self) -> Shutdown { pub fn shutdown(&self) -> Shutdown {
Shutdown(self.shutdown.clone()) self.shutdown.clone()
} }
fn into_orbit(self) -> Rocket<Orbit> { fn into_orbit(self) -> Rocket<Orbit> {
@ -650,17 +645,13 @@ impl Rocket<Orbit> {
&self.config &self.config
} }
/// Returns a handle which can be used to notify this instance of Rocket to /// Returns a handle which can be used to trigger a shutdown and detect a
/// stop serving connections, resolving the future returned by /// triggered shutdown.
/// [`Rocket::launch()`].
/// ///
/// # Caveats /// A completed graceful shutdown resolves the future returned by
/// /// [`Rocket::launch()`]. See [`Shutdown`] and
/// Due to [bugs](https://github.com/hyperium/hyper/issues/1885) in Rocket's /// [`config::Shutdown`](crate::config::Shutdown) for details on graceful
/// upstream HTTP library, graceful shutdown currently works by stopping new /// shutdown.
/// connections from arriving without stopping in-process connections from
/// sending or receiving. As a result, shutdown will stall if a response is
/// infinite or if a client stalls a connection.
/// ///
/// # Example /// # Example
/// ///
@ -682,7 +673,7 @@ impl Rocket<Orbit> {
/// } /// }
/// ``` /// ```
pub fn shutdown(&self) -> Shutdown { pub fn shutdown(&self) -> Shutdown {
Shutdown(self.shutdown.clone()) self.shutdown.clone()
} }
} }
@ -817,8 +808,7 @@ impl<P: Phase> Rocket<P> {
/// ///
/// The `Future` resolves as an `Ok` if any of the following occur: /// The `Future` resolves as an `Ok` if any of the following occur:
/// ///
/// * the server is shutdown via [`Shutdown::notify()`]. /// * graceful shutdown via [`Shutdown::notify()`] completes.
/// * if the `ctrlc` config option is `true`, when `Ctrl+C` is pressed.
/// ///
/// The `Future` does not resolve otherwise. /// The `Future` does not resolve otherwise.
/// ///

View File

@ -6,12 +6,11 @@ use futures::future::{self, FutureExt, Future, TryFutureExt, BoxFuture};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use yansi::Paint; use yansi::Paint;
use crate::{Rocket, Orbit, Request, Data, route}; use crate::{Rocket, Orbit, Request, Response, Data, route};
use crate::form::Form; use crate::form::Form;
use crate::response::{Response, Body};
use crate::outcome::Outcome; use crate::outcome::Outcome;
use crate::error::{Error, ErrorKind}; use crate::error::{Error, ErrorKind};
use crate::ext::AsyncReadExt; use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
use crate::http::{Method, Status, Header, hyper}; use crate::http::{Method, Status, Header, hyper};
use crate::http::private::{Listener, Connection, Incoming}; use crate::http::private::{Listener, Connection, Incoming};
@ -117,7 +116,7 @@ impl Rocket<Orbit> {
) { ) {
match self.make_response(response, tx).await { match self.make_response(response, tx).await {
Ok(()) => info_!("{}", Paint::green("Response succeeded.")), Ok(()) => info_!("{}", Paint::green("Response succeeded.")),
Err(e) => error_!("Failed to write response: {:?}.", e), Err(e) => error_!("Failed to write response: {}.", e),
} }
} }
@ -393,8 +392,7 @@ impl Rocket<Orbit> {
// TODO.async: Solidify the Listener APIs and make this function public // TODO.async: Solidify the Listener APIs and make this function public
pub(crate) async fn http_server<L>(self, listener: L) -> Result<(), Error> pub(crate) async fn http_server<L>(self, listener: L) -> Result<(), Error>
where L: Listener + Send + Unpin + 'static, where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
<L as Listener>::Connection: Send + Unpin + 'static
{ {
// Determine keep-alives. // Determine keep-alives.
let http1_keepalive = self.config.keep_alive != 0; let http1_keepalive = self.config.keep_alive != 0;
@ -403,15 +401,16 @@ impl Rocket<Orbit> {
n => Some(std::time::Duration::from_secs(n as u64)) n => Some(std::time::Duration::from_secs(n as u64))
}; };
// Get the shutdown handle (to initiate) and signal (when initiated). // Set up cancellable I/O from the given listener. Shutdown occurs when
let shutdown_handle = self.shutdown.clone(); // `Shutdown` (`TripWire`) resolves. This can occur directly through a
let shutdown_signal = match self.config.ctrlc { // notification or indirectly through an external signal which, when
true => tokio::signal::ctrl_c().boxed(), // received, results in triggering the notify.
false => future::pending().boxed(), let shutdown = self.shutdown();
}; let external_shutdown = self.config.shutdown.collective_signal();
let grace = self.config.shutdown.grace as u64;
let rocket = Arc::new(self); let rocket = Arc::new(self);
let service = hyper::make_service_fn(move |conn: &<L as Listener>::Connection| { let service_fn = move |conn: &CancellableIo<_, L::Connection>| {
let rocket = rocket.clone(); let rocket = rocket.clone();
let remote = conn.remote_addr().unwrap_or_else(|| ([0, 0, 0, 0], 0).into()); let remote = conn.remote_addr().unwrap_or_else(|| ([0, 0, 0, 0], 0).into());
async move { async move {
@ -419,33 +418,26 @@ impl Rocket<Orbit> {
hyper_service_fn(rocket.clone(), remote, req) hyper_service_fn(rocket.clone(), remote, req)
})) }))
} }
}); };
// NOTE: `hyper` uses `tokio::spawn()` as the default executor. // NOTE: `hyper` uses `tokio::spawn()` as the default executor.
let shutdown_receiver = shutdown_handle.clone(); let listener = CancellableListener::new(shutdown.clone(), listener, grace);
let server = hyper::Server::builder(Incoming::from_listener(listener)) let server = hyper::Server::builder(Incoming::new(listener))
.http1_keepalive(http1_keepalive) .http1_keepalive(http1_keepalive)
.http2_keep_alive_interval(http2_keep_alive) .http2_keep_alive_interval(http2_keep_alive)
.serve(service) .serve(hyper::make_service_fn(service_fn))
.with_graceful_shutdown(async move { shutdown_receiver.notified().await; }) .with_graceful_shutdown(shutdown.clone())
.map_err(|e| Error::new(ErrorKind::Runtime(Box::new(e)))); .map_err(|e| Error::new(ErrorKind::Runtime(Box::new(e))));
tokio::pin!(server); tokio::pin!(server, external_shutdown);
let selecter = future::select(external_shutdown, server);
let selecter = future::select(shutdown_signal, server);
match selecter.await { match selecter.await {
future::Either::Left((Ok(()), server)) => { future::Either::Left((_, server)) => {
// Ctrl-was pressed. Signal shutdown, wait for the server. // External signal received. Request shutdown, wait for server.
shutdown_handle.notify_one(); shutdown.notify();
server.await server.await
} }
future::Either::Left((Err(err), server)) => { // Internal shutdown or server error. Return the result.
// Error setting up ctrl-c signal. Let the user know.
warn!("Failed to enable `ctrl-c` graceful signal shutdown.");
info_!("Error: {}", err);
server.await
}
// Server shut down before Ctrl-C; return the result.
future::Either::Right((result, _)) => result, future::Either::Right((result, _)) => result,
} }
} }

View File

@ -1,18 +1,45 @@
use std::sync::Arc; use std::future::Future;
use std::task::{Context, Poll};
use std::pin::Pin;
use tokio::sync::Notify; use futures::FutureExt;
use crate::request::{FromRequest, Outcome, Request}; use crate::request::{FromRequest, Outcome, Request};
use crate::trip_wire::TripWire;
/// A request guard to gracefully shutdown a Rocket server. /// A request guard and future for graceful shutdown.
/// ///
/// A server shutdown is manually requested by calling [`Shutdown::notify()`] /// A server shutdown is manually requested by calling [`Shutdown::notify()`]
/// or, if enabled, by pressing `Ctrl-C`. Rocket will finish handling any /// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop accepting new
/// pending requests and return `Ok()` to the caller of [`Rocket::launch()`]. /// requests, finish handling any pending requests, wait a grace period before
/// cancelling any outstanding I/O, and return `Ok()` to the caller of
/// [`Rocket::launch()`]. Graceful shutdown is configured via
/// [`config::Shutdown`](crate::config::Shutdown).
/// ///
/// [`Rocket::launch()`]: crate::Rocket::launch() /// [`Rocket::launch()`]: crate::Rocket::launch()
/// [automatic triggers]: crate::config::Shutdown#triggers
/// ///
/// # Example /// # Detecting Shutdown
///
/// `Shutdown` is also a future that resolves when [`Shutdown::notify()`] is
/// called. This can be used to detect shutdown in any part of the application:
///
/// ```rust
/// # use rocket::*;
/// use rocket::Shutdown;
///
/// #[get("/wait/for/shutdown")]
/// async fn wait_for_shutdown(shutdown: Shutdown) -> &'static str {
/// shutdown.await;
/// "Somewhere, shutdown was requested."
/// }
/// ```
///
/// See the [`stream`](crate::response::stream#graceful-shutdown) docs for an
/// example of detecting shutdown in an infinite responder.
///
/// Additionally, a completed shutdown request resolves the future returned from
/// [`Rocket::launch()`](crate::Rocket::launch()):
/// ///
/// ```rust,no_run /// ```rust,no_run
/// # #[macro_use] extern crate rocket; /// # #[macro_use] extern crate rocket;
@ -36,19 +63,32 @@ use crate::request::{FromRequest, Outcome, Request};
/// result.expect("server failed unexpectedly"); /// result.expect("server failed unexpectedly");
/// } /// }
/// ``` /// ```
#[must_use = "a shutdown request is only sent on `shutdown.notify()`"]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Shutdown(pub(crate) Arc<Notify>); #[must_use = "`Shutdown` does nothing unless polled or `notify`ed"]
pub struct Shutdown(pub(crate) TripWire);
impl Shutdown { impl Shutdown {
/// Notify Rocket to shut down gracefully. /// Notify the application to shut down gracefully.
/// ///
/// This function returns immediately; pending requests will continue to run /// This function returns immediately; pending requests will continue to run
/// until completion before the actual shutdown occurs. /// until completion or expiration of the grace period, which ever comes
/// first, before the actual shutdown occurs. The grace period can be
/// configured via [`Shutdown::grace`](crate::config::Shutdown::grace).
///
/// ```rust
/// # use rocket::*;
/// use rocket::Shutdown;
///
/// #[get("/shutdown")]
/// fn shutdown(shutdown: Shutdown) -> &'static str {
/// shutdown.notify();
/// "Shutting down..."
/// }
/// ```
#[inline] #[inline]
pub fn notify(self) { pub fn notify(self) {
self.0.notify_one(); self.0.trip();
info!("Server shutdown requested, waiting for all pending requests to finish."); info!("Shutdown requested. Waiting for pending I/O to finish...");
} }
} }
@ -58,7 +98,25 @@ impl<'r> FromRequest<'r> for Shutdown {
#[inline] #[inline]
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let notifier = request.rocket().shutdown.clone(); Outcome::Success(request.rocket().shutdown())
Outcome::Success(Shutdown(notifier)) }
}
impl Future for Shutdown {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
#[cfg(test)]
mod tests {
use super::Shutdown;
#[test]
fn ensure_is_send_sync_clone_unpin() {
fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
is_send_sync_clone_unpin::<Shutdown>();
} }
} }

201
core/lib/src/trip_wire.rs Normal file
View File

@ -0,0 +1,201 @@
use std::fmt;
use std::{ops::Deref, pin::Pin, future::Future};
use std::task::{Context, Poll};
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use tokio::sync::Notify;
#[doc(hidden)]
pub struct State {
tripped: AtomicBool,
notify: Notify,
}
#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
pub struct TripWire {
state: Arc<State>,
// `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it.
event: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
}
impl Deref for TripWire {
type Target = State;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl Clone for TripWire {
fn clone(&self) -> Self {
TripWire {
state: self.state.clone(),
event: None
}
}
}
impl fmt::Debug for TripWire {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TripWire")
.field("tripped", &self.tripped)
.finish()
}
}
impl Future for TripWire {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.tripped.load(Ordering::Acquire) {
self.event = None;
return Poll::Ready(());
}
if self.event.is_none() {
let state = self.state.clone();
self.event = Some(Box::pin(async move {
let notified = state.notify.notified();
notified.await
}));
}
if let Some(ref mut event) = self.event {
if event.as_mut().poll(cx).is_ready() {
// We need to call `trip()` to avoid a race condition where:
// 1) many trip wires have seen !self.tripped but have not
// polled for `self.event` yet, so are not subscribed
// 2) trip() is called, adding a permit to `event`
// 3) some trip wires poll `event` for the first time
// 4) one of those wins, returns `Ready()`
// 5) the rest return pending
//
// Without this `self.trip()` those will never be awoken. With
// the call to self.trip(), those that made it to poll() in 3)
// will be awoken by `notify_waiters()`. For those the didn't,
// one will be awoken by `notify_one()`, which will in-turn call
// self.trip(), awaking more until there are no more to awake.
self.trip();
self.event = None;
return Poll::Ready(());
}
}
Poll::Pending
}
}
impl TripWire {
pub fn new() -> Self {
TripWire {
state: Arc::new(State {
tripped: AtomicBool::new(false),
notify: Notify::new()
}),
event: None,
}
}
pub fn trip(&self) {
self.tripped.store(true, Ordering::Release);
self.notify.notify_waiters();
self.notify.notify_one();
}
}
#[cfg(test)]
mod tests {
use super::TripWire;
#[test]
fn ensure_is_send_sync_clone_unpin() {
fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
is_send_sync_clone_unpin::<TripWire>();
}
#[tokio::test]
async fn simple_trip() {
let wire = TripWire::new();
wire.trip();
wire.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn no_trip() {
use tokio::time::{sleep, Duration};
use futures::stream::{FuturesUnordered as Set, StreamExt};
use futures::future::{BoxFuture, FutureExt};
let wire = TripWire::new();
let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
for _ in 0..10 {
futs.push(Box::pin(wire.clone().map(|_| false)));
}
let sleep = sleep(Duration::from_secs(1));
futs.push(Box::pin(sleep.map(|_| true)));
assert!(futs.next().await.unwrap());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn general_trip() {
let wire = TripWire::new();
let mut tasks = vec![];
for _ in 0..1000 {
tasks.push(tokio::spawn(wire.clone()));
tokio::task::yield_now().await;
}
wire.trip();
for task in tasks {
task.await.unwrap();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn single_stage_trip() {
let mut tasks = vec![];
for i in 0..1000 {
// Trip once every 100. 50 will be left "untripped", but should be.
if i % 2 == 0 {
let wire = TripWire::new();
tasks.push(tokio::spawn(wire.clone()));
tasks.push(tokio::spawn(async move { wire.trip() }));
} else {
let wire = TripWire::new();
let wire2 = wire.clone();
tasks.push(tokio::spawn(async move { wire.trip() }));
tasks.push(tokio::spawn(wire2));
}
}
for task in tasks {
task.await.unwrap();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn staged_trip() {
let wire = TripWire::new();
let mut tasks = vec![];
for i in 0..1050 {
let wire = wire.clone();
// Trip once every 100. 50 will be left "untripped", but should be.
let task = if i % 100 == 0 {
tokio::spawn(async move { wire.trip() })
} else {
tokio::spawn(wire)
};
if i % 20 == 0 {
tokio::task::yield_now().await;
}
tasks.push(task);
}
for task in tasks {
task.await.unwrap();
}
}
}