mirror of https://github.com/rwf2/Rocket.git
Initial implementation of mTLS.
Co-authored-by: Howard Su <howard0su@gmail.com> Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
This commit is contained in:
parent
71823915db
commit
bbc36ba27f
|
@ -17,6 +17,7 @@ edition = "2018"
|
|||
[features]
|
||||
default = []
|
||||
tls = ["rustls", "tokio-rustls"]
|
||||
mtls = ["tls", "x509-parser"]
|
||||
private-cookies = ["cookie/private", "cookie/key-expansion"]
|
||||
serde = ["uncased/with-serde-alloc", "serde_"]
|
||||
uuid = ["uuid_"]
|
||||
|
@ -43,6 +44,10 @@ stable-pattern = "0.1"
|
|||
cookie = { version = "0.15", features = ["percent-encode"] }
|
||||
state = "0.5.1"
|
||||
|
||||
[dependencies.x509-parser]
|
||||
version = "0.9.2"
|
||||
optional = true
|
||||
|
||||
[dependencies.hyper]
|
||||
version = "0.14.9"
|
||||
default-features = false
|
||||
|
@ -62,4 +67,4 @@ optional = true
|
|||
default-features = false
|
||||
|
||||
[dev-dependencies]
|
||||
rocket = { version = "0.5.0-rc.1", path = "../lib" }
|
||||
rocket = { version = "0.5.0-rc.1", path = "../lib", features = ["mtls"] }
|
||||
|
|
|
@ -44,7 +44,7 @@ pub mod uncased {
|
|||
pub mod private {
|
||||
pub use crate::parse::Indexed;
|
||||
pub use smallvec::{SmallVec, Array};
|
||||
pub use crate::listener::{Incoming, Listener, Connection, bind_tcp};
|
||||
pub use crate::listener::{bind_tcp, Incoming, Listener, Connection, RawCertificate};
|
||||
pub use cookie;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,10 +30,31 @@ pub trait Listener {
|
|||
) -> Poll<io::Result<Self::Connection>>;
|
||||
}
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
#[cfg(not(feature = "tls"))]
|
||||
#[derive(Clone, Eq, PartialEq)]
|
||||
pub struct RawCertificate(pub Vec<u8>);
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
// NOTE: `rustls::Certificate` is exactly isomorphic to `RawCertificate`.
|
||||
#[doc(inline)]
|
||||
#[cfg(feature = "tls")]
|
||||
pub use rustls::Certificate as RawCertificate;
|
||||
|
||||
/// A 'Connection' represents an open connection to a client
|
||||
pub trait Connection: AsyncRead + AsyncWrite {
|
||||
/// The remote address, i.e. the client's socket address, if it is known.
|
||||
fn peer_address(&self) -> Option<SocketAddr>;
|
||||
|
||||
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
||||
///
|
||||
/// The certificate order must be as it appears in the TLS protocol: the
|
||||
/// first certificate relates to the peer, the second certifies the first,
|
||||
/// the third certifies the second, and so on.
|
||||
///
|
||||
/// Defaults to an empty vector to indicate that no certificates were
|
||||
/// presented.
|
||||
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> { None }
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
|
@ -114,9 +135,8 @@ impl<L: Listener> Incoming<L> {
|
|||
}
|
||||
|
||||
if let Some(duration) = me.sleep_on_errors {
|
||||
error!("connection accept error: {}", e);
|
||||
|
||||
// Sleep for the specified duration
|
||||
error!("connection accept error: {}", e);
|
||||
me.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
|
||||
} else {
|
||||
return Poll::Ready(Err(e));
|
||||
|
|
|
@ -9,8 +9,8 @@ use rustls::{ServerConfig, SupportedCipherSuite};
|
|||
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::tls::util::{load_certs, load_private_key};
|
||||
use crate::listener::{Connection, Listener};
|
||||
use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
|
||||
use crate::listener::{Connection, Listener, RawCertificate};
|
||||
|
||||
/// A TLS listener over TCP.
|
||||
pub struct TlsListener {
|
||||
|
@ -24,35 +24,55 @@ enum State {
|
|||
Accepting(Accept<TcpStream>),
|
||||
}
|
||||
|
||||
pub struct Config<R> {
|
||||
pub cert_chain: R,
|
||||
pub private_key: R,
|
||||
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
|
||||
pub prefer_server_order: bool,
|
||||
pub ca_certs: Option<R>,
|
||||
pub mandatory_mtls: bool,
|
||||
}
|
||||
|
||||
impl TlsListener {
|
||||
pub async fn bind(
|
||||
address: SocketAddr,
|
||||
mut cert_chain: impl io::BufRead + Send,
|
||||
mut private_key: impl io::BufRead + Send,
|
||||
ciphersuites: impl Iterator<Item = &'static SupportedCipherSuite>,
|
||||
prefer_server_order: bool,
|
||||
) -> io::Result<TlsListener> {
|
||||
let cert_chain = load_certs(&mut cert_chain).map_err(|e| {
|
||||
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
|
||||
where R: io::BufRead
|
||||
{
|
||||
let cert_chain = load_certs(&mut c.cert_chain).map_err(|e| {
|
||||
let msg = format!("malformed TLS certificate chain: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let key = load_private_key(&mut private_key).map_err(|e| {
|
||||
let key = load_private_key(&mut c.private_key).map_err(|e| {
|
||||
let msg = format!("malformed TLS private key: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let client_auth = rustls::NoClientAuth::new();
|
||||
let client_auth = match c.ca_certs {
|
||||
Some(ref mut ca_certs) => {
|
||||
let roots = load_ca_certs(ca_certs).map_err(|e| {
|
||||
let msg = format!("malformed CA certificate(s): {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
if c.mandatory_mtls {
|
||||
rustls::AllowAnyAuthenticatedClient::new(roots)
|
||||
} else {
|
||||
rustls::AllowAnyAnonymousOrAuthenticatedClient::new(roots)
|
||||
}
|
||||
}
|
||||
None => rustls::NoClientAuth::new(),
|
||||
};
|
||||
|
||||
let mut tls_config = ServerConfig::new(client_auth);
|
||||
let cache = rustls::ServerSessionMemoryCache::new(1024);
|
||||
tls_config.set_persistence(cache);
|
||||
tls_config.ticketer = rustls::Ticketer::new();
|
||||
tls_config.ciphersuites = ciphersuites.collect();
|
||||
tls_config.ignore_client_order = prefer_server_order;
|
||||
tls_config.ciphersuites = c.ciphersuites;
|
||||
tls_config.ignore_client_order = c.prefer_server_order;
|
||||
tls_config.set_single_cert(cert_chain, key).expect("invalid key");
|
||||
tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
|
||||
|
||||
let listener = TcpListener::bind(address).await?;
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||
Ok(TlsListener { listener, acceptor, state: State::Listening })
|
||||
}
|
||||
|
@ -99,4 +119,10 @@ impl Connection for TlsStream<TcpStream> {
|
|||
fn peer_address(&self) -> Option<SocketAddr> {
|
||||
self.get_ref().0.peer_address()
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> {
|
||||
use rustls::Session;
|
||||
|
||||
self.get_ref().1.get_peer_certificates()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
mod listener;
|
||||
mod util;
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
pub mod mtls;
|
||||
|
||||
pub use rustls;
|
||||
pub use listener::TlsListener;
|
||||
pub use listener::{TlsListener, Config};
|
||||
|
|
|
@ -0,0 +1,241 @@
|
|||
pub mod oid {
|
||||
//! Lower-level OID types re-exported from
|
||||
//! [`oid_registry`](https://docs.rs/oid-registry/0.1) and
|
||||
//! [`der-parser`](https://docs.rs/der-parser/5).
|
||||
|
||||
pub use x509_parser::oid_registry::*;
|
||||
pub use x509_parser::der_parser::oid::*;
|
||||
pub use x509_parser::objects::*;
|
||||
}
|
||||
|
||||
pub mod bigint {
|
||||
//! Signed and unsigned big integer types re-exported from
|
||||
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
|
||||
pub use x509_parser::der_parser::num_bigint::*;
|
||||
}
|
||||
|
||||
pub mod x509 {
|
||||
//! Lower-level X.509 types re-exported from
|
||||
//! [`x509_parser`](https://docs.rs/x509-parser/0.9).
|
||||
//!
|
||||
//! Lack of documentation is directly inherited from the source crate.
|
||||
//! Prefer to use Rocket's wrappers when possible.
|
||||
|
||||
pub use x509_parser::certificate::*;
|
||||
pub use x509_parser::cri_attributes::*;
|
||||
pub use x509_parser::error::*;
|
||||
pub use x509_parser::extensions::*;
|
||||
pub use x509_parser::revocation_list::*;
|
||||
pub use x509_parser::time::*;
|
||||
pub use x509_parser::x509::*;
|
||||
pub use x509_parser::der_parser::der;
|
||||
pub use x509_parser::der_parser::ber;
|
||||
}
|
||||
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
use ref_cast::RefCast;
|
||||
use x509_parser::nom;
|
||||
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error};
|
||||
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
||||
|
||||
use crate::listener::RawCertificate;
|
||||
|
||||
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
||||
/// [`Error`].
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
/// An error returned by the [`Certificate`] request guard.
|
||||
///
|
||||
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
|
||||
/// guard type:
|
||||
///
|
||||
/// ```rust
|
||||
/// # extern crate rocket;
|
||||
/// # use rocket::get;
|
||||
/// use rocket::mtls::{self, Certificate};
|
||||
///
|
||||
/// #[get("/auth")]
|
||||
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
|
||||
/// match cert {
|
||||
/// Ok(cert) => { /* do something with the client cert */ },
|
||||
/// Err(e) => { /* do something with the error */ },
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Error {
|
||||
/// The certificate chain presented by the client had no certificates.
|
||||
Empty,
|
||||
/// The certificate contained neither a subject nor a subjectAlt extension.
|
||||
NoSubject,
|
||||
/// There is no subject and the subjectAlt is not marked as critical.
|
||||
NonCriticalSubjectAlt,
|
||||
// FIXME: Waiting on https://github.com/rusticata/x509-parser/pull/92.
|
||||
// Parse(X509Error),
|
||||
/// An error occurred while parsing the certificate.
|
||||
#[doc(hidden)]
|
||||
Parse(String),
|
||||
/// The certificate parsed partially but is incomplete.
|
||||
///
|
||||
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
|
||||
/// of expected bytes is unknown.
|
||||
Incomplete(Option<NonZeroUsize>),
|
||||
/// The certificate contained `.0` bytes of trailing data.
|
||||
Trailing(usize),
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct Certificate<'a>(X509Certificate<'a>);
|
||||
|
||||
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
||||
///
|
||||
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
|
||||
/// complete documentation. Should the data exposed by the inherent methods not
|
||||
/// suffice, this type derefs to [`x509::X509Name`].
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, PartialEq, RefCast)]
|
||||
pub struct Name<'a>(X509Name<'a>);
|
||||
|
||||
impl<'a> Certificate<'a> {
|
||||
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> {
|
||||
let (left, x509) = X509Certificate::from_der(raw)?;
|
||||
if !left.is_empty() {
|
||||
return Err(Error::Trailing(left.len()));
|
||||
}
|
||||
|
||||
if x509.subject().as_raw().is_empty() {
|
||||
if let Some(ext) = x509.extensions().get(&SUBJECT_ALT_NAME) {
|
||||
if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) {
|
||||
return Err(Error::NoSubject);
|
||||
} else if !ext.critical {
|
||||
return Err(Error::NonCriticalSubjectAlt);
|
||||
}
|
||||
} else {
|
||||
return Err(Error::NoSubject);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(x509)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inner(&self) -> &TbsCertificate<'a> {
|
||||
&self.0.tbs_certificate
|
||||
}
|
||||
|
||||
/// PRIVATE: For internal Rocket use only!
|
||||
#[doc(hidden)]
|
||||
pub fn parse(chain: &[RawCertificate]) -> Result<Certificate<'_>> {
|
||||
match chain.first() {
|
||||
Some(cert) => Certificate::parse_one(&cert.0).map(Certificate),
|
||||
None => Err(Error::Empty)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serial(&self) -> &bigint::BigUint {
|
||||
&self.inner().serial
|
||||
}
|
||||
|
||||
pub fn version(&self) -> u32 {
|
||||
self.inner().version.0
|
||||
}
|
||||
|
||||
pub fn subject(&self) -> &Name<'a> {
|
||||
Name::ref_cast(&self.inner().subject)
|
||||
}
|
||||
|
||||
pub fn issuer(&self) -> &Name<'a> {
|
||||
Name::ref_cast(&self.inner().issuer)
|
||||
}
|
||||
|
||||
pub fn extensions(&self) -> &HashMap<oid::Oid<'a>, x509::X509Extension<'a>> {
|
||||
&self.inner().extensions
|
||||
}
|
||||
|
||||
pub fn has_serial(&self, number: &str) -> Option<bool> {
|
||||
let uint: bigint::BigUint = number.parse().ok()?;
|
||||
Some(&uint == self.serial())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for Certificate<'a> {
|
||||
type Target = TbsCertificate<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Name<'a> {
|
||||
pub fn common_name(&self) -> Option<&'a str> {
|
||||
self.common_names().next()
|
||||
}
|
||||
|
||||
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
pub fn email(&self) -> Option<&'a str> {
|
||||
self.emails().next()
|
||||
}
|
||||
|
||||
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.as_raw().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for Name<'a> {
|
||||
type Target = X509Name<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Name<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Error::Parse(e) => write!(f, "parse failure: {}", e),
|
||||
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
|
||||
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
|
||||
Error::Empty => write!(f, "empty certificate chain"),
|
||||
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
|
||||
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<nom::Err<X509Error>> for Error {
|
||||
fn from(e: nom::Err<X509Error>) -> Self {
|
||||
match e {
|
||||
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
|
||||
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
|
||||
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
// match self {
|
||||
// Error::Parse(e) => Some(e),
|
||||
// _ => None
|
||||
// }
|
||||
// }
|
||||
}
|
|
@ -1,10 +1,14 @@
|
|||
use std::io::{self, ErrorKind::Other, Cursor, Error, Read};
|
||||
use std::io::{self, Cursor, Read};
|
||||
|
||||
use rustls::{internal::pemfile, Certificate, PrivateKey};
|
||||
use rustls::{internal::pemfile, Certificate, PrivateKey, RootCertStore};
|
||||
|
||||
fn err(message: impl Into<std::borrow::Cow<'static, str>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, message.into())
|
||||
}
|
||||
|
||||
/// Loads certificates from `reader`.
|
||||
pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
|
||||
pemfile::certs(reader).map_err(|_| Error::new(Other, "invalid certificate"))
|
||||
pemfile::certs(reader).map_err(|_| err("invalid certificate"))
|
||||
}
|
||||
|
||||
/// Load and decode the private key from `reader`.
|
||||
|
@ -17,23 +21,34 @@ pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey>
|
|||
let private_keys_fn = match first_line.trim_end() {
|
||||
"-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys,
|
||||
"-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys,
|
||||
_ => return Err(Error::new(Other, "invalid key header"))
|
||||
_ => return Err(err("invalid key header"))
|
||||
};
|
||||
|
||||
let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader))
|
||||
.map_err(|_| Error::new(Other, "invalid key file"))
|
||||
.map_err(|_| err("invalid key file"))
|
||||
.and_then(|mut keys| match keys.len() {
|
||||
0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")),
|
||||
0 => Err(err("no valid keys found; is the file malformed?")),
|
||||
1 => Ok(keys.remove(0)),
|
||||
n => Err(Error::new(Other, format!("expected 1 key, found {}", n))),
|
||||
n => Err(err(format!("expected 1 key, found {}", n))),
|
||||
})?;
|
||||
|
||||
// Ensure we can use the key.
|
||||
rustls::sign::any_supported_type(&key)
|
||||
.map_err(|_| Error::new(Other, "key parsed but is unusable"))
|
||||
.map_err(|_| err("key parsed but is unusable"))
|
||||
.map(|_| key)
|
||||
}
|
||||
|
||||
/// Load and decode CA certificates from `reader`.
|
||||
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result<RootCertStore> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
let (_, e) = roots.add_pem_file(reader).map_err(|_| err("PEM format error"))?;
|
||||
if e != 0 {
|
||||
return Err(err("validity checks failed"));
|
||||
}
|
||||
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
|
|
@ -71,6 +71,21 @@ pub struct TlsConfig {
|
|||
/// Whether to prefer the server's cipher suite order over the client's.
|
||||
#[serde(default)]
|
||||
pub(crate) prefer_server_cipher_order: bool,
|
||||
/// Configuration for mutual TLS, if any.
|
||||
#[serde(default)]
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub(crate) mutual: Option<MutualTls>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub struct MutualTls {
|
||||
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||
pub mandatory: bool,
|
||||
}
|
||||
|
||||
/// A supported TLS cipher suite.
|
||||
|
@ -144,7 +159,18 @@ impl CipherSuite {
|
|||
}
|
||||
|
||||
impl TlsConfig {
|
||||
/// Constructs a `TlsConfig` from paths to a `certs` certificate-chain
|
||||
fn default() -> Self {
|
||||
TlsConfig {
|
||||
certs: Either::Right(vec![]),
|
||||
key: Either::Right(vec![]),
|
||||
ciphers: CipherSuite::default_set(),
|
||||
prefer_server_cipher_order: false,
|
||||
#[cfg(feature = "mtls")]
|
||||
mutual: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a `TlsConfig` from paths to a `certs` certificate chain
|
||||
/// a `key` private-key. This method does no validation; it simply creates a
|
||||
/// structure suitable for passing into a [`Config`](crate::Config).
|
||||
///
|
||||
|
@ -161,13 +187,12 @@ impl TlsConfig {
|
|||
TlsConfig {
|
||||
certs: Either::Left(certs.as_ref().to_path_buf().into()),
|
||||
key: Either::Left(key.as_ref().to_path_buf().into()),
|
||||
ciphers: CipherSuite::default_set(),
|
||||
prefer_server_cipher_order: Default::default(),
|
||||
..TlsConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a `TlsConfig` from byte buffers to a `certs`
|
||||
/// certificate-chain a `key` private-key. This method does no validation;
|
||||
/// certificate chain a `key` private-key. This method does no validation;
|
||||
/// it simply creates a structure suitable for passing into a
|
||||
/// [`Config`](crate::Config).
|
||||
///
|
||||
|
@ -184,8 +209,7 @@ impl TlsConfig {
|
|||
TlsConfig {
|
||||
certs: Either::Right(certs.to_vec()),
|
||||
key: Either::Right(key.to_vec()),
|
||||
ciphers: CipherSuite::default_set(),
|
||||
prefer_server_cipher_order: Default::default(),
|
||||
..TlsConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,6 +309,26 @@ impl TlsConfig {
|
|||
self
|
||||
}
|
||||
|
||||
/// Configures mutual TLS. See [`MutualTls`] for details.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::config::{TlsConfig, MutualTls};
|
||||
///
|
||||
/// # let certs = &[];
|
||||
/// # let key = &[];
|
||||
/// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true);
|
||||
/// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config);
|
||||
/// assert!(tls_config.mutual().is_some());
|
||||
/// ```
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub fn with_mutual(mut self, config: MutualTls) -> Self {
|
||||
self.mutual = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the value of the `certs` parameter.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -380,22 +424,132 @@ impl TlsConfig {
|
|||
pub fn prefer_server_cipher_order(&self) -> bool {
|
||||
self.prefer_server_cipher_order
|
||||
}
|
||||
|
||||
/// Returns the value of the `mutual` parameter.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::path::Path;
|
||||
/// use rocket::config::{TlsConfig, MutualTls};
|
||||
///
|
||||
/// # let certs = &[];
|
||||
/// # let key = &[];
|
||||
/// let mtls_config = MutualTls::from_path("path/to/cert.pem").mandatory(true);
|
||||
/// let tls_config = TlsConfig::from_bytes(certs, key).with_mutual(mtls_config);
|
||||
///
|
||||
/// let mtls = tls_config.mutual().unwrap();
|
||||
/// assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("path/to/cert.pem"));
|
||||
/// assert!(mtls.mandatory);
|
||||
/// ```
|
||||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub fn mutual(&self) -> Option<&MutualTls> {
|
||||
self.mutual.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
impl MutualTls {
|
||||
/// Constructs a `MutualTls` from a path to a PEM file with a certificate
|
||||
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
|
||||
/// method does no validation; it simply creates a structure suitable for
|
||||
/// passing into a [`TlsConfig`].
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::config::MutualTls;
|
||||
///
|
||||
/// let tls_config = MutualTls::from_path("/ssl/ca_certs.pem");
|
||||
/// ```
|
||||
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
|
||||
MutualTls {
|
||||
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
|
||||
mandatory: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a `MutualTls` from a byte buffer to a certificate authority
|
||||
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
|
||||
/// validation; it simply creates a structure suitable for passing into a
|
||||
/// [`TlsConfig`].
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::config::MutualTls;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf);
|
||||
/// ```
|
||||
pub fn from_bytes(ca_certs: &[u8]) -> Self {
|
||||
MutualTls {
|
||||
ca_certs: Either::Right(ca_certs.to_vec()),
|
||||
mandatory: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets whether client authentication is required. Disabled by default.
|
||||
///
|
||||
/// When `true`, client authentication will be required. TLS connections
|
||||
/// where the client does not present a certificate will be immediately
|
||||
/// terminated. When `false`, the client is not required to present a
|
||||
/// certificate. In either case, if a certificate _is_ presented, it must be
|
||||
/// valid or the connection is terminated.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::config::MutualTls;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true);
|
||||
/// ```
|
||||
pub fn mandatory(mut self, mandatory: bool) -> Self {
|
||||
self.mandatory = mandatory;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the value of the `ca_certs` parameter.
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use rocket::config::MutualTls;
|
||||
///
|
||||
/// # let ca_certs_buf = &[];
|
||||
/// let mtls_config = MutualTls::from_bytes(ca_certs_buf).mandatory(true);
|
||||
/// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf);
|
||||
/// ```
|
||||
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
|
||||
match &self.ca_certs {
|
||||
Either::Left(path) => either::Either::Left(path.relative()),
|
||||
Either::Right(bytes) => either::Either::Right(&bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod with_tls_feature {
|
||||
use std::fs;
|
||||
use std::io::{self, Error};
|
||||
|
||||
use crate::http::tls::Config;
|
||||
use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher;
|
||||
use crate::http::tls::rustls::ciphersuite as rustls;
|
||||
|
||||
use super::*;
|
||||
use yansi::Paint;
|
||||
|
||||
use super::{Either, RelativePathBuf, TlsConfig, CipherSuite};
|
||||
|
||||
type Reader = Box<dyn std::io::BufRead + Sync + Send>;
|
||||
|
||||
impl TlsConfig {
|
||||
pub(crate) fn to_readers(&self) -> std::io::Result<(Reader, Reader)> {
|
||||
use std::{io::{self, Error}, fs};
|
||||
use yansi::Paint;
|
||||
|
||||
fn to_reader(value: &Either<RelativePathBuf, Vec<u8>>) -> io::Result<Reader> {
|
||||
match value {
|
||||
Either::Left(path) => {
|
||||
|
@ -411,10 +565,29 @@ mod with_tls_feature {
|
|||
}
|
||||
}
|
||||
|
||||
Ok((to_reader(&self.certs)?, to_reader(&self.key)?))
|
||||
impl TlsConfig {
|
||||
/// This is only called when TLS is enabled.
|
||||
pub(crate) fn to_native_config(&self) -> io::Result<Config<Reader>> {
|
||||
Ok(Config {
|
||||
cert_chain: to_reader(&self.certs)?,
|
||||
private_key: to_reader(&self.key)?,
|
||||
ciphersuites: self.rustls_ciphers().collect(),
|
||||
prefer_server_order: self.prefer_server_cipher_order,
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
mandatory_mtls: false,
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
ca_certs: None,
|
||||
#[cfg(feature = "mtls")]
|
||||
mandatory_mtls: self.mutual.as_ref().map_or(false, |m| m.mandatory),
|
||||
#[cfg(feature = "mtls")]
|
||||
ca_certs: match self.mutual {
|
||||
Some(ref mtls) => Some(to_reader(&mtls.ca_certs)?),
|
||||
None => None
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rustls_ciphers(&self) -> impl Iterator<Item = &'static RustlsCipher> + '_ {
|
||||
fn rustls_ciphers(&self) -> impl Iterator<Item = &'static RustlsCipher> + '_ {
|
||||
self.ciphers().map(|ciphersuite| match ciphersuite {
|
||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
|
||||
&rustls::TLS13_CHACHA20_POLY1305_SHA256,
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{io, time::Duration};
|
|||
use std::task::{Poll, Context};
|
||||
use std::pin::Pin;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::time::{sleep, Sleep};
|
||||
|
@ -10,8 +10,6 @@ use tokio::time::{sleep, Sleep};
|
|||
use futures::stream::Stream;
|
||||
use futures::future::{self, Future, FutureExt};
|
||||
|
||||
use crate::http::hyper::Bytes;
|
||||
|
||||
pin_project! {
|
||||
pub struct ReaderStream<R> {
|
||||
#[pin]
|
||||
|
@ -293,12 +291,16 @@ impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
|||
}
|
||||
}
|
||||
|
||||
use crate::http::private::{Listener, Connection};
|
||||
use crate::http::private::{Listener, Connection, RawCertificate};
|
||||
|
||||
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
|
||||
fn peer_address(&self) -> Option<std::net::SocketAddr> {
|
||||
self.io.peer_address()
|
||||
}
|
||||
|
||||
fn peer_certificates(&self) -> Option<Vec<RawCertificate>> {
|
||||
self.io.peer_certificates()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
|
|
|
@ -6,4 +6,4 @@ edition = "2018"
|
|||
publish = false
|
||||
|
||||
[dependencies]
|
||||
rocket = { path = "../../core/lib", features = ["tls"] }
|
||||
rocket = { path = "../../core/lib", features = ["tls", "mtls"] }
|
||||
|
|
|
@ -9,6 +9,10 @@
|
|||
certs = "private/rsa_sha256_cert.pem"
|
||||
key = "private/rsa_sha256_key.pem"
|
||||
|
||||
[default.tls.mutual]
|
||||
ca_certs = "private/ca_cert.pem"
|
||||
mandatory = false
|
||||
|
||||
[rsa_sha256.tls]
|
||||
certs = "private/rsa_sha256_cert.pem"
|
||||
key = "private/rsa_sha256_key.pem"
|
||||
|
|
|
@ -2,7 +2,14 @@
|
|||
|
||||
#[cfg(test)] mod tests;
|
||||
|
||||
use rocket::mtls::Certificate;
|
||||
|
||||
#[get("/")]
|
||||
fn mutual(cert: Certificate<'_>) -> String {
|
||||
format!("Hello! Here's what we know: [{}] {}", cert.serial(), cert.subject())
|
||||
}
|
||||
|
||||
#[get("/", rank = 2)]
|
||||
fn hello() -> &'static str {
|
||||
"Hello, world!"
|
||||
}
|
||||
|
@ -10,5 +17,6 @@ fn hello() -> &'static str {
|
|||
#[launch]
|
||||
fn rocket() -> _ {
|
||||
// See `Rocket.toml` and `Cargo.toml` for TLS configuration.
|
||||
rocket::build().mount("/", routes![hello])
|
||||
// Run `./private/gen_certs.sh` to generate a CA and key pairs.
|
||||
rocket::build().mount("/", routes![hello, mutual])
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue