mirror of https://github.com/rwf2/Rocket.git
Upgrade 'rustls' to '0.22'.
In the process, the following improvements were also made: * Error messages related to TLS were improved. * 'Redirector' in 'tls' example was improved.
This commit is contained in:
parent
a59f3c4c1f
commit
9c2b74b23c
|
@ -30,9 +30,9 @@ percent-encoding = "2"
|
|||
http = "0.2"
|
||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||
indexmap = "2"
|
||||
rustls = { version = "0.21", optional = true }
|
||||
tokio-rustls = { version = "0.24", optional = true }
|
||||
rustls-pemfile = { version = "1.0.2", optional = true }
|
||||
rustls = { version = "0.22", optional = true }
|
||||
tokio-rustls = { version = "0.25", optional = true }
|
||||
rustls-pemfile = { version = "2.0.0", optional = true }
|
||||
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
|
||||
log = "0.4"
|
||||
ref-cast = "1.0"
|
||||
|
|
|
@ -17,36 +17,45 @@ use state::InitCell;
|
|||
pub use tokio::net::TcpListener;
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
|
||||
#[doc(inline)]
|
||||
#[cfg(feature = "tls")]
|
||||
pub use rustls::Certificate as CertificateData;
|
||||
#[cfg(not(feature = "tls"))]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct CertificateDer(pub(crate) Vec<u8>);
|
||||
|
||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||
#[cfg(not(feature = "tls"))]
|
||||
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub struct CertificateData(pub Vec<u8>);
|
||||
#[cfg(feature = "tls")]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
#[repr(transparent)]
|
||||
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);
|
||||
|
||||
/// A collection of raw certificate data.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
|
||||
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);
|
||||
|
||||
impl From<Vec<CertificateData>> for Certificates {
|
||||
fn from(value: Vec<CertificateData>) -> Self {
|
||||
impl From<Vec<CertificateDer>> for Certificates {
|
||||
fn from(value: Vec<CertificateDer>) -> Self {
|
||||
Certificates(Arc::new(value.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
|
||||
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
|
||||
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
|
||||
Certificates(Arc::new(value.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl Certificates {
|
||||
/// Set the the raw certificate chain data. Only the first call actually
|
||||
/// sets the data; the remaining do nothing.
|
||||
#[cfg(feature = "tls")]
|
||||
pub(crate) fn set(&self, data: Vec<CertificateData>) {
|
||||
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
|
||||
self.0.set(data);
|
||||
}
|
||||
|
||||
/// Returns the raw certificate chain data, if any is available.
|
||||
pub fn chain_data(&self) -> Option<&[CertificateData]> {
|
||||
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
|
||||
self.0.try_get().map(|v| v.as_slice())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum KeyError {
|
||||
BadKeyCount(usize),
|
||||
Io(std::io::Error),
|
||||
Unsupported(rustls::Error),
|
||||
BadItem(rustls_pemfile::Item),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Io(std::io::Error),
|
||||
Tls(rustls::Error),
|
||||
Mtls(rustls::server::VerifierBuilderError),
|
||||
CertChain(std::io::Error),
|
||||
PrivKey(KeyError),
|
||||
CertAuth(rustls::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use Error::*;
|
||||
|
||||
match self {
|
||||
Io(e) => write!(f, "i/o error during tls binding: {e}"),
|
||||
Tls(e) => write!(f, "tls configuration error: {e}"),
|
||||
Mtls(e) => write!(f, "mtls verifier error: {e}"),
|
||||
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
|
||||
PrivKey(e) => write!(f, "failed to process private key: {e}"),
|
||||
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for KeyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use KeyError::*;
|
||||
|
||||
match self {
|
||||
Io(e) => write!(f, "error reading key file: {e}"),
|
||||
BadKeyCount(0) => write!(f, "no valid keys found. is the file malformed?"),
|
||||
BadKeyCount(n) => write!(f, "expected exactly 1 key, found {n}"),
|
||||
Unsupported(e) => write!(f, "key is valid but is unsupported: {e}"),
|
||||
BadItem(i) => write!(f, "found unexpected item in key file: {i:#?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for KeyError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
KeyError::Io(e) => Some(e),
|
||||
KeyError::Unsupported(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Error::Io(e) => Some(e),
|
||||
Error::Tls(e) => Some(e),
|
||||
Error::Mtls(e) => Some(e),
|
||||
Error::CertChain(e) => Some(e),
|
||||
Error::PrivKey(e) => Some(e),
|
||||
Error::CertAuth(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
Error::Io(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rustls::Error> for Error {
|
||||
fn from(e: rustls::Error) -> Self {
|
||||
Error::Tls(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rustls::server::VerifierBuilderError> for Error {
|
||||
fn from(value: rustls::server::VerifierBuilderError) -> Self {
|
||||
Error::Mtls(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KeyError> for Error {
|
||||
fn from(value: KeyError) -> Self {
|
||||
Error::PrivKey(value)
|
||||
}
|
||||
}
|
|
@ -8,9 +8,10 @@ use std::net::SocketAddr;
|
|||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
|
||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||
|
||||
use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
|
||||
use crate::listener::{Connection, Listener, Certificates};
|
||||
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
|
||||
use crate::listener::{Connection, Listener, Certificates, CertificateDer};
|
||||
|
||||
/// A TLS listener over TCP.
|
||||
pub struct TlsListener {
|
||||
|
@ -40,7 +41,7 @@ pub struct TlsListener {
|
|||
///
|
||||
/// To work around this, we "lie" when `peer_certificates()` are requested and
|
||||
/// always return `Some(Certificates)`. Internally, `Certificates` is an
|
||||
/// `Arc<InitCell<Vec<CertificateData>>>`, effectively a shared, thread-safe,
|
||||
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
|
||||
/// `OnceCell`. The cell is initially empty and is filled as soon as the
|
||||
/// handshake is complete. If the certificate data were to be requested prior to
|
||||
/// this point, it would be empty. However, in Rocket, we only request
|
||||
|
@ -72,49 +73,43 @@ pub struct Config<R> {
|
|||
}
|
||||
|
||||
impl TlsListener {
|
||||
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
|
||||
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
|
||||
where R: io::BufRead
|
||||
{
|
||||
use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient};
|
||||
use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig};
|
||||
|
||||
let cert_chain = load_certs(&mut c.cert_chain)
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?;
|
||||
|
||||
let key = load_private_key(&mut c.private_key)
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?;
|
||||
|
||||
let client_auth = match c.ca_certs {
|
||||
Some(ref mut ca_certs) => match load_ca_certs(ca_certs) {
|
||||
Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(),
|
||||
Ok(ca) => AllowAnyAnonymousOrAuthenticatedClient::new(ca).boxed(),
|
||||
Err(e) => return Err(io::Error::new(e.kind(), format!("bad CA cert(s): {}", e))),
|
||||
},
|
||||
None => NoClientAuth::boxed(),
|
||||
let provider = rustls::crypto::CryptoProvider {
|
||||
cipher_suites: c.ciphersuites,
|
||||
..rustls::crypto::ring::default_provider()
|
||||
};
|
||||
|
||||
let mut tls_config = ServerConfig::builder()
|
||||
.with_cipher_suites(&c.ciphersuites)
|
||||
.with_safe_default_kx_groups()
|
||||
.with_safe_default_protocol_versions()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
|
||||
.with_client_cert_verifier(client_auth)
|
||||
.with_single_cert(cert_chain, key)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
|
||||
let verifier = match c.ca_certs {
|
||||
Some(ref mut ca_certs) => {
|
||||
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
|
||||
let verifier = WebPkiClientVerifier::builder(ca_roots);
|
||||
match c.mandatory_mtls {
|
||||
true => verifier.build()?,
|
||||
false => verifier.allow_unauthenticated().build()?,
|
||||
}
|
||||
},
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
tls_config.ignore_client_order = c.prefer_server_order;
|
||||
let key = load_key(&mut c.private_key)?;
|
||||
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
|
||||
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
|
||||
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
config.ignore_client_order = c.prefer_server_order;
|
||||
config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
if cfg!(feature = "http2") {
|
||||
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
}
|
||||
|
||||
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
tls_config.ticketer = rustls::Ticketer::new()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?;
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
Ok(TlsListener { listener, acceptor })
|
||||
}
|
||||
}
|
||||
|
@ -179,8 +174,10 @@ impl TlsStream {
|
|||
TlsState::Handshaking(ref mut accept) => {
|
||||
match futures::ready!(Pin::new(accept).poll(cx)) {
|
||||
Ok(stream) => {
|
||||
if let Some(cert_chain) = stream.get_ref().1.peer_certificates() {
|
||||
self.certs.set(cert_chain.to_vec());
|
||||
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
|
||||
self.certs.set(peer_certs.into_iter()
|
||||
.map(|v| CertificateDer(v.clone().into_owned()))
|
||||
.collect());
|
||||
}
|
||||
|
||||
self.state = TlsState::Streaming(stream);
|
||||
|
|
|
@ -6,3 +6,6 @@ pub mod mtls;
|
|||
pub use rustls;
|
||||
pub use listener::{TlsListener, Config};
|
||||
pub mod util;
|
||||
pub mod error;
|
||||
|
||||
pub use error::Result;
|
||||
|
|
|
@ -41,7 +41,7 @@ use x509_parser::nom;
|
|||
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
|
||||
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
||||
|
||||
use crate::listener::CertificateData;
|
||||
use crate::listener::CertificateDer;
|
||||
|
||||
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
||||
/// [`Error`].
|
||||
|
@ -144,7 +144,7 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|||
#[derive(Debug, PartialEq)]
|
||||
pub struct Certificate<'a> {
|
||||
x509: X509Certificate<'a>,
|
||||
data: &'a CertificateData,
|
||||
data: &'a CertificateDer,
|
||||
}
|
||||
|
||||
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
||||
|
@ -224,7 +224,7 @@ impl<'a> Certificate<'a> {
|
|||
|
||||
/// PRIVATE: For internal Rocket use only!
|
||||
#[doc(hidden)]
|
||||
pub fn parse(chain: &[CertificateData]) -> Result<Certificate<'_>> {
|
||||
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
|
||||
let data = chain.first().ok_or_else(|| Error::Empty)?;
|
||||
let x509 = Certificate::parse_one(&data.0)?;
|
||||
Ok(Certificate { x509, data })
|
||||
|
|
|
@ -1,55 +1,47 @@
|
|||
use std::io::{self, Cursor, Read};
|
||||
use std::io;
|
||||
|
||||
use rustls::{Certificate, PrivateKey, RootCertStore};
|
||||
use rustls::RootCertStore;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
|
||||
fn err(message: impl Into<std::borrow::Cow<'static, str>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, message.into())
|
||||
}
|
||||
use crate::tls::error::{Result, Error, KeyError};
|
||||
|
||||
/// Loads certificates from `reader`.
|
||||
pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
|
||||
let certs = rustls_pemfile::certs(reader).map_err(|_| err("invalid certificate"))?;
|
||||
Ok(certs.into_iter().map(Certificate).collect())
|
||||
pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result<Vec<CertificateDer<'static>>> {
|
||||
rustls_pemfile::certs(reader)
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(Error::CertChain)
|
||||
}
|
||||
|
||||
/// Load and decode the private key from `reader`.
|
||||
pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey> {
|
||||
// "rsa" (PKCS1) PEM files have a different first-line header than PKCS8
|
||||
// PEM files, use that to determine the parse function to use.
|
||||
let mut header = String::new();
|
||||
let private_keys_fn = loop {
|
||||
header.clear();
|
||||
if reader.read_line(&mut header)? == 0 {
|
||||
return Err(err("failed to find key header; supported formats are: RSA, PKCS8, SEC1"));
|
||||
}
|
||||
pub fn load_key(reader: &mut dyn io::BufRead) -> Result<PrivateKeyDer<'static>> {
|
||||
use rustls_pemfile::Item::*;
|
||||
|
||||
break match header.trim_end() {
|
||||
"-----BEGIN RSA PRIVATE KEY-----" => rustls_pemfile::rsa_private_keys,
|
||||
"-----BEGIN PRIVATE KEY-----" => rustls_pemfile::pkcs8_private_keys,
|
||||
"-----BEGIN EC PRIVATE KEY-----" => rustls_pemfile::ec_private_keys,
|
||||
_ => continue,
|
||||
};
|
||||
};
|
||||
let mut keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::read_all(reader)
|
||||
.map(|result| result.map_err(KeyError::Io)
|
||||
.and_then(|item| match item {
|
||||
Pkcs1Key(key) => Ok(key.into()),
|
||||
Pkcs8Key(key) => Ok(key.into()),
|
||||
Sec1Key(key) => Ok(key.into()),
|
||||
_ => Err(KeyError::BadItem(item))
|
||||
})
|
||||
)
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let key = private_keys_fn(&mut Cursor::new(header).chain(reader))
|
||||
.map_err(|_| err("invalid key file"))
|
||||
.and_then(|mut keys| match keys.len() {
|
||||
0 => Err(err("no valid keys found; is the file malformed?")),
|
||||
1 => Ok(PrivateKey(keys.remove(0))),
|
||||
n => Err(err(format!("expected 1 key, found {}", n))),
|
||||
})?;
|
||||
if keys.len() != 1 {
|
||||
return Err(KeyError::BadKeyCount(keys.len()).into());
|
||||
}
|
||||
|
||||
// Ensure we can use the key.
|
||||
rustls::sign::any_supported_type(&key)
|
||||
.map_err(|_| err("key parsed but is unusable"))
|
||||
.map(|_| key)
|
||||
let key = keys.remove(0);
|
||||
rustls::crypto::ring::sign::any_supported_type(&key).map_err(KeyError::Unsupported)?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
/// Load and decode CA certificates from `reader`.
|
||||
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result<RootCertStore> {
|
||||
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result<RootCertStore> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
for cert in load_certs(reader)? {
|
||||
roots.add(&cert).map_err(|e| err(format!("CA cert error: {}", e)))?;
|
||||
for cert in load_cert_chain(reader)? {
|
||||
roots.add(cert).map_err(Error::CertAuth)?;
|
||||
}
|
||||
|
||||
Ok(roots)
|
||||
|
@ -66,31 +58,31 @@ mod test {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_private_keys_of_different_types() -> io::Result<()> {
|
||||
fn verify_load_private_keys_of_different_types() -> Result<()> {
|
||||
let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem");
|
||||
let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem");
|
||||
let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem");
|
||||
let ed2551_key = tls_example_key!("ed25519_key.pem");
|
||||
|
||||
load_private_key(&mut Cursor::new(rsa_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?;
|
||||
load_private_key(&mut Cursor::new(ed2551_key))?;
|
||||
load_key(&mut &rsa_sha256_key[..])?;
|
||||
load_key(&mut &ecdsa_nistp256_sha256_key[..])?;
|
||||
load_key(&mut &ecdsa_nistp384_sha384_key[..])?;
|
||||
load_key(&mut &ed2551_key[..])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_certs_of_different_types() -> io::Result<()> {
|
||||
fn verify_load_certs_of_different_types() -> Result<()> {
|
||||
let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem");
|
||||
let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem");
|
||||
let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem");
|
||||
let ed2551_cert = tls_example_key!("ed25519_cert.pem");
|
||||
|
||||
load_certs(&mut Cursor::new(rsa_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?;
|
||||
load_certs(&mut Cursor::new(ed2551_cert))?;
|
||||
load_cert_chain(&mut &rsa_sha256_cert[..])?;
|
||||
load_cert_chain(&mut &ecdsa_nistp256_sha256_cert[..])?;
|
||||
load_cert_chain(&mut &ecdsa_nistp384_sha384_cert[..])?;
|
||||
load_cert_chain(&mut &ed2551_cert[..])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -631,7 +631,7 @@ mod with_tls_feature {
|
|||
|
||||
use crate::http::tls::Config;
|
||||
use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher;
|
||||
use crate::http::tls::rustls::cipher_suite;
|
||||
use crate::http::tls::rustls::crypto::ring::cipher_suite;
|
||||
|
||||
use yansi::Paint;
|
||||
|
||||
|
|
|
@ -76,6 +76,9 @@ pub struct Error {
|
|||
pub enum ErrorKind {
|
||||
/// Binding to the provided address/port failed.
|
||||
Bind(io::Error),
|
||||
/// Binding via TLS to the provided address/port failed.
|
||||
#[cfg(feature = "tls")]
|
||||
TlsBind(crate::http::tls::error::Error),
|
||||
/// An I/O error occurred during launch.
|
||||
Io(io::Error),
|
||||
/// A valid [`Config`](crate::Config) could not be extracted from the
|
||||
|
@ -234,6 +237,12 @@ impl Error {
|
|||
|
||||
"aborting due to failed shutdown"
|
||||
}
|
||||
#[cfg(feature = "tls")]
|
||||
ErrorKind::TlsBind(e) => {
|
||||
error!("Rocket failed to bind via TLS to network socket.");
|
||||
info_!("{}", e);
|
||||
"aborting due to TLS bind error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -244,15 +253,17 @@ impl fmt::Display for ErrorKind {
|
|||
#[inline]
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ErrorKind::Bind(e) => write!(f, "binding failed: {}", e),
|
||||
ErrorKind::Io(e) => write!(f, "I/O error: {}", e),
|
||||
ErrorKind::Bind(e) => write!(f, "binding failed: {e}"),
|
||||
ErrorKind::Io(e) => write!(f, "I/O error: {e}"),
|
||||
ErrorKind::Collisions(_) => "collisions detected".fmt(f),
|
||||
ErrorKind::FailedFairings(_) => "launch fairing(s) failed".fmt(f),
|
||||
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
|
||||
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
|
||||
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
|
||||
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {}", e),
|
||||
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"),
|
||||
ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f),
|
||||
#[cfg(feature = "tls")]
|
||||
ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -228,10 +228,10 @@ macro_rules! pub_request_impl {
|
|||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
|
||||
use crate::http::{tls::util::load_certs, private::Certificates};
|
||||
use crate::http::{tls::util::load_cert_chain, private::Certificates};
|
||||
|
||||
let mut reader = std::io::BufReader::new(reader);
|
||||
let certs = load_certs(&mut reader).map(Certificates::from);
|
||||
let certs = load_cert_chain(&mut reader).map(Certificates::from);
|
||||
self._request_mut().connection.client_certificates = certs.ok();
|
||||
self
|
||||
}
|
||||
|
|
|
@ -427,7 +427,7 @@ impl Rocket<Orbit> {
|
|||
use crate::http::tls::TlsListener;
|
||||
|
||||
let conf = config.to_native_config().map_err(ErrorKind::Io)?;
|
||||
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?;
|
||||
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?;
|
||||
addr = l.local_addr().unwrap_or(addr);
|
||||
self.config.address = addr.ip();
|
||||
self.config.port = addr.port();
|
||||
|
|
|
@ -7,3 +7,4 @@ publish = false
|
|||
|
||||
[dependencies]
|
||||
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] }
|
||||
yansi = "1.0.0-rc.1"
|
||||
|
|
|
@ -22,5 +22,5 @@ fn rocket() -> _ {
|
|||
// Run `./private/gen_certs.sh` to generate a CA and key pairs.
|
||||
rocket::build()
|
||||
.mount("/", routes![hello, mutual])
|
||||
.attach(redirector::Redirector { port: 3000 })
|
||||
.attach(redirector::Redirector::on(3000))
|
||||
}
|
||||
|
|
|
@ -1,22 +1,35 @@
|
|||
//! Redirect all HTTP requests to HTTPs.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use rocket::http::Status;
|
||||
use rocket::log::LogLevel;
|
||||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config};
|
||||
use rocket::fairing::{Fairing, Info, Kind};
|
||||
use rocket::response::Redirect;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Redirector {
|
||||
pub port: u16
|
||||
pub listen_port: u16,
|
||||
pub tls_port: OnceLock<u16>,
|
||||
}
|
||||
|
||||
impl Redirector {
|
||||
// Route function that gets call on every single request.
|
||||
pub fn on(port: u16) -> Self {
|
||||
Redirector { listen_port: port, tls_port: OnceLock::new() }
|
||||
}
|
||||
|
||||
// Route function that gets called on every single request.
|
||||
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
|
||||
// FIXME: Check the host against a whitelist!
|
||||
let redirector = req.rocket().state::<Self>().expect("managed Self");
|
||||
if let Some(host) = req.host() {
|
||||
let https_uri = format!("https://{}{}", host, req.uri());
|
||||
let domain = host.domain();
|
||||
let https_uri = match redirector.tls_port.get() {
|
||||
Some(443) | None => format!("https://{domain}{}", req.uri()),
|
||||
Some(port) => format!("https://{domain}:{port}{}", req.uri()),
|
||||
};
|
||||
|
||||
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
|
||||
} else {
|
||||
route::Outcome::from(req, Status::BadRequest).pin()
|
||||
|
@ -25,13 +38,21 @@ impl Redirector {
|
|||
|
||||
// Launch an instance of Rocket than handles redirection on `self.port`.
|
||||
pub async fn try_launch(self, mut config: Config) -> Result<Rocket<Ignite>, Error> {
|
||||
use yansi::Paint;
|
||||
use rocket::http::Method::*;
|
||||
|
||||
// Determine the port TLS is being served on.
|
||||
let tls_port = self.tls_port.get_or_init(|| config.port);
|
||||
|
||||
// Adjust config for redirector: disable TLS, set port, disable logging.
|
||||
config.tls = None;
|
||||
config.port = self.port;
|
||||
config.port = self.listen_port;
|
||||
config.log_level = LogLevel::Critical;
|
||||
|
||||
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
||||
info_!("redirecting on insecure port {} to TLS port {}",
|
||||
self.listen_port.yellow(), tls_port.green());
|
||||
|
||||
// Build a vector of routes to `redirect` on `<path..>` for each method.
|
||||
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
|
||||
.into_iter()
|
||||
|
@ -39,6 +60,7 @@ impl Redirector {
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
rocket::custom(config)
|
||||
.manage(self)
|
||||
.mount("/", redirects)
|
||||
.launch()
|
||||
.await
|
||||
|
@ -48,11 +70,14 @@ impl Redirector {
|
|||
#[rocket::async_trait]
|
||||
impl Fairing for Redirector {
|
||||
fn info(&self) -> Info {
|
||||
Info { name: "HTTP -> HTTPS Redirector", kind: Kind::Liftoff }
|
||||
Info {
|
||||
name: "HTTP -> HTTPS Redirector",
|
||||
kind: Kind::Liftoff | Kind::Singleton
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_liftoff(&self, rkt: &Rocket<Orbit>) {
|
||||
let (this, shutdown, config) = (*self, rkt.shutdown(), rkt.config().clone());
|
||||
let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone());
|
||||
let _ = rocket::tokio::spawn(async move {
|
||||
if let Err(e) = this.try_launch(config).await {
|
||||
error!("Failed to start HTTP -> HTTPS redirector.");
|
||||
|
|
Loading…
Reference in New Issue