Use default 'CryptoProvider' for all TLS ops.

Prior to this commit, some TLS related operations used 'ring' even when
a different default 'CryptoProvider' was installed. This commit fixes
that by refactoring 'TlsConfig' such that all utility methods are
required to use the default 'CryptoProvider'.

This commit also cleans up code related to the rustls 0.23 update.
This commit is contained in:
Sergio Benitez 2024-03-26 15:38:36 -07:00 committed by Sergio Benitez
parent ce92c5dd76
commit edce8bd656
11 changed files with 183 additions and 183 deletions

View File

@ -49,11 +49,6 @@ serde_json = { version = "1.0.26", optional = true }
rmp-serde = { version = "1", optional = true } rmp-serde = { version = "1", optional = true }
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] } uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }
# Optional TLS dependencies
rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"], optional = true }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"], optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
# Optional MTLS dependencies # Optional MTLS dependencies
x509-parser = { version = "0.16", optional = true } x509-parser = { version = "0.16", optional = true }
@ -111,6 +106,22 @@ version = "0.6.0-dev"
path = "../http" path = "../http"
features = ["serde"] features = ["serde"]
[dependencies.rustls]
version = "0.23"
default-features = false
features = ["ring", "logging", "std", "tls12"]
optional = true
[dependencies.tokio-rustls]
version = "0.26"
default-features = false
features = ["logging", "tls12", "ring"]
optional = true
[dependencies.rustls-pemfile]
version = "2.1.0"
optional = true
[dependencies.s2n-quic] [dependencies.s2n-quic]
version = "1.32" version = "1.32"
default-features = false default-features = false

View File

@ -66,15 +66,15 @@ impl QuicListener {
use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer}; use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer};
// FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22. // FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22.
let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap()) let cert_chain = tls.load_certs()
.unwrap() .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.into_iter() .into_iter()
.map(|v| v.to_vec()) .map(|v| v.to_vec())
.map(rustls::Certificate) .map(rustls::Certificate)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap()) let key = tls.load_key()
.unwrap() .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.secret_der() .secret_der()
.to_vec(); .to_vec();

View File

