From aa6ad7030ad3b374295ae4a11465a4c14e0dbf0f Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 31 Mar 2023 11:13:40 -0700 Subject: [PATCH] Allow setting mTLS certificates on local 'Client'. This allows testing with client certificates. Co-authored-by: Brett Buford --- core/http/src/listener.rs | 6 +++++ core/http/src/tls/mod.rs | 2 +- core/http/src/tls/mtls.rs | 2 +- core/lib/src/local/request.rs | 41 ++++++++++++++++++++++++++++++++++- examples/tls/src/tests.rs | 23 +++++++++++++++++++- 5 files changed, 70 insertions(+), 4 deletions(-) diff --git a/core/http/src/listener.rs b/core/http/src/listener.rs index 7258721e..e958702e 100644 --- a/core/http/src/listener.rs +++ b/core/http/src/listener.rs @@ -31,6 +31,12 @@ pub struct CertificateData(pub Vec); #[derive(Clone, Default)] pub struct Certificates(Arc>>); +impl From> for Certificates { + fn from(value: Vec) -> Self { + Certificates(Arc::new(value.into())) + } +} + impl Certificates { /// Set the the raw certificate chain data. Only the first call actually /// sets the data; the remaining do nothing. diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs index b529ee40..04959ba2 100644 --- a/core/http/src/tls/mod.rs +++ b/core/http/src/tls/mod.rs @@ -1,8 +1,8 @@ mod listener; -mod util; #[cfg(feature = "mtls")] pub mod mtls; pub use rustls; pub use listener::{TlsListener, Config}; +pub mod util; diff --git a/core/http/src/tls/mtls.rs b/core/http/src/tls/mtls.rs index 2269a397..65246d4d 100644 --- a/core/http/src/tls/mtls.rs +++ b/core/http/src/tls/mtls.rs @@ -62,7 +62,7 @@ pub type Result = std::result::Result; /// * The certificates are active and not yet expired. /// * The client's certificate chain was signed by the CA identified by the /// configured `ca_certs` and with respect to SNI, if any. See [module level -/// docs](self) for configuration details. +/// docs](crate::mtls) for configuration details. /// /// If the client does not present certificates, the guard _forwards_. /// diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index 76ce4af3..4ba38634 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -191,8 +191,47 @@ macro_rules! pub_request_impl { self } - /// Sets the body data of the request. + /// Set mTLS client certificates to send along with the request. /// + /// If the request already contained certificates, they are replaced with + /// thsoe in `reader.` + /// + /// `reader` is expected to be PEM-formatted and contain X509 certificates. + /// If it contains more than one certificate, the entire chain is set on the + /// request. If it contains items other than certificates, the certificate + /// chain up to the first non-certificate item is set on the request. If + /// `reader` is syntactically invalid PEM, certificates are cleared on the + /// request. + /// + /// The type `C` can be anything that implements [`std::io::Read`]. This + /// includes: `&[u8]`, `File`, `&File`, `Stdin`, and so on. To read a file + /// in at compile-time, use [`include_bytes!()`]. + /// + /// ```rust + /// use std::fs::File; + /// + #[doc = $import] + /// use rocket::fs::relative; + /// + /// # Client::_test(|_, request, _| { + /// let request: LocalRequest = request; + /// let path = relative!("../../examples/tls/private/ed25519_cert.pem"); + /// let req = request.identity(File::open(path).unwrap()); + /// # }); + /// ``` + #[cfg(feature = "mtls")] + #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] + pub fn identity(mut self, reader: C) -> Self { + use crate::http::{tls::util::load_certs, private::Certificates}; + + let mut reader = std::io::BufReader::new(reader); + let certs = load_certs(&mut reader).map(Certificates::from); + self._request_mut().connection.client_certificates = certs.ok(); + self + } + + /// Sets the body data of the request. + ///core/lib/src/local/request.rs /// # Examples /// /// ```rust diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index ae0f12d7..171b5b5d 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -1,4 +1,25 @@ +use std::fs::{self, File}; + use rocket::local::blocking::Client; +use rocket::fs::relative; + +#[test] +fn hello_mutual() { + let client = Client::tracked(super::rocket()).unwrap(); + let cert_paths = fs::read_dir(relative!("private")).unwrap() + .map(|entry| entry.unwrap().path().to_string_lossy().into_owned()) + .filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem")); + + for path in cert_paths { + let response = client.get("/") + .identity(File::open(&path).unwrap()) + .dispatch(); + + let response = response.into_string().unwrap(); + let subject = response.split(']').nth(1).unwrap().trim(); + assert_eq!(subject, "C=US, ST=CA, O=Rocket, CN=localhost"); + } +} #[test] fn hello_world() { @@ -16,6 +37,6 @@ fn hello_world() { let config = rocket::Config::figment().select(profile); let client = Client::tracked(super::rocket().configure(config)).unwrap(); let response = client.get("/").dispatch(); - assert_eq!(response.into_string(), Some("Hello, world!".into())); + assert_eq!(response.into_string().unwrap(), "Hello, world!"); } }