Rework TLS listener/connection implementations.

The previous implementation allowed a trivial DoS attack in which the
client need simply maintain open connections with incomplete handshakes.
This commit resolves that by allowing a server worker to progress as
soon as a TCP connection has been established. This comes at the expense
of a more complex implementation necessitated by deficiencies in Hyper.

Potentially resolves #2118.
This commit is contained in:
Sergio Benitez 2022-05-03 08:27:39 -07:00
parent e9d46b917e
commit 07460df279
9 changed files with 236 additions and 100 deletions

View File

@ -31,7 +31,7 @@ http = "0.2"
time = { version = "0.3", features = ["formatting", "macros"] }
indexmap = { version = "1.5.2", features = ["std"] }
rustls = { version = "0.20", optional = true }
tokio-rustls = { version = "0.23.0", optional = true }
# tokio-rustls = { version = "0.23.0", optional = true }
rustls-pemfile = { version = "1", optional = true }
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
log = "0.4"
@ -43,9 +43,14 @@ pin-project-lite = "0.2"
memchr = "2"
stable-pattern = "0.1"
cookie = { version = "0.16.0", features = ["percent-encode", "secure"] }
state = "0.5.1"
state = "0.5.3"
futures = { version = "0.3", default-features = false }
[dependencies.tokio-rustls]
git = "https://github.com/SergioBenitez/tokio-tls/"
branch = "stream-from-accept"
optional = true
[dependencies.x509-parser]
version = "0.13"
optional = true

View File

@ -44,7 +44,7 @@ pub mod uncased {
pub mod private {
pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array};
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, RawCertificate};
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates};
pub use cookie;
}

View File

@ -5,21 +5,45 @@ use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::sync::Arc;
use log::warn;
use hyper::server::accept::Accept;
use tokio::time::Sleep;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use hyper::server::accept::Accept;
use state::Storage;
pub use tokio::net::TcpListener;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
pub use rustls::Certificate as CertificateData;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct CertificateData(pub Vec<u8>);
/// A collection of raw certificate data.
#[derive(Clone, Default)]
pub struct Certificates(Arc<Storage<Vec<CertificateData>>>);
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.
#[cfg(feature = "tls")]
pub(crate) fn set(&self, data: Vec<CertificateData>) {
self.0.set(data);
}
/// Returns the raw certificate chain data, if any is available.
pub fn chain_data(&self) -> Option<&[CertificateData]> {
self.0.try_get().map(|v| v.as_slice())
}
}
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries.
@ -57,14 +81,9 @@ pub trait Connection: AsyncRead + AsyncWrite {
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<&[RawCertificate]> { None }
fn peer_certificates(&self) -> Option<Certificates> { None }
}
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
pin_project_lite::pin_project! {
/// This is a generic version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
@ -119,7 +138,10 @@ impl<L: Listener> Incoming<L> {
self
}
fn poll_accept_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<L::Connection>> {
fn poll_accept_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<L::Connection>> {
/// This function defines per-connection errors: errors that affect only
/// a single connection. Since the error affects only one connection, we
/// can attempt to `accept()` another connection immediately. All other
@ -172,6 +194,7 @@ impl<L: Listener> Accept for Incoming<L> {
type Conn = L::Connection;
type Error = io::Error;
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
@ -191,10 +214,12 @@ impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
impl Listener for TcpListener {
type Connection = TcpStream;
#[inline]
fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr().ok()
}
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
@ -204,10 +229,12 @@ impl Listener for TcpListener {
}
impl Connection for TcpStream {
#[inline]
fn peer_address(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
}
#[inline]
fn enable_nodelay(&self) -> io::Result<()> {
self.set_nodelay(true)
}

View File

@ -2,28 +2,66 @@ use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::net::SocketAddr;
use std::future::Future;
use std::net::SocketAddr;
use futures::ready;
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
use crate::listener::{Connection, Listener, RawCertificate};
use crate::listener::{Connection, Listener, Certificates};
/// A TLS listener over TCP.
pub struct TlsListener {
listener: TcpListener,
acceptor: TlsAcceptor,
state: State,
}
enum State {
Listening,
Accepting(Accept<TcpStream>, SocketAddr),
/// This implementation exists so that ROCKET_WORKERS=1 can make progress while
/// a TLS handshake is being completed. It does this by returning `Ready` from
/// `poll_accept()` as soon as we have a TCP connection and performing the
/// handshake in the `AsyncRead` and `AsyncWrite` implementations.
///
/// A straight-forward implementation of this strategy results in none of the
/// TLS information being available at the time the connection is "established",
/// that is, when `poll_accept()` returns, since the handshake has yet to occur.
/// Importantly, certificate information isn't available at the time that we
/// request it.
///
/// The underlying problem is hyper's "Accept" trait. Were we to manage
/// connections ourselves, we'd likely want to:
///
/// 1. Stop blocking the worker as soon as we have a TCP connection.
/// 2. Perform the handshake in the background.
/// 3. Give the connection to Rocket when/if the handshake is done.
///
/// See hyperium/hyper/issues/2321 for more details.
///
/// To work around this, we "lie" when `peer_certificates()` are requested and
/// always return `Some(Certificates)`. Internally, `Certificates` is an
/// `Arc<Storage<Vec<CertificateData>>>`, effectively a shared, thread-safe,
/// `OnceCell`. The cell is initially empty and is filled as soon as the
/// handshake is complete. If the certificate data were to be requested prior to
/// this point, it would be empty. However, in Rocket, we only request
/// certificate data when we have a `Request` object, which implies we're
/// receiving payload data, which implies the TLS handshake has finished, so the
/// certificate data as seen by a Rocket application will always be "fresh".
pub struct TlsStream {
remote: SocketAddr,
state: TlsState,
certs: Certificates,
}
/// State of `TlsStream`.
pub enum TlsState {
/// The TLS handshake is taking place. We don't have a full connection yet.
Handshaking(Accept<TcpStream>),
/// TLS handshake completed successfully; we're getting payload data.
Streaming(BareTlsStream<TcpStream>),
}
/// TLS as ~configured by `TlsConfig` in `rocket` core.
pub struct Config<R> {
pub cert_chain: R,
pub private_key: R,
@ -77,59 +115,124 @@ impl TlsListener {
let listener = TcpListener::bind(addr).await?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
Ok(TlsListener { listener, acceptor, state: State::Listening })
Ok(TlsListener { listener, acceptor })
}
}
impl Listener for TlsListener {
type Connection = TlsStream<TcpStream>;
type Connection = TlsStream;
fn local_addr(&self) -> Option<SocketAddr> {
self.listener.local_addr().ok()
}
fn poll_accept(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
match futures::ready!(self.listener.poll_accept(cx)) {
Ok((io, addr)) => Poll::Ready(Ok(TlsStream {
remote: addr,
state: TlsState::Handshaking(self.acceptor.accept(io)),
// These are empty and filled in after handshake is complete.
certs: Certificates::default(),
})),
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl Connection for TlsStream {
fn peer_address(&self) -> Option<SocketAddr> {
Some(self.remote)
}
fn enable_nodelay(&self) -> io::Result<()> {
// If `Handshaking` is `None`, it either failed, so we returned an `Err`
// from `poll_accept()` and there's no connection to enable `NODELAY`
// on, or it succeeded, so we're in the `Streaming` stage and we have
// infallible access to the connection.
match &self.state {
TlsState::Handshaking(accept) => match accept.get_ref() {
None => Ok(()),
Some(s) => s.enable_nodelay(),
},
TlsState::Streaming(stream) => stream.get_ref().0.enable_nodelay()
}
}
fn peer_certificates(&self) -> Option<Certificates> {
Some(self.certs.clone())
}
}
impl TlsStream {
fn poll_accept_then<F, T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut f: F
) -> Poll<io::Result<T>>
where F: FnMut(&mut BareTlsStream<TcpStream>, &mut Context<'_>) -> Poll<io::Result<T>>
{
loop {
match &mut self.state {
State::Listening => {
match ready!(self.listener.poll_accept(cx)) {
Err(e) => return Poll::Ready(Err(e)),
Ok((stream, addr)) => {
let accept = self.acceptor.accept(stream);
self.state = State::Accepting(accept, addr);
}
}
}
State::Accepting(accept, addr) => {
match ready!(Pin::new(accept).poll(cx)) {
match self.state {
TlsState::Handshaking(ref mut accept) => {
match futures::ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
self.state = State::Listening;
return Poll::Ready(Ok(stream));
},
if let Some(cert_chain) = stream.get_ref().1.peer_certificates() {
self.certs.set(cert_chain.to_vec());
}
self.state = TlsState::Streaming(stream);
}
Err(e) => {
log::warn!("TLS accept {} failure: {}", addr, e);
self.state = State::Listening;
log::warn!("tls handshake with {} failed: {}", self.remote, e);
return Poll::Ready(Err(e));
}
}
}
},
TlsState::Streaming(ref mut stream) => return f(stream, cx),
}
}
}
}
impl Connection for TlsStream<TcpStream> {
fn peer_address(&self) -> Option<SocketAddr> {
self.get_ref().0.peer_address()
}
fn peer_certificates(&self) -> Option<&[RawCertificate]> {
self.get_ref().1.peer_certificates()
}
fn enable_nodelay(&self) -> io::Result<()> {
self.get_ref().0.enable_nodelay()
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_read(cx, buf))
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_write(cx, buf))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
TlsState::Handshaking(accept) => match accept.get_mut() {
Some(io) => Pin::new(io).poll_flush(cx),
None => Poll::Ready(Ok(())),
}
TlsState::Streaming(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
TlsState::Handshaking(accept) => match accept.get_mut() {
Some(io) => Pin::new(io).poll_shutdown(cx),
None => Poll::Ready(Ok(())),
}
TlsState::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

View File

@ -42,50 +42,12 @@ use x509_parser::nom;
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
use crate::listener::RawCertificate;
use crate::listener::CertificateData;
/// A type alias for [`Result`](std::result::Result) with the error type set to
/// [`Error`].
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// An error returned by the [`Certificate`] request guard.
///
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
/// guard type:
///
/// ```rust
/// # extern crate rocket;
/// # use rocket::get;
/// use rocket::mtls::{self, Certificate};
///
/// #[get("/auth")]
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
/// match cert {
/// Ok(cert) => { /* do something with the client cert */ },
/// Err(e) => { /* do something with the error */ },
/// }
/// }
/// ```
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
/// The certificate chain presented by the client had no certificates.
Empty,
/// The certificate contained neither a subject nor a subjectAlt extension.
NoSubject,
/// There is no subject and the subjectAlt is not marked as critical.
NonCriticalSubjectAlt,
/// An error occurred while parsing the certificate.
Parse(X509Error),
/// The certificate parsed partially but is incomplete.
///
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
/// of expected bytes is unknown.
Incomplete(Option<NonZeroUsize>),
/// The certificate contained `.0` bytes of trailing data.
Trailing(usize),
}
/// A request guard for validated, verified client certificates.
///
/// This type is a wrapper over [`x509::TbsCertificate`] with convenient
@ -191,6 +153,44 @@ pub struct Certificate<'a>(X509Certificate<'a>);
#[derive(Debug, PartialEq, RefCast)]
pub struct Name<'a>(X509Name<'a>);
/// An error returned by the [`Certificate`] request guard.
///
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
/// guard type:
///
/// ```rust
/// # extern crate rocket;
/// # use rocket::get;
/// use rocket::mtls::{self, Certificate};
///
/// #[get("/auth")]
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
/// match cert {
/// Ok(cert) => { /* do something with the client cert */ },
/// Err(e) => { /* do something with the error */ },
/// }
/// }
/// ```
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
/// The certificate chain presented by the client had no certificates.
Empty,
/// The certificate contained neither a subject nor a subjectAlt extension.
NoSubject,
/// There is no subject and the subjectAlt is not marked as critical.
NonCriticalSubjectAlt,
/// An error occurred while parsing the certificate.
Parse(X509Error),
/// The certificate parsed partially but is incomplete.
///
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
/// of expected bytes is unknown.
Incomplete(Option<NonZeroUsize>),
/// The certificate contained `.0` bytes of trailing data.
Trailing(usize),
}
impl<'a> Certificate<'a> {
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> {
let (left, x509) = X509Certificate::from_der(raw)?;
@ -221,7 +221,7 @@ impl<'a> Certificate<'a> {
/// PRIVATE: For internal Rocket use only!
#[doc(hidden)]
pub fn parse(chain: &[RawCertificate]) -> Result<Certificate<'_>> {
pub fn parse(chain: &[CertificateData]) -> Result<Certificate<'_>> {
match chain.first() {
Some(cert) => Certificate::parse_one(&cert.0).map(Certificate),
None => Err(Error::Empty)

View File

@ -265,14 +265,14 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
}
}
use crate::http::private::{Listener, Connection, RawCertificate};
use crate::http::private::{Listener, Connection, Certificates};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn peer_address(&self) -> Option<std::net::SocketAddr> {
self.io().and_then(|io| io.peer_address())
}
fn peer_certificates(&self) -> Option<&[RawCertificate]> {
fn peer_certificates(&self) -> Option<Certificates> {
self.io().and_then(|io| io.peer_certificates())
}

View File

@ -19,6 +19,7 @@ impl<'r> FromRequest<'r> for Certificate<'r> {
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let certs = try_outcome!(req.connection.client_certificates.as_ref().or_forward(()));
Certificate::parse(certs).into_outcome(Status::Unauthorized)
let data = try_outcome!(certs.chain_data().or_forward(()));
Certificate::parse(data).into_outcome(Status::Unauthorized)
}
}

View File

@ -16,7 +16,7 @@ use crate::data::Limits;
use crate::http::{hyper, Method, Header, HeaderMap};
use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie};
use crate::http::uncased::UncasedStr;
use crate::http::private::RawCertificate;
use crate::http::private::Certificates;
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
/// The type of an incoming web request.
@ -38,7 +38,7 @@ pub struct Request<'r> {
pub(crate) struct ConnectionMeta {
pub remote: Option<SocketAddr>,
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
pub client_certificates: Option<Arc<Vec<RawCertificate>>>,
pub client_certificates: Option<Certificates>,
}
/// Information derived from the request.

View File

@ -440,7 +440,7 @@ impl Rocket<Orbit> {
let rocket = rocket.clone();
let connection = ConnectionMeta {
remote: conn.peer_address(),
client_certificates: conn.peer_certificates().map(|certs| Arc::new(certs.to_vec())),
client_certificates: conn.peer_certificates(),
};
async move {