Fully 'drop()' I/O struct in 'CancellableIo'.

This should improve the reliability of graceful shutdown.
This commit is contained in:
Sergio Benitez 2022-04-30 15:41:07 -07:00
parent bf84b1cdb5
commit e9d46b917e
6 changed files with 122 additions and 158 deletions

View File

@ -44,6 +44,7 @@ memchr = "2"
stable-pattern = "0.1" stable-pattern = "0.1"
cookie = { version = "0.16.0", features = ["percent-encode", "secure"] } cookie = { version = "0.16.0", features = ["percent-encode", "secure"] }
state = "0.5.1" state = "0.5.1"
futures = { version = "0.3", default-features = false }
[dependencies.x509-parser] [dependencies.x509-parser]
version = "0.13" version = "0.13"

View File

@ -44,7 +44,7 @@ pub mod uncased {
pub mod private { pub mod private {
pub use crate::parse::Indexed; pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array}; pub use smallvec::{SmallVec, Array};
pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate}; pub use crate::listener::{TcpListener, Incoming, Listener, Connection, RawCertificate};
pub use cookie; pub use cookie;
} }

View File

@ -6,12 +6,20 @@ use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use log::{debug, error}; use log::warn;
use hyper::server::accept::Accept; use hyper::server::accept::Accept;
use tokio::time::Sleep; use tokio::time::Sleep;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::TcpStream;
pub use tokio::net::TcpListener;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
// TODO.async: 'Listener' and 'Connection' provide common enough functionality // TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries. // that they could be introduced in upstream libraries.
@ -23,24 +31,14 @@ pub trait Listener {
/// Return the actual address this listener bound to. /// Return the actual address this listener bound to.
fn local_addr(&self) -> Option<SocketAddr>; fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready /// Try to accept an incoming Connection if ready. This should only return
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
fn poll_accept( fn poll_accept(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_> cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>>; ) -> Poll<io::Result<Self::Connection>>;
} }
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
/// A 'Connection' represents an open connection to a client /// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite { pub trait Connection: AsyncRead + AsyncWrite {
/// The remote address, i.e. the client's socket address, if it is known. /// The remote address, i.e. the client's socket address, if it is known.
@ -62,6 +60,11 @@ pub trait Connection: AsyncRead + AsyncWrite {
fn peer_certificates(&self) -> Option<&[RawCertificate]> { None } fn peer_certificates(&self) -> Option<&[RawCertificate]> { None }
} }
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
pin_project_lite::pin_project! { pin_project_lite::pin_project! {
/// This is a generic version of hyper's AddrIncoming that is intended to be /// This is a generic version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix /// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
@ -116,47 +119,46 @@ impl<L: Listener> Incoming<L> {
self self
} }
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> { fn poll_accept_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> {
let mut me = self.project(); /// This function defines per-connection errors: errors that affect only
let mut optimistic_retry = true; /// a single connection. Since the error affects only one connection, we
/// can attempt to `accept()` another connection immediately. All other
/// errors will incur a delay before the next `accept()` is performed.
/// The delay is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset)
}
let mut this = self.project();
loop { loop {
// Check if a previous sleep timer is active that was set by IO errors. // Check if a previous sleep timer is active, set on I/O errors.
if let Some(delay) = me.pending_error_delay.as_mut().as_pin_mut() { if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
if optimistic_retry { futures::ready!(delay.poll(cx));
error!("optimistically retrying now");
optimistic_retry = false;
} else {
error!("retrying in {:?}", me.sleep_on_errors);
match delay.poll(cx) {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
} }
me.pending_error_delay.set(None); this.pending_error_delay.set(None);
match me.listener.as_mut().poll_accept(cx) { match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
Poll::Ready(Ok(stream)) => { Ok(stream) => {
if *me.nodelay { if *this.nodelay {
let _ = stream.enable_nodelay(); if let Err(e) = stream.enable_nodelay() {
warn!("failed to enable NODELAY: {}", e);
}
} }
return Poll::Ready(Ok(stream)); return Poll::Ready(Ok(stream));
}, },
Poll::Pending => return Poll::Pending, Err(e) => {
Poll::Ready(Err(e)) => {
// Connection errors can be ignored directly, continue by
// accepting the next request.
if is_connection_error(&e) { if is_connection_error(&e) {
debug!("accepted connection already errored: {}", e); warn!("single connection accept error {}; accepting next now", e);
continue; } else if let Some(duration) = this.sleep_on_errors {
} // We might be able to recover. Try again in a bit.
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
if let Some(duration) = me.sleep_on_errors { this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
// Sleep for the specified duration
error!("connection accept error: {}", e);
me.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
} else { } else {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
@ -174,24 +176,10 @@ impl<L: Listener> Accept for Incoming<L> {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_> cx: &mut Context<'_>
) -> Poll<Option<io::Result<Self::Conn>>> { ) -> Poll<Option<io::Result<Self::Conn>>> {
self.poll_next(cx).map(Some) self.poll_accept_next(cx).map(Some)
} }
} }
/// This function defines errors that are per-connection. Which basically
/// means that if we get this error from `accept()` system call it means
/// next connection might be ready to be accepted.
///
/// All other errors will incur a delay before next `accept()` is performed.
/// The delay is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
io::ErrorKind::ConnectionRefused |
io::ErrorKind::ConnectionAborted |
io::ErrorKind::ConnectionReset)
}
impl<L: fmt::Debug> fmt::Debug for Incoming<L> { impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Incoming") f.debug_struct("Incoming")
@ -200,11 +188,6 @@ impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
} }
} }
/// Binds a TCP listener to `address` and returns it.
pub async fn bind_tcp(address: SocketAddr) -> io::Result<TcpListener> {
Ok(TcpListener::bind(address).await?)
}
impl Listener for TcpListener { impl Listener for TcpListener {
type Connection = TcpStream; type Connection = TcpStream;

View File

@ -5,6 +5,7 @@ use std::task::{Context, Poll};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::future::Future; use std::future::Future;
use futures::ready;
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -20,7 +21,7 @@ pub struct TlsListener {
enum State { enum State {
Listening, Listening,
Accepting(Accept<TcpStream>), Accepting(Accept<TcpStream>, SocketAddr),
} }
pub struct Config<R> { pub struct Config<R> {
@ -92,23 +93,25 @@ impl Listener for TlsListener {
cx: &mut Context<'_> cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> { ) -> Poll<io::Result<Self::Connection>> {
loop { loop {
match self.state { match &mut self.state {
State::Listening => { State::Listening => {
match self.listener.poll_accept(cx) { match ready!(self.listener.poll_accept(cx)) {
Poll::Pending => return Poll::Pending, Err(e) => return Poll::Ready(Err(e)),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Ok((stream, addr)) => {
Poll::Ready(Ok((stream, _addr))) => { let accept = self.acceptor.accept(stream);
let fut = self.acceptor.accept(stream); self.state = State::Accepting(accept, addr);
self.state = State::Accepting(fut);
} }
} }
} }
State::Accepting(ref mut fut) => { State::Accepting(accept, addr) => {
match Pin::new(fut).poll(cx) { match ready!(Pin::new(accept).poll(cx)) {
Poll::Pending => return Poll::Pending, Ok(stream) => {
Poll::Ready(result) => { self.state = State::Listening;
return Poll::Ready(Ok(stream));
},
Err(e) => {
log::warn!("TLS accept {} failure: {}", addr, e);
self.state = State::Listening; self.state = State::Listening;
return Poll::Ready(result);
} }
} }
} }

View File

@ -135,10 +135,6 @@ enum State {
/// Grace period elapsed. Shutdown the connection, waiting for the timer /// Grace period elapsed. Shutdown the connection, waiting for the timer
/// until we force close. /// until we force close.
Mercy(Pin<Box<Sleep>>), Mercy(Pin<Box<Sleep>>),
/// We failed to shutdown and are force-closing the connection.
Terminated,
/// We successfully shutdown the connection.
Inactive,
} }
pin_project! { pin_project! {
@ -146,7 +142,7 @@ pin_project! {
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> { pub struct CancellableIo<F, I> {
#[pin] #[pin]
io: I, io: Option<I>,
#[pin] #[pin]
trigger: future::Fuse<F>, trigger: future::Fuse<F>,
state: State, state: State,
@ -158,82 +154,60 @@ pin_project! {
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> { impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self { pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
CancellableIo { CancellableIo {
io, grace, mercy, grace, mercy,
io: Some(io),
trigger: trigger.fuse(), trigger: trigger.fuse(),
state: State::Active state: State::Active,
} }
} }
/// Returns `Ok(true)` if connection processing should continue. pub fn io(&self) -> Option<&I> {
self.io.as_ref()
}
/// Run `do_io` while connection processing should continue.
fn poll_trigger_then<T>( fn poll_trigger_then<T>(
self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>, do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
) -> Poll<io::Result<T>> { ) -> Poll<io::Result<T>> {
let mut me = self.project(); let mut me = self.as_mut().project();
let io = match me.io.as_pin_mut() {
// CORRECTNESS: _EVERY_ branch must reset `state`! If `state` is Some(io) => io,
// unchanged in a branch, that branch _must_ `break`! No `return`! None => return Poll::Ready(Err(gone())),
let mut state = std::mem::replace(me.state, State::Active);
let result = loop {
match state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
state = State::Active;
break io(me.io, cx);
}
}
State::Grace(mut sleep) => {
if sleep.as_mut().poll(cx).is_ready() {
if let Some(deadline) = sleep.deadline().checked_add(*me.mercy) {
sleep.as_mut().reset(deadline);
state = State::Mercy(sleep);
} else {
state = State::Terminated;
}
} else {
state = State::Grace(sleep);
break io(me.io, cx);
}
},
State::Mercy(mut sleep) => {
if sleep.as_mut().poll(cx).is_ready() {
state = State::Terminated;
continue;
}
match me.io.as_mut().poll_shutdown(cx) {
Poll::Ready(Err(e)) => {
state = State::Terminated;
break Poll::Ready(Err(e));
}
Poll::Ready(Ok(())) => {
state = State::Inactive;
break Poll::Ready(Err(gone()));
}
Poll::Pending => {
state = State::Mercy(sleep);
break Poll::Pending;
}
}
},
State::Terminated => {
// Just in case, as a last ditch effort. Ignore pending.
state = State::Terminated;
let _ = me.io.as_mut().poll_shutdown(cx);
break Poll::Ready(Err(time_out()));
},
State::Inactive => {
state = State::Inactive;
break Poll::Ready(Err(gone()));
}
}
}; };
*me.state = state; loop {
result match me.state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
return do_io(io, cx);
}
}
State::Grace(timer) => {
if timer.as_mut().poll(cx).is_ready() {
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
} else {
return do_io(io, cx);
}
}
State::Mercy(timer) => {
if timer.as_mut().poll(cx).is_ready() {
self.project().io.set(None);
return Poll::Ready(Err(time_out()));
} else {
let result = futures::ready!(io.poll_shutdown(cx));
self.project().io.set(None);
return match result {
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Err(gone()))
};
}
},
}
}
} }
} }
@ -287,7 +261,7 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
} }
fn is_write_vectored(&self) -> bool { fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored() self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
} }
} }
@ -295,15 +269,18 @@ use crate::http::private::{Listener, Connection, RawCertificate};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> { impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn peer_address(&self) -> Option<std::net::SocketAddr> { fn peer_address(&self) -> Option<std::net::SocketAddr> {
self.io.peer_address() self.io().and_then(|io| io.peer_address())
} }
fn peer_certificates(&self) -> Option<&[RawCertificate]> { fn peer_certificates(&self) -> Option<&[RawCertificate]> {
self.io.peer_certificates() self.io().and_then(|io| io.peer_certificates())
} }
fn enable_nodelay(&self) -> io::Result<()> { fn enable_nodelay(&self) -> io::Result<()> {
self.io.enable_nodelay() match self.io() {
Some(io) => io.enable_nodelay(),
None => Ok(())
}
} }
} }

View File

@ -15,7 +15,7 @@ use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
use crate::request::ConnectionMeta; use crate::request::ConnectionMeta;
use crate::http::{uri::Origin, hyper, Method, Status, Header}; use crate::http::{uri::Origin, hyper, Method, Status, Header};
use crate::http::private::{bind_tcp, Listener, Connection, Incoming}; use crate::http::private::{TcpListener, Listener, Connection, Incoming};
// A token returned to force the execution of one method before another. // A token returned to force the execution of one method before another.
pub(crate) struct RequestToken; pub(crate) struct RequestToken;
@ -381,7 +381,7 @@ impl Rocket<Orbit> {
} }
} }
let l = bind_tcp(addr).await.map_err(ErrorKind::Bind)?; let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?;
addr = l.local_addr().unwrap_or(addr); addr = l.local_addr().unwrap_or(addr);
self.config.address = addr.ip(); self.config.address = addr.ip();
self.config.port = addr.port(); self.config.port = addr.port();