Fully 'drop()' I/O struct in 'CancellableIo'.

This should improve the reliability of graceful shutdown.
This commit is contained in:
Sergio Benitez 2022-04-30 15:41:07 -07:00
parent bf84b1cdb5
commit e9d46b917e
6 changed files with 122 additions and 158 deletions

View File

@ -44,6 +44,7 @@ memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.16.0", features = ["percent-encode", "secure"] }
state = "0.5.1"
futures = { version = "0.3", default-features = false }
[dependencies.x509-parser]
version = "0.13"

View File

@ -44,7 +44,7 @@ pub mod uncased {
pub mod private {
pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array};
pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate};
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, RawCertificate};
pub use cookie;
}

View File

@ -6,12 +6,20 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use log::{debug, error};
use log::warn;
use hyper::server::accept::Accept;
use tokio::time::Sleep;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpStream;
pub use tokio::net::TcpListener;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries.
@ -23,24 +31,14 @@ pub trait Listener {
/// Return the actual address this listener bound to.
fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready
/// Try to accept an incoming Connection if ready. This should only return
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>>;
}
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
/// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite {
/// The remote address, i.e. the client's socket address, if it is known.
@ -62,6 +60,11 @@ pub trait Connection: AsyncRead + AsyncWrite {
fn peer_certificates(&self) -> Option<&[RawCertificate]> { None }
}
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
pin_project_lite::pin_project! {
/// This is a generic version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
@ -116,47 +119,46 @@ impl<L: Listener> Incoming<L> {
self
}
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> {
let mut me = self.project();
let mut optimistic_retry = true;
fn poll_accept_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> {
/// This function defines per-connection errors: errors that affect only
/// a single connection. Since the error affects only one connection, we
/// can attempt to `accept()` another connection immediately. All other
/// errors will incur a delay before the next `accept()` is performed.
/// The delay is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset)
}
let mut this = self.project();
loop {
// Check if a previous sleep timer is active that was set by IO errors.
if let Some(delay) = me.pending_error_delay.as_mut().as_pin_mut() {
if optimistic_retry {
error!("optimistically retrying now");
optimistic_retry = false;
} else {
error!("retrying in {:?}", me.sleep_on_errors);
match delay.poll(cx) {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
// Check if a previous sleep timer is active, set on I/O errors.
if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
futures::ready!(delay.poll(cx));
}
me.pending_error_delay.set(None);
this.pending_error_delay.set(None);
match me.listener.as_mut().poll_accept(cx) {
Poll::Ready(Ok(stream)) => {
if *me.nodelay {
let _ = stream.enable_nodelay();
match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
Ok(stream) => {
if *this.nodelay {
if let Err(e) = stream.enable_nodelay() {
warn!("failed to enable NODELAY: {}", e);
}
}
return Poll::Ready(Ok(stream));
},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
// Connection errors can be ignored directly, continue by
// accepting the next request.
Err(e) => {
if is_connection_error(&e) {
debug!("accepted connection already errored: {}", e);
continue;
}
if let Some(duration) = me.sleep_on_errors {
// Sleep for the specified duration
error!("connection accept error: {}", e);
me.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
warn!("single connection accept error {}; accepting next now", e);
} else if let Some(duration) = this.sleep_on_errors {
// We might be able to recover. Try again in a bit.
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
} else {
return Poll::Ready(Err(e));
}
@ -174,24 +176,10 @@ impl<L: Listener> Accept for Incoming<L> {
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Option<io::Result<Self::Conn>>> {
self.poll_next(cx).map(Some)
self.poll_accept_next(cx).map(Some)
}
}
/// This function defines errors that are per-connection. Which basically
/// means that if we get this error from `accept()` system call it means
/// next connection might be ready to be accepted.
///
/// All other errors will incur a delay before next `accept()` is performed.
/// The delay is useful to handle resource exhaustion errors like ENFILE
/// and EMFILE. Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
io::ErrorKind::ConnectionRefused |
io::ErrorKind::ConnectionAborted |
io::ErrorKind::ConnectionReset)
}
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Incoming")
@ -200,11 +188,6 @@ impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
}
}
/// Binds a TCP listener to `address` and returns it.
pub async fn bind_tcp(address: SocketAddr) -> io::Result<TcpListener> {
Ok(TcpListener::bind(address).await?)
}
impl Listener for TcpListener {
type Connection = TcpStream;

View File

@ -5,6 +5,7 @@ use std::task::{Context, Poll};
use std::net::SocketAddr;
use std::future::Future;
use futures::ready;
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
use tokio::net::{TcpListener, TcpStream};
@ -20,7 +21,7 @@ pub struct TlsListener {
enum State {
Listening,
Accepting(Accept<TcpStream>),
Accepting(Accept<TcpStream>, SocketAddr),
}
pub struct Config<R> {
@ -92,23 +93,25 @@ impl Listener for TlsListener {
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
loop {
match self.state {
match &mut self.state {
State::Listening => {
match self.listener.poll_accept(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok((stream, _addr))) => {
let fut = self.acceptor.accept(stream);
self.state = State::Accepting(fut);
match ready!(self.listener.poll_accept(cx)) {
Err(e) => return Poll::Ready(Err(e)),
Ok((stream, addr)) => {
let accept = self.acceptor.accept(stream);
self.state = State::Accepting(accept, addr);
}
}
}
State::Accepting(ref mut fut) => {
match Pin::new(fut).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
State::Accepting(accept, addr) => {
match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
self.state = State::Listening;
return Poll::Ready(Ok(stream));
},
Err(e) => {
log::warn!("TLS accept {} failure: {}", addr, e);
self.state = State::Listening;
return Poll::Ready(result);
}
}
}

View File

@ -135,10 +135,6 @@ enum State {
/// Grace period elapsed. Shutdown the connection, waiting for the timer
/// until we force close.
Mercy(Pin<Box<Sleep>>),
/// We failed to shutdown and are force-closing the connection.
Terminated,
/// We successfully shutdown the connection.
Inactive,
}
pin_project! {
@ -146,7 +142,7 @@ pin_project! {
#[must_use = "futures do nothing unless polled"]
pub struct CancellableIo<F, I> {
#[pin]
io: I,
io: Option<I>,
#[pin]
trigger: future::Fuse<F>,
state: State,
@ -158,82 +154,60 @@ pin_project! {
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
CancellableIo {
io, grace, mercy,
grace, mercy,
io: Some(io),
trigger: trigger.fuse(),
state: State::Active
state: State::Active,
}
}
/// Returns `Ok(true)` if connection processing should continue.
pub fn io(&self) -> Option<&I> {
self.io.as_ref()
}
/// Run `do_io` while connection processing should continue.
fn poll_trigger_then<T>(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
) -> Poll<io::Result<T>> {
let mut me = self.project();
// CORRECTNESS: _EVERY_ branch must reset `state`! If `state` is
// unchanged in a branch, that branch _must_ `break`! No `return`!
let mut state = std::mem::replace(me.state, State::Active);
let result = loop {
match state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
state = State::Active;
break io(me.io, cx);
}
}
State::Grace(mut sleep) => {
if sleep.as_mut().poll(cx).is_ready() {
if let Some(deadline) = sleep.deadline().checked_add(*me.mercy) {
sleep.as_mut().reset(deadline);
state = State::Mercy(sleep);
} else {
state = State::Terminated;
}
} else {
state = State::Grace(sleep);
break io(me.io, cx);
}
},
State::Mercy(mut sleep) => {
if sleep.as_mut().poll(cx).is_ready() {
state = State::Terminated;
continue;
}
match me.io.as_mut().poll_shutdown(cx) {
Poll::Ready(Err(e)) => {
state = State::Terminated;
break Poll::Ready(Err(e));
}
Poll::Ready(Ok(())) => {
state = State::Inactive;
break Poll::Ready(Err(gone()));
}
Poll::Pending => {
state = State::Mercy(sleep);
break Poll::Pending;
}
}
},
State::Terminated => {
// Just in case, as a last ditch effort. Ignore pending.
state = State::Terminated;
let _ = me.io.as_mut().poll_shutdown(cx);
break Poll::Ready(Err(time_out()));
},
State::Inactive => {
state = State::Inactive;
break Poll::Ready(Err(gone()));
}
}
let mut me = self.as_mut().project();
let io = match me.io.as_pin_mut() {
Some(io) => io,
None => return Poll::Ready(Err(gone())),
};
*me.state = state;
result
loop {
match me.state {
State::Active => {
if me.trigger.as_mut().poll(cx).is_ready() {
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
} else {
return do_io(io, cx);
}
}
State::Grace(timer) => {
if timer.as_mut().poll(cx).is_ready() {
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
} else {
return do_io(io, cx);
}
}
State::Mercy(timer) => {
if timer.as_mut().poll(cx).is_ready() {
self.project().io.set(None);
return Poll::Ready(Err(time_out()));
} else {
let result = futures::ready!(io.poll_shutdown(cx));
self.project().io.set(None);
return match result {
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Err(gone()))
};
}
},
}
}
}
}
@ -287,7 +261,7 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
}
fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored()
self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
}
}
@ -295,15 +269,18 @@ use crate::http::private::{Listener, Connection, RawCertificate};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn peer_address(&self) -> Option<std::net::SocketAddr> {
self.io.peer_address()
self.io().and_then(|io| io.peer_address())
}
fn peer_certificates(&self) -> Option<&[RawCertificate]> {
self.io.peer_certificates()
self.io().and_then(|io| io.peer_certificates())
}
fn enable_nodelay(&self) -> io::Result<()> {
self.io.enable_nodelay()
match self.io() {
Some(io) => io.enable_nodelay(),
None => Ok(())
}
}
}

View File

@ -15,7 +15,7 @@ use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
use crate::request::ConnectionMeta;
use crate::http::{uri::Origin, hyper, Method, Status, Header};
use crate::http::private::{bind_tcp, Listener, Connection, Incoming};
use crate::http::private::{TcpListener, Listener, Connection, Incoming};
// A token returned to force the execution of one method before another.
pub(crate) struct RequestToken;
@ -381,7 +381,7 @@ impl Rocket<Orbit> {
}
}
let l = bind_tcp(addr).await.map_err(ErrorKind::Bind)?;
let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?;
addr = l.local_addr().unwrap_or(addr);
self.config.address = addr.ip();
self.config.port = addr.port();