Allow setting mTLS certificates on local 'Client'.

This allows testing with client certificates.

Co-authored-by: Brett Buford <blbuford@gmail.com>
This commit is contained in:
Sergio Benitez 2023-03-31 11:13:40 -07:00
parent 847e87d5c9
commit aa6ad7030a
5 changed files with 70 additions and 4 deletions

View File

@ -31,6 +31,12 @@ pub struct CertificateData(pub Vec<u8>);
#[derive(Clone, Default)]
pub struct Certificates(Arc<Storage<Vec<CertificateData>>>);
impl From<Vec<CertificateData>> for Certificates {
fn from(value: Vec<CertificateData>) -> Self {
Certificates(Arc::new(value.into()))
}
}
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.

View File

@ -1,8 +1,8 @@
mod listener;
mod util;
#[cfg(feature = "mtls")]
pub mod mtls;
pub use rustls;
pub use listener::{TlsListener, Config};
pub mod util;

View File

@ -62,7 +62,7 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// * The certificates are active and not yet expired.
/// * The client's certificate chain was signed by the CA identified by the
/// configured `ca_certs` and with respect to SNI, if any. See [module level
/// docs](self) for configuration details.
/// docs](crate::mtls) for configuration details.
///
/// If the client does not present certificates, the guard _forwards_.
///

View File

@ -191,8 +191,47 @@ macro_rules! pub_request_impl {
self
}
/// Sets the body data of the request.
/// Set mTLS client certificates to send along with the request.
///
/// If the request already contained certificates, they are replaced with
/// thsoe in `reader.`
///
/// `reader` is expected to be PEM-formatted and contain X509 certificates.
/// If it contains more than one certificate, the entire chain is set on the
/// request. If it contains items other than certificates, the certificate
/// chain up to the first non-certificate item is set on the request. If
/// `reader` is syntactically invalid PEM, certificates are cleared on the
/// request.
///
/// The type `C` can be anything that implements [`std::io::Read`]. This
/// includes: `&[u8]`, `File`, `&File`, `Stdin`, and so on. To read a file
/// in at compile-time, use [`include_bytes!()`].
///
/// ```rust
/// use std::fs::File;
///
#[doc = $import]
/// use rocket::fs::relative;
///
/// # Client::_test(|_, request, _| {
/// let request: LocalRequest = request;
/// let path = relative!("../../examples/tls/private/ed25519_cert.pem");
/// let req = request.identity(File::open(path).unwrap());
/// # });
/// ```
#[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
use crate::http::{tls::util::load_certs, private::Certificates};
let mut reader = std::io::BufReader::new(reader);
let certs = load_certs(&mut reader).map(Certificates::from);
self._request_mut().connection.client_certificates = certs.ok();
self
}
/// Sets the body data of the request.
///core/lib/src/local/request.rs
/// # Examples
///
/// ```rust

View File

@ -1,4 +1,25 @@
use std::fs::{self, File};
use rocket::local::blocking::Client;
use rocket::fs::relative;
#[test]
fn hello_mutual() {
let client = Client::tracked(super::rocket()).unwrap();
let cert_paths = fs::read_dir(relative!("private")).unwrap()
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
for path in cert_paths {
let response = client.get("/")
.identity(File::open(&path).unwrap())
.dispatch();
let response = response.into_string().unwrap();
let subject = response.split(']').nth(1).unwrap().trim();
assert_eq!(subject, "C=US, ST=CA, O=Rocket, CN=localhost");
}
}
#[test]
fn hello_world() {
@ -16,6 +37,6 @@ fn hello_world() {
let config = rocket::Config::figment().select(profile);
let client = Client::tracked(super::rocket().configure(config)).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.into_string(), Some("Hello, world!".into()));
assert_eq!(response.into_string().unwrap(), "Hello, world!");
}
}