From e9d46b917e4ca6558b1ec76bdd7969e0c9465ed9 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Sat, 30 Apr 2022 15:41:07 -0700 Subject: [PATCH] Fully 'drop()' I/O struct in 'CancellableIo'. This should improve the reliability of graceful shutdown. --- core/http/Cargo.toml | 1 + core/http/src/lib.rs | 2 +- core/http/src/listener.rs | 115 +++++++++++++----------------- core/http/src/tls/listener.rs | 29 ++++---- core/lib/src/ext.rs | 129 ++++++++++++++-------------------- core/lib/src/server.rs | 4 +- 6 files changed, 122 insertions(+), 158 deletions(-) diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index bcc7542e..cb2228d0 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -44,6 +44,7 @@ memchr = "2" stable-pattern = "0.1" cookie = { version = "0.16.0", features = ["percent-encode", "secure"] } state = "0.5.1" +futures = { version = "0.3", default-features = false } [dependencies.x509-parser] version = "0.13" diff --git a/core/http/src/lib.rs b/core/http/src/lib.rs index 11289d43..8a0b8799 100644 --- a/core/http/src/lib.rs +++ b/core/http/src/lib.rs @@ -44,7 +44,7 @@ pub mod uncased { pub mod private { pub use crate::parse::Indexed; pub use smallvec::{SmallVec, Array}; - pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate}; + pub use crate::listener::{TcpListener, Incoming, Listener, Connection, RawCertificate}; pub use cookie; } diff --git a/core/http/src/listener.rs b/core/http/src/listener.rs index 2915daf1..beec78f1 100644 --- a/core/http/src/listener.rs +++ b/core/http/src/listener.rs @@ -6,12 +6,20 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use log::{debug, error}; +use log::warn; use hyper::server::accept::Accept; use tokio::time::Sleep; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpStream; + +pub use tokio::net::TcpListener; + +/// A thin wrapper over raw, DER-encoded X.509 client certificate data. +// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`. +#[doc(inline)] +#[cfg(feature = "tls")] +pub use rustls::Certificate as RawCertificate; // TODO.async: 'Listener' and 'Connection' provide common enough functionality // that they could be introduced in upstream libraries. @@ -23,24 +31,14 @@ pub trait Listener { /// Return the actual address this listener bound to. fn local_addr(&self) -> Option; - /// Try to accept an incoming Connection if ready + /// Try to accept an incoming Connection if ready. This should only return + /// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`. fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll>; } -/// A thin wrapper over raw, DER-encoded X.509 client certificate data. -#[cfg(not(feature = "tls"))] -#[derive(Clone, Eq, PartialEq)] -pub struct RawCertificate(pub Vec); - -/// A thin wrapper over raw, DER-encoded X.509 client certificate data. -// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`. -#[doc(inline)] -#[cfg(feature = "tls")] -pub use rustls::Certificate as RawCertificate; - /// A 'Connection' represents an open connection to a client pub trait Connection: AsyncRead + AsyncWrite { /// The remote address, i.e. the client's socket address, if it is known. @@ -62,6 +60,11 @@ pub trait Connection: AsyncRead + AsyncWrite { fn peer_certificates(&self) -> Option<&[RawCertificate]> { None } } +/// A thin wrapper over raw, DER-encoded X.509 client certificate data. +#[cfg(not(feature = "tls"))] +#[derive(Clone, Eq, PartialEq)] +pub struct RawCertificate(pub Vec); + pin_project_lite::pin_project! { /// This is a generic version of hyper's AddrIncoming that is intended to be /// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix @@ -116,47 +119,46 @@ impl Incoming { self } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut me = self.project(); - let mut optimistic_retry = true; + fn poll_accept_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + /// This function defines per-connection errors: errors that affect only + /// a single connection. Since the error affects only one connection, we + /// can attempt to `accept()` another connection immediately. All other + /// errors will incur a delay before the next `accept()` is performed. + /// The delay is useful to handle resource exhaustion errors like ENFILE + /// and EMFILE. Otherwise, could enter into tight loop. + fn is_connection_error(e: &io::Error) -> bool { + matches!(e.kind(), + | io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset) + } + + let mut this = self.project(); loop { - // Check if a previous sleep timer is active that was set by IO errors. - if let Some(delay) = me.pending_error_delay.as_mut().as_pin_mut() { - if optimistic_retry { - error!("optimistically retrying now"); - optimistic_retry = false; - } else { - error!("retrying in {:?}", me.sleep_on_errors); - match delay.poll(cx) { - Poll::Ready(()) => {} - Poll::Pending => return Poll::Pending, - } - } + // Check if a previous sleep timer is active, set on I/O errors. + if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() { + futures::ready!(delay.poll(cx)); } - me.pending_error_delay.set(None); + this.pending_error_delay.set(None); - match me.listener.as_mut().poll_accept(cx) { - Poll::Ready(Ok(stream)) => { - if *me.nodelay { - let _ = stream.enable_nodelay(); + match futures::ready!(this.listener.as_mut().poll_accept(cx)) { + Ok(stream) => { + if *this.nodelay { + if let Err(e) = stream.enable_nodelay() { + warn!("failed to enable NODELAY: {}", e); + } } return Poll::Ready(Ok(stream)); }, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => { - // Connection errors can be ignored directly, continue by - // accepting the next request. + Err(e) => { if is_connection_error(&e) { - debug!("accepted connection already errored: {}", e); - continue; - } - - if let Some(duration) = me.sleep_on_errors { - // Sleep for the specified duration - error!("connection accept error: {}", e); - me.pending_error_delay.set(Some(tokio::time::sleep(*duration))); + warn!("single connection accept error {}; accepting next now", e); + } else if let Some(duration) = this.sleep_on_errors { + // We might be able to recover. Try again in a bit. + warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis()); + this.pending_error_delay.set(Some(tokio::time::sleep(*duration))); } else { return Poll::Ready(Err(e)); } @@ -174,24 +176,10 @@ impl Accept for Incoming { self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll>> { - self.poll_next(cx).map(Some) + self.poll_accept_next(cx).map(Some) } } -/// This function defines errors that are per-connection. Which basically -/// means that if we get this error from `accept()` system call it means -/// next connection might be ready to be accepted. -/// -/// All other errors will incur a delay before next `accept()` is performed. -/// The delay is useful to handle resource exhaustion errors like ENFILE -/// and EMFILE. Otherwise, could enter into tight loop. -fn is_connection_error(e: &io::Error) -> bool { - matches!(e.kind(), - io::ErrorKind::ConnectionRefused | - io::ErrorKind::ConnectionAborted | - io::ErrorKind::ConnectionReset) - } - impl fmt::Debug for Incoming { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Incoming") @@ -200,11 +188,6 @@ impl fmt::Debug for Incoming { } } -/// Binds a TCP listener to `address` and returns it. -pub async fn bind_tcp(address: SocketAddr) -> io::Result { - Ok(TcpListener::bind(address).await?) -} - impl Listener for TcpListener { type Connection = TcpStream; diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs index c4cd44cd..6c3909c0 100644 --- a/core/http/src/tls/listener.rs +++ b/core/http/src/tls/listener.rs @@ -5,6 +5,7 @@ use std::task::{Context, Poll}; use std::net::SocketAddr; use std::future::Future; +use futures::ready; use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; use tokio::net::{TcpListener, TcpStream}; @@ -20,7 +21,7 @@ pub struct TlsListener { enum State { Listening, - Accepting(Accept), + Accepting(Accept, SocketAddr), } pub struct Config { @@ -92,23 +93,25 @@ impl Listener for TlsListener { cx: &mut Context<'_> ) -> Poll> { loop { - match self.state { + match &mut self.state { State::Listening => { - match self.listener.poll_accept(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok((stream, _addr))) => { - let fut = self.acceptor.accept(stream); - self.state = State::Accepting(fut); + match ready!(self.listener.poll_accept(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok((stream, addr)) => { + let accept = self.acceptor.accept(stream); + self.state = State::Accepting(accept, addr); } } } - State::Accepting(ref mut fut) => { - match Pin::new(fut).poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(result) => { + State::Accepting(accept, addr) => { + match ready!(Pin::new(accept).poll(cx)) { + Ok(stream) => { + self.state = State::Listening; + return Poll::Ready(Ok(stream)); + }, + Err(e) => { + log::warn!("TLS accept {} failure: {}", addr, e); self.state = State::Listening; - return Poll::Ready(result); } } } diff --git a/core/lib/src/ext.rs b/core/lib/src/ext.rs index 02251882..d2a8aab8 100644 --- a/core/lib/src/ext.rs +++ b/core/lib/src/ext.rs @@ -135,10 +135,6 @@ enum State { /// Grace period elapsed. Shutdown the connection, waiting for the timer /// until we force close. Mercy(Pin>), - /// We failed to shutdown and are force-closing the connection. - Terminated, - /// We successfully shutdown the connection. - Inactive, } pin_project! { @@ -146,7 +142,7 @@ pin_project! { #[must_use = "futures do nothing unless polled"] pub struct CancellableIo { #[pin] - io: I, + io: Option, #[pin] trigger: future::Fuse, state: State, @@ -158,82 +154,60 @@ pin_project! { impl CancellableIo { pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self { CancellableIo { - io, grace, mercy, + grace, mercy, + io: Some(io), trigger: trigger.fuse(), - state: State::Active + state: State::Active, } } - /// Returns `Ok(true)` if connection processing should continue. + pub fn io(&self) -> Option<&I> { + self.io.as_ref() + } + + /// Run `do_io` while connection processing should continue. fn poll_trigger_then( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, - io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, + do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, ) -> Poll> { - let mut me = self.project(); - - // CORRECTNESS: _EVERY_ branch must reset `state`! If `state` is - // unchanged in a branch, that branch _must_ `break`! No `return`! - let mut state = std::mem::replace(me.state, State::Active); - let result = loop { - match state { - State::Active => { - if me.trigger.as_mut().poll(cx).is_ready() { - state = State::Grace(Box::pin(sleep(*me.grace))); - } else { - state = State::Active; - break io(me.io, cx); - } - } - State::Grace(mut sleep) => { - if sleep.as_mut().poll(cx).is_ready() { - if let Some(deadline) = sleep.deadline().checked_add(*me.mercy) { - sleep.as_mut().reset(deadline); - state = State::Mercy(sleep); - } else { - state = State::Terminated; - } - } else { - state = State::Grace(sleep); - break io(me.io, cx); - } - }, - State::Mercy(mut sleep) => { - if sleep.as_mut().poll(cx).is_ready() { - state = State::Terminated; - continue; - } - - match me.io.as_mut().poll_shutdown(cx) { - Poll::Ready(Err(e)) => { - state = State::Terminated; - break Poll::Ready(Err(e)); - } - Poll::Ready(Ok(())) => { - state = State::Inactive; - break Poll::Ready(Err(gone())); - } - Poll::Pending => { - state = State::Mercy(sleep); - break Poll::Pending; - } - } - }, - State::Terminated => { - // Just in case, as a last ditch effort. Ignore pending. - state = State::Terminated; - let _ = me.io.as_mut().poll_shutdown(cx); - break Poll::Ready(Err(time_out())); - }, - State::Inactive => { - state = State::Inactive; - break Poll::Ready(Err(gone())); - } - } + let mut me = self.as_mut().project(); + let io = match me.io.as_pin_mut() { + Some(io) => io, + None => return Poll::Ready(Err(gone())), }; - *me.state = state; - result + loop { + match me.state { + State::Active => { + if me.trigger.as_mut().poll(cx).is_ready() { + *me.state = State::Grace(Box::pin(sleep(*me.grace))); + } else { + return do_io(io, cx); + } + } + State::Grace(timer) => { + if timer.as_mut().poll(cx).is_ready() { + *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); + } else { + return do_io(io, cx); + } + } + State::Mercy(timer) => { + if timer.as_mut().poll(cx).is_ready() { + self.project().io.set(None); + return Poll::Ready(Err(time_out())); + } else { + let result = futures::ready!(io.poll_shutdown(cx)); + self.project().io.set(None); + return match result { + Err(e) => Poll::Ready(Err(e)), + Ok(()) => Poll::Ready(Err(gone())) + }; + } + }, + } + } } } @@ -287,7 +261,7 @@ impl AsyncWrite for CancellableIo { } fn is_write_vectored(&self) -> bool { - self.io.is_write_vectored() + self.io().map(|io| io.is_write_vectored()).unwrap_or(false) } } @@ -295,15 +269,18 @@ use crate::http::private::{Listener, Connection, RawCertificate}; impl Connection for CancellableIo { fn peer_address(&self) -> Option { - self.io.peer_address() + self.io().and_then(|io| io.peer_address()) } fn peer_certificates(&self) -> Option<&[RawCertificate]> { - self.io.peer_certificates() + self.io().and_then(|io| io.peer_certificates()) } fn enable_nodelay(&self) -> io::Result<()> { - self.io.enable_nodelay() + match self.io() { + Some(io) => io.enable_nodelay(), + None => Ok(()) + } } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 5f793160..9e08c8d5 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -15,7 +15,7 @@ use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo}; use crate::request::ConnectionMeta; use crate::http::{uri::Origin, hyper, Method, Status, Header}; -use crate::http::private::{bind_tcp, Listener, Connection, Incoming}; +use crate::http::private::{TcpListener, Listener, Connection, Incoming}; // A token returned to force the execution of one method before another. pub(crate) struct RequestToken; @@ -381,7 +381,7 @@ impl Rocket { } } - let l = bind_tcp(addr).await.map_err(ErrorKind::Bind)?; + let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?; addr = l.local_addr().unwrap_or(addr); self.config.address = addr.ip(); self.config.port = addr.port();