From edce8bd656577a8ee49cbf0deb106e07cfc6605f Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 26 Mar 2024 15:38:36 -0700 Subject: [PATCH] Use default 'CryptoProvider' for all TLS ops. Prior to this commit, some TLS related operations used 'ring' even when a different default 'CryptoProvider' was installed. This commit fixes that by refactoring 'TlsConfig' such that all utility methods are required to use the default 'CryptoProvider'. This commit also cleans up code related to the rustls 0.23 update. --- core/lib/Cargo.toml | 21 +++-- core/lib/src/listener/quic.rs | 8 +- core/lib/src/listener/tls.rs | 16 ++-- core/lib/src/local/request.rs | 8 +- core/lib/src/mtls/config.rs | 13 +++ core/lib/src/tls/config.rs | 136 ++++++++++++++++++++++++------ core/lib/src/tls/mod.rs | 1 - core/lib/src/tls/util.rs | 104 ----------------------- examples/tls/Cargo.toml | 4 +- examples/tls/private/gen_certs.sh | 1 + examples/tls/src/tests.rs | 54 ++++++------ 11 files changed, 183 insertions(+), 183 deletions(-) delete mode 100644 core/lib/src/tls/util.rs diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index c3820b0e..bfe6b67a 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -49,11 +49,6 @@ serde_json = { version = "1.0.26", optional = true } rmp-serde = { version = "1", optional = true } uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] } -# Optional TLS dependencies -rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"], optional = true } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"], optional = true } -rustls-pemfile = { version = "2.0.0", optional = true } - # Optional MTLS dependencies x509-parser = { version = "0.16", optional = true } @@ -111,6 +106,22 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] +[dependencies.rustls] +version = "0.23" +default-features = false +features = ["ring", "logging", "std", "tls12"] +optional = true + +[dependencies.tokio-rustls] +version = "0.26" +default-features = false +features = ["logging", "tls12", "ring"] +optional = true + +[dependencies.rustls-pemfile] +version = "2.1.0" +optional = true + [dependencies.s2n-quic] version = "1.32" default-features = false diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs index e43cabc9..f237a88e 100644 --- a/core/lib/src/listener/quic.rs +++ b/core/lib/src/listener/quic.rs @@ -66,15 +66,15 @@ impl QuicListener { use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer}; // FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22. - let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap()) - .unwrap() + let cert_chain = tls.load_certs() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))? .into_iter() .map(|v| v.to_vec()) .map(rustls::Certificate) .collect::>(); - let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap()) - .unwrap() + let key = tls.load_key() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))? .secret_der() .to_vec(); diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index 25dd1a6d..220011e4 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -7,7 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use crate::tls::{TlsConfig, Error}; -use crate::tls::util::{self, load_cert_chain, load_key, load_ca_certs}; use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint}; #[doc(inline)] @@ -29,16 +28,13 @@ pub struct TlsBindable { impl TlsConfig { pub(crate) fn server_config(&self) -> Result { - let provider = rustls::crypto::CryptoProvider { - cipher_suites: self.ciphers().map(|c| c.into()).collect(), - ..util::get_crypto_provider() - }; + let provider = Arc::new(self.default_crypto_provider()); #[cfg(feature = "mtls")] let verifier = match self.mutual { Some(ref mtls) => { - let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?; - let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)); + let ca = Arc::new(mtls.load_ca_certs()?); + let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone()); match mtls.mandatory { true => verifier.build()?, false => verifier.allow_unauthenticated().build()?, @@ -50,12 +46,10 @@ impl TlsConfig { #[cfg(not(feature = "mtls"))] let verifier = WebPkiClientVerifier::no_client_auth(); - let key = load_key(&mut self.key_reader()?)?; - let cert_chain = load_cert_chain(&mut self.certs_reader()?)?; - let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider)) + let mut tls_config = ServerConfig::builder_with_provider(provider) .with_safe_default_protocol_versions()? .with_client_cert_verifier(verifier) - .with_single_cert(cert_chain, key)?; + .with_single_cert(self.load_certs()?, self.load_key()?)?; tls_config.ignore_client_order = self.prefer_server_cipher_order; tls_config.session_storage = ServerSessionMemoryCache::new(1024); diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 7a156e28..9fbed290 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -246,12 +246,14 @@ macro_rules! pub_request_impl { #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub fn identity(mut self, reader: C) -> Self { use std::sync::Arc; - use crate::tls::util::load_cert_chain; use crate::listener::Certificates; let mut reader = std::io::BufReader::new(reader); - let certs = load_cert_chain(&mut reader).map(Certificates::from); - self._request_mut().connection.peer_certs = certs.ok().map(Arc::new); + self._request_mut().connection.peer_certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map(|certs| Arc::new(Certificates::from(certs))) + .ok(); + self } diff --git a/core/lib/src/mtls/config.rs b/core/lib/src/mtls/config.rs index a1e1e722..8bcbf0c0 100644 --- a/core/lib/src/mtls/config.rs +++ b/core/lib/src/mtls/config.rs @@ -3,6 +3,8 @@ use std::io; use figment::value::magic::{RelativePathBuf, Either}; use serde::{Serialize, Deserialize}; +use crate::tls::{Result, Error}; + /// Mutual TLS configuration. /// /// Configuration works in concert with the [`mtls`](crate::mtls) module, which @@ -142,6 +144,7 @@ impl MtlsConfig { } /// Returns the value of the `ca_certs` parameter. + /// /// # Example /// /// ```rust @@ -162,6 +165,16 @@ impl MtlsConfig { pub fn ca_certs_reader(&self) -> io::Result> { crate::tls::config::to_reader(&self.ca_certs) } + + /// Load and decode CA certificates from `reader`. + pub(crate) fn load_ca_certs(&self) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_pemfile::certs(&mut self.ca_certs_reader()?) { + roots.add(cert?).map_err(Error::CertAuth)?; + } + + Ok(roots) + } } #[cfg(test)] diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index c94bb332..5762e8f1 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -1,9 +1,13 @@ use std::io; +use rustls::crypto::{ring, CryptoProvider}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use figment::value::magic::{Either, RelativePathBuf}; use serde::{Deserialize, Serialize}; use indexmap::IndexSet; +use crate::tls::error::{Result, Error, KeyError}; + /// TLS configuration: certificate chain, key, and ciphersuites. /// /// Four parameters control `tls` configuration: @@ -431,6 +435,72 @@ impl TlsConfig { } } +/// Loads certificates from `reader`. +impl TlsConfig { + pub(crate) fn load_certs(&self) -> Result>> { + rustls_pemfile::certs(&mut self.certs_reader()?) + .collect::>() + .map_err(Error::CertChain) + } + + /// Load and decode the private key from `reader`. + pub(crate) fn load_key(&self) -> Result> { + use rustls_pemfile::Item::*; + + let mut keys = rustls_pemfile::read_all(&mut self.key_reader()?) + .map(|result| result.map_err(KeyError::Io) + .and_then(|item| match item { + Pkcs1Key(key) => Ok(key.into()), + Pkcs8Key(key) => Ok(key.into()), + Sec1Key(key) => Ok(key.into()), + _ => Err(KeyError::BadItem(item)) + }) + ) + .collect::>, _>>()?; + + if keys.len() != 1 { + return Err(KeyError::BadKeyCount(keys.len()).into()); + } + + // Ensure we can use the key. + let key = keys.remove(0); + self.default_crypto_provider() + .key_provider + .load_private_key(key.clone_key()) + .map_err(KeyError::Unsupported)?; + + Ok(key) + } + + pub(crate) fn default_crypto_provider(&self) -> CryptoProvider { + CryptoProvider::get_default() + .map(|arc| (**arc).clone()) + .unwrap_or_else(|| rustls::crypto::CryptoProvider { + cipher_suites: self.ciphers().map(|cipher| match cipher { + CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS13_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS13_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }).collect(), + ..ring::default_provider() + }) + } +} + impl CipherSuite { /// The default set and order of cipher suites. These are all of the /// variants in [`CipherSuite`] in their declaration order. @@ -474,33 +544,6 @@ impl CipherSuite { } } -impl From for rustls::SupportedCipherSuite { - fn from(cipher: CipherSuite) -> Self { - use rustls::crypto::ring::cipher_suite; - - match cipher { - CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384 => - cipher_suite::TLS13_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256 => - cipher_suite::TLS13_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - } - } -} - pub(crate) fn to_reader( value: &Either> ) -> io::Result> { @@ -522,6 +565,7 @@ pub(crate) fn to_reader( #[cfg(test)] mod tests { + use super::*; use figment::{Figment, providers::{Toml, Format}}; #[test] @@ -650,4 +694,42 @@ mod tests { Ok(()) }); } + + macro_rules! tls_example_private_pem { + ($k:expr) => { + concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k) + } + } + + #[test] + fn verify_load_private_keys_of_different_types() -> Result<()> { + let key_paths = [ + tls_example_private_pem!("rsa_sha256_key.pem"), + tls_example_private_pem!("ecdsa_nistp256_sha256_key_pkcs8.pem"), + tls_example_private_pem!("ecdsa_nistp384_sha384_key_pkcs8.pem"), + tls_example_private_pem!("ed25519_key.pem"), + ]; + + for key in key_paths { + TlsConfig::from_paths("", key).load_key()?; + } + + Ok(()) + } + + #[test] + fn verify_load_certs_of_different_types() -> Result<()> { + let cert_paths = [ + tls_example_private_pem!("rsa_sha256_cert.pem"), + tls_example_private_pem!("ecdsa_nistp256_sha256_cert.pem"), + tls_example_private_pem!("ecdsa_nistp384_sha384_cert.pem"), + tls_example_private_pem!("ed25519_cert.pem"), + ]; + + for cert in cert_paths { + TlsConfig::from_paths(cert, "").load_certs()?; + } + + Ok(()) + } } diff --git a/core/lib/src/tls/mod.rs b/core/lib/src/tls/mod.rs index d6128e3b..7f5a05de 100644 --- a/core/lib/src/tls/mod.rs +++ b/core/lib/src/tls/mod.rs @@ -1,6 +1,5 @@ mod error; pub(crate) mod config; -pub(crate) mod util; pub use error::Result; pub use config::{TlsConfig, CipherSuite}; diff --git a/core/lib/src/tls/util.rs b/core/lib/src/tls/util.rs deleted file mode 100644 index 497c5201..00000000 --- a/core/lib/src/tls/util.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::io; - -use rustls::RootCertStore; -use rustls::crypto::CryptoProvider; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - -use crate::tls::error::{Result, Error, KeyError}; - -/// Loads certificates from `reader`. -pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result>> { - rustls_pemfile::certs(reader) - .collect::>() - .map_err(Error::CertChain) -} - -/// Load and decode the private key from `reader`. -pub fn load_key(reader: &mut dyn io::BufRead) -> Result> { - use rustls_pemfile::Item::*; - - let mut keys: Vec> = rustls_pemfile::read_all(reader) - .map(|result| result.map_err(KeyError::Io) - .and_then(|item| match item { - Pkcs1Key(key) => Ok(key.into()), - Pkcs8Key(key) => Ok(key.into()), - Sec1Key(key) => Ok(key.into()), - _ => Err(KeyError::BadItem(item)) - }) - ) - .collect::>()?; - - if keys.len() != 1 { - return Err(KeyError::BadKeyCount(keys.len()).into()); - } - - // Ensure we can use the key. - let key = keys.remove(0); - get_crypto_provider().key_provider.load_private_key(key.clone_key()) - .map_err(KeyError::Unsupported)?; - Ok(key) -} - -/// Load and decode CA certificates from `reader`. -pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result { - let mut roots = rustls::RootCertStore::empty(); - for cert in load_cert_chain(reader)? { - roots.add(cert).map_err(Error::CertAuth)?; - } - - Ok(roots) -} - -pub(crate) fn get_crypto_provider() -> CryptoProvider { - if let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() { - CryptoProvider::clone(crypto_provider) - } else { - let crypto_provider = rustls::crypto::ring::default_provider(); - // Should only fail due to other concurrent install, so we ignore it - let _ = crypto_provider.clone().install_default(); - - crypto_provider - } - -} - -#[cfg(test)] -mod test { - use super::*; - - macro_rules! tls_example_key { - ($k:expr) => { - include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)) - } - } - - #[test] - fn verify_load_private_keys_of_different_types() -> Result<()> { - let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem"); - let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem"); - let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem"); - let ed2551_key = tls_example_key!("ed25519_key.pem"); - - load_key(&mut &rsa_sha256_key[..])?; - load_key(&mut &ecdsa_nistp256_sha256_key[..])?; - load_key(&mut &ecdsa_nistp384_sha384_key[..])?; - load_key(&mut &ed2551_key[..])?; - - Ok(()) - } - - #[test] - fn verify_load_certs_of_different_types() -> Result<()> { - let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem"); - let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem"); - let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem"); - let ed2551_cert = tls_example_key!("ed25519_cert.pem"); - - load_cert_chain(&mut &rsa_sha256_cert[..])?; - load_cert_chain(&mut &ecdsa_nistp256_sha256_cert[..])?; - load_cert_chain(&mut &ecdsa_nistp384_sha384_cert[..])?; - load_cert_chain(&mut &ed2551_cert[..])?; - - Ok(()) - } -} diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index a505647f..aa6e502f 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -7,5 +7,7 @@ publish = false [dependencies] rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] } -rustls = { version = "0.23", features = ["aws_lc_rs"] } yansi = "1.0.1" + +[target.'cfg(unix)'.dependencies] +rustls = { version = "0.23", features = ["aws_lc_rs"] } diff --git a/examples/tls/private/gen_certs.sh b/examples/tls/private/gen_certs.sh index 2418c27b..d98d152e 100755 --- a/examples/tls/private/gen_certs.sh +++ b/examples/tls/private/gen_certs.sh @@ -8,6 +8,7 @@ # rsa_sha256 # ecdsa_nistp256_sha256 # ecdsa_nistp384_sha384 +# ecdsa_nistp521_sha512 # # Generate a certificate of the [cert-kind] key type, or if no cert-kind is # specified, all of the certificates. diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index cec2e9f7..413e64fd 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -65,13 +65,31 @@ fn insecure_cookies() { assert_eq!(c4.secure(), None); } -#[test] -fn hello_world() { +fn validate_profiles(profiles: &[&str]) { use rocket::listener::DefaultListener; use rocket::config::{Config, SecretKey}; - use rustls::crypto::aws_lc_rs; - let mut profiles = vec![ + for profile in profiles { + let config = Config { + secret_key: SecretKey::generate().unwrap(), + ..Config::debug_default() + }; + + let figment = Config::figment().merge(config).select(profile); + let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); + let response = client.get("/").dispatch(); + assert_eq!(response.into_string().unwrap(), "Hello, world!"); + + let figment = client.rocket().figment(); + let listener: DefaultListener = figment.extract().unwrap(); + assert_eq!(figment.profile(), profile); + listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); + } +} + +#[test] +fn validate_tls_profiles() { + const DEFAULT_PROFILES: &[&str] = &[ "rsa_sha256", "ecdsa_nistp256_sha256_pkcs8", "ecdsa_nistp384_sha384_pkcs8", @@ -80,29 +98,11 @@ fn hello_world() { "ed25519", ]; - for use_aws_lc in [false, true] { - if use_aws_lc { - let crypto_provider = aws_lc_rs::default_provider(); - crypto_provider.install_default().unwrap(); + validate_profiles(DEFAULT_PROFILES); - profiles.push("ecdsa_nistp521_sha512_pkcs8"); - } - - for profile in &profiles { - let config = Config { - secret_key: SecretKey::generate().unwrap(), - ..Config::debug_default() - }; - - let figment = Config::figment().merge(config).select(profile); - let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); - let response = client.get("/").dispatch(); - assert_eq!(response.into_string().unwrap(), "Hello, world!"); - - let figment = client.rocket().figment(); - let listener: DefaultListener = figment.extract().unwrap(); - assert_eq!(figment.profile(), profile); - listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); - } + #[cfg(unix)] { + rustls::crypto::aws_lc_rs::default_provider().install_default().unwrap(); + validate_profiles(DEFAULT_PROFILES); + validate_profiles(&["ecdsa_nistp521_sha512_pkcs8"]); } }