diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index 8e65b931..ec5a578a 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -16,7 +16,7 @@ edition = "2018" [features] default = [] -tls = ["tokio-rustls"] +tls = ["rustls", "tokio-rustls"] private-cookies = ["cookie/private", "cookie/key-expansion"] serde = ["uncased/with-serde-alloc", "serde_"] uuid = ["uuid_"] @@ -28,6 +28,7 @@ http = "0.2" mime = "0.3.13" time = "0.2.11" indexmap = { version = "1.5.2", features = ["std"] } +rustls = { version = "0.19", optional = true } tokio-rustls = { version = "0.22.0", optional = true } tokio = { version = "1.6.1", features = ["net", "sync", "time"] } log = "0.4" diff --git a/core/http/src/tls/listener.rs b/core/http/src/tls/listener.rs new file mode 100644 index 00000000..19b2c262 --- /dev/null +++ b/core/http/src/tls/listener.rs @@ -0,0 +1,102 @@ +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::net::SocketAddr; +use std::future::Future; + +use rustls::{ServerConfig, SupportedCipherSuite}; +use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; +use tokio::net::{TcpListener, TcpStream}; + +use crate::tls::util::{load_certs, load_private_key}; +use crate::listener::{Connection, Listener}; + +/// A TLS listener over TCP. +pub struct TlsListener { + listener: TcpListener, + acceptor: TlsAcceptor, + state: State, +} + +enum State { + Listening, + Accepting(Accept), +} + +impl TlsListener { + pub async fn bind( + address: SocketAddr, + mut cert_chain: impl io::BufRead + Send, + mut private_key: impl io::BufRead + Send, + ciphersuites: impl Iterator, + prefer_server_order: bool, + ) -> io::Result { + let cert_chain = load_certs(&mut cert_chain).map_err(|e| { + let msg = format!("malformed TLS certificate chain: {}", e); + io::Error::new(e.kind(), msg) + })?; + + let key = load_private_key(&mut private_key).map_err(|e| { + let msg = format!("malformed TLS private key: {}", e); + io::Error::new(e.kind(), msg) + })?; + + let client_auth = rustls::NoClientAuth::new(); + let mut tls_config = ServerConfig::new(client_auth); + let cache = rustls::ServerSessionMemoryCache::new(1024); + tls_config.set_persistence(cache); + tls_config.ticketer = rustls::Ticketer::new(); + tls_config.ciphersuites = ciphersuites.collect(); + tls_config.ignore_client_order = prefer_server_order; + tls_config.set_single_cert(cert_chain, key).expect("invalid key"); + tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]); + + let listener = TcpListener::bind(address).await?; + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); + Ok(TlsListener { listener, acceptor, state: State::Listening }) + } +} + +impl Listener for TlsListener { + type Connection = TlsStream; + + fn local_addr(&self) -> Option { + self.listener.local_addr().ok() + } + + fn poll_accept( + mut self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + loop { + match self.state { + State::Listening => { + match self.listener.poll_accept(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok((stream, _addr))) => { + let fut = self.acceptor.accept(stream); + self.state = State::Accepting(fut); + } + } + } + State::Accepting(ref mut fut) => { + match Pin::new(fut).poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(result) => { + self.state = State::Listening; + return Poll::Ready(result); + } + } + } + } + } + } +} + +impl Connection for TlsStream { + fn remote_addr(&self) -> Option { + self.get_ref().0.remote_addr() + } +} diff --git a/core/http/src/tls/mod.rs b/core/http/src/tls/mod.rs index b009932c..4e21e40f 100644 --- a/core/http/src/tls/mod.rs +++ b/core/http/src/tls/mod.rs @@ -1,180 +1,5 @@ -use std::io; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +mod listener; +mod util; -pub use tokio_rustls::rustls; - -use rustls::internal::pemfile; -use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite}; -use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream}; -use tokio::net::{TcpListener, TcpStream}; - -use crate::listener::{Connection, Listener}; - -fn load_certs(reader: &mut dyn io::BufRead) -> io::Result> { - pemfile::certs(reader) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid certificate")) -} - -fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result { - use std::io::{Cursor, Error, Read, ErrorKind::Other}; - - // "rsa" (PKCS1) PEM files have a different first-line header than PKCS8 - // PEM files, use that to determine the parse function to use. - let mut first_line = String::new(); - reader.read_line(&mut first_line)?; - - let private_keys_fn = match first_line.trim_end() { - "-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys, - "-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys, - _ => return Err(Error::new(Other, "invalid key header")) - }; - - let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader)) - .map_err(|_| Error::new(Other, "invalid key file")) - .and_then(|mut keys| match keys.len() { - 0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")), - 1 => Ok(keys.remove(0)), - n => Err(Error::new(Other, format!("expected 1 key, found {}", n))), - })?; - - // Ensure we can use the key. - rustls::sign::any_supported_type(&key) - .map_err(|_| Error::new(Other, "key parsed but is unusable")) - .map(|_| key) -} - -pub struct TlsListener { - listener: TcpListener, - acceptor: TlsAcceptor, - state: TlsListenerState, -} - -enum TlsListenerState { - Listening, - Accepting(Accept), -} - -impl Listener for TlsListener { - type Connection = TlsStream; - - fn local_addr(&self) -> Option { - self.listener.local_addr().ok() - } - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { - loop { - match self.state { - TlsListenerState::Listening => { - match self.listener.poll_accept(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok((stream, _addr))) => { - let fut = self.acceptor.accept(stream); - self.state = TlsListenerState::Accepting(fut); - } - } - } - TlsListenerState::Accepting(ref mut fut) => { - match Pin::new(fut).poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(result) => { - self.state = TlsListenerState::Listening; - return Poll::Ready(result); - } - } - } - } - } - } -} - -pub async fn bind_tls( - address: SocketAddr, - mut cert_chain: impl io::BufRead + Send, - mut private_key: impl io::BufRead + Send, - ciphersuites: impl Iterator, - prefer_server_order: bool, -) -> io::Result { - let cert_chain = load_certs(&mut cert_chain).map_err(|e| { - let msg = format!("malformed TLS certificate chain: {}", e); - io::Error::new(e.kind(), msg) - })?; - - let key = load_private_key(&mut private_key).map_err(|e| { - let msg = format!("malformed TLS private key: {}", e); - io::Error::new(e.kind(), msg) - })?; - - let listener = TcpListener::bind(address).await?; - - let client_auth = rustls::NoClientAuth::new(); - let mut tls_config = ServerConfig::new(client_auth); - let cache = rustls::ServerSessionMemoryCache::new(1024); - tls_config.set_persistence(cache); - tls_config.ticketer = rustls::Ticketer::new(); - tls_config.ciphersuites = ciphersuites.collect(); - tls_config.ignore_client_order = prefer_server_order; - tls_config.set_single_cert(cert_chain, key).expect("invalid key"); - tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]); - - let acceptor = TlsAcceptor::from(Arc::new(tls_config)); - let state = TlsListenerState::Listening; - - Ok(TlsListener { listener, acceptor, state }) -} - -impl Connection for TlsStream { - fn remote_addr(&self) -> Option { - self.get_ref().0.remote_addr() - } -} - -#[cfg(test)] -mod test { - use super::*; - - use std::io::Cursor; - - macro_rules! tls_example_key { - ($k:expr) => { - include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)) - } - } - - #[test] - fn verify_load_private_keys_of_different_types() -> io::Result<()> { - let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem"); - let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem"); - let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem"); - let ed2551_key = tls_example_key!("ed25519_key.pem"); - - load_private_key(&mut Cursor::new(rsa_sha256_key))?; - load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?; - load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?; - load_private_key(&mut Cursor::new(ed2551_key))?; - - Ok(()) - } - - #[test] - fn verify_load_certs_of_different_types() -> io::Result<()> { - let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem"); - let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem"); - let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem"); - let ed2551_cert = tls_example_key!("ed25519_cert.pem"); - - load_certs(&mut Cursor::new(rsa_sha256_cert))?; - load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?; - load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?; - load_certs(&mut Cursor::new(ed2551_cert))?; - - Ok(()) - } -} +pub use rustls; +pub use listener::TlsListener; diff --git a/core/http/src/tls/util.rs b/core/http/src/tls/util.rs new file mode 100644 index 00000000..650f693d --- /dev/null +++ b/core/http/src/tls/util.rs @@ -0,0 +1,76 @@ +use std::io::{self, ErrorKind::Other, Cursor, Error, Read}; + +use rustls::{internal::pemfile, Certificate, PrivateKey}; + +/// Loads certificates from `reader`. +pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result> { + pemfile::certs(reader).map_err(|_| Error::new(Other, "invalid certificate")) +} + +/// Load and decode the private key from `reader`. +pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result { + // "rsa" (PKCS1) PEM files have a different first-line header than PKCS8 + // PEM files, use that to determine the parse function to use. + let mut first_line = String::new(); + reader.read_line(&mut first_line)?; + + let private_keys_fn = match first_line.trim_end() { + "-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys, + "-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys, + _ => return Err(Error::new(Other, "invalid key header")) + }; + + let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader)) + .map_err(|_| Error::new(Other, "invalid key file")) + .and_then(|mut keys| match keys.len() { + 0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")), + 1 => Ok(keys.remove(0)), + n => Err(Error::new(Other, format!("expected 1 key, found {}", n))), + })?; + + // Ensure we can use the key. + rustls::sign::any_supported_type(&key) + .map_err(|_| Error::new(Other, "key parsed but is unusable")) + .map(|_| key) +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! tls_example_key { + ($k:expr) => { + include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)) + } + } + + #[test] + fn verify_load_private_keys_of_different_types() -> io::Result<()> { + let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem"); + let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem"); + let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem"); + let ed2551_key = tls_example_key!("ed25519_key.pem"); + + load_private_key(&mut Cursor::new(rsa_sha256_key))?; + load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?; + load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?; + load_private_key(&mut Cursor::new(ed2551_key))?; + + Ok(()) + } + + #[test] + fn verify_load_certs_of_different_types() -> io::Result<()> { + let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem"); + let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem"); + let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem"); + let ed2551_cert = tls_example_key!("ed25519_cert.pem"); + + load_certs(&mut Cursor::new(rsa_sha256_cert))?; + load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?; + load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?; + load_certs(&mut Cursor::new(ed2551_cert))?; + + Ok(()) + } +} diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 08731c32..8830360e 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -367,12 +367,12 @@ impl Rocket { #[cfg(feature = "tls")] if let Some(ref config) = self.config.tls { - use crate::http::tls::bind_tls; + use crate::http::tls::TlsListener; let (certs, key) = config.to_readers().map_err(ErrorKind::Io)?; let ciphers = config.rustls_ciphers(); let server_order = config.prefer_server_cipher_order; - let l = bind_tls(addr, certs, key, ciphers, server_order).await + let l = TlsListener::bind(addr, certs, key, ciphers, server_order).await .map_err(ErrorKind::Bind)?; addr = l.local_addr().unwrap_or(addr);