@ -7,7 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::tls::{TlsConfig, Error}; use crate::tls::{TlsConfig, Error};
use crate::tls::util::{self, load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint}; use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
#[doc(inline)] #[doc(inline)]
@ -29,16 +28,13 @@ pub struct TlsBindable<I> {
impl TlsConfig { impl TlsConfig {
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> { pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
let provider = rustls::crypto::CryptoProvider { let provider = Arc::new(self.default_crypto_provider());
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..util::get_crypto_provider()
};
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
let verifier = match self.mutual { let verifier = match self.mutual {
Some(ref mtls) => { Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?; let ca = Arc::new(mtls.load_ca_certs()?);
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)); let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
match mtls.mandatory { match mtls.mandatory {
true => verifier.build()?, true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?, false => verifier.allow_unauthenticated().build()?,
@ -50,12 +46,10 @@ impl TlsConfig {
#[cfg(not(feature = "mtls"))] #[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth(); let verifier = WebPkiClientVerifier::no_client_auth();
let key = load_key(&mut self.key_reader()?)?; let mut tls_config = ServerConfig::builder_with_provider(provider)
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()? .with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier) .with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?; .with_single_cert(self.load_certs()?, self.load_key()?)?;
tls_config.ignore_client_order = self.prefer_server_cipher_order; tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024); tls_config.session_storage = ServerSessionMemoryCache::new(1024);

View File

@ -246,12 +246,14 @@ macro_rules! pub_request_impl {
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self { pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
use std::sync::Arc; use std::sync::Arc;
use crate::tls::util::load_cert_chain;
use crate::listener::Certificates; use crate::listener::Certificates;
let mut reader = std::io::BufReader::new(reader); let mut reader = std::io::BufReader::new(reader);
let certs = load_cert_chain(&mut reader).map(Certificates::from); self._request_mut().connection.peer_certs = rustls_pemfile::certs(&mut reader)
self._request_mut().connection.peer_certs = certs.ok().map(Arc::new); .collect::<Result<Vec<_>, _>>()
.map(|certs| Arc::new(Certificates::from(certs)))
.ok();
self self
} }

View File

@ -3,6 +3,8 @@ use std::io;
use figment::value::magic::{RelativePathBuf, Either}; use figment::value::magic::{RelativePathBuf, Either};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::tls::{Result, Error};
/// Mutual TLS configuration. /// Mutual TLS configuration.
/// ///
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which /// Configuration works in concert with the [`mtls`](crate::mtls) module, which
@ -142,6 +144,7 @@ impl MtlsConfig {
} }
/// Returns the value of the `ca_certs` parameter. /// Returns the value of the `ca_certs` parameter.
///
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
@ -162,6 +165,16 @@ impl MtlsConfig {
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> { pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
crate::tls::config::to_reader(&self.ca_certs) crate::tls::config::to_reader(&self.ca_certs)
} }
/// Load and decode CA certificates from `reader`.
pub(crate) fn load_ca_certs(&self) -> Result<rustls::RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_pemfile::certs(&mut self.ca_certs_reader()?) {
roots.add(cert?).map_err(Error::CertAuth)?;
}
Ok(roots)
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,9 +1,13 @@
use std::io; use std::io;
use rustls::crypto::{ring, CryptoProvider};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use figment::value::magic::{Either, RelativePathBuf}; use figment::value::magic::{Either, RelativePathBuf};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use indexmap::IndexSet; use indexmap::IndexSet;
use crate::tls::error::{Result, Error, KeyError};
/// TLS configuration: certificate chain, key, and ciphersuites. /// TLS configuration: certificate chain, key, and ciphersuites.
/// ///
/// Four parameters control `tls` configuration: /// Four parameters control `tls` configuration:
@ -431,6 +435,72 @@ impl TlsConfig {
} }
} }
/// Loads certificates from `reader`.
impl TlsConfig {
pub(crate) fn load_certs(&self) -> Result<Vec<CertificateDer<'static>>> {
rustls_pemfile::certs(&mut self.certs_reader()?)
.collect::<Result<_, _>>()
.map_err(Error::CertChain)
}
/// Load and decode the private key from `reader`.
pub(crate) fn load_key(&self) -> Result<PrivateKeyDer<'static>> {
use rustls_pemfile::Item::*;
let mut keys = rustls_pemfile::read_all(&mut self.key_reader()?)
.map(|result| result.map_err(KeyError::Io)
.and_then(|item| match item {
Pkcs1Key(key) => Ok(key.into()),
Pkcs8Key(key) => Ok(key.into()),
Sec1Key(key) => Ok(key.into()),
_ => Err(KeyError::BadItem(item))
})
)
.collect::<Result<Vec<PrivateKeyDer<'static>>, _>>()?;
if keys.len() != 1 {
return Err(KeyError::BadKeyCount(keys.len()).into());
}
// Ensure we can use the key.
let key = keys.remove(0);
self.default_crypto_provider()
.key_provider
.load_private_key(key.clone_key())
.map_err(KeyError::Unsupported)?;
Ok(key)
}
pub(crate) fn default_crypto_provider(&self) -> CryptoProvider {
CryptoProvider::get_default()
.map(|arc| (**arc).clone())
.unwrap_or_else(|| rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|cipher| match cipher {
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}).collect(),
..ring::default_provider()
})
}
}
impl CipherSuite { impl CipherSuite {
/// The default set and order of cipher suites. These are all of the /// The default set and order of cipher suites. These are all of the
/// variants in [`CipherSuite`] in their declaration order. /// variants in [`CipherSuite`] in their declaration order.
@ -474,33 +544,6 @@ impl CipherSuite {
} }
} }
impl From<CipherSuite> for rustls::SupportedCipherSuite {
fn from(cipher: CipherSuite) -> Self {
use rustls::crypto::ring::cipher_suite;
match cipher {
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384 =>
cipher_suite::TLS13_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256 =>
cipher_suite::TLS13_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 =>
cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 =>
cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}
}
}
pub(crate) fn to_reader( pub(crate) fn to_reader(
value: &Either<RelativePathBuf, Vec<u8>> value: &Either<RelativePathBuf, Vec<u8>>
) -> io::Result<Box<dyn io::BufRead + Sync + Send>> { ) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
@ -522,6 +565,7 @@ pub(crate) fn to_reader(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use figment::{Figment, providers::{Toml, Format}}; use figment::{Figment, providers::{Toml, Format}};
#[test] #[test]
@ -650,4 +694,42 @@ mod tests {
Ok(()) Ok(())
}); });
} }
macro_rules! tls_example_private_pem {
($k:expr) => {
concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)
}
}
#[test]
fn verify_load_private_keys_of_different_types() -> Result<()> {
let key_paths = [
tls_example_private_pem!("rsa_sha256_key.pem"),
tls_example_private_pem!("ecdsa_nistp256_sha256_key_pkcs8.pem"),
tls_example_private_pem!("ecdsa_nistp384_sha384_key_pkcs8.pem"),
tls_example_private_pem!("ed25519_key.pem"),
];
for key in key_paths {
TlsConfig::from_paths("", key).load_key()?;
}
Ok(())
}
#[test]
fn verify_load_certs_of_different_types() -> Result<()> {
let cert_paths = [
tls_example_private_pem!("rsa_sha256_cert.pem"),
tls_example_private_pem!("ecdsa_nistp256_sha256_cert.pem"),
tls_example_private_pem!("ecdsa_nistp384_sha384_cert.pem"),
tls_example_private_pem!("ed25519_cert.pem"),
];
for cert in cert_paths {
TlsConfig::from_paths(cert, "").load_certs()?;
}
Ok(())
}
} }

View File

@ -1,6 +1,5 @@
mod error; mod error;
pub(crate) mod config; pub(crate) mod config;
pub(crate) mod util;
pub use error::Result; pub use error::Result;
pub use config::{TlsConfig, CipherSuite}; pub use config::{TlsConfig, CipherSuite};

View File

@ -1,104 +0,0 @@
use std::io;
use rustls::RootCertStore;
use rustls::crypto::CryptoProvider;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use crate::tls::error::{Result, Error, KeyError};
/// Loads certificates from `reader`.
pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result<Vec<CertificateDer<'static>>> {
rustls_pemfile::certs(reader)
.collect::<Result<_, _>>()
.map_err(Error::CertChain)
}
/// Load and decode the private key from `reader`.
pub fn load_key(reader: &mut dyn io::BufRead) -> Result<PrivateKeyDer<'static>> {
use rustls_pemfile::Item::*;
let mut keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::read_all(reader)
.map(|result| result.map_err(KeyError::Io)
.and_then(|item| match item {
Pkcs1Key(key) => Ok(key.into()),
Pkcs8Key(key) => Ok(key.into()),
Sec1Key(key) => Ok(key.into()),
_ => Err(KeyError::BadItem(item))
})
)
.collect::<Result<_, _>>()?;
if keys.len() != 1 {
return Err(KeyError::BadKeyCount(keys.len()).into());
}
// Ensure we can use the key.
let key = keys.remove(0);
get_crypto_provider().key_provider.load_private_key(key.clone_key())
.map_err(KeyError::Unsupported)?;
Ok(key)
}
/// Load and decode CA certificates from `reader`.
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result<RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
for cert in load_cert_chain(reader)? {
roots.add(cert).map_err(Error::CertAuth)?;
}
Ok(roots)
}
pub(crate) fn get_crypto_provider() -> CryptoProvider {
if let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() {
CryptoProvider::clone(crypto_provider)
} else {
let crypto_provider = rustls::crypto::ring::default_provider();
// Should only fail due to other concurrent install, so we ignore it
let _ = crypto_provider.clone().install_default();
crypto_provider
}
}
#[cfg(test)]
mod test {
use super::*;
macro_rules! tls_example_key {
($k:expr) => {
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k))
}
}
#[test]
fn verify_load_private_keys_of_different_types() -> Result<()> {
let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem");
let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem");
let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem");
let ed2551_key = tls_example_key!("ed25519_key.pem");
load_key(&mut &rsa_sha256_key[..])?;
load_key(&mut &ecdsa_nistp256_sha256_key[..])?;
load_key(&mut &ecdsa_nistp384_sha384_key[..])?;
load_key(&mut &ed2551_key[..])?;
Ok(())
}
#[test]
fn verify_load_certs_of_different_types() -> Result<()> {
let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem");
let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem");
let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem");
let ed2551_cert = tls_example_key!("ed25519_cert.pem");
load_cert_chain(&mut &rsa_sha256_cert[..])?;
load_cert_chain(&mut &ecdsa_nistp256_sha256_cert[..])?;
load_cert_chain(&mut &ecdsa_nistp384_sha384_cert[..])?;
load_cert_chain(&mut &ed2551_cert[..])?;
Ok(())
}
}

View File

@ -7,5 +7,7 @@ publish = false
[dependencies] [dependencies]
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] } rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] }
rustls = { version = "0.23", features = ["aws_lc_rs"] }
yansi = "1.0.1" yansi = "1.0.1"
[target.'cfg(unix)'.dependencies]
rustls = { version = "0.23", features = ["aws_lc_rs"] }

View File

@ -8,6 +8,7 @@
# rsa_sha256 # rsa_sha256
# ecdsa_nistp256_sha256 # ecdsa_nistp256_sha256
# ecdsa_nistp384_sha384 # ecdsa_nistp384_sha384
# ecdsa_nistp521_sha512
# #
# Generate a certificate of the [cert-kind] key type, or if no cert-kind is # Generate a certificate of the [cert-kind] key type, or if no cert-kind is
# specified, all of the certificates. # specified, all of the certificates.

View File

@ -65,13 +65,31 @@ fn insecure_cookies() {
assert_eq!(c4.secure(), None); assert_eq!(c4.secure(), None);
} }
#[test] fn validate_profiles(profiles: &[&str]) {
fn hello_world() {
use rocket::listener::DefaultListener; use rocket::listener::DefaultListener;
use rocket::config::{Config, SecretKey}; use rocket::config::{Config, SecretKey};
use rustls::crypto::aws_lc_rs;
let mut profiles = vec![ for profile in profiles {
let config = Config {
secret_key: SecretKey::generate().unwrap(),
..Config::debug_default()
};
let figment = Config::figment().merge(config).select(profile);
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.into_string().unwrap(), "Hello, world!");
let figment = client.rocket().figment();
let listener: DefaultListener = figment.extract().unwrap();
assert_eq!(figment.profile(), profile);
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
}
}
#[test]
fn validate_tls_profiles() {
const DEFAULT_PROFILES: &[&str] = &[
"rsa_sha256", "rsa_sha256",
"ecdsa_nistp256_sha256_pkcs8", "ecdsa_nistp256_sha256_pkcs8",
"ecdsa_nistp384_sha384_pkcs8", "ecdsa_nistp384_sha384_pkcs8",
@ -80,29 +98,11 @@ fn hello_world() {
"ed25519", "ed25519",
]; ];
for use_aws_lc in [false, true] { validate_profiles(DEFAULT_PROFILES);
if use_aws_lc {
let crypto_provider = aws_lc_rs::default_provider();
crypto_provider.install_default().unwrap();
profiles.push("ecdsa_nistp521_sha512_pkcs8"); #[cfg(unix)] {
} rustls::crypto::aws_lc_rs::default_provider().install_default().unwrap();
validate_profiles(DEFAULT_PROFILES);
for profile in &profiles { validate_profiles(&["ecdsa_nistp521_sha512_pkcs8"]);
let config = Config {
secret_key: SecretKey::generate().unwrap(),
..Config::debug_default()
};
let figment = Config::figment().merge(config).select(profile);
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.into_string().unwrap(), "Hello, world!");
let figment = client.rocket().figment();
let listener: DefaultListener = figment.extract().unwrap();
assert_eq!(figment.profile(), profile);
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
}
} }
} }