diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index 9aa6a726..a3537d7b 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -30,9 +30,9 @@ percent-encoding = "2" http = "0.2" time = { version = "0.3", features = ["formatting", "macros"] } indexmap = "2" -rustls = { version = "0.21", optional = true } -tokio-rustls = { version = "0.24", optional = true } -rustls-pemfile = { version = "1.0.2", optional = true } +rustls = { version = "0.22", optional = true } +tokio-rustls = { version = "0.25", optional = true } +rustls-pemfile = { version = "2.0.0", optional = true } tokio = { version = "1.6.1", features = ["net", "sync", "time"] } log = "0.4" ref-cast = "1.0" diff --git a/core/http/src/listener.rs b/core/http/src/listener.rs index f898a108..956c8ec4 100644 --- a/core/http/src/listener.rs +++ b/core/http/src/listener.rs @@ -17,36 +17,45 @@ use state::InitCell; pub use tokio::net::TcpListener; /// A thin wrapper over raw, DER-encoded X.509 client certificate data. -// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`. -#[doc(inline)] -#[cfg(feature = "tls")] -pub use rustls::Certificate as CertificateData; +#[cfg(not(feature = "tls"))] +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CertificateDer(pub(crate) Vec); /// A thin wrapper over raw, DER-encoded X.509 client certificate data. -#[cfg(not(feature = "tls"))] -#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct CertificateData(pub Vec); +#[cfg(feature = "tls")] +#[derive(Debug, Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>); /// A collection of raw certificate data. #[derive(Clone, Default)] -pub struct Certificates(Arc>>); +pub struct Certificates(Arc>>); -impl From> for Certificates { - fn from(value: Vec) -> Self { +impl From> for Certificates { + fn from(value: Vec) -> Self { Certificates(Arc::new(value.into())) } } +#[cfg(feature = "tls")] +impl From>> for Certificates { + fn from(value: Vec>) -> Self { + let value: Vec<_> = value.into_iter().map(CertificateDer).collect(); + Certificates(Arc::new(value.into())) + } +} + +#[doc(hidden)] impl Certificates { /// Set the the raw certificate chain data. Only the first call actually /// sets the data; the remaining do nothing. #[cfg(feature = "tls")] - pub(crate) fn set(&self, data: Vec) { + pub(crate) fn set(&self, data: Vec) { self.0.set(data); } /// Returns the raw certificate chain data, if any is available. - pub fn chain_data(&self) -> Option<&[CertificateData]> { + pub fn chain_data(&self) -> Option<&[CertificateDer]> { self.0.try_get().map(|v| v.as_slice()) } } diff --git a/core/http/src/tls/error.rs b/core/http/src/tls/error.rs new file mode 100644 index 00000000..429f4a9d --- /dev/null +++ b/core/http/src/tls/error.rs @@ -0,0 +1,95 @@ +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum KeyError { + BadKeyCount(usize), + Io(std::io::Error), + Unsupported(rustls::Error), + BadItem(rustls_pemfile::Item), +} + +#[derive(Debug)] +pub enum Error { + Io(std::io::Error), + Tls(rustls::Error), + Mtls(rustls::server::VerifierBuilderError), + CertChain(std::io::Error), + PrivKey(KeyError), + CertAuth(rustls::Error), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Error::*; + + match self { + Io(e) => write!(f, "i/o error during tls binding: {e}"), + Tls(e) => write!(f, "tls configuration error: {e}"), + Mtls(e) => write!(f, "mtls verifier error: {e}"), + CertChain(e) => write!(f, "failed to process certificate chain: {e}"), + PrivKey(e) => write!(f, "failed to process private key: {e}"), + CertAuth(e) => write!(f, "failed to process certificate authority: {e}"), + } + } +} + +impl std::fmt::Display for KeyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use KeyError::*; + + match self { + Io(e) => write!(f, "error reading key file: {e}"), + BadKeyCount(0) => write!(f, "no valid keys found. is the file malformed?"), + BadKeyCount(n) => write!(f, "expected exactly 1 key, found {n}"), + Unsupported(e) => write!(f, "key is valid but is unsupported: {e}"), + BadItem(i) => write!(f, "found unexpected item in key file: {i:#?}"), + } + } +} + +impl std::error::Error for KeyError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + KeyError::Io(e) => Some(e), + KeyError::Unsupported(e) => Some(e), + _ => None, + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Io(e) => Some(e), + Error::Tls(e) => Some(e), + Error::Mtls(e) => Some(e), + Error::CertChain(e) => Some(e), + Error::PrivKey(e) => Some(e), + Error::CertAuth(e) => Some(e), + } + } +} + +impl From for Error { + fn from(e: std::io::Error) -> Self { + Error::Io(e) + } +} + +impl From for Error { + fn from(e: rustls::Error) -> Self { + Error::Tls(e) + } +} + +impl From for Error { + fn from(value: rustls::server::VerifierBuilderError) -> Self { + Error::Mtls(value) + } +} + +impl From for Error { + fn from(value: KeyError) -> Self { + Error::PrivKey(value) + } +} diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs index 8c675d61..7ef76ebd 100644 --- a/core/http/src/tls/listener.rs +++ b/core/http/src/tls/listener.rs @@ -8,9 +8,10 @@ use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream}; +use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; -use crate::tls::util::{load_certs, load_private_key, load_ca_certs}; -use crate::listener::{Connection, Listener, Certificates}; +use crate::tls::util::{load_cert_chain, load_key, load_ca_certs}; +use crate::listener::{Connection, Listener, Certificates, CertificateDer}; /// A TLS listener over TCP. pub struct TlsListener { @@ -40,7 +41,7 @@ pub struct TlsListener { /// /// To work around this, we "lie" when `peer_certificates()` are requested and /// always return `Some(Certificates)`. Internally, `Certificates` is an -/// `Arc>>`, effectively a shared, thread-safe, +/// `Arc>>`, effectively a shared, thread-safe, /// `OnceCell`. The cell is initially empty and is filled as soon as the /// handshake is complete. If the certificate data were to be requested prior to /// this point, it would be empty. However, in Rocket, we only request @@ -72,49 +73,43 @@ pub struct Config { } impl TlsListener { - pub async fn bind(addr: SocketAddr, mut c: Config) -> io::Result + pub async fn bind(addr: SocketAddr, mut c: Config) -> crate::tls::Result where R: io::BufRead { - use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient}; - use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig}; - - let cert_chain = load_certs(&mut c.cert_chain) - .map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?; - - let key = load_private_key(&mut c.private_key) - .map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?; - - let client_auth = match c.ca_certs { - Some(ref mut ca_certs) => match load_ca_certs(ca_certs) { - Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(), - Ok(ca) => AllowAnyAnonymousOrAuthenticatedClient::new(ca).boxed(), - Err(e) => return Err(io::Error::new(e.kind(), format!("bad CA cert(s): {}", e))), - }, - None => NoClientAuth::boxed(), + let provider = rustls::crypto::CryptoProvider { + cipher_suites: c.ciphersuites, + ..rustls::crypto::ring::default_provider() }; - let mut tls_config = ServerConfig::builder() - .with_cipher_suites(&c.ciphersuites) - .with_safe_default_kx_groups() - .with_safe_default_protocol_versions() - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))? - .with_client_cert_verifier(client_auth) - .with_single_cert(cert_chain, key) - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + let verifier = match c.ca_certs { + Some(ref mut ca_certs) => { + let ca_roots = Arc::new(load_ca_certs(ca_certs)?); + let verifier = WebPkiClientVerifier::builder(ca_roots); + match c.mandatory_mtls { + true => verifier.build()?, + false => verifier.allow_unauthenticated().build()?, + } + }, + None => WebPkiClientVerifier::no_client_auth(), + }; - tls_config.ignore_client_order = c.prefer_server_order; + let key = load_key(&mut c.private_key)?; + let cert_chain = load_cert_chain(&mut c.cert_chain)?; + let mut config = ServerConfig::builder_with_provider(Arc::new(provider)) + .with_safe_default_protocol_versions()? + .with_client_cert_verifier(verifier) + .with_single_cert(cert_chain, key)?; - tls_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + config.ignore_client_order = c.prefer_server_order; + config.session_storage = ServerSessionMemoryCache::new(1024); + config.ticketer = rustls::crypto::ring::Ticketer::new()?; + config.alpn_protocols = vec![b"http/1.1".to_vec()]; if cfg!(feature = "http2") { - tls_config.alpn_protocols.insert(0, b"h2".to_vec()); + config.alpn_protocols.insert(0, b"h2".to_vec()); } - tls_config.session_storage = ServerSessionMemoryCache::new(1024); - tls_config.ticketer = rustls::Ticketer::new() - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?; - let listener = TcpListener::bind(addr).await?; - let acceptor = TlsAcceptor::from(Arc::new(tls_config)); + let acceptor = TlsAcceptor::from(Arc::new(config)); Ok(TlsListener { listener, acceptor }) } } @@ -179,8 +174,10 @@ impl TlsStream { TlsState::Handshaking(ref mut accept) => { match futures::ready!(Pin::new(accept).poll(cx)) { Ok(stream) => { - if let Some(cert_chain) = stream.get_ref().1.peer_certificates() { - self.certs.set(cert_chain.to_vec()); + if let Some(peer_certs) = stream.get_ref().1.peer_certificates() { + self.certs.set(peer_certs.into_iter() + .map(|v| CertificateDer(v.clone().into_owned())) + .collect()); } self.state = TlsState::Streaming(stream); diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs index 04959ba2..8d3bcb3d 100644 --- a/core/http/src/tls/mod.rs +++ b/core/http/src/tls/mod.rs @@ -6,3 +6,6 @@ pub mod mtls; pub use rustls; pub use listener::{TlsListener, Config}; pub mod util; +pub mod error; + +pub use error::Result; diff --git a/core/http/src/tls/mtls.rs b/core/http/src/tls/mtls.rs index 7a7cd169..417db2f8 100644 --- a/core/http/src/tls/mtls.rs +++ b/core/http/src/tls/mtls.rs @@ -41,7 +41,7 @@ use x509_parser::nom; use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer}; use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME; -use crate::listener::CertificateData; +use crate::listener::CertificateDer; /// A type alias for [`Result`](std::result::Result) with the error type set to /// [`Error`]. @@ -144,7 +144,7 @@ pub type Result = std::result::Result; #[derive(Debug, PartialEq)] pub struct Certificate<'a> { x509: X509Certificate<'a>, - data: &'a CertificateData, + data: &'a CertificateDer, } /// An X.509 Distinguished Name (DN) found in a [`Certificate`]. @@ -224,7 +224,7 @@ impl<'a> Certificate<'a> { /// PRIVATE: For internal Rocket use only! #[doc(hidden)] - pub fn parse(chain: &[CertificateData]) -> Result> { + pub fn parse(chain: &[CertificateDer]) -> Result> { let data = chain.first().ok_or_else(|| Error::Empty)?; let x509 = Certificate::parse_one(&data.0)?; Ok(Certificate { x509, data }) diff --git a/core/http/src/tls/util.rs b/core/http/src/tls/util.rs index 8eb54c5d..c07135ad 100644 --- a/core/http/src/tls/util.rs +++ b/core/http/src/tls/util.rs @@ -1,55 +1,47 @@ -use std::io::{self, Cursor, Read}; +use std::io; -use rustls::{Certificate, PrivateKey, RootCertStore}; +use rustls::RootCertStore; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -fn err(message: impl Into>) -> io::Error { - io::Error::new(io::ErrorKind::Other, message.into()) -} +use crate::tls::error::{Result, Error, KeyError}; /// Loads certificates from `reader`. -pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result> { - let certs = rustls_pemfile::certs(reader).map_err(|_| err("invalid certificate"))?; - Ok(certs.into_iter().map(Certificate).collect()) +pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result>> { + rustls_pemfile::certs(reader) + .collect::>() + .map_err(Error::CertChain) } /// Load and decode the private key from `reader`. -pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result { - // "rsa" (PKCS1) PEM files have a different first-line header than PKCS8 - // PEM files, use that to determine the parse function to use. - let mut header = String::new(); - let private_keys_fn = loop { - header.clear(); - if reader.read_line(&mut header)? == 0 { - return Err(err("failed to find key header; supported formats are: RSA, PKCS8, SEC1")); - } +pub fn load_key(reader: &mut dyn io::BufRead) -> Result> { + use rustls_pemfile::Item::*; - break match header.trim_end() { - "-----BEGIN RSA PRIVATE KEY-----" => rustls_pemfile::rsa_private_keys, - "-----BEGIN PRIVATE KEY-----" => rustls_pemfile::pkcs8_private_keys, - "-----BEGIN EC PRIVATE KEY-----" => rustls_pemfile::ec_private_keys, - _ => continue, - }; - }; + let mut keys: Vec> = rustls_pemfile::read_all(reader) + .map(|result| result.map_err(KeyError::Io) + .and_then(|item| match item { + Pkcs1Key(key) => Ok(key.into()), + Pkcs8Key(key) => Ok(key.into()), + Sec1Key(key) => Ok(key.into()), + _ => Err(KeyError::BadItem(item)) + }) + ) + .collect::>()?; - let key = private_keys_fn(&mut Cursor::new(header).chain(reader)) - .map_err(|_| err("invalid key file")) - .and_then(|mut keys| match keys.len() { - 0 => Err(err("no valid keys found; is the file malformed?")), - 1 => Ok(PrivateKey(keys.remove(0))), - n => Err(err(format!("expected 1 key, found {}", n))), - })?; + if keys.len() != 1 { + return Err(KeyError::BadKeyCount(keys.len()).into()); + } // Ensure we can use the key. - rustls::sign::any_supported_type(&key) - .map_err(|_| err("key parsed but is unusable")) - .map(|_| key) + let key = keys.remove(0); + rustls::crypto::ring::sign::any_supported_type(&key).map_err(KeyError::Unsupported)?; + Ok(key) } /// Load and decode CA certificates from `reader`. -pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result { +pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result { let mut roots = rustls::RootCertStore::empty(); - for cert in load_certs(reader)? { - roots.add(&cert).map_err(|e| err(format!("CA cert error: {}", e)))?; + for cert in load_cert_chain(reader)? { + roots.add(cert).map_err(Error::CertAuth)?; } Ok(roots) @@ -66,31 +58,31 @@ mod test { } #[test] - fn verify_load_private_keys_of_different_types() -> io::Result<()> { + fn verify_load_private_keys_of_different_types() -> Result<()> { let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem"); let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem"); let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem"); let ed2551_key = tls_example_key!("ed25519_key.pem"); - load_private_key(&mut Cursor::new(rsa_sha256_key))?; - load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?; - load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?; - load_private_key(&mut Cursor::new(ed2551_key))?; + load_key(&mut &rsa_sha256_key[..])?; + load_key(&mut &ecdsa_nistp256_sha256_key[..])?; + load_key(&mut &ecdsa_nistp384_sha384_key[..])?; + load_key(&mut &ed2551_key[..])?; Ok(()) } #[test] - fn verify_load_certs_of_different_types() -> io::Result<()> { + fn verify_load_certs_of_different_types() -> Result<()> { let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem"); let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem"); let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem"); let ed2551_cert = tls_example_key!("ed25519_cert.pem"); - load_certs(&mut Cursor::new(rsa_sha256_cert))?; - load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?; - load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?; - load_certs(&mut Cursor::new(ed2551_cert))?; + load_cert_chain(&mut &rsa_sha256_cert[..])?; + load_cert_chain(&mut &ecdsa_nistp256_sha256_cert[..])?; + load_cert_chain(&mut &ecdsa_nistp384_sha384_cert[..])?; + load_cert_chain(&mut &ed2551_cert[..])?; Ok(()) } diff --git a/core/lib/src/config/tls.rs b/core/lib/src/config/tls.rs index 41e88082..12face00 100644 --- a/core/lib/src/config/tls.rs +++ b/core/lib/src/config/tls.rs @@ -631,7 +631,7 @@ mod with_tls_feature { use crate::http::tls::Config; use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher; - use crate::http::tls::rustls::cipher_suite; + use crate::http::tls::rustls::crypto::ring::cipher_suite; use yansi::Paint; diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 28e71fc2..ff3ef79f 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -76,6 +76,9 @@ pub struct Error { pub enum ErrorKind { /// Binding to the provided address/port failed. Bind(io::Error), + /// Binding via TLS to the provided address/port failed. + #[cfg(feature = "tls")] + TlsBind(crate::http::tls::error::Error), /// An I/O error occurred during launch. Io(io::Error), /// A valid [`Config`](crate::Config) could not be extracted from the @@ -234,6 +237,12 @@ impl Error { "aborting due to failed shutdown" } + #[cfg(feature = "tls")] + ErrorKind::TlsBind(e) => { + error!("Rocket failed to bind via TLS to network socket."); + info_!("{}", e); + "aborting due to TLS bind error" + } } } } @@ -244,15 +253,17 @@ impl fmt::Display for ErrorKind { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ErrorKind::Bind(e) => write!(f, "binding failed: {}", e), - ErrorKind::Io(e) => write!(f, "I/O error: {}", e), + ErrorKind::Bind(e) => write!(f, "binding failed: {e}"), + ErrorKind::Io(e) => write!(f, "I/O error: {e}"), ErrorKind::Collisions(_) => "collisions detected".fmt(f), ErrorKind::FailedFairings(_) => "launch fairing(s) failed".fmt(f), ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), - ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {}", e), + ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"), ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f), + #[cfg(feature = "tls")] + ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"), } } } diff --git a/core/lib/src/local/request.rs b/core/lib/src/local/request.rs index eabe933e..78e97595 100644 --- a/core/lib/src/local/request.rs +++ b/core/lib/src/local/request.rs @@ -228,10 +228,10 @@ macro_rules! pub_request_impl { #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub fn identity(mut self, reader: C) -> Self { - use crate::http::{tls::util::load_certs, private::Certificates}; + use crate::http::{tls::util::load_cert_chain, private::Certificates}; let mut reader = std::io::BufReader::new(reader); - let certs = load_certs(&mut reader).map(Certificates::from); + let certs = load_cert_chain(&mut reader).map(Certificates::from); self._request_mut().connection.client_certificates = certs.ok(); self } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index a8b5d6a4..da2626e3 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -427,7 +427,7 @@ impl Rocket { use crate::http::tls::TlsListener; let conf = config.to_native_config().map_err(ErrorKind::Io)?; - let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?; + let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?; addr = l.local_addr().unwrap_or(addr); self.config.address = addr.ip(); self.config.port = addr.port(); diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index 11bcc1ff..9c724939 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -7,3 +7,4 @@ publish = false [dependencies] rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] } +yansi = "1.0.0-rc.1" diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 8ecb685c..4ce4254c 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -22,5 +22,5 @@ fn rocket() -> _ { // Run `./private/gen_certs.sh` to generate a CA and key pairs. rocket::build() .mount("/", routes![hello, mutual]) - .attach(redirector::Redirector { port: 3000 }) + .attach(redirector::Redirector::on(3000)) } diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index 0aafddf9..aeffe9ad 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -1,22 +1,35 @@ //! Redirect all HTTP requests to HTTPs. +use std::sync::OnceLock; + use rocket::http::Status; use rocket::log::LogLevel; use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config}; use rocket::fairing::{Fairing, Info, Kind}; use rocket::response::Redirect; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct Redirector { - pub port: u16 + pub listen_port: u16, + pub tls_port: OnceLock, } impl Redirector { - // Route function that gets call on every single request. + pub fn on(port: u16) -> Self { + Redirector { listen_port: port, tls_port: OnceLock::new() } + } + + // Route function that gets called on every single request. fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { // FIXME: Check the host against a whitelist! + let redirector = req.rocket().state::().expect("managed Self"); if let Some(host) = req.host() { - let https_uri = format!("https://{}{}", host, req.uri()); + let domain = host.domain(); + let https_uri = match redirector.tls_port.get() { + Some(443) | None => format!("https://{domain}{}", req.uri()), + Some(port) => format!("https://{domain}:{port}{}", req.uri()), + }; + route::Outcome::from(req, Redirect::permanent(https_uri)).pin() } else { route::Outcome::from(req, Status::BadRequest).pin() @@ -25,13 +38,21 @@ impl Redirector { // Launch an instance of Rocket than handles redirection on `self.port`. pub async fn try_launch(self, mut config: Config) -> Result, Error> { + use yansi::Paint; use rocket::http::Method::*; + // Determine the port TLS is being served on. + let tls_port = self.tls_port.get_or_init(|| config.port); + // Adjust config for redirector: disable TLS, set port, disable logging. config.tls = None; - config.port = self.port; + config.port = self.listen_port; config.log_level = LogLevel::Critical; + info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta()); + info_!("redirecting on insecure port {} to TLS port {}", + self.listen_port.yellow(), tls_port.green()); + // Build a vector of routes to `redirect` on `` for each method. let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch] .into_iter() @@ -39,6 +60,7 @@ impl Redirector { .collect::>(); rocket::custom(config) + .manage(self) .mount("/", redirects) .launch() .await @@ -48,11 +70,14 @@ impl Redirector { #[rocket::async_trait] impl Fairing for Redirector { fn info(&self) -> Info { - Info { name: "HTTP -> HTTPS Redirector", kind: Kind::Liftoff } + Info { + name: "HTTP -> HTTPS Redirector", + kind: Kind::Liftoff | Kind::Singleton + } } async fn on_liftoff(&self, rkt: &Rocket) { - let (this, shutdown, config) = (*self, rkt.shutdown(), rkt.config().clone()); + let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone()); let _ = rocket::tokio::spawn(async move { if let Err(e) = this.try_launch(config).await { error!("Failed to start HTTP -> HTTPS redirector.");