mirror of https://github.com/rwf2/Rocket.git
Split TLS module into logical submodules.
This commit is contained in:
parent
de4b3e3891
commit
b5e3569554
|
@ -16,7 +16,7 @@ edition = "2018"
|
|||
|
||||
[features]
|
||||
default = []
|
||||
tls = ["tokio-rustls"]
|
||||
tls = ["rustls", "tokio-rustls"]
|
||||
private-cookies = ["cookie/private", "cookie/key-expansion"]
|
||||
serde = ["uncased/with-serde-alloc", "serde_"]
|
||||
uuid = ["uuid_"]
|
||||
|
@ -28,6 +28,7 @@ http = "0.2"
|
|||
mime = "0.3.13"
|
||||
time = "0.2.11"
|
||||
indexmap = { version = "1.5.2", features = ["std"] }
|
||||
rustls = { version = "0.19", optional = true }
|
||||
tokio-rustls = { version = "0.22.0", optional = true }
|
||||
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
|
||||
log = "0.4"
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::net::SocketAddr;
|
||||
use std::future::Future;
|
||||
|
||||
use rustls::{ServerConfig, SupportedCipherSuite};
|
||||
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::tls::util::{load_certs, load_private_key};
|
||||
use crate::listener::{Connection, Listener};
|
||||
|
||||
/// A TLS listener over TCP.
|
||||
pub struct TlsListener {
|
||||
listener: TcpListener,
|
||||
acceptor: TlsAcceptor,
|
||||
state: State,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Listening,
|
||||
Accepting(Accept<TcpStream>),
|
||||
}
|
||||
|
||||
impl TlsListener {
|
||||
pub async fn bind(
|
||||
address: SocketAddr,
|
||||
mut cert_chain: impl io::BufRead + Send,
|
||||
mut private_key: impl io::BufRead + Send,
|
||||
ciphersuites: impl Iterator<Item = &'static SupportedCipherSuite>,
|
||||
prefer_server_order: bool,
|
||||
) -> io::Result<TlsListener> {
|
||||
let cert_chain = load_certs(&mut cert_chain).map_err(|e| {
|
||||
let msg = format!("malformed TLS certificate chain: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let key = load_private_key(&mut private_key).map_err(|e| {
|
||||
let msg = format!("malformed TLS private key: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let client_auth = rustls::NoClientAuth::new();
|
||||
let mut tls_config = ServerConfig::new(client_auth);
|
||||
let cache = rustls::ServerSessionMemoryCache::new(1024);
|
||||
tls_config.set_persistence(cache);
|
||||
tls_config.ticketer = rustls::Ticketer::new();
|
||||
tls_config.ciphersuites = ciphersuites.collect();
|
||||
tls_config.ignore_client_order = prefer_server_order;
|
||||
tls_config.set_single_cert(cert_chain, key).expect("invalid key");
|
||||
tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
|
||||
|
||||
let listener = TcpListener::bind(address).await?;
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||
Ok(TlsListener { listener, acceptor, state: State::Listening })
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for TlsListener {
|
||||
type Connection = TlsStream<TcpStream>;
|
||||
|
||||
fn local_addr(&self) -> Option<SocketAddr> {
|
||||
self.listener.local_addr().ok()
|
||||
}
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
loop {
|
||||
match self.state {
|
||||
State::Listening => {
|
||||
match self.listener.poll_accept(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Ready(Ok((stream, _addr))) => {
|
||||
let fut = self.acceptor.accept(stream);
|
||||
self.state = State::Accepting(fut);
|
||||
}
|
||||
}
|
||||
}
|
||||
State::Accepting(ref mut fut) => {
|
||||
match Pin::new(fut).poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
self.state = State::Listening;
|
||||
return Poll::Ready(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TlsStream<TcpStream> {
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.get_ref().0.remote_addr()
|
||||
}
|
||||
}
|
|
@ -1,180 +1,5 @@
|
|||
use std::io;
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
mod listener;
|
||||
mod util;
|
||||
|
||||
pub use tokio_rustls::rustls;
|
||||
|
||||
use rustls::internal::pemfile;
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
|
||||
use tokio_rustls::{TlsAcceptor, Accept, server::TlsStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::listener::{Connection, Listener};
|
||||
|
||||
fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
|
||||
pemfile::certs(reader)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid certificate"))
|
||||
}
|
||||
|
||||
fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey> {
|
||||
use std::io::{Cursor, Error, Read, ErrorKind::Other};
|
||||
|
||||
// "rsa" (PKCS1) PEM files have a different first-line header than PKCS8
|
||||
// PEM files, use that to determine the parse function to use.
|
||||
let mut first_line = String::new();
|
||||
reader.read_line(&mut first_line)?;
|
||||
|
||||
let private_keys_fn = match first_line.trim_end() {
|
||||
"-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys,
|
||||
"-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys,
|
||||
_ => return Err(Error::new(Other, "invalid key header"))
|
||||
};
|
||||
|
||||
let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader))
|
||||
.map_err(|_| Error::new(Other, "invalid key file"))
|
||||
.and_then(|mut keys| match keys.len() {
|
||||
0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")),
|
||||
1 => Ok(keys.remove(0)),
|
||||
n => Err(Error::new(Other, format!("expected 1 key, found {}", n))),
|
||||
})?;
|
||||
|
||||
// Ensure we can use the key.
|
||||
rustls::sign::any_supported_type(&key)
|
||||
.map_err(|_| Error::new(Other, "key parsed but is unusable"))
|
||||
.map(|_| key)
|
||||
}
|
||||
|
||||
pub struct TlsListener {
|
||||
listener: TcpListener,
|
||||
acceptor: TlsAcceptor,
|
||||
state: TlsListenerState,
|
||||
}
|
||||
|
||||
enum TlsListenerState {
|
||||
Listening,
|
||||
Accepting(Accept<TcpStream>),
|
||||
}
|
||||
|
||||
impl Listener for TlsListener {
|
||||
type Connection = TlsStream<TcpStream>;
|
||||
|
||||
fn local_addr(&self) -> Option<SocketAddr> {
|
||||
self.listener.local_addr().ok()
|
||||
}
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>
|
||||
) -> Poll<io::Result<Self::Connection>> {
|
||||
loop {
|
||||
match self.state {
|
||||
TlsListenerState::Listening => {
|
||||
match self.listener.poll_accept(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Ready(Ok((stream, _addr))) => {
|
||||
let fut = self.acceptor.accept(stream);
|
||||
self.state = TlsListenerState::Accepting(fut);
|
||||
}
|
||||
}
|
||||
}
|
||||
TlsListenerState::Accepting(ref mut fut) => {
|
||||
match Pin::new(fut).poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
self.state = TlsListenerState::Listening;
|
||||
return Poll::Ready(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn bind_tls(
|
||||
address: SocketAddr,
|
||||
mut cert_chain: impl io::BufRead + Send,
|
||||
mut private_key: impl io::BufRead + Send,
|
||||
ciphersuites: impl Iterator<Item = &'static SupportedCipherSuite>,
|
||||
prefer_server_order: bool,
|
||||
) -> io::Result<TlsListener> {
|
||||
let cert_chain = load_certs(&mut cert_chain).map_err(|e| {
|
||||
let msg = format!("malformed TLS certificate chain: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let key = load_private_key(&mut private_key).map_err(|e| {
|
||||
let msg = format!("malformed TLS private key: {}", e);
|
||||
io::Error::new(e.kind(), msg)
|
||||
})?;
|
||||
|
||||
let listener = TcpListener::bind(address).await?;
|
||||
|
||||
let client_auth = rustls::NoClientAuth::new();
|
||||
let mut tls_config = ServerConfig::new(client_auth);
|
||||
let cache = rustls::ServerSessionMemoryCache::new(1024);
|
||||
tls_config.set_persistence(cache);
|
||||
tls_config.ticketer = rustls::Ticketer::new();
|
||||
tls_config.ciphersuites = ciphersuites.collect();
|
||||
tls_config.ignore_client_order = prefer_server_order;
|
||||
tls_config.set_single_cert(cert_chain, key).expect("invalid key");
|
||||
tls_config.set_protocols(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
|
||||
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
|
||||
let state = TlsListenerState::Listening;
|
||||
|
||||
Ok(TlsListener { listener, acceptor, state })
|
||||
}
|
||||
|
||||
impl Connection for TlsStream<TcpStream> {
|
||||
fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.get_ref().0.remote_addr()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
use std::io::Cursor;
|
||||
|
||||
macro_rules! tls_example_key {
|
||||
($k:expr) => {
|
||||
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_private_keys_of_different_types() -> io::Result<()> {
|
||||
let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem");
|
||||
let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem");
|
||||
let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem");
|
||||
let ed2551_key = tls_example_key!("ed25519_key.pem");
|
||||
|
||||
load_private_key(&mut Cursor::new(rsa_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?;
|
||||
load_private_key(&mut Cursor::new(ed2551_key))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_certs_of_different_types() -> io::Result<()> {
|
||||
let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem");
|
||||
let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem");
|
||||
let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem");
|
||||
let ed2551_cert = tls_example_key!("ed25519_cert.pem");
|
||||
|
||||
load_certs(&mut Cursor::new(rsa_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?;
|
||||
load_certs(&mut Cursor::new(ed2551_cert))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
pub use rustls;
|
||||
pub use listener::TlsListener;
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
use std::io::{self, ErrorKind::Other, Cursor, Error, Read};
|
||||
|
||||
use rustls::{internal::pemfile, Certificate, PrivateKey};
|
||||
|
||||
/// Loads certificates from `reader`.
|
||||
pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
|
||||
pemfile::certs(reader).map_err(|_| Error::new(Other, "invalid certificate"))
|
||||
}
|
||||
|
||||
/// Load and decode the private key from `reader`.
|
||||
pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey> {
|
||||
// "rsa" (PKCS1) PEM files have a different first-line header than PKCS8
|
||||
// PEM files, use that to determine the parse function to use.
|
||||
let mut first_line = String::new();
|
||||
reader.read_line(&mut first_line)?;
|
||||
|
||||
let private_keys_fn = match first_line.trim_end() {
|
||||
"-----BEGIN RSA PRIVATE KEY-----" => pemfile::rsa_private_keys,
|
||||
"-----BEGIN PRIVATE KEY-----" => pemfile::pkcs8_private_keys,
|
||||
_ => return Err(Error::new(Other, "invalid key header"))
|
||||
};
|
||||
|
||||
let key = private_keys_fn(&mut Cursor::new(first_line).chain(reader))
|
||||
.map_err(|_| Error::new(Other, "invalid key file"))
|
||||
.and_then(|mut keys| match keys.len() {
|
||||
0 => Err(Error::new(Other, "no valid keys found; is the file malformed?")),
|
||||
1 => Ok(keys.remove(0)),
|
||||
n => Err(Error::new(Other, format!("expected 1 key, found {}", n))),
|
||||
})?;
|
||||
|
||||
// Ensure we can use the key.
|
||||
rustls::sign::any_supported_type(&key)
|
||||
.map_err(|_| Error::new(Other, "key parsed but is unusable"))
|
||||
.map(|_| key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
macro_rules! tls_example_key {
|
||||
($k:expr) => {
|
||||
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_private_keys_of_different_types() -> io::Result<()> {
|
||||
let rsa_sha256_key = tls_example_key!("rsa_sha256_key.pem");
|
||||
let ecdsa_nistp256_sha256_key = tls_example_key!("ecdsa_nistp256_sha256_key_pkcs8.pem");
|
||||
let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem");
|
||||
let ed2551_key = tls_example_key!("ed25519_key.pem");
|
||||
|
||||
load_private_key(&mut Cursor::new(rsa_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?;
|
||||
load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?;
|
||||
load_private_key(&mut Cursor::new(ed2551_key))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_load_certs_of_different_types() -> io::Result<()> {
|
||||
let rsa_sha256_cert = tls_example_key!("rsa_sha256_cert.pem");
|
||||
let ecdsa_nistp256_sha256_cert = tls_example_key!("ecdsa_nistp256_sha256_cert.pem");
|
||||
let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem");
|
||||
let ed2551_cert = tls_example_key!("ed25519_cert.pem");
|
||||
|
||||
load_certs(&mut Cursor::new(rsa_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?;
|
||||
load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?;
|
||||
load_certs(&mut Cursor::new(ed2551_cert))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -367,12 +367,12 @@ impl Rocket<Orbit> {
|
|||
|
||||
#[cfg(feature = "tls")]
|
||||
if let Some(ref config) = self.config.tls {
|
||||
use crate::http::tls::bind_tls;
|
||||
use crate::http::tls::TlsListener;
|
||||
|
||||
let (certs, key) = config.to_readers().map_err(ErrorKind::Io)?;
|
||||
let ciphers = config.rustls_ciphers();
|
||||
let server_order = config.prefer_server_cipher_order;
|
||||
let l = bind_tls(addr, certs, key, ciphers, server_order).await
|
||||
let l = TlsListener::bind(addr, certs, key, ciphers, server_order).await
|
||||
.map_err(ErrorKind::Bind)?;
|
||||
|
||||
addr = l.local_addr().unwrap_or(addr);
|
||||
|
|
Loading…
Reference in New Issue