From bbc36ba27fea494cbab20eaaff73a204aac93c12 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 8 Jul 2021 23:58:25 -0700 Subject: [PATCH] Initial implementation of mTLS. Co-authored-by: Howard Su Co-authored-by: Mikail Bagishov --- core/http/Cargo.toml | 7 +- core/http/src/lib.rs | 2 +- core/http/src/listener.rs | 24 +++- core/http/src/tls/listener.rs | 56 +++++--- core/http/src/tls/mod.rs | 5 +- core/http/src/tls/mtls.rs | 241 ++++++++++++++++++++++++++++++++++ core/http/src/tls/util.rs | 31 +++-- core/lib/src/config/tls.rs | 225 +++++++++++++++++++++++++++---- core/lib/src/ext.rs | 10 +- examples/tls/Cargo.toml | 2 +- examples/tls/Rocket.toml | 4 + examples/tls/src/main.rs | 10 +- 12 files changed, 557 insertions(+), 60 deletions(-) create mode 100644 core/http/src/tls/mtls.rs diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index ec5a578a..108de8c1 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -17,6 +17,7 @@ edition = "2018" [features] default = [] tls = ["rustls", "tokio-rustls"] +mtls = ["tls", "x509-parser"] private-cookies = ["cookie/private", "cookie/key-expansion"] serde = ["uncased/with-serde-alloc", "serde_"] uuid = ["uuid_"] @@ -43,6 +44,10 @@ stable-pattern = "0.1" cookie = { version = "0.15", features = ["percent-encode"] } state = "0.5.1" +[dependencies.x509-parser] +version = "0.9.2" +optional = true + [dependencies.hyper] version = "0.14.9" default-features = false @@ -62,4 +67,4 @@ optional = true default-features = false [dev-dependencies] -rocket = { version = "0.5.0-rc.1", path = "../lib" } +rocket = { version = "0.5.0-rc.1", path = "../lib", features = ["mtls"] } diff --git a/core/http/src/lib.rs b/core/http/src/lib.rs index cdf2c9e7..11289d43 100644 --- a/core/http/src/lib.rs +++ b/core/http/src/lib.rs @@ -44,7 +44,7 @@ pub mod uncased { pub mod private { pub use crate::parse::Indexed; pub use smallvec::{SmallVec, Array}; - pub use crate::listener::{Incoming, Listener, Connection, bind_tcp}; + pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate}; pub use cookie; } diff --git a/core/http/src/listener.rs b/core/http/src/listener.rs index c574c45e..2c608119 100644 --- a/core/http/src/listener.rs +++ b/core/http/src/listener.rs @@ -30,10 +30,31 @@ pub trait Listener { ) -> Poll>; } +/// A thin wrapper over raw, DER-encoded X.509 client certificate data. +#[cfg(not(feature = "tls"))] +#[derive(Clone, Eq, PartialEq)] +pub struct RawCertificate(pub Vec); + +/// A thin wrapper over raw, DER-encoded X.509 client certificate data. +// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`. +#[doc(inline)] +#[cfg(feature = "tls")] +pub use rustls::Certificate as RawCertificate; + /// A 'Connection' represents an open connection to a client pub trait Connection: AsyncRead + AsyncWrite { /// The remote address, i.e. the client's socket address, if it is known. fn peer_address(&self) -> Option; + + /// DER-encoded X.509 certificate chain presented by the client, if any. + /// + /// The certificate order must be as it appears in the TLS protocol: the + /// first certificate relates to the peer, the second certifies the first, + /// the third certifies the second, and so on. + /// + /// Defaults to an empty vector to indicate that no certificates were + /// presented. + fn peer_certificates(&self) -> Option> { None } } pin_project_lite::pin_project! { @@ -114,9 +135,8 @@ impl Incoming { } if let Some(duration) = me.sleep_on_errors { - error!("connection accept error: {}", e); - // Sleep for the specified duration + error!("connection accept error: {}", e); me.pending_error_delay.set(Some(tokio::time::sleep(*duration))); } else { return Poll::Ready(Err(e)); diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs index 3d53baf2..90d14f22 100644 --- a/core/http/src/tls/listener.rs +++ b/core/http/src/tls/listener.rs @@ -9,8 +9,8 @@ use rustls::{ServerConfig, SupportedCipherSuite}; use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; use tokio::net::{TcpListener, TcpStream}; -use crate::tls::util::{load_certs, load_private_key}; -use crate::listener::{Connection, Listener}; +use crate::tls::util::{load_certs, load_private_key, load_ca_certs}; +use crate::listener::{Connection, Listener, RawCertificate}; /// A TLS listener over TCP. pub struct TlsListener { @@ -24,35 +24,55 @@ enum State { Accepting(Accept), } +pub struct Config { + pub cert_chain: R, + pub private_key: R, + pub ciphersuites: Vec<&'static SupportedCipherSuite>, + pub prefer_server_order: bool, + pub ca_certs: Option, + pub mandatory_mtls: bool, +} + impl TlsListener { - pub async fn bind( - address: SocketAddr, - mut cert_chain: impl io::BufRead + Send, - mut private_key: impl io::BufRead + Send, - ciphersuites: impl Iterator, - prefer_server_order: bool, - ) -> io::Result { - let cert_chain = load_certs(&mut cert_chain).map_err(|e| { + pub async fn bind(addr: SocketAddr, mut c: Config) -> io::Result + where R: io::BufRead + { + let cert_chain = load_certs(&mut c.cert_chain).map_err(|e| { let msg = format!("malformed TLS certificate chain: {}", e); io::Error::new(e.kind(), msg) })?; - let key = load_private_key(&mut private_key).map_err(|e| { + let key = load_private_key(&mut c.private_key).map_err(|e| { let msg = format!("malformed TLS private key: {}", e); io::Error::new(e.kind(), msg) })?; - let client_auth = rustls::NoClientAuth::new(); + let client_auth = match c.ca_certs { + Some(ref mut ca_certs) => { + let roots = load_ca_certs(ca_certs).map_err(|e| { + let msg = format!("malformed CA certificate(s): {}", e); + io::Error::new(e.kind(), msg) + })?; + + if c.mandatory_mtls { + rustls::AllowAnyAuthenticatedClient::new(roots) + } else { + rustls::AllowAnyAnonymousOrAuthenticatedClient::new(roots) + } + } + None => rustls::NoClientAuth::new(), + }; + let mut tls_config = ServerConfig::new(client_auth); let cache = rustls::ServerSessionMemoryCache::new(1024); tls_config.set_persistence(cache); tls_config.ticketer = rustls::Ticketer::new(); - tls_config.ciphersuites = ciphersuites.collect(); - tls_config.ignore_client_order = prefer_server_order; + tls_config.ciphersuites = c.ciphersuites; + tls_config.ignore_client_order = c.prefer_server_order; tls_config.set_single_cert(cert_chain, key).expect("invalid key"); tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]); - let listener = TcpListener::bind(address).await?; + let listener = TcpListener::bind(addr).await?; let acceptor = TlsAcceptor::from(Arc::new(tls_config)); Ok(TlsListener { listener, acceptor, state: State::Listening }) } @@ -99,4 +119,10 @@ impl Connection for TlsStream { fn peer_address(&self) -> Option { self.get_ref().0.peer_address() } + + fn peer_certificates(&self) -> Option> { + use rustls::Session; + + self.get_ref().1.get_peer_certificates() + } } diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs index 4e21e40f..b529ee40 100644 --- a/core/http/src/tls/mod.rs +++ b/core/http/src/tls/mod.rs @@ -1,5 +1,8 @@ mod listener; mod util; +#[cfg(feature = "mtls")] +pub mod mtls; + pub use rustls; -pub use listener::TlsListener; +pub use listener::{TlsListener, Config}; diff --git a/core/http/src/tls/mtls.rs b/core/http/src/tls/mtls.rs new file mode 100644 index 00000000..9f712d22 --- /dev/null +++ b/core/http/src/tls/mtls.rs @@ -0,0 +1,241 @@ +pub mod oid { + //! Lower-level OID types re-exported from + //! [`oid_registry`](https://docs.rs/oid-registry/0.1) and + //! [`der-parser`](https://docs.rs/der-parser/5). + + pub use x509_parser::oid_registry::*; + pub use x509_parser::der_parser::oid::*; + pub use x509_parser::objects::*; +} + +pub mod bigint { + //! Signed and unsigned big integer types re-exported from + //! [`num_bigint`](https://docs.rs/num-bigint/0.4). + pub use x509_parser::der_parser::num_bigint::*; +} + +pub mod x509 { + //! Lower-level X.509 types re-exported from + //! [`x509_parser`](https://docs.rs/x509-parser/0.9). + //! + //! Lack of documentation is directly inherited from the source crate. + //! Prefer to use Rocket's wrappers when possible. + + pub use x509_parser::certificate::*; + pub use x509_parser::cri_attributes::*; + pub use x509_parser::error::*; + pub use x509_parser::extensions::*; + pub use x509_parser::revocation_list::*; + pub use x509_parser::time::*; + pub use x509_parser::x509::*; + pub use x509_parser::der_parser::der; + pub use x509_parser::der_parser::ber; +} + +use std::fmt; +use std::ops::Deref; +use std::collections::HashMap; +use std::num::NonZeroUsize; + +use ref_cast::RefCast; +use x509_parser::nom; +use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error}; +use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME; + +use crate::listener::RawCertificate; + +/// A type alias for [`Result`](std::result::Result) with the error type set to +/// [`Error`]. +pub type Result = std::result::Result; + +/// An error returned by the [`Certificate`] request guard. +/// +/// To retrieve this error in a handler, use an `mtls::Result` +/// guard type: +/// +/// ```rust +/// # extern crate rocket; +/// # use rocket::get; +/// use rocket::mtls::{self, Certificate}; +/// +/// #[get("/auth")] +/// fn auth(cert: mtls::Result>) { +/// match cert { +/// Ok(cert) => { /* do something with the client cert */ }, +/// Err(e) => { /* do something with the error */ }, +/// } +/// } +/// ``` +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum Error { + /// The certificate chain presented by the client had no certificates. + Empty, + /// The certificate contained neither a subject nor a subjectAlt extension. + NoSubject, + /// There is no subject and the subjectAlt is not marked as critical. + NonCriticalSubjectAlt, + // FIXME: Waiting on https://github.com/rusticata/x509-parser/pull/92. + // Parse(X509Error), + /// An error occurred while parsing the certificate. + #[doc(hidden)] + Parse(String), + /// The certificate parsed partially but is incomplete. + /// + /// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number + /// of expected bytes is unknown. + Incomplete(Option), + /// The certificate contained `.0` bytes of trailing data. + Trailing(usize), +} + +#[repr(transparent)] +#[derive(Debug, PartialEq)] +pub struct Certificate<'a>(X509Certificate<'a>); + +/// An X.509 Distinguished Name (DN) found in a [`Certificate`]. +/// +/// This type is a wrapper over [`x509::X509Name`] with convenient methods and +/// complete documentation. Should the data exposed by the inherent methods not +/// suffice, this type derefs to [`x509::X509Name`]. +#[repr(transparent)] +#[derive(Debug, PartialEq, RefCast)] +pub struct Name<'a>(X509Name<'a>); + +impl<'a> Certificate<'a> { + fn parse_one(raw: &[u8]) -> Result> { + let (left, x509) = X509Certificate::from_der(raw)?; + if !left.is_empty() { + return Err(Error::Trailing(left.len())); + } + + if x509.subject().as_raw().is_empty() { + if let Some(ext) = x509.extensions().get(&SUBJECT_ALT_NAME) { + if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) { + return Err(Error::NoSubject); + } else if !ext.critical { + return Err(Error::NonCriticalSubjectAlt); + } + } else { + return Err(Error::NoSubject); + } + } + + Ok(x509) + } + + #[inline(always)] + fn inner(&self) -> &TbsCertificate<'a> { + &self.0.tbs_certificate + } + + /// PRIVATE: For internal Rocket use only! + #[doc(hidden)] + pub fn parse(chain: &[RawCertificate]) -> Result> { + match chain.first() { + Some(cert) => Certificate::parse_one(&cert.0).map(Certificate), + None => Err(Error::Empty) + } + } + + pub fn serial(&self) -> &bigint::BigUint { + &self.inner().serial + } + + pub fn version(&self) -> u32 { + self.inner().version.0 + } + + pub fn subject(&self) -> &Name<'a> { + Name::ref_cast(&self.inner().subject) + } + + pub fn issuer(&self) -> &Name<'a> { + Name::ref_cast(&self.inner().issuer) + } + + pub fn extensions(&self) -> &HashMap, x509::X509Extension<'a>> { + &self.inner().extensions + } + + pub fn has_serial(&self, number: &str) -> Option { + let uint: bigint::BigUint = number.parse().ok()?; + Some(&uint == self.serial()) + } +} + +impl<'a> Deref for Certificate<'a> { + type Target = TbsCertificate<'a>; + + fn deref(&self) -> &Self::Target { + self.inner() + } +} + +impl<'a> Name<'a> { + pub fn common_name(&self) -> Option<&'a str> { + self.common_names().next() + } + + pub fn common_names(&self) -> impl Iterator + '_ { + self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok()) + } + + pub fn email(&self) -> Option<&'a str> { + self.emails().next() + } + + pub fn emails(&self) -> impl Iterator + '_ { + self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok()) + } + + pub fn is_empty(&self) -> bool { + self.0.as_raw().is_empty() + } +} + +impl<'a> Deref for Name<'a> { + type Target = X509Name<'a>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for Name<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Parse(e) => write!(f, "parse failure: {}", e), + Error::Incomplete(_) => write!(f, "incomplete certificate data"), + Error::Trailing(n) => write!(f, "found {} trailing bytes", n), + Error::Empty => write!(f, "empty certificate chain"), + Error::NoSubject => write!(f, "empty subject without subjectAlt"), + Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"), + } + } +} + +impl From> for Error { + fn from(e: nom::Err) -> Self { + match e { + nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None), + nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)), + nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e.to_string()), + } + } +} + +impl std::error::Error for Error { + // fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + // match self { + // Error::Parse(e) => Some(e), + // _ => None + // } + // } +} diff --git a/core/http/src/tls/util.rs b/core/http/src/tls/util.rs index 650f693d..fa598f29 100644 --- a/core/http/src/tls/util.rs +++ b/core/http/src/tls/util.rs @@ -1,10 +1,14 @@ -use std::io::{self, ErrorKind::Other, Cursor, Error, Read}; +use std::io::{self, Cursor, Read}; -use rustls::{internal::pemfile, Certificate, PrivateKey}; +use rustls::{internal::pemfile, Certificate, PrivateKey, RootCertStore}; + +fn err(message: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, message.into()) +} /// Loads certificates from `reader`. pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result> { - pemfile::certs(reader).map_err(|_| Error::new(Other, "invalid certificate")) + pemfile::certs(reader).map_err(|_| err("invalid certificate")) } /// Load and decode the private key from `reader`. @@ -17,23 +21,34 @@ pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result let private_keys_fn = match first_line.trim_end() { "-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys, "-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys, - _ => return Err(Error::new(Other, "invalid key header")) + _ => return Err(err("invalid key header")) }; let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader)) - .map_err(|_| Error::new(Other, "invalid key file")) + .map_err(|_| err("invalid key file")) .and_then(|mut keys| match keys.len() { - 0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")), + 0 => Err(err("no valid keys found; is the file malformed?")), 1 => Ok(keys.remove(0)), - n => Err(Error::new(Other, format!("expected 1 key, found {}", n))), + n => Err(err(format!("expected 1 key, found {}", n))), })?; // Ensure we can use the key. rustls::sign::any_supported_type(&key) - .map_err(|_| Error::new(Other, "key parsed but is unusable")) + .map_err(|_| err("key parsed but is unusable")) .map(|_| key) } +/// Load and decode CA certificates from `reader`. +pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result { + let mut roots = rustls::RootCertStore::empty(); + let (_, e) = roots.add_pem_file(reader).map_err(|_| err("PEM format error"))?; + if e != 0 { + return Err(err("validity checks failed")); + } + + Ok(roots) +} + #[cfg(test)] mod test { use super::*; diff --git a/core/lib/src/config/tls.rs b/core/lib/src/config/tls.rs index 41554802..67fb0bbd 100644 --- a/core/lib/src/config/tls.rs +++ b/core/lib/src/config/tls.rs @@ -71,6 +71,21 @@ pub struct TlsConfig { /// Whether to prefer the server's cipher suite order over the client's. #[serde(default)] pub(crate) prefer_server_cipher_order: bool, + /// Configuration for mutual TLS, if any. + #[serde(default)] + #[cfg(feature = "mtls")] + #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] + pub(crate) mutual: Option, +} + +#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] +#[cfg(feature = "mtls")] +#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] +pub struct MutualTls { + pub(crate) ca_certs: Either>, + #[serde(default)] + #[serde(deserialize_with = "figment::util::bool_from_str_or_int")] + pub mandatory: bool, } /// A supported TLS cipher suite. @@ -144,7 +159,18 @@ impl CipherSuite { } impl TlsConfig { - /// Constructs a `TlsConfig` from paths to a `certs` certificate-chain + fn default() -> Self { + TlsConfig { + certs: Either::Right(vec![]), + key: Either::Right(vec![]), + ciphers: CipherSuite::default_set(), + prefer_server_cipher_order: false, + #[cfg(feature = "mtls")] + mutual: None, + } + } + + /// Constructs a `TlsConfig` from paths to a `certs` certificate chain /// a `key` private-key. This method does no validation; it simply creates a /// structure suitable for passing into a [`Config`](crate::Config). /// @@ -161,13 +187,12 @@ impl TlsConfig { TlsConfig { certs: Either::Left(certs.as_ref().to_path_buf().into()), key: Either::Left(key.as_ref().to_path_buf().into()), - ciphers: CipherSuite::default_set(), - prefer_server_cipher_order: Default::default(), + ..TlsConfig::default() } } /// Constructs a `TlsConfig` from byte buffers to a `certs` - /// certificate-chain a `key` private-key. This method does no validation; + /// certificate chain a `key` private-key. This method does no validation; /// it simply creates a structure suitable for passing into a /// [`Config`](crate::Config). /// @@ -184,8 +209,7 @@ impl TlsConfig { TlsConfig { certs: Either::Right(certs.to_vec()), key: Either::Right(key.to_vec()), - ciphers: CipherSuite::default_set(), - prefer_server_cipher_order: Default::default(), + ..TlsConfig::default() } } @@ -285,6 +309,26 @@ impl TlsConfig { self } + /// Configures mutual TLS. See [`MutualTls`] for details. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::{TlsConfig, MutualTls}; + /// + /// # let certs = &[]; + /// # let key = &[]; + /// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true); + /// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config); + /// assert!(tls_config.mutual().is_some()); + /// ``` + #[cfg(feature = "mtls")] + #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] + pub fn with_mutual(mut self, config: MutualTls) -> Self { + self.mutual = Some(config); + self + } + /// Returns the value of the `certs` parameter. /// /// # Example @@ -380,41 +424,170 @@ impl TlsConfig { pub fn prefer_server_cipher_order(&self) -> bool { self.prefer_server_cipher_order } + + /// Returns the value of the `mutual` parameter. + /// + /// # Example + /// + /// ```rust + /// use std::path::Path; + /// use rocket::config::{TlsConfig, MutualTls}; + /// + /// # let certs = &[]; + /// # let key = &[]; + /// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true); + /// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config); + /// + /// let mtls = tls_config.mutual().unwrap(); + /// assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("path/to/cert.pem")); + /// assert!(mtls.mandatory); + /// ``` + #[cfg(feature = "mtls")] + #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] + pub fn mutual(&self) -> Option<&MutualTls> { + self.mutual.as_ref() + } +} + +#[cfg(feature = "mtls")] +impl MutualTls { + /// Constructs a `MutualTls` from a path to a PEM file with a certificate + /// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This + /// method does no validation; it simply creates a structure suitable for + /// passing into a [`TlsConfig`]. + /// + /// These certificates will be used to verify client-presented certificates + /// in TLS connections. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::MutualTls; + /// + /// let tls_config = MutualTls::from_path("/ssl/ca_certs.pem"); + /// ``` + pub fn from_path>(ca_certs: C) -> Self { + MutualTls { + ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()), + mandatory: Default::default() + } + } + + /// Constructs a `MutualTls` from a byte buffer to a certificate authority + /// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no + /// validation; it simply creates a structure suitable for passing into a + /// [`TlsConfig`]. + /// + /// These certificates will be used to verify client-presented certificates + /// in TLS connections. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::MutualTls; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MutualTls::from_bytes(ca_certs_buf); + /// ``` + pub fn from_bytes(ca_certs: &[u8]) -> Self { + MutualTls { + ca_certs: Either::Right(ca_certs.to_vec()), + mandatory: Default::default() + } + } + + /// Sets whether client authentication is required. Disabled by default. + /// + /// When `true`, client authentication will be required. TLS connections + /// where the client does not present a certificate will be immediately + /// terminated. When `false`, the client is not required to present a + /// certificate. In either case, if a certificate _is_ presented, it must be + /// valid or the connection is terminated. + /// + /// # Example + /// + /// ```rust + /// use rocket::config::MutualTls; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true); + /// ``` + pub fn mandatory(mut self, mandatory: bool) -> Self { + self.mandatory = mandatory; + self + } + + /// Returns the value of the `ca_certs` parameter. + /// # Example + /// + /// ```rust + /// use rocket::config::MutualTls; + /// + /// # let ca_certs_buf = &[]; + /// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true); + /// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf); + /// ``` + pub fn ca_certs(&self) -> either::Either { + match &self.ca_certs { + Either::Left(path) => either::Either::Left(path.relative()), + Either::Right(bytes) => either::Either::Right(&bytes), + } + } } #[cfg(feature = "tls")] mod with_tls_feature { + use std::fs; + use std::io::{self, Error}; + + use crate::http::tls::Config; use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher; use crate::http::tls::rustls::ciphersuite as rustls; - use super::*; + use yansi::Paint; + + use super::{Either, RelativePathBuf, TlsConfig, CipherSuite}; type Reader = Box; - impl TlsConfig { - pub(crate) fn to_readers(&self) -> std::io::Result<(Reader, Reader)> { - use std::{io::{self, Error}, fs}; - use yansi::Paint; + fn to_reader(value: &Either>) -> io::Result { + match value { + Either::Left(path) => { + let path = path.relative(); + let file = fs::File::open(&path).map_err(move |e| { + Error::new(e.kind(), format!("error reading TLS file `{}`: {}", + Paint::white(figment::Source::File(path)), e)) + })?; - fn to_reader(value: &Either>) -> io::Result { - match value { - Either::Left(path) => { - let path = path.relative(); - let file = fs::File::open(&path).map_err(move |e| { - Error::new(e.kind(), format!("error reading TLS file `{}`: {}", - Paint::white(figment::Source::File(path)), e)) - })?; - - Ok(Box::new(io::BufReader::new(file))) - } - Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))), - } + Ok(Box::new(io::BufReader::new(file))) } + Either::Right(vec) => Ok(Box::new(io::Cursor::new(vec.clone()))), + } + } - Ok((to_reader(&self.certs)?, to_reader(&self.key)?)) + impl TlsConfig { + /// This is only called when TLS is enabled. + pub(crate) fn to_native_config(&self) -> io::Result> { + Ok(Config { + cert_chain: to_reader(&self.certs)?, + private_key: to_reader(&self.key)?, + ciphersuites: self.rustls_ciphers().collect(), + prefer_server_order: self.prefer_server_cipher_order, + #[cfg(not(feature = "mtls"))] + mandatory_mtls: false, + #[cfg(not(feature = "mtls"))] + ca_certs: None, + #[cfg(feature = "mtls")] + mandatory_mtls: self.mutual.as_ref().map_or(false, |m| m.mandatory), + #[cfg(feature = "mtls")] + ca_certs: match self.mutual { + Some(ref mtls) => Some(to_reader(&mtls.ca_certs)?), + None => None + }, + }) } - pub(crate) fn rustls_ciphers(&self) -> impl Iterator + '_ { + fn rustls_ciphers(&self) -> impl Iterator + '_ { self.ciphers().map(|ciphersuite| match ciphersuite { CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => &rustls::TLS13_CHACHA20_POLY1305_SHA256, diff --git a/core/lib/src/ext.rs b/core/lib/src/ext.rs index c272d3b2..0433c698 100644 --- a/core/lib/src/ext.rs +++ b/core/lib/src/ext.rs @@ -2,7 +2,7 @@ use std::{io, time::Duration}; use std::task::{Poll, Context}; use std::pin::Pin; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::time::{sleep, Sleep}; @@ -10,8 +10,6 @@ use tokio::time::{sleep, Sleep}; use futures::stream::Stream; use futures::future::{self, Future, FutureExt}; -use crate::http::hyper::Bytes; - pin_project! { pub struct ReaderStream { #[pin] @@ -293,12 +291,16 @@ impl AsyncWrite for CancellableIo { } } -use crate::http::private::{Listener, Connection}; +use crate::http::private::{Listener, Connection, RawCertificate}; impl Connection for CancellableIo { fn peer_address(&self) -> Option { self.io.peer_address() } + + fn peer_certificates(&self) -> Option> { + self.io.peer_certificates() + } } pin_project! { diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index bb9cb820..f179fbfa 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,4 +6,4 @@ edition = "2018" publish = false [dependencies] -rocket = { path = "../../core/lib", features = ["tls"] } +rocket = { path = "../../core/lib", features = ["tls", "mtls"] } diff --git a/examples/tls/Rocket.toml b/examples/tls/Rocket.toml index 0a1a0827..85f081ba 100644 --- a/examples/tls/Rocket.toml +++ b/examples/tls/Rocket.toml @@ -9,6 +9,10 @@ certs = "private/rsa_sha256_cert.pem" key = "private/rsa_sha256_key.pem" +[default.tls.mutual] +ca_certs = "private/ca_cert.pem" +mandatory = false + [rsa_sha256.tls] certs = "private/rsa_sha256_cert.pem" key = "private/rsa_sha256_key.pem" diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 9ec33254..7cd59828 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -2,7 +2,14 @@ #[cfg(test)] mod tests; +use rocket::mtls::Certificate; + #[get("/")] +fn mutual(cert: Certificate<'_>) -> String { + format!("Hello! Here's what we know: [{}] {}", cert.serial(), cert.subject()) +} + +#[get("/", rank = 2)] fn hello() -> &'static str { "Hello, world!" } @@ -10,5 +17,6 @@ fn hello() -> &'static str { #[launch] fn rocket() -> _ { // See `Rocket.toml` and `Cargo.toml` for TLS configuration. - rocket::build().mount("/", routes![hello]) + // Run `./private/gen_certs.sh` to generate a CA and key pairs. + rocket::build().mount("/", routes![hello, mutual]) }