diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index c3820b0e..bfe6b67a 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -49,11 +49,6 @@ serde_json = { version = "1.0.26", optional = true } rmp-serde = { version = "1", optional = true } uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] } -# Optional TLS dependencies -rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"], optional = true } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"], optional = true } -rustls-pemfile = { version = "2.0.0", optional = true } - # Optional MTLS dependencies x509-parser = { version = "0.16", optional = true } @@ -111,6 +106,22 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] +[dependencies.rustls] +version = "0.23" +default-features = false +features = ["ring", "logging", "std", "tls12"] +optional = true + +[dependencies.tokio-rustls] +version = "0.26" +default-features = false +features = ["logging", "tls12", "ring"] +optional = true + +[dependencies.rustls-pemfile] +version = "2.1.0" +optional = true + [dependencies.s2n-quic] version = "1.32" default-features = false diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs index e43cabc9..f237a88e 100644 --- a/core/lib/src/listener/quic.rs +++ b/core/lib/src/listener/quic.rs @@ -66,15 +66,15 @@ impl QuicListener { use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer}; // FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22. - let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap()) - .unwrap() + let cert_chain = tls.load_certs() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))? .into_iter() .map(|v| v.to_vec()) .map(rustls::Certificate) .collect::>(); - let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap()) - .unwrap() + let key = tls.load_key() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))? .secret_der() .to_vec(); diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index 25dd1a6d..220011e4 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -7,7 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use crate::tls::{TlsConfig, Error}; -use crate::tls::util::{self, load_cert_chain, load_key, load_ca_certs}; use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint}; #[doc(inline)] @@ -29,16 +28,13 @@ pub struct TlsBindable { impl TlsConfig { pub(crate) fn server_config(&self) -> Result { - let provider = rustls::crypto::CryptoProvider { - cipher_suites: self.ciphers().map(|c| c.into()).collect(), - ..util::get_crypto_provider() - }; + let provider = Arc::new(self.default_crypto_provider()); #[cfg(feature = "mtls")] let verifier = match self.mutual { Some(ref mtls) => { - let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?; - let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)); + let ca = Arc::new(mtls.load_ca_certs()?); + let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone()); match mtls.mandatory { true => verifier.build()?, false => verifier.allow_unauthenticated().build()?, @@ -50,12 +46,10 @@ impl TlsConfig { #[cfg(not(feature = "mtls"))] let verifier = WebPkiClientVerifier::no_client_auth(); - let key = load_key(&mut self.key_reader()?)?; - let cert_chain = load_cert_chain(&mut self.certs_reader()?)?; - let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider)) + let mut tls_config = ServerConfig::builder_with_provider(provider) .with_safe_default_protocol_versions()? .with_client_cert_verifier(verifier) - .with_single_cert(cert_chain, key)?; + .with_single_cert(self.load_certs()?, self.load_key()?)?; tls_config.ignore_client_order = self.prefer_server_cipher_order; tls_config.session_storage = ServerSessionMemoryCache::new(1024); diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 7a156e28..9fbed290 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -246,12 +246,14 @@ macro_rules! pub_request_impl { #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub fn identity(mut self, reader: C) -> Self { use std::sync::Arc; - use crate::tls::util::load_cert_chain; use crate::listener::Certificates; let mut reader = std::io::BufReader::new(reader); - let certs = load_cert_chain(&mut reader).map(Certificates::from); - self._request_mut().connection.peer_certs = certs.ok().map(Arc::new); + self._request_mut().connection.peer_certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map(|certs| Arc::new(Certificates::from(certs))) + .ok(); + self } diff --git a/core/lib/src/mtls/config.rs b/core/lib/src/mtls/config.rs index a1e1e722..8bcbf0c0 100644 --- a/core/lib/src/mtls/config.rs +++ b/core/lib/src/mtls/config.rs @@ -3,6 +3,8 @@ use std::io; use figment::value::magic::{RelativePathBuf, Either}; use serde::{Serialize, Deserialize}; +use crate::tls::{Result, Error}; + /// Mutual TLS configuration. /// /// Configuration works in concert with the [`mtls`](crate::mtls) module, which @@ -142,6 +144,7 @@ impl MtlsConfig { } /// Returns the value of the `ca_certs` parameter. + /// /// # Example /// /// ```rust @@ -162,6 +165,16 @@ impl MtlsConfig { pub fn ca_certs_reader(&self) -> io::Result> { crate::tls::config::to_reader(&self.ca_certs) } + + /// Load and decode CA certificates from `reader`. + pub(crate) fn load_ca_certs(&self) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_pemfile::certs(&mut self.ca_certs_reader()?) { + roots.add(cert?).map_err(Error::CertAuth)?; + } + + Ok(roots) + } } #[cfg(test)] diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index c94bb332..5762e8f1 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -1,9 +1,13 @@ use std::io; +use rustls::crypto::{ring, CryptoProvider}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use figment::value::magic::{Either, RelativePathBuf}; use serde::{Deserialize, Serialize}; use indexmap::IndexSet; +use crate::tls::error::{Result, Error, KeyError}; + /// TLS configuration: certificate chain, key, and ciphersuites. /// /// Four parameters control `tls` configuration: @@ -431,6 +435,72 @@ impl TlsConfig { } } +/// Loads certificates from `reader`. +impl TlsConfig { + pub(crate) fn load_certs(&self) -> Result>> { + rustls_pemfile::certs(&mut self.certs_reader()?) + .collect::>() + .map_err(Error::CertChain) + } + + /// Load and decode the private key from `reader`. + pub(crate) fn load_key(&self) -> Result> { + use rustls_pemfile::Item::*; + + let mut keys = rustls_pemfile::read_all(&mut self.key_reader()?) + .map(|result| result.map_err(KeyError::Io) + .and_then(|item| match item { + Pkcs1Key(key) => Ok(key.into()), + Pkcs8Key(key) => Ok(key.into()), + Sec1Key(key) => Ok(key.into()), + _ => Err(KeyError::BadItem(item)) + }) + ) + .collect::>, _>>()?; + + if keys.len() != 1 { + return Err(KeyError::BadKeyCount(keys.len()).into()); + } + + // Ensure we can use the key. + let key = keys.remove(0); + self.default_crypto_provider() + .key_provider + .load_private_key(key.clone_key()) + .map_err(KeyError::Unsupported)?; + + Ok(key) + } + + pub(crate) fn default_crypto_provider(&self) -> CryptoProvider { + CryptoProvider::get_default() + .map(|arc| (**arc).clone()) + .unwrap_or_else(|| rustls::crypto::CryptoProvider { + cipher_suites: self.ciphers().map(|cipher| match cipher { + CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS13_AES_256_GCM_SHA384, + CipherSuite::TLS_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS13_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => + ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }).collect(), + ..ring::default_provider() + }) + } +} + impl CipherSuite { /// The default set and order of cipher suites. These are all of the /// variants in [`CipherSuite`] in their declaration order. @@ -474,33 +544,6 @@ impl CipherSuite { } } -impl From for rustls::SupportedCipherSuite { - fn from(cipher: CipherSuite) -> Self { - use rustls::crypto::ring::cipher_suite; - - match cipher { - CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_AES_256_GCM_SHA384 => - cipher_suite::TLS13_AES_256_GCM_SHA384, - CipherSuite::TLS_AES_128_GCM_SHA256 => - cipher_suite::TLS13_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 => - cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - } - } -} - pub(crate) fn to_reader( value: &Either> ) -> io::Result> { @@ -522,6 +565,7 @@ pub(crate) fn to_reader( #[cfg(test)] mod tests { + use super::*; use figment::{Figment, providers::{Toml, Format}}; #[test] @@ -650,4 +694,42 @@ mod tests { Ok(()) }); } + + macro_rules! tls_example_private_pem { + ($k:expr) => { + concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k) + } + } + + #[test] + fn verify_load_private_keys_of_different_types() -> Result<()> { + let key_paths = [ + tls_example_private_pem!("rsa_sha256_key.pem"), + tls_example_private_pem!("ecdsa_nistp256_sha256_key_pkcs8.pem"), + tls_example_private_pem!("ecdsa_nistp384_sha384_key_pkcs8.pem"), + tls_example_private_pem!("ed25519_key.pem"), + ]; + + for key in key_paths { + TlsConfig::from_paths("", key).load_key()?; + } + + Ok(()) + } + + #[test] + fn verify_load_certs_of_different_types() -> Result<()> { + let cert_paths = [ + tls_example_private_pem!("rsa_sha256_cert.pem"), + tls_example_private_pem!("ecdsa_nistp256_sha256_cert.pem"), + tls_example_private_pem!("ecdsa_nistp384_sha384_cert.pem"), + tls_example_private_pem!("ed25519_cert.pem"), + ]; + + for cert in cert_paths { + TlsConfig::from_paths(cert, "").load_certs()?; + } + + Ok(()) + } } diff --git a/core/lib/src/tls/mod.rs b/core/lib/src/tls/mod.rs index d6128e3b..7f5a05de 100644 --- a/core/lib/src/tls/mod.rs +++ b/core/lib/src/tls/mod.rs @@ -1,6 +1,5 @@ mod error; pub(crate) mod config; -pub(crate) mod util; pub use error::Result; pub use config::{TlsConfig, CipherSuite}; diff --git a/core/lib/src/tls/util.rs b/core/lib/src/tls/util.rs deleted file mode 100644 index 497c5201..00000000 --- a/core/lib/src/tls/util.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::io; - -use rustls::RootCertStore; -use rustls::crypto::CryptoProvider; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - -use crate::tls::error::{Result, Error, KeyError}; - -/// Loads certificates from `reader`. -pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result>> { - rustls_pemfile::certs(reader) - .collect::>() - .map_err(Error::CertChain) -} - -/// Load and decode the private key from `reader`. -pub fn load_key(reader: &mut dyn io::BufRead) -> Result> { - use rustls_pemfile::Item::*; - - let mut keys: Vec> = rustls_pemfile::read_all(reader) - .map(|result| result.map_err(KeyError::Io) - .and_then(|item| match item { - Pkcs1Key(key) => Ok(key.into()), - Pkcs8Key(key) => Ok(key.into()), - Sec1Key(key) => Ok(key.into()), - _ => Err(KeyError::BadItem(item)) - }) - ) - .collect::>()?; - - if keys.len() != 1 { - return Err(KeyError::BadKeyCount(keys.len()).into()); - } - - // Ensure we can use the key. - let key = keys.remove(0); - get_crypto_provider().key_provider.load_private_key(key.clone_key()) - .map_err(KeyError::Unsupported)?; - Ok(key) -} - -/// Load and decode CA certificates from `reader`. -pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result { - let mut roots = rustls::RootCertStore::empty(); - for cert in load_cert_chain(reader)? { - roots.add(cert).map_err(Error::CertAuth)?; - } - - Ok(roots) -} - -pub(crate) fn get_crypto_provider() -> CryptoProvider { - if let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() { - CryptoProvider::clone(crypto_provider) - } else { - let crypto_provider = rustls::crypto::ring::default_provider(); - // Should only fail due to other concurrent install, so we ignore it - let _ = crypto_provider.clone().install_default(); - - crypto_provider - } - -} - -#[cfg(test)] -mod test { - use super::*; - - macro_rules! tls_example_key { - ($k:expr) => { - include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)) - } - } - - #[test] - fn verify_load_private_keys_of_different_types() -> Result<()> { - let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem"); - let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem"); - let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem"); - let ed2551_key = tls_example_key!("ed25519_key.pem"); - - load_key(&mut &rsa_sha256_key[..])?; - load_key(&mut &ecdsa_nistp256_sha256_key[..])?; - load_key(&mut &ecdsa_nistp384_sha384_key[..])?; - load_key(&mut &ed2551_key[..])?; - - Ok(()) - } - - #[test] - fn verify_load_certs_of_different_types() -> Result<()> { - let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem"); - let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem"); - let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem"); - let ed2551_cert = tls_example_key!("ed25519_cert.pem"); - - load_cert_chain(&mut &rsa_sha256_cert[..])?; - load_cert_chain(&mut &ecdsa_nistp256_sha256_cert[..])?; - load_cert_chain(&mut &ecdsa_nistp384_sha384_cert[..])?; - load_cert_chain(&mut &ed2551_cert[..])?; - - Ok(()) - } -} diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index a505647f..aa6e502f 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -7,5 +7,7 @@ publish = false [dependencies] rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3-preview"] } -rustls = { version = "0.23", features = ["aws_lc_rs"] } yansi = "1.0.1" + +[target.'cfg(unix)'.dependencies] +rustls = { version = "0.23", features = ["aws_lc_rs"] } diff --git a/examples/tls/private/gen_certs.sh b/examples/tls/private/gen_certs.sh index 2418c27b..d98d152e 100755 --- a/examples/tls/private/gen_certs.sh +++ b/examples/tls/private/gen_certs.sh @@ -8,6 +8,7 @@ # rsa_sha256 # ecdsa_nistp256_sha256 # ecdsa_nistp384_sha384 +# ecdsa_nistp521_sha512 # # Generate a certificate of the [cert-kind] key type, or if no cert-kind is # specified, all of the certificates. diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index cec2e9f7..413e64fd 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -65,13 +65,31 @@ fn insecure_cookies() { assert_eq!(c4.secure(), None); } -#[test] -fn hello_world() { +fn validate_profiles(profiles: &[&str]) { use rocket::listener::DefaultListener; use rocket::config::{Config, SecretKey}; - use rustls::crypto::aws_lc_rs; - let mut profiles = vec![ + for profile in profiles { + let config = Config { + secret_key: SecretKey::generate().unwrap(), + ..Config::debug_default() + }; + + let figment = Config::figment().merge(config).select(profile); + let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); + let response = client.get("/").dispatch(); + assert_eq!(response.into_string().unwrap(), "Hello, world!"); + + let figment = client.rocket().figment(); + let listener: DefaultListener = figment.extract().unwrap(); + assert_eq!(figment.profile(), profile); + listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); + } +} + +#[test] +fn validate_tls_profiles() { + const DEFAULT_PROFILES: &[&str] = &[ "rsa_sha256", "ecdsa_nistp256_sha256_pkcs8", "ecdsa_nistp384_sha384_pkcs8", @@ -80,29 +98,11 @@ fn hello_world() { "ed25519", ]; - for use_aws_lc in [false, true] { - if use_aws_lc { - let crypto_provider = aws_lc_rs::default_provider(); - crypto_provider.install_default().unwrap(); + validate_profiles(DEFAULT_PROFILES); - profiles.push("ecdsa_nistp521_sha512_pkcs8"); - } - - for profile in &profiles { - let config = Config { - secret_key: SecretKey::generate().unwrap(), - ..Config::debug_default() - }; - - let figment = Config::figment().merge(config).select(profile); - let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); - let response = client.get("/").dispatch(); - assert_eq!(response.into_string().unwrap(), "Hello, world!"); - - let figment = client.rocket().figment(); - let listener: DefaultListener = figment.extract().unwrap(); - assert_eq!(figment.profile(), profile); - listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); - } + #[cfg(unix)] { + rustls::crypto::aws_lc_rs::default_provider().install_default().unwrap(); + validate_profiles(DEFAULT_PROFILES); + validate_profiles(&["ecdsa_nistp521_sha512_pkcs8"]); } }