mirror of https://github.com/rwf2/Rocket.git
Implement graceful shutdown.
The crux of the implementation is as follows: * Configurable ctrl-c, signals that trigger a graceful shutdown. * Configurable grace period before forced I/O termination. * Programatic triggering via an application-wide method. * A future (`Shutdown`) that resolves only when shutdown is requested. Resolves #180.
This commit is contained in:
parent
63e6845386
commit
a72e8da735
|
@ -24,7 +24,10 @@ pub trait Listener {
|
|||
fn local_addr(&self) -> Option<SocketAddr>;
|
||||
|
||||
/// Try to accept an incoming Connection if ready
|
||||
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>>;
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>>;
|
||||
}
|
||||
|
||||
/// A 'Connection' represents an open connection to a client
|
||||
|
@ -40,16 +43,17 @@ pin_project_lite::pin_project! {
|
|||
/// Accept). This type is internal to Rocket.
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct Incoming<L> {
|
||||
listener: L,
|
||||
sleep_on_errors: Option<Duration>,
|
||||
#[pin]
|
||||
pending_error_delay: Option<Sleep>,
|
||||
#[pin]
|
||||
listener: L,
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener> Incoming<L> {
|
||||
/// Construct an `Incoming` from an existing `Listener`.
|
||||
pub fn from_listener(listener: L) -> Self {
|
||||
pub fn new(listener: L) -> Self {
|
||||
Self {
|
||||
listener,
|
||||
sleep_on_errors: Some(Duration::from_millis(250)),
|
||||
|
@ -96,7 +100,7 @@ impl<L: Listener> Incoming<L> {
|
|||
|
||||
me.pending_error_delay.set(None);
|
||||
|
||||
match me.listener.poll_accept(cx) {
|
||||
match me.listener.as_mut().poll_accept(cx) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
return Poll::Ready(Ok(stream));
|
||||
},
|
||||
|
@ -123,7 +127,7 @@ impl<L: Listener> Incoming<L> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<L: Listener + Unpin> Accept for Incoming<L> {
|
||||
impl<L: Listener> Accept for Incoming<L> {
|
||||
type Conn = L::Connection;
|
||||
type Error = io::Error;
|
||||
|
||||
|
@ -171,7 +175,10 @@ impl Listener for TcpListener {
|
|||
self.local_addr().ok()
|
||||
}
|
||||
|
||||
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>> {
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,7 +64,10 @@ impl Listener for TlsListener {
|
|||
self.listener.local_addr().ok()
|
||||
}
|
||||
|
||||
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Connection>> {
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
loop {
|
||||
match self.state {
|
||||
TlsListenerState::Listening => {
|
||||
|
|
|
@ -7,7 +7,7 @@ use figment::value::{Map, Dict};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use yansi::Paint;
|
||||
|
||||
use crate::config::{TlsConfig, LogLevel};
|
||||
use crate::config::{TlsConfig, LogLevel, Shutdown};
|
||||
use crate::request::{self, Request, FromRequest};
|
||||
use crate::data::Limits;
|
||||
|
||||
|
@ -82,17 +82,16 @@ pub struct Config {
|
|||
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
|
||||
#[serde(serialize_with = "SecretKey::serialize_zero")]
|
||||
pub secret_key: SecretKey,
|
||||
/// The directory to store temporary files in. **(default:
|
||||
/// [`std::env::temp_dir`]).
|
||||
/// Directory to store temporary files in. **(default:
|
||||
/// [`std::env::temp_dir()`])**
|
||||
pub temp_dir: PathBuf,
|
||||
/// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)**
|
||||
pub log_level: LogLevel,
|
||||
/// Graceful shutdown configuration. **(default: [`Shutdown::default()`])**
|
||||
pub shutdown: Shutdown,
|
||||
/// Whether to use colors and emoji when logging. **(default: `true`)**
|
||||
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||
pub cli_colors: bool,
|
||||
/// Whether `ctrl-c` initiates a server shutdown. **(default: `true`)**
|
||||
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||
pub ctrlc: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
|
@ -152,7 +151,7 @@ impl Config {
|
|||
temp_dir: std::env::temp_dir(),
|
||||
log_level: LogLevel::Normal,
|
||||
cli_colors: true,
|
||||
ctrlc: true,
|
||||
shutdown: Shutdown::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -318,6 +317,7 @@ impl Config {
|
|||
launch_info_!("temp dir: {}", Paint::default(&self.temp_dir.display()).bold());
|
||||
launch_info_!("log level: {}", Paint::default(self.log_level).bold());
|
||||
launch_info_!("cli colors: {}", Paint::default(&self.cli_colors).bold());
|
||||
launch_info_!("shutdown: {}", Paint::default(&self.shutdown).bold());
|
||||
|
||||
// Check for now depreacted config values.
|
||||
for (key, replacement) in Self::DEPRECATED_KEYS {
|
||||
|
@ -398,8 +398,8 @@ impl Config {
|
|||
/// The stringy parameter name for setting/extracting [`Config::log_level`].
|
||||
pub const LOG_LEVEL: &'static str = "log_level";
|
||||
|
||||
/// The stringy parameter name for setting/extracting [`Config::ctrlc`].
|
||||
pub const CTRLC: &'static str = "ctrlc";
|
||||
/// The stringy parameter name for setting/extracting [`Config::shutdown`].
|
||||
pub const SHUTDOWN: &'static str = "shutdown";
|
||||
}
|
||||
|
||||
impl Provider for Config {
|
||||
|
|
|
@ -113,6 +113,7 @@
|
|||
|
||||
mod config;
|
||||
mod tls;
|
||||
mod shutdown;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
mod secret_key;
|
||||
|
@ -121,19 +122,24 @@ mod secret_key;
|
|||
|
||||
pub use config::Config;
|
||||
pub use crate::log::LogLevel;
|
||||
pub use shutdown::Shutdown;
|
||||
pub use tls::TlsConfig;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
pub use shutdown::Sig;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
use figment::{Figment, Profile};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::config::{Config, TlsConfig};
|
||||
use crate::config::{Config, TlsConfig, Shutdown};
|
||||
use crate::log::LogLevel;
|
||||
use crate::data::{Limits, ToByteUnit};
|
||||
|
||||
|
@ -217,7 +223,7 @@ mod tests {
|
|||
|
||||
jail.create_file("Rocket.toml", r#"
|
||||
[global]
|
||||
ctrlc = 0
|
||||
shutdown.ctrlc = 0
|
||||
|
||||
[global.tls]
|
||||
certs = "/ssl/cert.pem"
|
||||
|
@ -231,7 +237,7 @@ mod tests {
|
|||
|
||||
let config = Config::from(Config::figment());
|
||||
assert_eq!(config, Config {
|
||||
ctrlc: false,
|
||||
shutdown: Shutdown { ctrlc: false, ..Default::default() },
|
||||
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
|
||||
limits: Limits::default()
|
||||
.limit("forms", 1.mebibytes())
|
||||
|
|
|
@ -0,0 +1,273 @@
|
|||
use std::fmt;
|
||||
use std::future::Future;
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use futures::future::{Either, pending};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A Unix signal for triggering graceful shutdown.
|
||||
///
|
||||
/// Each variant corresponds to a Unix process signal which can be used to
|
||||
/// trigger a graceful shutdown. See [`Shutdown`] for details.
|
||||
///
|
||||
/// ## (De)serialization
|
||||
///
|
||||
/// A `Sig` variant serializes and deserializes as a lowercase string equal to
|
||||
/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for
|
||||
/// [`Sig::Chld`], and so on.
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Sig {
|
||||
/// The `SIGALRM` Unix signal.
|
||||
Alrm,
|
||||
/// The `SIGCHLD` Unix signal.
|
||||
Chld,
|
||||
/// The `SIGHUP` Unix signal.
|
||||
Hup,
|
||||
/// The `SIGINT` Unix signal.
|
||||
Int,
|
||||
/// The `SIGIO` Unix signal.
|
||||
Io,
|
||||
/// The `SIGPIPE` Unix signal.
|
||||
Pipe,
|
||||
/// The `SIGQUIT` Unix signal.
|
||||
Quit,
|
||||
/// The `SIGTERM` Unix signal.
|
||||
Term,
|
||||
/// The `SIGUSR1` Unix signal.
|
||||
Usr1,
|
||||
/// The `SIGUSR2` Unix signal.
|
||||
Usr2
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
impl fmt::Display for Sig {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let s = match self {
|
||||
Sig::Alrm => "SIGALRM",
|
||||
Sig::Chld => "SIGCHLD",
|
||||
Sig::Hup => "SIGHUP",
|
||||
Sig::Int => "SIGINT",
|
||||
Sig::Io => "SIGIO",
|
||||
Sig::Pipe => "SIGPIPE",
|
||||
Sig::Quit => "SIGQUIT",
|
||||
Sig::Term => "SIGTERM",
|
||||
Sig::Usr1 => "SIGUSR1",
|
||||
Sig::Usr2 => "SIGUSR2",
|
||||
};
|
||||
|
||||
s.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// Graceful shutdown configuration.
|
||||
///
|
||||
/// This structure configures when and how graceful shutdown occurs. The `ctrlc`
|
||||
/// and `signals` properties control _when_ and the `grace` property controls
|
||||
/// _how_.
|
||||
///
|
||||
/// # Triggers
|
||||
///
|
||||
/// _All_ graceful shutdowns are initiated via
|
||||
/// [`Shutdown::notify()`](crate::Shutdown::notify()). Rocket can be configured
|
||||
/// to trigger shutdown automatically on certain conditions, specified via the
|
||||
/// `ctrlc` and `signals` properties of this structure. More specifically, if
|
||||
/// `ctrlc` is `true` (the default), `ctrl-c` (`SIGINT`) initiates a server
|
||||
/// shutdown, and on Unix, `signals` specifies a list of IPC signals that
|
||||
/// trigger a shutdown (`["term"]` by default).
|
||||
///
|
||||
/// # Grace Period
|
||||
///
|
||||
/// Once a shutdown is triggered, Rocket stops accepting new connections and
|
||||
/// waits at most `grace` seconds before force-closing all outstanding I/O.
|
||||
/// Applications can `await` the [`Shutdown`](crate::Shutdown) future to detect
|
||||
/// a shutdown and cancel any server-initiated I/O, such, as from [infinite
|
||||
/// responders](crate::response::stream#graceful-shutdown), to avoid abrupt I/O
|
||||
/// cancellation.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// As with all Rocket configuration options, when using the default
|
||||
/// [`Config::figment()`](crate::Config::figment()), `Shutdown` can be
|
||||
/// configured via a `Rocket.toml` file. As always, defaults are provided
|
||||
/// (documented below), and thus configuration only needs to provided to change
|
||||
/// defaults.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::figment::{Figment, providers::{Format, Toml}};
|
||||
/// use rocket::{Rocket, Config};
|
||||
///
|
||||
/// // If these are the contents of `Rocket.toml`...
|
||||
/// # let toml = Toml::string(r#"
|
||||
/// [default.shutdown]
|
||||
/// ctrlc = false
|
||||
/// signals = ["term", "hup"]
|
||||
/// grace = 10
|
||||
/// # "#).nested();
|
||||
///
|
||||
/// // The config parses as follows:
|
||||
/// # let config = Config::from(Figment::from(Config::debug_default()).merge(toml));
|
||||
/// assert_eq!(config.shutdown.ctrlc, false);
|
||||
/// assert_eq!(config.shutdown.grace, 10);
|
||||
///
|
||||
/// # #[cfg(unix)] {
|
||||
/// use rocket::config::Sig;
|
||||
///
|
||||
/// assert_eq!(config.shutdown.signals.len(), 2);
|
||||
/// assert!(config.shutdown.signals.contains(&Sig::Term));
|
||||
/// assert!(config.shutdown.signals.contains(&Sig::Hup));
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Or, as with all configuration options, programatically:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::figment::{Figment, providers::{Format, Toml}};
|
||||
/// use rocket::{Rocket, Config};
|
||||
/// use rocket::config::Shutdown;
|
||||
///
|
||||
/// #[cfg(unix)]
|
||||
/// use rocket::config::Sig;
|
||||
///
|
||||
/// let config = Config {
|
||||
/// shutdown: Shutdown {
|
||||
/// ctrlc: false,
|
||||
/// #[cfg(unix)]
|
||||
/// signals: {
|
||||
/// let mut set = std::collections::HashSet::new();
|
||||
/// set.insert(Sig::Term);
|
||||
/// set.insert(Sig::Hup);
|
||||
/// set
|
||||
/// },
|
||||
/// grace: 10
|
||||
/// },
|
||||
/// ..Config::default()
|
||||
/// };
|
||||
///
|
||||
/// assert_eq!(config.shutdown.ctrlc, false);
|
||||
/// assert_eq!(config.shutdown.grace, 10);
|
||||
///
|
||||
/// #[cfg(unix)] {
|
||||
/// assert_eq!(config.shutdown.signals.len(), 2);
|
||||
/// assert!(config.shutdown.signals.contains(&Sig::Term));
|
||||
/// assert!(config.shutdown.signals.contains(&Sig::Hup));
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Shutdown {
|
||||
/// Whether `ctrl-c` (`SIGINT`) initiates a server shutdown.
|
||||
///
|
||||
/// **default: `true`**
|
||||
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||
pub ctrlc: bool,
|
||||
/// On Unix, a set of signal which trigger a shutdown. On non-Unix, this
|
||||
/// option is unavailable and silently ignored.
|
||||
///
|
||||
/// **default: { [`Sig::Term`] }**
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
pub signals: HashSet<Sig>,
|
||||
/// The shutdown grace period: number of seconds to continue to try to
|
||||
/// finish outstanding I/O for before forcibly terminating it.
|
||||
///
|
||||
/// **default: `5`**
|
||||
pub grace: u32,
|
||||
}
|
||||
|
||||
impl fmt::Display for Shutdown {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "ctrlc = {}, ", self.ctrlc)?;
|
||||
|
||||
#[cfg(unix)] {
|
||||
write!(f, "signals = [")?;
|
||||
for (i, sig) in self.signals.iter().enumerate() {
|
||||
if i != 0 { write!(f, ", ")?; }
|
||||
write!(f, "{}", sig)?;
|
||||
}
|
||||
write!(f, "], ")?;
|
||||
}
|
||||
|
||||
write!(f, "grace = {}s", self.grace)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Shutdown {
|
||||
fn default() -> Self {
|
||||
Shutdown {
|
||||
ctrlc: true,
|
||||
#[cfg(unix)]
|
||||
signals: { let mut set = HashSet::new(); set.insert(Sig::Term); set },
|
||||
grace: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Shutdown {
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn collective_signal(&self) -> impl Future<Output = ()> {
|
||||
use futures::future::{FutureExt, select_all};
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
if !self.ctrlc && self.signals.is_empty() {
|
||||
return Either::Right(pending());
|
||||
}
|
||||
|
||||
let mut signals = self.signals.clone();
|
||||
if self.ctrlc {
|
||||
signals.insert(Sig::Int);
|
||||
}
|
||||
|
||||
let mut sigfuts = vec![];
|
||||
for sig in signals {
|
||||
let sigkind = match sig {
|
||||
Sig::Alrm => SignalKind::alarm(),
|
||||
Sig::Chld => SignalKind::child(),
|
||||
Sig::Hup => SignalKind::hangup(),
|
||||
Sig::Int => SignalKind::interrupt(),
|
||||
Sig::Io => SignalKind::io(),
|
||||
Sig::Pipe => SignalKind::pipe(),
|
||||
Sig::Quit => SignalKind::quit(),
|
||||
Sig::Term => SignalKind::terminate(),
|
||||
Sig::Usr1 => SignalKind::user_defined1(),
|
||||
Sig::Usr2 => SignalKind::user_defined2()
|
||||
};
|
||||
|
||||
let sigfut = match signal(sigkind) {
|
||||
Ok(mut signal) => Box::pin(async move {
|
||||
signal.recv().await;
|
||||
warn!("Received {} signal. Requesting shutdown.", sig);
|
||||
}),
|
||||
Err(e) => {
|
||||
warn!("Failed to enable `{}` shutdown signal.", sig);
|
||||
info_!("Error: {}", e);
|
||||
continue
|
||||
}
|
||||
};
|
||||
|
||||
sigfuts.push(sigfut);
|
||||
}
|
||||
|
||||
Either::Left(select_all(sigfuts).map(|_| ()))
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
pub(crate) fn collective_signal(&self) -> impl Future<Output = ()> {
|
||||
use futures::future::FutureExt;
|
||||
|
||||
match self.ctrlc {
|
||||
true => Either::Left(tokio::signal::ctrl_c().map(|result| {
|
||||
if let Err(e) = result {
|
||||
warn!("Failed to enable `ctrl-c` shutdown signal.");
|
||||
info_!("Error: {}", e);
|
||||
}
|
||||
})),
|
||||
false => Either::Right(pending()),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,14 @@
|
|||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::{io, time::Duration};
|
||||
use std::task::{Poll, Context};
|
||||
use std::pin::Pin;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use pin_project_lite::pin_project;
|
||||
use futures::{ready, stream::Stream};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::time::{sleep, Sleep};
|
||||
|
||||
use futures::stream::Stream;
|
||||
use futures::future::{Future, Fuse, FutureExt};
|
||||
|
||||
use crate::http::hyper::Bytes;
|
||||
|
||||
|
@ -115,7 +118,7 @@ impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
|||
|
||||
if !*me.done_first {
|
||||
let init_rem = buf.remaining();
|
||||
ready!(me.first.poll_read(cx, buf))?;
|
||||
futures::ready!(me.first.poll_read(cx, buf))?;
|
||||
if buf.remaining() == init_rem {
|
||||
*me.done_first = true;
|
||||
} else {
|
||||
|
@ -125,3 +128,138 @@ impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
|||
me.second.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// I/O that can be cancelled when a future `F` resolves.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CancellableIo<F, I> {
|
||||
#[pin]
|
||||
io: I,
|
||||
#[pin]
|
||||
trigger: Fuse<F>,
|
||||
sleep: Option<Pin<Box<Sleep>>>,
|
||||
grace: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I> CancellableIo<F, I> {
|
||||
pub fn new(trigger: F, io: I, grace: Duration) -> Self {
|
||||
CancellableIo {
|
||||
trigger: trigger.fuse(),
|
||||
sleep: None,
|
||||
io, grace,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_trigger(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> io::Result<()> {
|
||||
let me = self.project();
|
||||
|
||||
if me.trigger.poll(cx).is_ready() {
|
||||
*me.sleep = Some(Box::pin(sleep(*me.grace)));
|
||||
}
|
||||
|
||||
if let Some(sleep) = me.sleep {
|
||||
if sleep.as_mut().poll(cx).is_ready() {
|
||||
return Err(io::Error::new(io::ErrorKind::TimedOut, "..."));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncRead> AsyncRead for CancellableIo<F, I> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger(cx)?;
|
||||
self.as_mut().project().io.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.as_mut().poll_trigger(cx)?;
|
||||
self.as_mut().project().io.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().poll_trigger(cx)?;
|
||||
self.as_mut().project().io.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
self.as_mut().poll_trigger(cx)?;
|
||||
self.as_mut().project().io.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.as_mut().poll_trigger(cx)?;
|
||||
self.as_mut().project().io.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.io.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
use crate::http::private::{Listener, Connection};
|
||||
|
||||
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
|
||||
fn remote_addr(&self) -> Option<std::net::SocketAddr> {
|
||||
self.io.remote_addr()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct CancellableListener<F, L> {
|
||||
pub trigger: F,
|
||||
#[pin]
|
||||
pub listener: L,
|
||||
pub grace: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, L> CancellableListener<F, L> {
|
||||
pub fn new(trigger: F, listener: L, grace: u64) -> Self {
|
||||
CancellableListener { trigger, listener, grace: Duration::from_secs(grace) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
|
||||
type Connection = CancellableIo<F, L::Connection>;
|
||||
|
||||
fn local_addr(&self) -> Option<std::net::SocketAddr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
self.as_mut().project().listener
|
||||
.poll_accept(cx)
|
||||
.map(|res| res.map(|conn| {
|
||||
CancellableIo::new(self.trigger.clone(), conn, self.grace)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,8 +104,8 @@
|
|||
|
||||
/// These are public dependencies! Update docs if these are changed, especially
|
||||
/// figment's version number in docs.
|
||||
#[doc(hidden)]
|
||||
pub use yansi;
|
||||
#[doc(hidden)] pub use yansi;
|
||||
#[doc(hidden)] pub use async_stream;
|
||||
pub use futures;
|
||||
pub use tokio;
|
||||
pub use figment;
|
||||
|
@ -139,6 +139,8 @@ pub mod http {
|
|||
pub use crate::cookies::*;
|
||||
}
|
||||
|
||||
/// TODO: We need a futures mod or something.
|
||||
mod trip_wire;
|
||||
mod shutdown;
|
||||
mod server;
|
||||
mod ext;
|
||||
|
@ -183,7 +185,7 @@ pub use async_trait::async_trait;
|
|||
|
||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||
#[doc(hidden)]
|
||||
pub fn async_test<R>(fut: impl std::future::Future<Output = R> + Send) -> R {
|
||||
pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_name("rocket-test-worker-thread")
|
||||
.worker_threads(1)
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use state::Container;
|
||||
use figment::Figment;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::{Route, Catcher, Config, Rocket};
|
||||
use crate::{Catcher, Config, Rocket, Route, Shutdown};
|
||||
use crate::router::Router;
|
||||
use crate::fairing::Fairings;
|
||||
|
||||
|
@ -100,7 +97,7 @@ phases! {
|
|||
pub(crate) figment: Figment,
|
||||
pub(crate) config: Config,
|
||||
pub(crate) state: Container![Send + Sync],
|
||||
pub(crate) shutdown: Arc<Notify>,
|
||||
pub(crate) shutdown: Shutdown,
|
||||
}
|
||||
|
||||
/// The final launch [`Phase`].
|
||||
|
@ -113,6 +110,6 @@ phases! {
|
|||
pub(crate) figment: Figment,
|
||||
pub(crate) config: Config,
|
||||
pub(crate) state: Container![Send + Sync],
|
||||
pub(crate) shutdown: Arc<Notify>,
|
||||
pub(crate) shutdown: Shutdown,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
use std::fmt;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::convert::TryInto;
|
||||
use std::sync::Arc;
|
||||
|
||||
use yansi::Paint;
|
||||
use either::Either;
|
||||
use tokio::sync::Notify;
|
||||
use figment::{Figment, Provider};
|
||||
|
||||
use crate::{Route, Catcher, Config, Shutdown, sentinel};
|
||||
use crate::{Catcher, Config, Route, Shutdown, sentinel};
|
||||
use crate::router::Router;
|
||||
use crate::trip_wire::TripWire;
|
||||
use crate::fairing::{Fairing, Fairings};
|
||||
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
|
||||
use crate::phase::{Stateful, StateRef, State};
|
||||
|
@ -505,7 +504,7 @@ impl Rocket<Build> {
|
|||
// Ignite the rocket.
|
||||
let rocket: Rocket<Ignite> = Rocket(Igniting {
|
||||
router, config,
|
||||
shutdown: Arc::new(Notify::new()),
|
||||
shutdown: Shutdown(TripWire::new()),
|
||||
figment: self.0.figment,
|
||||
fairings: self.0.fairings,
|
||||
state: self.0.state,
|
||||
|
@ -553,18 +552,14 @@ impl Rocket<Ignite> {
|
|||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a handle which can be used to notify this instance of Rocket to
|
||||
/// stop serving connections, resolving the future returned by
|
||||
/// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ the
|
||||
/// instance is launched, it will be immediately shutdown after liftoff.
|
||||
/// Returns a handle which can be used to trigger a shutdown and detect a
|
||||
/// triggered shutdown.
|
||||
///
|
||||
/// # Caveats
|
||||
///
|
||||
/// Due to [bugs](https://github.com/hyperium/hyper/issues/1885) in Rocket's
|
||||
/// upstream HTTP library, graceful shutdown currently works by stopping new
|
||||
/// connections from arriving without stopping in-process connections from
|
||||
/// sending or receiving. As a result, shutdown will stall if a response is
|
||||
/// infinite or if a client stalls a connection.
|
||||
/// A completed graceful shutdown resolves the future returned by
|
||||
/// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ an
|
||||
/// instance is launched, it will be immediately shutdown after liftoff. See
|
||||
/// [`Shutdown`] and [`config::Shutdown`](crate::config::Shutdown) for
|
||||
/// details on graceful shutdown.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -590,7 +585,7 @@ impl Rocket<Ignite> {
|
|||
/// }
|
||||
/// ```
|
||||
pub fn shutdown(&self) -> Shutdown {
|
||||
Shutdown(self.shutdown.clone())
|
||||
self.shutdown.clone()
|
||||
}
|
||||
|
||||
fn into_orbit(self) -> Rocket<Orbit> {
|
||||
|
@ -650,17 +645,13 @@ impl Rocket<Orbit> {
|
|||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a handle which can be used to notify this instance of Rocket to
|
||||
/// stop serving connections, resolving the future returned by
|
||||
/// [`Rocket::launch()`].
|
||||
/// Returns a handle which can be used to trigger a shutdown and detect a
|
||||
/// triggered shutdown.
|
||||
///
|
||||
/// # Caveats
|
||||
///
|
||||
/// Due to [bugs](https://github.com/hyperium/hyper/issues/1885) in Rocket's
|
||||
/// upstream HTTP library, graceful shutdown currently works by stopping new
|
||||
/// connections from arriving without stopping in-process connections from
|
||||
/// sending or receiving. As a result, shutdown will stall if a response is
|
||||
/// infinite or if a client stalls a connection.
|
||||
/// A completed graceful shutdown resolves the future returned by
|
||||
/// [`Rocket::launch()`]. See [`Shutdown`] and
|
||||
/// [`config::Shutdown`](crate::config::Shutdown) for details on graceful
|
||||
/// shutdown.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -682,7 +673,7 @@ impl Rocket<Orbit> {
|
|||
/// }
|
||||
/// ```
|
||||
pub fn shutdown(&self) -> Shutdown {
|
||||
Shutdown(self.shutdown.clone())
|
||||
self.shutdown.clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -817,8 +808,7 @@ impl<P: Phase> Rocket<P> {
|
|||
///
|
||||
/// The `Future` resolves as an `Ok` if any of the following occur:
|
||||
///
|
||||
/// * the server is shutdown via [`Shutdown::notify()`].
|
||||
/// * if the `ctrlc` config option is `true`, when `Ctrl+C` is pressed.
|
||||
/// * graceful shutdown via [`Shutdown::notify()`] completes.
|
||||
///
|
||||
/// The `Future` does not resolve otherwise.
|
||||
///
|
||||
|
|
|
@ -6,12 +6,11 @@ use futures::future::{self, FutureExt, Future, TryFutureExt, BoxFuture};
|
|||
use tokio::sync::oneshot;
|
||||
use yansi::Paint;
|
||||
|
||||
use crate::{Rocket, Orbit, Request, Data, route};
|
||||
use crate::{Rocket, Orbit, Request, Response, Data, route};
|
||||
use crate::form::Form;
|
||||
use crate::response::{Response, Body};
|
||||
use crate::outcome::Outcome;
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::ext::AsyncReadExt;
|
||||
use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
|
||||
|
||||
use crate::http::{Method, Status, Header, hyper};
|
||||
use crate::http::private::{Listener, Connection, Incoming};
|
||||
|
@ -117,7 +116,7 @@ impl Rocket<Orbit> {
|
|||
) {
|
||||
match self.make_response(response, tx).await {
|
||||
Ok(()) => info_!("{}", Paint::green("Response succeeded.")),
|
||||
Err(e) => error_!("Failed to write response: {:?}.", e),
|
||||
Err(e) => error_!("Failed to write response: {}.", e),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -393,8 +392,7 @@ impl Rocket<Orbit> {
|
|||
|
||||
// TODO.async: Solidify the Listener APIs and make this function public
|
||||
pub(crate) async fn http_server<L>(self, listener: L) -> Result<(), Error>
|
||||
where L: Listener + Send + Unpin + 'static,
|
||||
<L as Listener>::Connection: Send + Unpin + 'static
|
||||
where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
|
||||
{
|
||||
// Determine keep-alives.
|
||||
let http1_keepalive = self.config.keep_alive != 0;
|
||||
|
@ -403,15 +401,16 @@ impl Rocket<Orbit> {
|
|||
n => Some(std::time::Duration::from_secs(n as u64))
|
||||
};
|
||||
|
||||
// Get the shutdown handle (to initiate) and signal (when initiated).
|
||||
let shutdown_handle = self.shutdown.clone();
|
||||
let shutdown_signal = match self.config.ctrlc {
|
||||
true => tokio::signal::ctrl_c().boxed(),
|
||||
false => future::pending().boxed(),
|
||||
};
|
||||
// Set up cancellable I/O from the given listener. Shutdown occurs when
|
||||
// `Shutdown` (`TripWire`) resolves. This can occur directly through a
|
||||
// notification or indirectly through an external signal which, when
|
||||
// received, results in triggering the notify.
|
||||
let shutdown = self.shutdown();
|
||||
let external_shutdown = self.config.shutdown.collective_signal();
|
||||
let grace = self.config.shutdown.grace as u64;
|
||||
|
||||
let rocket = Arc::new(self);
|
||||
let service = hyper::make_service_fn(move |conn: &<L as Listener>::Connection| {
|
||||
let service_fn = move |conn: &CancellableIo<_, L::Connection>| {
|
||||
let rocket = rocket.clone();
|
||||
let remote = conn.remote_addr().unwrap_or_else(|| ([0, 0, 0, 0], 0).into());
|
||||
async move {
|
||||
|
@ -419,33 +418,26 @@ impl Rocket<Orbit> {
|
|||
hyper_service_fn(rocket.clone(), remote, req)
|
||||
}))
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// NOTE: `hyper` uses `tokio::spawn()` as the default executor.
|
||||
let shutdown_receiver = shutdown_handle.clone();
|
||||
let server = hyper::Server::builder(Incoming::from_listener(listener))
|
||||
let listener = CancellableListener::new(shutdown.clone(), listener, grace);
|
||||
let server = hyper::Server::builder(Incoming::new(listener))
|
||||
.http1_keepalive(http1_keepalive)
|
||||
.http2_keep_alive_interval(http2_keep_alive)
|
||||
.serve(service)
|
||||
.with_graceful_shutdown(async move { shutdown_receiver.notified().await; })
|
||||
.serve(hyper::make_service_fn(service_fn))
|
||||
.with_graceful_shutdown(shutdown.clone())
|
||||
.map_err(|e| Error::new(ErrorKind::Runtime(Box::new(e))));
|
||||
|
||||
tokio::pin!(server);
|
||||
|
||||
let selecter = future::select(shutdown_signal, server);
|
||||
tokio::pin!(server, external_shutdown);
|
||||
let selecter = future::select(external_shutdown, server);
|
||||
match selecter.await {
|
||||
future::Either::Left((Ok(()), server)) => {
|
||||
// Ctrl-was pressed. Signal shutdown, wait for the server.
|
||||
shutdown_handle.notify_one();
|
||||
future::Either::Left((_, server)) => {
|
||||
// External signal received. Request shutdown, wait for server.
|
||||
shutdown.notify();
|
||||
server.await
|
||||
}
|
||||
future::Either::Left((Err(err), server)) => {
|
||||
// Error setting up ctrl-c signal. Let the user know.
|
||||
warn!("Failed to enable `ctrl-c` graceful signal shutdown.");
|
||||
info_!("Error: {}", err);
|
||||
server.await
|
||||
}
|
||||
// Server shut down before Ctrl-C; return the result.
|
||||
// Internal shutdown or server error. Return the result.
|
||||
future::Either::Right((result, _)) => result,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,45 @@
|
|||
use std::sync::Arc;
|
||||
use std::future::Future;
|
||||
use std::task::{Context, Poll};
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio::sync::Notify;
|
||||
use futures::FutureExt;
|
||||
|
||||
use crate::request::{FromRequest, Outcome, Request};
|
||||
use crate::trip_wire::TripWire;
|
||||
|
||||
/// A request guard to gracefully shutdown a Rocket server.
|
||||
/// A request guard and future for graceful shutdown.
|
||||
///
|
||||
/// A server shutdown is manually requested by calling [`Shutdown::notify()`]
|
||||
/// or, if enabled, by pressing `Ctrl-C`. Rocket will finish handling any
|
||||
/// pending requests and return `Ok()` to the caller of [`Rocket::launch()`].
|
||||
/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop accepting new
|
||||
/// requests, finish handling any pending requests, wait a grace period before
|
||||
/// cancelling any outstanding I/O, and return `Ok()` to the caller of
|
||||
/// [`Rocket::launch()`]. Graceful shutdown is configured via
|
||||
/// [`config::Shutdown`](crate::config::Shutdown).
|
||||
///
|
||||
/// [`Rocket::launch()`]: crate::Rocket::launch()
|
||||
/// [automatic triggers]: crate::config::Shutdown#triggers
|
||||
///
|
||||
/// # Example
|
||||
/// # Detecting Shutdown
|
||||
///
|
||||
/// `Shutdown` is also a future that resolves when [`Shutdown::notify()`] is
|
||||
/// called. This can be used to detect shutdown in any part of the application:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::*;
|
||||
/// use rocket::Shutdown;
|
||||
///
|
||||
/// #[get("/wait/for/shutdown")]
|
||||
/// async fn wait_for_shutdown(shutdown: Shutdown) -> &'static str {
|
||||
/// shutdown.await;
|
||||
/// "Somewhere, shutdown was requested."
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// See the [`stream`](crate::response::stream#graceful-shutdown) docs for an
|
||||
/// example of detecting shutdown in an infinite responder.
|
||||
///
|
||||
/// Additionally, a completed shutdown request resolves the future returned from
|
||||
/// [`Rocket::launch()`](crate::Rocket::launch()):
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
|
@ -36,19 +63,32 @@ use crate::request::{FromRequest, Outcome, Request};
|
|||
/// result.expect("server failed unexpectedly");
|
||||
/// }
|
||||
/// ```
|
||||
#[must_use = "a shutdown request is only sent on `shutdown.notify()`"]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Shutdown(pub(crate) Arc<Notify>);
|
||||
#[must_use = "`Shutdown` does nothing unless polled or `notify`ed"]
|
||||
pub struct Shutdown(pub(crate) TripWire);
|
||||
|
||||
impl Shutdown {
|
||||
/// Notify Rocket to shut down gracefully.
|
||||
/// Notify the application to shut down gracefully.
|
||||
///
|
||||
/// This function returns immediately; pending requests will continue to run
|
||||
/// until completion before the actual shutdown occurs.
|
||||
/// until completion or expiration of the grace period, which ever comes
|
||||
/// first, before the actual shutdown occurs. The grace period can be
|
||||
/// configured via [`Shutdown::grace`](crate::config::Shutdown::grace).
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::*;
|
||||
/// use rocket::Shutdown;
|
||||
///
|
||||
/// #[get("/shutdown")]
|
||||
/// fn shutdown(shutdown: Shutdown) -> &'static str {
|
||||
/// shutdown.notify();
|
||||
/// "Shutting down..."
|
||||
/// }
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn notify(self) {
|
||||
self.0.notify_one();
|
||||
info!("Server shutdown requested, waiting for all pending requests to finish.");
|
||||
self.0.trip();
|
||||
info!("Shutdown requested. Waiting for pending I/O to finish...");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,7 +98,25 @@ impl<'r> FromRequest<'r> for Shutdown {
|
|||
|
||||
#[inline]
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let notifier = request.rocket().shutdown.clone();
|
||||
Outcome::Success(Shutdown(notifier))
|
||||
Outcome::Success(request.rocket().shutdown())
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Shutdown {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.0.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Shutdown;
|
||||
|
||||
#[test]
|
||||
fn ensure_is_send_sync_clone_unpin() {
|
||||
fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
|
||||
is_send_sync_clone_unpin::<Shutdown>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
use std::fmt;
|
||||
use std::{ops::Deref, pin::Pin, future::Future};
|
||||
use std::task::{Context, Poll};
|
||||
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
|
||||
|
||||
use tokio::sync::Notify;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct State {
|
||||
tripped: AtomicBool,
|
||||
notify: Notify,
|
||||
}
|
||||
|
||||
#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
|
||||
pub struct TripWire {
|
||||
state: Arc<State>,
|
||||
// `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it.
|
||||
event: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
|
||||
}
|
||||
|
||||
impl Deref for TripWire {
|
||||
type Target = State;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TripWire {
|
||||
fn clone(&self) -> Self {
|
||||
TripWire {
|
||||
state: self.state.clone(),
|
||||
event: None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TripWire {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("TripWire")
|
||||
.field("tripped", &self.tripped)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for TripWire {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.tripped.load(Ordering::Acquire) {
|
||||
self.event = None;
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
if self.event.is_none() {
|
||||
let state = self.state.clone();
|
||||
self.event = Some(Box::pin(async move {
|
||||
let notified = state.notify.notified();
|
||||
notified.await
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(ref mut event) = self.event {
|
||||
if event.as_mut().poll(cx).is_ready() {
|
||||
// We need to call `trip()` to avoid a race condition where:
|
||||
// 1) many trip wires have seen !self.tripped but have not
|
||||
// polled for `self.event` yet, so are not subscribed
|
||||
// 2) trip() is called, adding a permit to `event`
|
||||
// 3) some trip wires poll `event` for the first time
|
||||
// 4) one of those wins, returns `Ready()`
|
||||
// 5) the rest return pending
|
||||
//
|
||||
// Without this `self.trip()` those will never be awoken. With
|
||||
// the call to self.trip(), those that made it to poll() in 3)
|
||||
// will be awoken by `notify_waiters()`. For those the didn't,
|
||||
// one will be awoken by `notify_one()`, which will in-turn call
|
||||
// self.trip(), awaking more until there are no more to awake.
|
||||
self.trip();
|
||||
self.event = None;
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
impl TripWire {
|
||||
pub fn new() -> Self {
|
||||
TripWire {
|
||||
state: Arc::new(State {
|
||||
tripped: AtomicBool::new(false),
|
||||
notify: Notify::new()
|
||||
}),
|
||||
event: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn trip(&self) {
|
||||
self.tripped.store(true, Ordering::Release);
|
||||
self.notify.notify_waiters();
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::TripWire;
|
||||
|
||||
#[test]
|
||||
fn ensure_is_send_sync_clone_unpin() {
|
||||
fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
|
||||
is_send_sync_clone_unpin::<TripWire>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn simple_trip() {
|
||||
let wire = TripWire::new();
|
||||
wire.trip();
|
||||
wire.await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn no_trip() {
|
||||
use tokio::time::{sleep, Duration};
|
||||
use futures::stream::{FuturesUnordered as Set, StreamExt};
|
||||
use futures::future::{BoxFuture, FutureExt};
|
||||
|
||||
let wire = TripWire::new();
|
||||
let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
|
||||
for _ in 0..10 {
|
||||
futs.push(Box::pin(wire.clone().map(|_| false)));
|
||||
}
|
||||
|
||||
let sleep = sleep(Duration::from_secs(1));
|
||||
futs.push(Box::pin(sleep.map(|_| true)));
|
||||
assert!(futs.next().await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
|
||||
async fn general_trip() {
|
||||
let wire = TripWire::new();
|
||||
let mut tasks = vec![];
|
||||
for _ in 0..1000 {
|
||||
tasks.push(tokio::spawn(wire.clone()));
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
wire.trip();
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
|
||||
async fn single_stage_trip() {
|
||||
let mut tasks = vec![];
|
||||
for i in 0..1000 {
|
||||
// Trip once every 100. 50 will be left "untripped", but should be.
|
||||
if i % 2 == 0 {
|
||||
let wire = TripWire::new();
|
||||
tasks.push(tokio::spawn(wire.clone()));
|
||||
tasks.push(tokio::spawn(async move { wire.trip() }));
|
||||
} else {
|
||||
let wire = TripWire::new();
|
||||
let wire2 = wire.clone();
|
||||
tasks.push(tokio::spawn(async move { wire.trip() }));
|
||||
tasks.push(tokio::spawn(wire2));
|
||||
}
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
|
||||
async fn staged_trip() {
|
||||
let wire = TripWire::new();
|
||||
let mut tasks = vec![];
|
||||
for i in 0..1050 {
|
||||
let wire = wire.clone();
|
||||
// Trip once every 100. 50 will be left "untripped", but should be.
|
||||
let task = if i % 100 == 0 {
|
||||
tokio::spawn(async move { wire.trip() })
|
||||
} else {
|
||||
tokio::spawn(wire)
|
||||
};
|
||||
|
||||
if i % 20 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
tasks.push(task);
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue