Initial implementation of mTLS.

Co-authored-by: Howard Su <howard0su@gmail.com>
Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
This commit is contained in:
Sergio Benitez 2021-07-08 23:58:25 -07:00
parent 71823915db
commit bbc36ba27f
12 changed files with 557 additions and 60 deletions

View File

@ -17,6 +17,7 @@ edition = "2018"
[features] [features]
default = [] default = []
tls = ["rustls", "tokio-rustls"] tls = ["rustls", "tokio-rustls"]
mtls = ["tls", "x509-parser"]
private-cookies = ["cookie/private", "cookie/key-expansion"] private-cookies = ["cookie/private", "cookie/key-expansion"]
serde = ["uncased/with-serde-alloc", "serde_"] serde = ["uncased/with-serde-alloc", "serde_"]
uuid = ["uuid_"] uuid = ["uuid_"]
@ -43,6 +44,10 @@ stable-pattern = "0.1"
cookie = { version = "0.15", features = ["percent-encode"] } cookie = { version = "0.15", features = ["percent-encode"] }
state = "0.5.1" state = "0.5.1"
[dependencies.x509-parser]
version = "0.9.2"
optional = true
[dependencies.hyper] [dependencies.hyper]
version = "0.14.9" version = "0.14.9"
default-features = false default-features = false
@ -62,4 +67,4 @@ optional = true
default-features = false default-features = false
[dev-dependencies] [dev-dependencies]
rocket = { version = "0.5.0-rc.1", path = "../lib" } rocket = { version = "0.5.0-rc.1", path = "../lib", features = ["mtls"] }

View File

@ -44,7 +44,7 @@ pub mod uncased {
pub mod private { pub mod private {
pub use crate::parse::Indexed; pub use crate::parse::Indexed;
pub use smallvec::{SmallVec, Array}; pub use smallvec::{SmallVec, Array};
pub use crate::listener::{Incoming, Listener, Connection, bind_tcp}; pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate};
pub use cookie; pub use cookie;
} }

View File

@ -30,10 +30,31 @@ pub trait Listener {
) -> Poll<io::Result<Self::Connection>>; ) -> Poll<io::Result<Self::Connection>>;
} }
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, PartialEq)]
pub struct RawCertificate(pub Vec<u8>);
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as RawCertificate;
/// A 'Connection' represents an open connection to a client /// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite { pub trait Connection: AsyncRead + AsyncWrite {
/// The remote address, i.e. the client's socket address, if it is known. /// The remote address, i.e. the client's socket address, if it is known.
fn peer_address(&self) -> Option<SocketAddr>; fn peer_address(&self) -> Option<SocketAddr>;
/// DER-encoded X.509 certificate chain presented by the client, if any.
///
/// The certificate order must be as it appears in the TLS protocol: the
/// first certificate relates to the peer, the second certifies the first,
/// the third certifies the second, and so on.
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> { None }
} }
pin_project_lite::pin_project! { pin_project_lite::pin_project! {
@ -114,9 +135,8 @@ impl<L: Listener> Incoming<L> {
} }
if let Some(duration) = me.sleep_on_errors { if let Some(duration) = me.sleep_on_errors {
error!("connection accept error: {}", e);
// Sleep for the specified duration // Sleep for the specified duration
error!("connection accept error: {}", e);
me.pending_error_delay.set(Some(tokio::time::sleep(*duration))); me.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
} else { } else {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));

View File

@ -9,8 +9,8 @@ use rustls::{ServerConfig, SupportedCipherSuite};
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use crate::tls::util::{load_certs, load_private_key}; use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
use crate::listener::{Connection, Listener}; use crate::listener::{Connection, Listener, RawCertificate};
/// A TLS listener over TCP. /// A TLS listener over TCP.
pub struct TlsListener { pub struct TlsListener {
@ -24,35 +24,55 @@ enum State {
Accepting(Accept<TcpStream>), Accepting(Accept<TcpStream>),
} }
pub struct Config<R> {
pub cert_chain: R,
pub private_key: R,
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
pub prefer_server_order: bool,
pub ca_certs: Option<R>,
pub mandatory_mtls: bool,
}
impl TlsListener { impl TlsListener {
pub async fn bind( pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
address: SocketAddr, where R: io::BufRead
mut cert_chain: impl io::BufRead + Send, {
mut private_key: impl io::BufRead + Send, let cert_chain = load_certs(&mut c.cert_chain).map_err(|e| {
ciphersuites: impl Iterator<Item = &'static SupportedCipherSuite>,
prefer_server_order: bool,
) -> io::Result<TlsListener> {
let cert_chain = load_certs(&mut cert_chain).map_err(|e| {
let msg = format!("malformed TLS certificate chain: {}", e); let msg = format!("malformed TLS certificate chain: {}", e);
io::Error::new(e.kind(), msg) io::Error::new(e.kind(), msg)
})?; })?;
let key = load_private_key(&mut private_key).map_err(|e| { let key = load_private_key(&mut c.private_key).map_err(|e| {
let msg = format!("malformed TLS private key: {}", e); let msg = format!("malformed TLS private key: {}", e);
io::Error::new(e.kind(), msg) io::Error::new(e.kind(), msg)
})?; })?;
let client_auth = rustls::NoClientAuth::new(); let client_auth = match c.ca_certs {
Some(ref mut ca_certs) => {
let roots = load_ca_certs(ca_certs).map_err(|e| {
let msg = format!("malformed CA certificate(s): {}", e);
io::Error::new(e.kind(), msg)
})?;
if c.mandatory_mtls {
rustls::AllowAnyAuthenticatedClient::new(roots)
} else {
rustls::AllowAnyAnonymousOrAuthenticatedClient::new(roots)
}
}
None => rustls::NoClientAuth::new(),
};
let mut tls_config = ServerConfig::new(client_auth); let mut tls_config = ServerConfig::new(client_auth);
let cache = rustls::ServerSessionMemoryCache::new(1024); let cache = rustls::ServerSessionMemoryCache::new(1024);
tls_config.set_persistence(cache); tls_config.set_persistence(cache);
tls_config.ticketer = rustls::Ticketer::new(); tls_config.ticketer = rustls::Ticketer::new();
tls_config.ciphersuites = ciphersuites.collect(); tls_config.ciphersuites = c.ciphersuites;
tls_config.ignore_client_order = prefer_server_order; tls_config.ignore_client_order = c.prefer_server_order;
tls_config.set_single_cert(cert_chain, key).expect("invalid key"); tls_config.set_single_cert(cert_chain, key).expect("invalid key");
tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]); tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
let listener = TcpListener::bind(address).await?; let listener = TcpListener::bind(addr).await?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); let acceptor = TlsAcceptor::from(Arc::new(tls_config));
Ok(TlsListener { listener, acceptor, state: State::Listening }) Ok(TlsListener { listener, acceptor, state: State::Listening })
} }
@ -99,4 +119,10 @@ impl Connection for TlsStream<TcpStream> {
fn peer_address(&self) -> Option<SocketAddr> { fn peer_address(&self) -> Option<SocketAddr> {
self.get_ref().0.peer_address() self.get_ref().0.peer_address()
} }
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> {
use rustls::Session;
self.get_ref().1.get_peer_certificates()
}
} }

View File

@ -1,5 +1,8 @@
mod listener; mod listener;
mod util; mod util;
#[cfg(feature = "mtls")]
pub mod mtls;
pub use rustls; pub use rustls;
pub use listener::TlsListener; pub use listener::{TlsListener, Config};

241
core/http/src/tls/mtls.rs Normal file
View File

@ -0,0 +1,241 @@
pub mod oid {
//! Lower-level OID types re-exported from
//! [`oid_registry`](https://docs.rs/oid-registry/0.1) and
//! [`der-parser`](https://docs.rs/der-parser/5).
pub use x509_parser::oid_registry::*;
pub use x509_parser::der_parser::oid::*;
pub use x509_parser::objects::*;
}
pub mod bigint {
//! Signed and unsigned big integer types re-exported from
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
pub use x509_parser::der_parser::num_bigint::*;
}
pub mod x509 {
//! Lower-level X.509 types re-exported from
//! [`x509_parser`](https://docs.rs/x509-parser/0.9).
//!
//! Lack of documentation is directly inherited from the source crate.
//! Prefer to use Rocket's wrappers when possible.
pub use x509_parser::certificate::*;
pub use x509_parser::cri_attributes::*;
pub use x509_parser::error::*;
pub use x509_parser::extensions::*;
pub use x509_parser::revocation_list::*;
pub use x509_parser::time::*;
pub use x509_parser::x509::*;
pub use x509_parser::der_parser::der;
pub use x509_parser::der_parser::ber;
}
use std::fmt;
use std::ops::Deref;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use ref_cast::RefCast;
use x509_parser::nom;
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error};
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
use crate::listener::RawCertificate;
/// A type alias for [`Result`](std::result::Result) with the error type set to
/// [`Error`].
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// An error returned by the [`Certificate`] request guard.
///
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
/// guard type:
///
/// ```rust
/// # extern crate rocket;
/// # use rocket::get;
/// use rocket::mtls::{self, Certificate};
///
/// #[get("/auth")]
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
/// match cert {
/// Ok(cert) => { /* do something with the client cert */ },
/// Err(e) => { /* do something with the error */ },
/// }
/// }
/// ```
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
/// The certificate chain presented by the client had no certificates.
Empty,
/// The certificate contained neither a subject nor a subjectAlt extension.
NoSubject,
/// There is no subject and the subjectAlt is not marked as critical.
NonCriticalSubjectAlt,
// FIXME: Waiting on https://github.com/rusticata/x509-parser/pull/92.
// Parse(X509Error),
/// An error occurred while parsing the certificate.
#[doc(hidden)]
Parse(String),
/// The certificate parsed partially but is incomplete.
///
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
/// of expected bytes is unknown.
Incomplete(Option<NonZeroUsize>),
/// The certificate contained `.0` bytes of trailing data.
Trailing(usize),
}
#[repr(transparent)]
#[derive(Debug, PartialEq)]
pub struct Certificate<'a>(X509Certificate<'a>);
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
///
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
/// complete documentation. Should the data exposed by the inherent methods not
/// suffice, this type derefs to [`x509::X509Name`].
#[repr(transparent)]
#[derive(Debug, PartialEq, RefCast)]
pub struct Name<'a>(X509Name<'a>);
impl<'a> Certificate<'a> {
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> {
let (left, x509) = X509Certificate::from_der(raw)?;
if !left.is_empty() {
return Err(Error::Trailing(left.len()));
}
if x509.subject().as_raw().is_empty() {
if let Some(ext) = x509.extensions().get(&SUBJECT_ALT_NAME) {
if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) {
return Err(Error::NoSubject);
} else if !ext.critical {
return Err(Error::NonCriticalSubjectAlt);
}
} else {
return Err(Error::NoSubject);
}
}
Ok(x509)
}
#[inline(always)]
fn inner(&self) -> &TbsCertificate<'a> {
&self.0.tbs_certificate
}
/// PRIVATE: For internal Rocket use only!
#[doc(hidden)]
pub fn parse(chain: &[RawCertificate]) -> Result<Certificate<'_>> {
match chain.first() {
Some(cert) => Certificate::parse_one(&cert.0).map(Certificate),
None => Err(Error::Empty)
}
}
pub fn serial(&self) -> &bigint::BigUint {
&self.inner().serial
}
pub fn version(&self) -> u32 {
self.inner().version.0
}
pub fn subject(&self) -> &Name<'a> {
Name::ref_cast(&self.inner().subject)
}
pub fn issuer(&self) -> &Name<'a> {
Name::ref_cast(&self.inner().issuer)
}
pub fn extensions(&self) -> &HashMap<oid::Oid<'a>, x509::X509Extension<'a>> {
&self.inner().extensions
}
pub fn has_serial(&self, number: &str) -> Option<bool> {
let uint: bigint::BigUint = number.parse().ok()?;
Some(&uint == self.serial())
}
}
impl<'a> Deref for Certificate<'a> {
type Target = TbsCertificate<'a>;
fn deref(&self) -> &Self::Target {
self.inner()
}
}
impl<'a> Name<'a> {
pub fn common_name(&self) -> Option<&'a str> {
self.common_names().next()
}
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
}
pub fn email(&self) -> Option<&'a str> {
self.emails().next()
}
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
}
pub fn is_empty(&self) -> bool {
self.0.as_raw().is_empty()
}
}
impl<'a> Deref for Name<'a> {
type Target = X509Name<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Display for Name<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Parse(e) => write!(f, "parse failure: {}", e),
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
Error::Empty => write!(f, "empty certificate chain"),
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
}
}
}
impl From<nom::Err<X509Error>> for Error {
fn from(e: nom::Err<X509Error>) -> Self {
match e {
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e.to_string()),
}
}
}
impl std::error::Error for Error {
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
// match self {
// Error::Parse(e) => Some(e),
// _ => None
// }
// }
}

View File

@ -1,10 +1,14 @@
use std::io::{self, ErrorKind::Other, Cursor, Error, Read}; use std::io::{self, Cursor, Read};
use rustls::{internal::pemfile, Certificate, PrivateKey}; use rustls::{internal::pemfile, Certificate, PrivateKey, RootCertStore};
fn err(message: impl Into<std::borrow::Cow<'static, str>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, message.into())
}
/// Loads certificates from `reader`. /// Loads certificates from `reader`.
pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> { pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
pemfile::certs(reader).map_err(|_| Error::new(Other, "invalid certificate")) pemfile::certs(reader).map_err(|_| err("invalid certificate"))
} }
/// Load and decode the private key from `reader`. /// Load and decode the private key from `reader`.
@ -17,23 +21,34 @@ pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey>
let private_keys_fn = match first_line.trim_end() { let private_keys_fn = match first_line.trim_end() {
"-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys, "-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys,
"-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys, "-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys,
_ => return Err(Error::new(Other, "invalid key header")) _ => return Err(err("invalid key header"))
}; };
let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader)) let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader))
.map_err(|_| Error::new(Other, "invalid key file")) .map_err(|_| err("invalid key file"))
.and_then(|mut keys| match keys.len() { .and_then(|mut keys| match keys.len() {
0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")), 0 => Err(err("no valid keys found; is the file malformed?")),
1 => Ok(keys.remove(0)), 1 => Ok(keys.remove(0)),
n => Err(Error::new(Other, format!("expected 1 key, found {}", n))), n => Err(err(format!("expected 1 key, found {}", n))),
})?; })?;
// Ensure we can use the key. // Ensure we can use the key.
rustls::sign::any_supported_type(&key) rustls::sign::any_supported_type(&key)
.map_err(|_| Error::new(Other, "key parsed but is unusable")) .map_err(|_| err("key parsed but is unusable"))
.map(|_| key) .map(|_| key)
} }
/// Load and decode CA certificates from `reader`.
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result<RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
let (_, e) = roots.add_pem_file(reader).map_err(|_| err("PEM format error"))?;
if e != 0 {
return Err(err("validity checks failed"));
}
Ok(roots)
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;

View File

@ -71,6 +71,21 @@ pub struct TlsConfig {
/// Whether to prefer the server's cipher suite order over the client's. /// Whether to prefer the server's cipher suite order over the client's.
#[serde(default)] #[serde(default)]
pub(crate) prefer_server_cipher_order: bool, pub(crate) prefer_server_cipher_order: bool,
/// Configuration for mutual TLS, if any.
#[serde(default)]
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub(crate) mutual: Option<MutualTls>,
}
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub struct MutualTls {
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
#[serde(default)]
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
pub mandatory: bool,
} }
/// A supported TLS cipher suite. /// A supported TLS cipher suite.
@ -144,7 +159,18 @@ impl CipherSuite {
} }
impl TlsConfig { impl TlsConfig {
/// Constructs a `TlsConfig` from paths to a `certs` certificate-chain fn default() -> Self {
TlsConfig {
certs: Either::Right(vec![]),
key: Either::Right(vec![]),
ciphers: CipherSuite::default_set(),
prefer_server_cipher_order: false,
#[cfg(feature = "mtls")]
mutual: None,
}
}
/// Constructs a `TlsConfig` from paths to a `certs` certificate chain
/// a `key` private-key. This method does no validation; it simply creates a /// a `key` private-key. This method does no validation; it simply creates a
/// structure suitable for passing into a [`Config`](crate::Config). /// structure suitable for passing into a [`Config`](crate::Config).
/// ///
@ -161,13 +187,12 @@ impl TlsConfig {
TlsConfig { TlsConfig {
certs: Either::Left(certs.as_ref().to_path_buf().into()), certs: Either::Left(certs.as_ref().to_path_buf().into()),
key: Either::Left(key.as_ref().to_path_buf().into()), key: Either::Left(key.as_ref().to_path_buf().into()),
ciphers: CipherSuite::default_set(), ..TlsConfig::default()
prefer_server_cipher_order: Default::default(),
} }
} }
/// Constructs a `TlsConfig` from byte buffers to a `certs` /// Constructs a `TlsConfig` from byte buffers to a `certs`
/// certificate-chain a `key` private-key. This method does no validation; /// certificate chain a `key` private-key. This method does no validation;
/// it simply creates a structure suitable for passing into a /// it simply creates a structure suitable for passing into a
/// [`Config`](crate::Config). /// [`Config`](crate::Config).
/// ///
@ -184,8 +209,7 @@ impl TlsConfig {
TlsConfig { TlsConfig {
certs: Either::Right(certs.to_vec()), certs: Either::Right(certs.to_vec()),
key: Either::Right(key.to_vec()), key: Either::Right(key.to_vec()),
ciphers: CipherSuite::default_set(), ..TlsConfig::default()
prefer_server_cipher_order: Default::default(),
} }
} }
@ -285,6 +309,26 @@ impl TlsConfig {
self self
} }
/// Configures mutual TLS. See [`MutualTls`] for details.
///
/// # Example
///
/// ```rust
/// use rocket::config::{TlsConfig, MutualTls};
///
/// # let certs = &[];
/// # let key = &[];
/// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true);
/// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config);
/// assert!(tls_config.mutual().is_some());
/// ```
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn with_mutual(mut self, config: MutualTls) -> Self {
self.mutual = Some(config);
self
}
/// Returns the value of the `certs` parameter. /// Returns the value of the `certs` parameter.
/// ///
/// # Example /// # Example
@ -380,41 +424,170 @@ impl TlsConfig {
pub fn prefer_server_cipher_order(&self) -> bool { pub fn prefer_server_cipher_order(&self) -> bool {
self.prefer_server_cipher_order self.prefer_server_cipher_order
} }
/// Returns the value of the `mutual` parameter.
///
/// # Example
///
/// ```rust
/// use std::path::Path;
/// use rocket::config::{TlsConfig, MutualTls};
///
/// # let certs = &[];
/// # let key = &[];
/// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true);
/// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config);
///
/// let mtls = tls_config.mutual().unwrap();
/// assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("path/to/cert.pem"));
/// assert!(mtls.mandatory);
/// ```
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn mutual(&self) -> Option<&MutualTls> {
self.mutual.as_ref()
}
}
#[cfg(feature = "mtls")]
impl MutualTls {
/// Constructs a `MutualTls` from a path to a PEM file with a certificate
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
/// method does no validation; it simply creates a structure suitable for
/// passing into a [`TlsConfig`].
///
/// These certificates will be used to verify client-presented certificates
/// in TLS connections.
///
/// # Example
///
/// ```rust
/// use rocket::config::MutualTls;
///
/// let tls_config = MutualTls::from_path("/ssl/ca_certs.pem");
/// ```
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
MutualTls {
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
mandatory: Default::default()
}
}
/// Constructs a `MutualTls` from a byte buffer to a certificate authority
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
/// validation; it simply creates a structure suitable for passing into a
/// [`TlsConfig`].
///
/// These certificates will be used to verify client-presented certificates
/// in TLS connections.
///
/// # Example
///
/// ```rust
/// use rocket::config::MutualTls;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf);
/// ```
pub fn from_bytes(ca_certs: &[u8]) -> Self {
MutualTls {
ca_certs: Either::Right(ca_certs.to_vec()),
mandatory: Default::default()
}
}
/// Sets whether client authentication is required. Disabled by default.
///
/// When `true`, client authentication will be required. TLS connections
/// where the client does not present a certificate will be immediately
/// terminated. When `false`, the client is not required to present a
/// certificate. In either case, if a certificate _is_ presented, it must be
/// valid or the connection is terminated.
///
/// # Example
///
/// ```rust
/// use rocket::config::MutualTls;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true);
/// ```
pub fn mandatory(mut self, mandatory: bool) -> Self {
self.mandatory = mandatory;
self
}
/// Returns the value of the `ca_certs` parameter.
/// # Example
///
/// ```rust
/// use rocket::config::MutualTls;
///
/// # let ca_certs_buf = &[];
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true);
/// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf);
/// ```
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
match &self.ca_certs {
Either::Left(path) => either::Either::Left(path.relative()),
Either::Right(bytes) => either::Either::Right(&bytes),
}
}
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
mod with_tls_feature { mod with_tls_feature {
use std::fs;
use std::io::{self, Error};
use crate::http::tls::Config;
use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher; use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher;
use crate::http::tls::rustls::ciphersuite as rustls; use crate::http::tls::rustls::ciphersuite as rustls;
use super::*; use yansi::Paint;
use super::{Either, RelativePathBuf, TlsConfig, CipherSuite};
type Reader = Box<dyn std::io::BufRead + Sync + Send>; type Reader = Box<dyn std::io::BufRead + Sync + Send>;
impl TlsConfig { fn to_reader(value: &Either<RelativePathBuf, Vec<u8>>) -> io::Result<Reader> {
pub(crate) fn to_readers(&self) -> std::io::Result<(Reader, Reader)> { match value {
use std::{io::{self, Error}, fs}; Either::Left(path) => {
use yansi::Paint; let path = path.relative();
let file = fs::File::open(&path).map_err(move |e| {
Error::new(e.kind(), format!("error reading TLS file `{}`: {}",
Paint::white(figment::Source::File(path)), e))
})?;
fn to_reader(value: &Either<RelativePathBuf, Vec<u8>>) -> io::Result<Reader> { Ok(Box::new(io::BufReader::new(file)))
match value {
Either::Left(path) => {
let path = path.relative();
let file = fs::File::open(&path).map_err(move |e| {
Error::new(e.kind(), format!("error reading TLS file `{}`: {}",
Paint::white(figment::Source::File(path)), e))
})?;
Ok(Box::new(io::BufReader::new(file)))
}
Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))),
}
} }
Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))),
}
}
Ok((to_reader(&self.certs)?, to_reader(&self.key)?)) impl TlsConfig {
/// This is only called when TLS is enabled.
pub(crate) fn to_native_config(&self) -> io::Result<Config<Reader>> {
Ok(Config {
cert_chain: to_reader(&self.certs)?,
private_key: to_reader(&self.key)?,
ciphersuites: self.rustls_ciphers().collect(),
prefer_server_order: self.prefer_server_cipher_order,
#[cfg(not(feature = "mtls"))]
mandatory_mtls: false,
#[cfg(not(feature = "mtls"))]
ca_certs: None,
#[cfg(feature = "mtls")]
mandatory_mtls: self.mutual.as_ref().map_or(false, |m| m.mandatory),
#[cfg(feature = "mtls")]
ca_certs: match self.mutual {
Some(ref mtls) => Some(to_reader(&mtls.ca_certs)?),
None => None
},
})
} }
pub(crate) fn rustls_ciphers(&self) -> impl Iterator<Item = &'static RustlsCipher> + '_ { fn rustls_ciphers(&self) -> impl Iterator<Item = &'static RustlsCipher> + '_ {
self.ciphers().map(|ciphersuite| match ciphersuite { self.ciphers().map(|ciphersuite| match ciphersuite {
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
&rustls::TLS13_CHACHA20_POLY1305_SHA256, &rustls::TLS13_CHACHA20_POLY1305_SHA256,

View File

@ -2,7 +2,7 @@ use std::{io, time::Duration};
use std::task::{Poll, Context}; use std::task::{Poll, Context};
use std::pin::Pin; use std::pin::Pin;
use bytes::BytesMut; use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{sleep, Sleep}; use tokio::time::{sleep, Sleep};
@ -10,8 +10,6 @@ use tokio::time::{sleep, Sleep};
use futures::stream::Stream; use futures::stream::Stream;
use futures::future::{self, Future, FutureExt}; use futures::future::{self, Future, FutureExt};
use crate::http::hyper::Bytes;
pin_project! { pin_project! {
pub struct ReaderStream<R> { pub struct ReaderStream<R> {
#[pin] #[pin]
@ -293,12 +291,16 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
} }
} }
use crate::http::private::{Listener, Connection}; use crate::http::private::{Listener, Connection, RawCertificate};
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> { impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
fn peer_address(&self) -> Option<std::net::SocketAddr> { fn peer_address(&self) -> Option<std::net::SocketAddr> {
self.io.peer_address() self.io.peer_address()
} }
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> {
self.io.peer_certificates()
}
} }
pin_project! { pin_project! {

View File

@ -6,4 +6,4 @@ edition = "2018"
publish = false publish = false
[dependencies] [dependencies]
rocket = { path = "../../core/lib", features = ["tls"] } rocket = { path = "../../core/lib", features = ["tls", "mtls"] }

View File

@ -9,6 +9,10 @@
certs = "private/rsa_sha256_cert.pem" certs = "private/rsa_sha256_cert.pem"
key = "private/rsa_sha256_key.pem" key = "private/rsa_sha256_key.pem"
[default.tls.mutual]
ca_certs = "private/ca_cert.pem"
mandatory = false
[rsa_sha256.tls] [rsa_sha256.tls]
certs = "private/rsa_sha256_cert.pem" certs = "private/rsa_sha256_cert.pem"
key = "private/rsa_sha256_key.pem" key = "private/rsa_sha256_key.pem"

View File

@ -2,7 +2,14 @@
#[cfg(test)] mod tests; #[cfg(test)] mod tests;
use rocket::mtls::Certificate;
#[get("/")] #[get("/")]
fn mutual(cert: Certificate<'_>) -> String {
format!("Hello! Here's what we know: [{}] {}", cert.serial(), cert.subject())
}
#[get("/", rank = 2)]
fn hello() -> &'static str { fn hello() -> &'static str {
"Hello, world!" "Hello, world!"
} }
@ -10,5 +17,6 @@ fn hello() -> &'static str {
#[launch] #[launch]
fn rocket() -> _ { fn rocket() -> _ {
// See `Rocket.toml` and `Cargo.toml` for TLS configuration. // See `Rocket.toml` and `Cargo.toml` for TLS configuration.
rocket::build().mount("/", routes![hello]) // Run `./private/gen_certs.sh` to generate a CA and key pairs.
rocket::build().mount("/", routes![hello, mutual])
} }