Introduce dynamic TLS resolvers.

This commit introduces the ability to dynamically select a TLS
configuration based on the client's TLS hello via the new `Resolver`
trait. In support of this, it also makes the following changes:

  * Added `Authority::set_port()`.
  * `UdsListener` is now `UnixListener`.
  * `Bindable` removed in favor of new `Bind`.
  * All built-in listeners now implement `Bind<&Rocket>`.
  * `Connection` requires `AsyncRead + AsyncWrite`.
  * The `Debug` impl for `Endpoint` displays the underlying address.
  * `Listener` must be `Sized`.
  * The TLS listener was moved to `tls::TlsListener`.
  * The preview `quic` listener no longer implements `Listener`.
  * Added `TlsConfig::server_config()`.
  * Added `race` future helpers.
  * Added `Rocket::launch_with()`, `Rocket::bind_launch()`.
  * Added a default `client.pem` to the TLS example.
  * Various unnecessary listener `Config` structures removed.

In addition, the testbench was revamped to support more scenarios. This
resulted in the following issues being found and fixed:

  * Fix an issue where the logger would ignore color requests.
  * Clarified docs for `mtls::Certificate` guard.
  * Improved error messages on listener misconfiguration.

Resolves #2730.
Resolves #2363.
Closes #2748.
Closes #2683.
Closes #2577.
This commit is contained in:
Sergio Benitez 2024-04-16 02:39:52 -07:00
parent 60f3cd57b0
commit 7cc818cd85
45 changed files with 1635 additions and 718 deletions

View File

@ -117,6 +117,8 @@
//! to an `Object` (a dictionary) value. The [`context!`] macro can be used to //! to an `Object` (a dictionary) value. The [`context!`] macro can be used to
//! create inline `Serialize`-able context objects. //! create inline `Serialize`-able context objects.
//! //!
//! [`Serialize`]: rocket::serde::Serialize
//!
//! ```rust //! ```rust
//! # #[macro_use] extern crate rocket; //! # #[macro_use] extern crate rocket;
//! use rocket::serde::Serialize; //! use rocket::serde::Serialize;
@ -165,7 +167,7 @@
//! builds, template reloading is disabled to improve performance and cannot be //! builds, template reloading is disabled to improve performance and cannot be
//! enabled. //! enabled.
//! //!
//! [attached]: Rocket::attach() //! [attached]: rocket::Rocket::attach()
//! //!
//! ### Metadata and Rendering to `String` //! ### Metadata and Rendering to `String`
//! //!

View File

@ -140,11 +140,12 @@ impl Template {
} }
/// Render the template named `name` with the context `context`. The /// Render the template named `name` with the context `context`. The
/// `context` is typically created using the [`context!`] macro, but it can /// `context` is typically created using the [`context!()`](crate::context!)
/// be of any type that implements `Serialize`, such as `HashMap` or a /// macro, but it can be of any type that implements `Serialize`, such as
/// custom `struct`. /// `HashMap` or a custom `struct`.
/// ///
/// To render a template directly into a string, use [`Metadata::render()`]. /// To render a template directly into a string, use
/// [`Metadata::render()`](crate::Metadata::render()).
/// ///
/// # Examples /// # Examples
/// ///
@ -291,8 +292,8 @@ impl Sentinel for Template {
/// A macro to easily create a template rendering context. /// A macro to easily create a template rendering context.
/// ///
/// Invocations of this macro expand to a value of an anonymous type which /// Invocations of this macro expand to a value of an anonymous type which
/// implements [`serde::Serialize`]. Fields can be literal expressions or /// implements [`Serialize`]. Fields can be literal expressions or variables
/// variables captured from a surrounding scope, as long as all fields implement /// captured from a surrounding scope, as long as all fields implement
/// `Serialize`. /// `Serialize`.
/// ///
/// # Examples /// # Examples

View File

@ -1,5 +1,5 @@
#[cfg(all(feature = "diesel_sqlite_pool"))]
#[cfg(test)] #[cfg(test)]
#[cfg(all(feature = "diesel_sqlite_pool"))]
mod sqlite_shutdown_test { mod sqlite_shutdown_test {
use rocket::{async_test, Build, Rocket}; use rocket::{async_test, Build, Rocket};
use rocket_sync_db_pools::database; use rocket_sync_db_pools::database;

View File

@ -185,7 +185,7 @@ impl<'a> Authority<'a> {
self.host.from_cow_source(&self.source) self.host.from_cow_source(&self.source)
} }
/// Returns the port part of the authority URI, if there is one. /// Returns the `port` part of the authority URI, if there is one.
/// ///
/// # Example /// # Example
/// ///
@ -206,6 +206,28 @@ impl<'a> Authority<'a> {
pub fn port(&self) -> Option<u16> { pub fn port(&self) -> Option<u16> {
self.port self.port
} }
/// Set the `port` of the authority URI.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// let mut uri = uri!("username:password@host:123");
/// assert_eq!(uri.port(), Some(123));
///
/// uri.set_port(1024);
/// assert_eq!(uri.port(), Some(1024));
/// assert_eq!(uri, "username:password@host:1024");
///
/// uri.set_port(None);
/// assert_eq!(uri.port(), None);
/// assert_eq!(uri, "username:password@host");
/// ```
#[inline(always)]
pub fn set_port<T: Into<Option<u16>>>(&mut self, port: T) {
self.port = port.into();
}
} }
impl_serde!(Authority<'a>, "an authority-form URI"); impl_serde!(Authority<'a>, "an authority-form URI");

View File

@ -69,7 +69,7 @@ ref-swap = "0.1.2"
parking_lot = "0.12" parking_lot = "0.12"
ubyte = {version = "0.10.2", features = ["serde"] } ubyte = {version = "0.10.2", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
figment = { version = "0.10.13", features = ["toml", "env"] } figment = { version = "0.10.17", features = ["toml", "env"] }
rand = "0.8" rand = "0.8"
either = "1" either = "1"
pin-project-lite = "0.2" pin-project-lite = "0.2"
@ -140,5 +140,5 @@ version_check = "0.9.1"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["macros", "io-std"] } tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] } figment = { version = "0.10.17", features = ["test"] }
pretty_assertions = "1" pretty_assertions = "1"

View File

@ -137,9 +137,6 @@ mod secret_key;
#[cfg(unix)] #[cfg(unix)]
pub use crate::shutdown::Sig; pub use crate::shutdown::Sig;
#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;
#[cfg(feature = "secrets")] #[cfg(feature = "secrets")]
pub use secret_key::SecretKey; pub use secret_key::SecretKey;

View File

@ -178,6 +178,9 @@ impl Error {
self.mark_handled(); self.mark_handled();
match self.kind() { match self.kind() {
ErrorKind::Bind(ref a, ref e) => { ErrorKind::Bind(ref a, ref e) => {
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print()
} else {
match a { match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()), Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."), None => error!("Binding to network interface failed."),
@ -186,6 +189,7 @@ impl Error {
info_!("{}", e); info_!("{}", e);
"aborting due to bind error" "aborting due to bind error"
} }
}
ErrorKind::Io(ref e) => { ErrorKind::Io(ref e) => {
error!("Rocket failed to launch due to an I/O error."); error!("Rocket failed to launch due to an I/O error.");
info_!("{}", e); info_!("{}", e);

View File

@ -0,0 +1,10 @@
use crate::listener::{Endpoint, Listener};
pub trait Bind<T>: Listener + 'static {
type Error: std::error::Error + Send + 'static;
#[crate::async_bound(Send)]
async fn bind(to: T) -> Result<Self, Self::Error>;
fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
}

View File

@ -1,52 +0,0 @@
use std::io;
use futures::TryFutureExt;
use crate::listener::{Listener, Endpoint};
pub trait Bindable: Sized {
type Listener: Listener + 'static;
type Error: std::error::Error + Send + 'static;
async fn bind(self) -> Result<Self::Listener, Self::Error>;
/// The endpoint that `self` binds on.
fn bind_endpoint(&self) -> io::Result<Endpoint>;
}
impl<L: Listener + 'static> Bindable for L {
type Listener = L;
type Error = std::convert::Infallible;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(self)
}
fn bind_endpoint(&self) -> io::Result<Endpoint> {
L::endpoint(self)
}
}
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;
type Error = either::Either<A::Error, B::Error>;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
match self {
either::Either::Left(a) => a.bind()
.map_ok(tokio_util::either::Either::Left)
.map_err(either::Either::Left)
.await,
either::Either::Right(b) => b.bind()
.map_ok(tokio_util::either::Either::Right)
.map_err(either::Either::Right)
.await,
}
}
fn bind_endpoint(&self) -> io::Result<Endpoint> {
either::for_both!(self, a => a.bind_endpoint())
}
}

View File

@ -1,6 +1,7 @@
use std::io; use std::io;
use std::borrow::Cow; use std::borrow::Cow;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::either::Either; use tokio_util::either::Either;
use super::Endpoint; use super::Endpoint;
@ -9,7 +10,7 @@ use super::Endpoint;
#[derive(Clone)] #[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>); pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
pub trait Connection: Send + Unpin { pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn endpoint(&self) -> io::Result<Endpoint>; fn endpoint(&self) -> io::Result<Endpoint>;
/// DER-encoded X.509 certificate chain presented by the client, if any. /// DER-encoded X.509 certificate chain presented by the client, if any.

View File

@ -1,64 +1,190 @@
use core::fmt;
use serde::Deserialize;
use tokio_util::either::Either::{Left, Right};
use either::Either; use either::Either;
use crate::listener::{Bindable, Endpoint}; use crate::{Ignite, Rocket};
use crate::error::{Error, ErrorKind}; use crate::listener::{Bind, Endpoint, tcp::TcpListener};
#[derive(serde::Deserialize)] #[cfg(unix)] use crate::listener::unix::UnixListener;
pub struct DefaultListener { #[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};
mod private {
use super::*;
use tokio_util::either::Either;
#[cfg(feature = "tls")] type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] type TlsListener<T> = T;
#[cfg(unix)] type UnixListener = super::UnixListener;
#[cfg(not(unix))] type UnixListener = TcpListener;
pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;
/// The default connection listener.
///
/// # Configuration
///
/// Reads the following optional configuration parameters:
///
/// | parameter | type | default |
/// | ----------- | ----------------- | --------------------- |
/// | `address` | [`Endpoint`] | `tcp:127.0.0.1:8000` |
/// | `tls` | [`TlsConfig`] | None |
/// | `reuse` | boolean | `true` |
///
/// # Listener
///
/// Based on the above configuration, this listener defers to one of the
/// following existing listeners:
///
/// | listener | `address` type | `tls` enabled |
/// |-------------------------------|--------------------|---------------|
/// | [`TcpListener`] | [`Endpoint::Tcp`] | no |
/// | [`UnixListener`] | [`Endpoint::Unix`] | no |
/// | [`TlsListener<TcpListener>`] | [`Endpoint::Tcp`] | yes |
/// | [`TlsListener<UnixListener>`] | [`Endpoint::Unix`] | yes |
///
/// [`UnixListener`]: crate::listener::unix::UnixListener
/// [`TlsListener<TcpListener>`]: crate::tls::TlsListener
/// [`TlsListener<UnixListener>`]: crate::tls::TlsListener
///
/// * **address type** is the variant the `address` parameter parses as.
/// * **`tls` enabled** is `yes` when the `tls` feature is enabled _and_ a
/// `tls` configuration is provided.
#[cfg(doc)]
pub struct DefaultListener(());
}
#[derive(Deserialize)]
struct Config {
#[serde(default)] #[serde(default)]
pub address: Endpoint, address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>, tls: Option<TlsConfig>,
} }
#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>; #[cfg(doc)]
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>; pub use private::DefaultListener;
#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>; #[cfg(doc)]
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>; type Connection = crate::listener::tcp::TcpStream;
impl DefaultListener { #[cfg(doc)]
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> { impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
match &self.address { type Error = Error;
Endpoint::Tcp(mut address) => { async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
if let Some(port) = self.port { fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
address.set_port(port);
} }
Ok(BaseBindable::Left(address)) #[cfg(doc)]
}, impl super::Listener for DefaultListener {
#[doc(hidden)] type Accept = Connection;
#[doc(hidden)] type Connection = Connection;
#[doc(hidden)]
async fn accept(&self) -> std::io::Result<Connection> { unreachable!() }
#[doc(hidden)]
async fn connect(&self, _: Self::Accept) -> std::io::Result<Connection> { unreachable!() }
#[doc(hidden)]
fn endpoint(&self) -> std::io::Result<Endpoint> { unreachable!() }
}
#[cfg(not(doc))]
pub type DefaultListener = private::Listener;
#[cfg(not(doc))]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Left(listener)))
}
Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Right(listener)))
}
#[cfg(unix)] #[cfg(unix)]
Endpoint::Unix(path) => { Endpoint::Unix(_) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, }; let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
Ok(BaseBindable::Right(uds)) Ok(Right(Right(listener)))
},
#[cfg(not(unix))]
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
} }
endpoint => Err(Error::Unsupported(endpoint)),
} }
} }
pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> { fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: Config = rocket.figment().extract()?;
Ok(config.address)
}
}
#[derive(Debug)]
pub enum Error {
Config(figment::Error),
Io(std::io::Error),
Unsupported(Endpoint),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() { Tls(crate::tls::Error),
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
} }
TlsBindable::Right(inner) impl From<figment::Error> for Error {
fn from(value: figment::Error) -> Self {
Error::Config(value)
}
} }
pub fn bindable(&self) -> Result<impl Bindable, crate::Error> { impl From<std::io::Error> for Error {
self.base_bindable() fn from(value: std::io::Error) -> Self {
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b))) Error::Io(value)
}
}
#[cfg(feature = "tls")]
impl From<crate::tls::Error> for Error {
fn from(value: crate::tls::Error) -> Self {
Error::Tls(value)
}
}
impl From<Either<figment::Error, std::io::Error>> for Error {
fn from(value: Either<figment::Error, std::io::Error>) -> Self {
value.either(Error::Config, Error::Io)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Config(e) => e.fmt(f),
Error::Io(e) => e.fmt(f),
Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
#[cfg(feature = "tls")]
Error::Tls(error) => error.fmt(f),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Config(e) => Some(e),
Error::Io(e) => Some(e),
Error::Unsupported(_) => None,
#[cfg(feature = "tls")]
Error::Tls(e) => Some(e),
}
} }
} }

View File

@ -5,6 +5,7 @@ use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use figment::Figment;
use serde::de; use serde::de;
use crate::http::uncased::AsUncased; use crate::http::uncased::AsUncased;
@ -12,27 +13,43 @@ use crate::http::uncased::AsUncased;
#[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>; #[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>; #[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { } pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { }
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {} impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> CustomEndpoint for T {}
/// # Conversions /// # Conversions
/// ///
/// * [`&str`] - parse with [`FromStr`] /// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`] /// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`] /// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
#[derive(Debug, Clone)] ///
/// # Syntax
///
/// The string syntax is:
///
/// ```text
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
#[derive(Clone)]
#[non_exhaustive] #[non_exhaustive]
pub enum Endpoint { pub enum Endpoint {
Tcp(net::SocketAddr), Tcp(net::SocketAddr),
Quic(net::SocketAddr), Quic(net::SocketAddr),
Unix(PathBuf), Unix(PathBuf),
Tls(Arc<Endpoint>, TlsInfo), Tls(Arc<Endpoint>, TlsInfo),
Custom(Arc<dyn EndpointAddr>), Custom(Arc<dyn CustomEndpoint>),
} }
impl Endpoint { impl Endpoint {
pub fn new<T: EndpointAddr>(value: T) -> Endpoint { pub fn new<T: CustomEndpoint>(value: T) -> Endpoint {
Endpoint::Custom(Arc::new(value)) Endpoint::Custom(Arc::new(value))
} }
@ -152,6 +169,29 @@ impl Endpoint {
Self::Tls(Arc::new(self), None) Self::Tls(Arc::new(self), None)
} }
/// Fetch the endpoint at `path` in `figment` of kind `kind` (e.g, "tcp")
/// then map the value using `f(Some(value))` if present and `f(None)` if
/// missing into a different value of typr `T`.
///
/// If the conversion succeeds, returns `Ok(value)`. If the conversion fails
/// and `Some` value was passed in, returns an error indicating the endpoint
/// was an invalid `kind` and otherwise returns a "missing field" error.
pub(crate) fn fetch<T, F>(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result<T>
where F: FnOnce(Option<&Endpoint>) -> Option<T>
{
match figment.extract_inner::<Endpoint>(path) {
Ok(endpoint) => f(Some(&endpoint)).ok_or_else(|| {
let msg = format!("invalid {kind} endpoint: {endpoint:?}");
let mut error = figment::Error::from(msg).with_path(path);
error.profile = Some(figment.profile().clone());
error.metadata = figment.find_metadata(path).cloned();
error
}),
Err(e) if e.missing() => f(None).ok_or(e),
Err(e) => Err(e)
}
}
} }
impl fmt::Display for Endpoint { impl fmt::Display for Endpoint {
@ -180,29 +220,16 @@ impl fmt::Display for Endpoint {
} }
} }
impl From<PathBuf> for Endpoint { impl fmt::Debug for Endpoint {
fn from(value: PathBuf) -> Self { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Self::Unix(value) match self {
Self::Tcp(a) => write!(f, "tcp:{a}"),
Self::Quic(a) => write!(f, "quic:{a}]"),
Self::Unix(a) => write!(f, "unix:{}", a.display()),
Self::Tls(e, _) => write!(f, "unix:{:?}", &**e),
Self::Custom(e) => e.fmt(f),
} }
} }
#[cfg(unix)]
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
type Error = std::io::Error;
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
v.as_pathname()
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
.map(|path| Endpoint::Unix(path.to_path_buf()))
}
}
impl TryFrom<&str> for Endpoint {
type Error = AddrParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
} }
impl Default for Endpoint { impl Default for Endpoint {
@ -211,21 +238,6 @@ impl Default for Endpoint {
} }
} }
/// Parses an address into a `Endpoint`.
///
/// The syntax is:
///
/// ```text
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
impl FromStr for Endpoint { impl FromStr for Endpoint {
type Err = AddrParseError; type Err = AddrParseError;
@ -237,8 +249,6 @@ impl FromStr for Endpoint {
if let Some((proto, string)) = string.split_once(':') { if let Some((proto, string)) = string.split_once(':') {
if proto.trim().as_uncased() == "tcp" { if proto.trim().as_uncased() == "tcp" {
return parse_tcp(string.trim(), 8000).map(Self::Tcp); return parse_tcp(string.trim(), 8000).map(Self::Tcp);
} else if proto.trim().as_uncased() == "quic" {
return parse_tcp(string.trim(), 8000).map(Self::Quic);
} else if proto.trim().as_uncased() == "unix" { } else if proto.trim().as_uncased() == "unix" {
return Ok(Self::Unix(PathBuf::from(string.trim()))); return Ok(Self::Unix(PathBuf::from(string.trim())));
} }
@ -256,7 +266,7 @@ impl<'de> de::Deserialize<'de> for Endpoint {
type Value = Endpoint; type Value = Endpoint;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("TCP or Unix address") formatter.write_str("valid TCP (ip) or unix (path) endpoint")
} }
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> { fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
@ -294,3 +304,37 @@ impl PartialEq<Path> for Endpoint {
self.unix() == Some(other) self.unix() == Some(other)
} }
} }
#[cfg(unix)]
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
type Error = std::io::Error;
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
v.as_pathname()
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
.map(|path| Endpoint::Unix(path.to_path_buf()))
}
}
impl TryFrom<&str> for Endpoint {
type Error = AddrParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
macro_rules! impl_from {
($T:ty => $V:ident) => {
impl From<$T> for Endpoint {
fn from(value: $T) -> Self {
Self::$V(value.into())
}
}
}
}
impl_from!(std::net::SocketAddr => Tcp);
impl_from!(std::net::SocketAddrV4 => Tcp);
impl_from!(std::net::SocketAddrV6 => Tcp);
impl_from!(PathBuf => Unix);

View File

@ -5,7 +5,7 @@ use tokio_util::either::Either;
use crate::listener::{Connection, Endpoint}; use crate::listener::{Connection, Endpoint};
pub trait Listener: Send + Sync { pub trait Listener: Sized + Send + Sync {
type Accept: Send; type Accept: Send;
type Connection: Connection; type Connection: Connection;

View File

@ -3,15 +3,12 @@ mod bounced;
mod listener; mod listener;
mod endpoint; mod endpoint;
mod connection; mod connection;
mod bindable; mod bind;
mod default; mod default;
#[cfg(unix)] #[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))] #[cfg_attr(nightly, doc(cfg(unix)))]
pub mod unix; pub mod unix;
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
pub mod tcp; pub mod tcp;
#[cfg(feature = "http3-preview")] #[cfg(feature = "http3-preview")]
pub mod quic; pub mod quic;
@ -19,7 +16,7 @@ pub mod quic;
pub use endpoint::*; pub use endpoint::*;
pub use listener::*; pub use listener::*;
pub use connection::*; pub use connection::*;
pub use bindable::*; pub use bind::*;
pub use default::*; pub use default::*;
pub(crate) use cancellable::*; pub(crate) use cancellable::*;

View File

@ -38,7 +38,7 @@ use tokio::sync::Mutex;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::tls::{TlsConfig, Error}; use crate::tls::{TlsConfig, Error};
use crate::listener::{Listener, Connection, Endpoint}; use crate::listener::Endpoint;
type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>; type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;
@ -51,14 +51,16 @@ pub struct QuicListener {
pub struct H3Stream(H3Conn); pub struct H3Stream(H3Conn);
pub struct H3Connection { pub struct H3Connection {
pub handle: quic::connection::Handle, pub(crate) handle: quic::connection::Handle,
pub parts: http::request::Parts, pub(crate) parts: http::request::Parts,
pub tx: QuicTx, pub(crate) tx: QuicTx,
pub rx: QuicRx, pub(crate) rx: QuicRx,
} }
#[doc(hidden)]
pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>); pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);
#[doc(hidden)]
pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>); pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);
impl QuicListener { impl QuicListener {
@ -94,25 +96,20 @@ impl QuicListener {
} }
} }
impl Listener for QuicListener { impl QuicListener {
type Accept = quic::Connection; pub async fn accept(&self) -> Option<quic::Connection> {
type Connection = H3Stream;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener self.listener
.lock().await .lock().await
.accept().await .accept().await
.ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed"))
} }
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> { pub async fn connect(&self, accept: quic::Connection) -> io::Result<H3Stream> {
let quic_conn = quic_h3::Connection::new(accept); let quic_conn = quic_h3::Connection::new(accept);
let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
Ok(H3Stream(conn)) Ok(H3Stream(conn))
} }
fn endpoint(&self) -> io::Result<Endpoint> { pub fn endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls)) Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
} }
} }
@ -159,16 +156,8 @@ impl QuicTx {
} }
// FIXME: Expose certificates when possible. // FIXME: Expose certificates when possible.
impl Connection for H3Stream { impl H3Connection {
fn endpoint(&self) -> io::Result<Endpoint> { pub fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.0.inner.conn.handle().remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls())
}
}
// FIXME: Expose certificates when possible.
impl Connection for H3Connection {
fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.handle.remote_addr()?; let addr = self.handle.remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls()) Ok(Endpoint::Quic(addr).assume_tls())
} }

View File

@ -1,21 +1,61 @@
//! TCP listener.
//!
//! # Configuration
//!
//! Reads the following configuration parameters:
//!
//! | parameter | type | default | note |
//! |-----------|--------------|-------------|---------------------------------|
//! | `address` | [`Endpoint`] | `127.0.0.1` | must be `tcp:ip` |
//! | `port` | `u16` | `8000` | replaces the port in `address ` |
use std::io; use std::io;
use std::net::{Ipv4Addr, SocketAddr};
use either::{Either, Left, Right};
#[doc(inline)] #[doc(inline)]
pub use tokio::net::{TcpListener, TcpStream}; pub use tokio::net::{TcpListener, TcpStream};
use crate::listener::{Listener, Bindable, Connection, Endpoint}; use crate::{Ignite, Rocket};
use crate::listener::{Bind, Connection, Endpoint, Listener};
impl Bindable for std::net::SocketAddr { impl Bind<SocketAddr> for TcpListener {
type Listener = TcpListener; type Error = std::io::Error;
type Error = io::Error; async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
Self::bind(addr).await
async fn bind(self) -> Result<Self::Listener, Self::Error> {
TcpListener::bind(self).await
} }
fn bind_endpoint(&self) -> io::Result<Endpoint> { fn bind_endpoint(addr: &SocketAddr) -> Result<Endpoint, Self::Error> {
Ok(Endpoint::Tcp(*self)) Ok(Endpoint::Tcp(*addr))
}
}
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?;
let addr = endpoint.tcp()
.ok_or_else(|| io::Error::other("internal error: invalid endpoint"))
.map_err(Right)?;
Self::bind(addr).await.map_err(Right)
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let figment = rocket.figment();
let mut address = Endpoint::fetch(figment, "tcp", "address", |e| {
let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000);
e.map(|e| e.tcp()).unwrap_or(Some(default))
}).map_err(Left)?;
if let Some(port) = figment.extract_inner("port").map_err(Left)? {
address.set_port(port);
}
Ok(Endpoint::Tcp(address))
} }
} }

View File

@ -1,119 +0,0 @@
use std::io;
use std::sync::Arc;
use serde::Deserialize;
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use crate::tls::{TlsConfig, Error};
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
#[doc(inline)]
pub use tokio_rustls::server::TlsStream;
/// A TLS listener over some listener interface L.
pub struct TlsListener<L> {
listener: L,
acceptor: TlsAcceptor,
config: TlsConfig,
}
#[derive(Clone, Deserialize)]
pub struct TlsBindable<I> {
#[serde(flatten)]
pub inner: I,
pub tls: TlsConfig,
}
impl TlsConfig {
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
let provider = Arc::new(self.default_crypto_provider());
#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca = Arc::new(mtls.load_ca_certs()?);
let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};
#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();
let mut tls_config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(self.load_certs()?, self.load_key()?)?;
tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}
Ok(tls_config)
}
}
impl<I: Bindable> Bindable for TlsBindable<I>
where I::Listener: Listener<Accept = <I::Listener as Listener>::Connection>,
<I::Listener as Listener>::Connection: AsyncRead + AsyncWrite
{
type Listener = TlsListener<I::Listener>;
type Error = Error;
async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(TlsListener {
acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)),
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
config: self.tls,
})
}
fn bind_endpoint(&self) -> io::Result<Endpoint> {
let inner = self.inner.bind_endpoint()?;
Ok(inner.with_tls(&self.tls))
}
}
impl<L> Listener for TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>,
L::Connection: AsyncRead + AsyncWrite
{
type Accept = L::Connection;
type Connection = TlsStream<L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener.accept().await
}
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
self.acceptor.accept(conn).await
}
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.listener.endpoint()?.with_tls(&self.config))
}
}
impl<C: Connection> Connection for TlsStream<C> {
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.get_ref().0.endpoint()?.assume_tls())
}
#[cfg(feature = "mtls")]
fn certificates(&self) -> Option<Certificates<'_>> {
let cert_chain = self.get_ref().1.peer_certificates()?;
Some(Certificates::from(cert_chain))
}
}

View File

@ -1,48 +1,49 @@
use std::io; use std::io;
use std::path::PathBuf; use std::path::{Path, PathBuf};
use either::{Either, Left, Right};
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use crate::fs::NamedFile; use crate::fs::NamedFile;
use crate::listener::{Listener, Bindable, Connection, Endpoint}; use crate::listener::{Listener, Bind, Connection, Endpoint};
use crate::util::unix; use crate::util::unix;
use crate::{Ignite, Rocket};
pub use tokio::net::UnixStream; pub use tokio::net::UnixStream;
#[derive(Debug, Clone)] /// Unix domain sockets listener.
pub struct UdsConfig { ///
/// Socket address. /// # Configuration
pub path: PathBuf, ///
/// Recreate a socket that already exists. /// Reads the following configuration parameters:
pub reuse: Option<bool>, ///
} /// | parameter | type | default | note |
/// |-----------|--------------|---------|-------------------------------------------|
pub struct UdsListener { /// | `address` | [`Endpoint`] | | required: must be `unix:path` |
/// | `reuse` | boolean | `true` | whether to create/reuse/delete the socket |
pub struct UnixListener {
path: PathBuf, path: PathBuf,
lock: Option<NamedFile>, lock: Option<NamedFile>,
listener: tokio::net::UnixListener, listener: tokio::net::UnixListener,
} }
impl Bindable for UdsConfig { impl UnixListener {
type Listener = UdsListener; pub async fn bind<P: AsRef<Path>>(path: P, reuse: bool) -> io::Result<Self> {
let path = path.as_ref();
type Error = io::Error; let lock = if reuse {
let lock_ext = match path.extension().and_then(|s| s.to_str()) {
async fn bind(self) -> Result<Self::Listener, Self::Error> {
let lock = if self.reuse.unwrap_or(true) {
let lock_ext = match self.path.extension().and_then(|s| s.to_str()) {
Some(ext) if !ext.is_empty() => format!("{}.lock", ext), Some(ext) if !ext.is_empty() => format!("{}.lock", ext),
_ => "lock".to_string() _ => "lock".to_string()
}; };
let mut opts = tokio::fs::File::options(); let mut opts = tokio::fs::File::options();
opts.create(true).write(true); opts.create(true).write(true);
let lock_path = self.path.with_extension(lock_ext); let lock_path = path.with_extension(lock_ext);
let lock_file = NamedFile::open_with(lock_path, &opts).await?; let lock_file = NamedFile::open_with(lock_path, &opts).await?;
unix::lock_exclusive_nonblocking(lock_file.file())?; unix::lock_exclusive_nonblocking(lock_file.file())?;
if self.path.exists() { if path.exists() {
tokio::fs::remove_file(&self.path).await?; tokio::fs::remove_file(&path).await?;
} }
Some(lock_file) Some(lock_file)
@ -55,9 +56,9 @@ impl Bindable for UdsConfig {
// and this will succeed. So let's try a few times. // and this will succeed. So let's try a few times.
let mut retries = 5; let mut retries = 5;
let listener = loop { let listener = loop {
match tokio::net::UnixListener::bind(&self.path) { match tokio::net::UnixListener::bind(&path) {
Ok(listener) => break listener, Ok(listener) => break listener,
Err(e) if self.path.exists() && lock.is_none() => return Err(e), Err(e) if path.exists() && lock.is_none() => return Err(e),
Err(_) if retries > 0 => { Err(_) if retries > 0 => {
retries -= 1; retries -= 1;
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
@ -66,15 +67,31 @@ impl Bindable for UdsConfig {
} }
}; };
Ok(UdsListener { lock, listener, path: self.path, }) Ok(UnixListener { lock, listener, path: path.into() })
}
fn bind_endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Unix(self.path.clone()))
} }
} }
impl Listener for UdsListener { impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?;
let path = endpoint.unix()
.ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?;
let reuse: Option<bool> = rocket.figment().extract_inner("reuse").map_err(Left)?;
Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?)
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf()));
Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf)
.map(Endpoint::Unix)
.map_err(Left)
}
}
impl Listener for UnixListener {
type Accept = UnixStream; type Accept = UnixStream;
type Connection = Self::Accept; type Connection = Self::Accept;
@ -98,7 +115,7 @@ impl Connection for UnixStream {
} }
} }
impl Drop for UdsListener { impl Drop for UnixListener {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(lock) = &self.lock { if let Some(lock) = &self.lock {
let _ = std::fs::remove_file(&self.path); let _ = std::fs::remove_file(&self.path);

View File

@ -154,13 +154,15 @@ impl log::Log for RocketLogger {
} }
} }
static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false);
pub(crate) fn init_default() { pub(crate) fn init_default() {
if !ROCKET_LOGGER_SET.load(Ordering::Acquire) {
crate::log::init(&crate::Config::debug_default()) crate::log::init(&crate::Config::debug_default())
} }
}
pub(crate) fn init(config: &crate::Config) { pub(crate) fn init(config: &crate::Config) {
static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false);
// Try to initialize Rocket's logger, recording if we succeeded. // Try to initialize Rocket's logger, recording if we succeeded.
if log::set_boxed_logger(Box::new(RocketLogger)).is_ok() { if log::set_boxed_logger(Box::new(RocketLogger)).is_ok() {
ROCKET_LOGGER_SET.store(true, Ordering::Release); ROCKET_LOGGER_SET.store(true, Ordering::Release);

View File

@ -14,8 +14,9 @@ use crate::http::Status;
/// ///
/// The request guard implementation succeeds if: /// The request guard implementation succeeds if:
/// ///
/// * MTLS is [configured](crate::mtls).
/// * The client presents certificates. /// * The client presents certificates.
/// * The certificates are active and not yet expired. /// * The certificates are valid and not expired.
/// * The client's certificate chain was signed by the CA identified by the /// * The client's certificate chain was signed by the CA identified by the
/// configured `ca_certs` and with respect to SNI, if any. See [module level /// configured `ca_certs` and with respect to SNI, if any. See [module level
/// docs](crate::mtls) for configuration details. /// docs](crate::mtls) for configuration details.
@ -24,7 +25,7 @@ use crate::http::Status;
/// status of 401 Unauthorized. /// status of 401 Unauthorized.
/// ///
/// If the certificate chain fails to validate or verify, the guard _fails_ with /// If the certificate chain fails to validate or verify, the guard _fails_ with
/// the respective [`Error`]. /// the respective [`Error`] a status of 401 Unauthorized.
/// ///
/// # Wrapping /// # Wrapping
/// ///

View File

@ -79,8 +79,8 @@ pub struct MtlsConfig {
impl MtlsConfig { impl MtlsConfig {
/// Constructs a `MtlsConfig` from a path to a PEM file with a certificate /// Constructs a `MtlsConfig` from a path to a PEM file with a certificate
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This /// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
/// method does no validation; it simply creates a structure suitable for /// method does no validation; it simply creates an [`MtlsConfig`] for later
/// passing into a [`TlsConfig`]. /// use.
/// ///
/// These certificates will be used to verify client-presented certificates /// These certificates will be used to verify client-presented certificates
/// in TLS connections. /// in TLS connections.
@ -101,8 +101,7 @@ impl MtlsConfig {
/// Constructs a `MtlsConfig` from a byte buffer to a certificate authority /// Constructs a `MtlsConfig` from a byte buffer to a certificate authority
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no /// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
/// validation; it simply creates a structure suitable for passing into a /// validation; it simply creates an [`MtlsConfig`] for later use.
/// [`TlsConfig`].
/// ///
/// These certificates will be used to verify client-presented certificates /// These certificates will be used to verify client-presented certificates
/// in TLS connections. /// in TLS connections.

View File

@ -1,6 +1,6 @@
use std::convert::Infallible;
use std::fmt::Debug; use std::fmt::Debug;
use std::net::IpAddr; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use crate::{Request, Route}; use crate::{Request, Route};
use crate::outcome::{self, IntoOutcome, Outcome::*}; use crate::outcome::{self, IntoOutcome, Outcome::*};
@ -496,7 +496,7 @@ impl<'r> FromRequest<'r> for &'r Endpoint {
} }
#[crate::async_trait] #[crate::async_trait]
impl<'r> FromRequest<'r> for std::net::SocketAddr { impl<'r> FromRequest<'r> for SocketAddr {
type Error = Infallible; type Error = Infallible;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {

View File

@ -1,4 +1,4 @@
use std::fmt; use std::{io, fmt};
use std::ops::RangeFrom; use std::ops::RangeFrom;
use std::sync::{Arc, atomic::Ordering}; use std::sync::{Arc, atomic::Ordering};
use std::borrow::Cow; use std::borrow::Cow;
@ -18,7 +18,7 @@ use crate::data::Limits;
use crate::http::ProxyProto; use crate::http::ProxyProto;
use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie};
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
use crate::listener::{Certificates, Endpoint, Connection}; use crate::listener::{Certificates, Endpoint};
/// The type of an incoming web request. /// The type of an incoming web request.
/// ///
@ -44,11 +44,11 @@ pub(crate) struct ConnectionMeta {
pub peer_certs: Option<Arc<Certificates<'static>>>, pub peer_certs: Option<Arc<Certificates<'static>>>,
} }
impl<C: Connection> From<&C> for ConnectionMeta { impl ConnectionMeta {
fn from(conn: &C) -> Self { pub fn new(endpoint: io::Result<Endpoint>, certs: Option<Certificates<'_>>) -> Self {
ConnectionMeta { ConnectionMeta {
peer_endpoint: conn.endpoint().ok(), peer_endpoint: endpoint.ok(),
peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new), peer_certs: certs.map(|c| c.into_owned()).map(Arc::new),
} }
} }
} }

View File

@ -114,9 +114,8 @@ impl<'r> Builder<'r> {
/// the same name exist, they are all removed, and only the new header and /// the same name exist, they are all removed, and only the new header and
/// value will remain. /// value will remain.
/// ///
/// The type of `header` can be any type that implements `Into<Header>`. /// The type of `header` can be any type that implements `Into<Header>`. See
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and /// [trait implementations](Header#trait-implementations).
/// [hyper::header types](crate::http::hyper::header).
/// ///
/// # Example /// # Example
/// ///
@ -144,9 +143,8 @@ impl<'r> Builder<'r> {
/// `Response`. This allows for multiple headers with the same name and /// `Response`. This allows for multiple headers with the same name and
/// potentially different values to be present in the `Response`. /// potentially different values to be present in the `Response`.
/// ///
/// The type of `header` can be any type that implements `Into<Header>`. /// The type of `header` can be any type that implements `Into<Header>`. See
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) /// [trait implementations](Header#trait-implementations).
/// and [`Accept`](crate::http::Accept).
/// ///
/// # Example /// # Example
/// ///
@ -641,9 +639,8 @@ impl<'r> Response<'r> {
/// Sets the header `header` in `self`. Any existing headers with the name /// Sets the header `header` in `self`. Any existing headers with the name
/// `header.name` will be lost, and only `header` will remain. The type of /// `header.name` will be lost, and only `header` will remain. The type of
/// `header` can be any type that implements `Into<Header>`. This includes /// `header` can be any type that implements `Into<Header>`. See [trait
/// `Header` itself, [`ContentType`](crate::http::ContentType) and /// implementations](Header#trait-implementations).
/// [`hyper::header` types](crate::http::hyper::header).
/// ///
/// # Example /// # Example
/// ///
@ -723,10 +720,7 @@ impl<'r> Response<'r> {
/// Adds a custom header with name `name` and value `value` to `self`. If /// Adds a custom header with name `name` and value `value` to `self`. If
/// `self` already contains headers with the name `name`, another header /// `self` already contains headers with the name `name`, another header
/// with the same `name` and `value` is added. The type of `header` can be /// with the same `name` and `value` is added.
/// any type implements `Into<Header>`. This includes `Header` itself,
/// [`ContentType`](crate::http::ContentType) and [`hyper::header`
/// types](crate::http::hyper::header).
/// ///
/// # Example /// # Example
/// ///

View File

@ -2,15 +2,16 @@ use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::any::Any;
use futures::TryFutureExt;
use yansi::Paint; use yansi::Paint;
use either::Either; use either::Either;
use figment::{Figment, Provider}; use figment::{Figment, Provider};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::shutdown::{Stages, Shutdown}; use crate::shutdown::{Stages, Shutdown};
use crate::{sentinel, shield::Shield, Catcher, Config, Route}; use crate::{sentinel, shield::Shield, Catcher, Config, Route};
use crate::listener::{Bindable, DefaultListener, Endpoint, Listener}; use crate::listener::{Bind, DefaultListener, Endpoint, Listener};
use crate::router::Router; use crate::router::Router;
use crate::fairing::{Fairing, Fairings}; use crate::fairing::{Fairing, Fairings};
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
@ -681,19 +682,34 @@ impl Rocket<Ignite> {
rocket rocket
} }
async fn _launch(self) -> Result<Rocket<Ignite>, Error> { async fn _launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
let config = self.figment().extract::<DefaultListener>()?; where B: for<'r> Bind<&'r Rocket<Ignite>>
either::for_both!(config.base_bindable()?, base => { {
either::for_both!(config.tls_bindable(base), bindable => { let bind_endpoint = B::bind_endpoint(&&self).ok();
self._launch_on(bindable).await let listener: B = B::bind(&self).await
.map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?;
let any: Box<dyn Any + Send + Sync> = Box::new(listener);
match any.downcast::<DefaultListener>() {
Ok(listener) => {
let listener = *listener;
crate::util::for_both!(listener, listener => {
crate::util::for_both!(listener, listener => {
self._launch_on(listener).await
}) })
}) })
} }
Err(any) => {
let listener = *any.downcast::<B>().unwrap();
self._launch_on(listener).await
}
}
}
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> async fn _launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite where L: Listener + 'static,
{ {
let rocket = self.bind_and_serve(bindable, |rocket| async move { let rocket = self.listen_and_serve(listener, |rocket| async move {
let rocket = Arc::new(rocket); let rocket = Arc::new(rocket);
rocket.shutdown.spawn_listener(&rocket.config.shutdown); rocket.shutdown.spawn_listener(&rocket.config.shutdown);
@ -996,19 +1012,31 @@ impl<P: Phase> Rocket<P> {
/// } /// }
/// ``` /// ```
pub async fn launch(self) -> Result<Rocket<Ignite>, Error> { pub async fn launch(self) -> Result<Rocket<Ignite>, Error> {
self.launch_with::<DefaultListener>().await
}
pub async fn bind_launch<T, B: Bind<T>>(self, value: T) -> Result<Rocket<Ignite>, Error> {
let endpoint = B::bind_endpoint(&value).ok();
let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e)));
self.launch_on(listener.await?).await
}
pub async fn launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
where B: for<'r> Bind<&'r Rocket<Ignite>>
{
match self.0.into_state() { match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch().await, State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await,
State::Ignite(s) => Rocket::from(s)._launch().await, State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await,
State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
} }
} }
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite where L: Listener + 'static,
{ {
match self.0.into_state() { match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await, State::Build(s) => Rocket::from(s).ignite().await?._launch_on(listener).await,
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await, State::Ignite(s) => Rocket::from(s)._launch_on(listener).await,
State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
} }
} }

View File

@ -6,14 +6,14 @@ use std::time::Duration;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use hyper_util::server::conn::auto::Builder; use hyper_util::server::conn::auto::Builder;
use futures::{Future, TryFutureExt, future::Either::*}; use futures::{Future, TryFutureExt};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use crate::{Ignite, Orbit, Request, Rocket}; use crate::{Ignite, Orbit, Request, Rocket};
use crate::request::ConnectionMeta; use crate::request::ConnectionMeta;
use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
use crate::listener::{Bindable, BouncedExt, CancellableExt, Listener}; use crate::listener::{Listener, Connection, BouncedExt, CancellableExt};
use crate::error::{log_server_error, ErrorKind}; use crate::error::log_server_error;
use crate::data::{IoStream, RawStream}; use crate::data::{IoStream, RawStream};
use crate::util::{spawn_inspect, FutureExt, ReaderStream}; use crate::util::{spawn_inspect, FutureExt, ReaderStream};
use crate::http::Status; use crate::http::Status;
@ -91,31 +91,28 @@ async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
} }
impl Rocket<Ignite> { impl Rocket<Ignite> {
pub(crate) async fn bind_and_serve<B, R>( pub(crate) async fn listen_and_serve<L, R>(
self, self,
bindable: B, listener: L,
post_bind_callback: impl FnOnce(Rocket<Orbit>) -> R, orbit_callback: impl FnOnce(Rocket<Orbit>) -> R,
) -> Result<Arc<Rocket<Orbit>>> ) -> Result<Arc<Rocket<Orbit>>>
where B: Bindable, where L: Listener + 'static,
<B::Listener as Listener>::Connection: AsyncRead + AsyncWrite,
R: Future<Output = Result<Arc<Rocket<Orbit>>>> R: Future<Output = Result<Arc<Rocket<Orbit>>>>
{ {
let binding_endpoint = bindable.bind_endpoint().ok(); let endpoint = listener.endpoint()?;
let h12listener = bindable.bind()
.map_err(|e| ErrorKind::Bind(binding_endpoint, Box::new(e)))
.await?;
let endpoint = h12listener.endpoint()?;
#[cfg(feature = "http3-preview")] #[cfg(feature = "http3-preview")]
if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) { if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) {
use crate::error::ErrorKind;
let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone()) let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone())
.map_err(|e| ErrorKind::Bind(Some(endpoint.clone()), Box::new(e))) .map_err(|e| ErrorKind::Bind(Some(endpoint.clone()), Box::new(e)))
.await?; .await?;
let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]); let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]);
let rocket = post_bind_callback(rocket).await?; let rocket = orbit_callback(rocket).await?;
let http12 = tokio::task::spawn(rocket.clone().serve12(h12listener)); let http12 = tokio::task::spawn(rocket.clone().serve12(listener));
let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener)); let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener));
let (r1, r2) = tokio::join!(http12, http3); let (r1, r2) = tokio::join!(http12, http3);
r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
@ -129,8 +126,8 @@ impl Rocket<Ignite> {
} }
let rocket = self.into_orbit(vec![endpoint]); let rocket = self.into_orbit(vec![endpoint]);
let rocket = post_bind_callback(rocket).await?; let rocket = orbit_callback(rocket).await?;
rocket.clone().serve12(h12listener).await?; rocket.clone().serve12(listener).await?;
Ok(rocket) Ok(rocket)
} }
} }
@ -160,11 +157,11 @@ impl Rocket<Orbit> {
} }
let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder)); let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder));
while let Some(accept) = listener.accept().unless(self.shutdown()).await? { while let Some(accept) = listener.accept().race(self.shutdown()).await.left().transpose()? {
let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone()); let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone());
spawn_inspect(|e| log_server_error(&**e), async move { spawn_inspect(|e| log_server_error(&**e), async move {
let conn = listener.connect(accept).io_unless(rocket.shutdown()).await?; let conn = listener.connect(accept).race_io(rocket.shutdown()).await?;
let meta = ConnectionMeta::from(&conn); let meta = ConnectionMeta::new(conn.endpoint(), conn.certificates());
let service = service_fn(|mut req| { let service = service_fn(|mut req| {
let upgrade = hyper::upgrade::on(&mut req); let upgrade = hyper::upgrade::on(&mut req);
let (parts, incoming) = req.into_parts(); let (parts, incoming) = req.into_parts();
@ -173,9 +170,9 @@ impl Rocket<Orbit> {
let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone())); let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone()));
let mut server = pin!(server.serve_connection_with_upgrades(io, service)); let mut server = pin!(server.serve_connection_with_upgrades(io, service));
match server.as_mut().or(rocket.shutdown()).await { match server.as_mut().race(rocket.shutdown()).await.left() {
Left(result) => result, Some(result) => result,
Right(()) => { None => {
server.as_mut().graceful_shutdown(); server.as_mut().graceful_shutdown();
server.await server.await
}, },
@ -189,26 +186,26 @@ impl Rocket<Orbit> {
#[cfg(feature = "http3-preview")] #[cfg(feature = "http3-preview")]
async fn serve3(self: Arc<Self>, listener: crate::listener::quic::QuicListener) -> Result<()> { async fn serve3(self: Arc<Self>, listener: crate::listener::quic::QuicListener) -> Result<()> {
let rocket = self.clone(); let rocket = self.clone();
let listener = Arc::new(listener.bounced()); let listener = Arc::new(listener);
while let Some(accept) = listener.accept().unless(rocket.shutdown()).await? { while let Some(Some(accept)) = listener.accept().race(rocket.shutdown()).await.left() {
let (listener, rocket) = (listener.clone(), rocket.clone()); let (listener, rocket) = (listener.clone(), rocket.clone());
spawn_inspect(|e: &io::Error| log_server_error(e), async move { spawn_inspect(|e: &io::Error| log_server_error(e), async move {
let mut stream = listener.connect(accept).io_unless(rocket.shutdown()).await?; let mut stream = listener.connect(accept).race_io(rocket.shutdown()).await?;
while let Some(mut conn) = stream.accept().io_unless(rocket.shutdown()).await? { while let Some(mut conn) = stream.accept().race_io(rocket.shutdown()).await? {
let rocket = rocket.clone(); let rocket = rocket.clone();
spawn_inspect(|e: &io::Error| log_server_error(e), async move { spawn_inspect(|e: &io::Error| log_server_error(e), async move {
let meta = ConnectionMeta::from(&conn); let meta = ConnectionMeta::new(conn.endpoint(), None);
let rx = conn.rx.cancellable(rocket.shutdown.clone()); let rx = conn.rx.cancellable(rocket.shutdown.clone());
let response = rocket.clone() let response = rocket.clone()
.service(conn.parts, rx, None, meta) .service(conn.parts, rx, None, meta)
.map_err(io::Error::other) .map_err(io::Error::other)
.io_unless(rocket.shutdown.mercy.clone()) .race_io(rocket.shutdown.mercy.clone())
.await?; .await?;
let grace = rocket.shutdown.grace.clone(); let grace = rocket.shutdown.grace.clone();
match conn.tx.send_response(response).or(grace).await { match conn.tx.send_response(response).race(grace).await.left() {
Left(result) => result, Some(result) => result,
Right(_) => Ok(conn.tx.cancel()), None => Ok(conn.tx.cancel()),
} }
}); });
} }

View File

@ -88,7 +88,7 @@ impl Shutdown {
/// This function returns immediately; pending requests will continue to run /// This function returns immediately; pending requests will continue to run
/// until completion or expiration of the grace period, which ever comes /// until completion or expiration of the grace period, which ever comes
/// first, before the actual shutdown occurs. The grace period can be /// first, before the actual shutdown occurs. The grace period can be
/// configured via [`Shutdown::grace`](crate::config::ShutdownConfig::grace). /// configured via [`ShutdownConfig`]'s `grace` field.
/// ///
/// ```rust /// ```rust
/// # use rocket::*; /// # use rocket::*;

View File

@ -1,11 +1,15 @@
use std::io; use std::io;
use std::sync::Arc;
use rustls::crypto::{ring, CryptoProvider}; use futures::TryFutureExt;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use figment::value::magic::{Either, RelativePathBuf}; use figment::value::magic::{Either, RelativePathBuf};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use indexmap::IndexSet; use indexmap::IndexSet;
use rustls::crypto::{ring, CryptoProvider};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use crate::tls::resolver::DynResolver;
use crate::tls::error::{Result, Error, KeyError}; use crate::tls::error::{Result, Error, KeyError};
/// TLS configuration: certificate chain, key, and ciphersuites. /// TLS configuration: certificate chain, key, and ciphersuites.
@ -35,7 +39,8 @@ use crate::tls::error::{Result, Error, KeyError};
/// ///
/// Additionally, the `mutual` parameter controls if and how the server /// Additionally, the `mutual` parameter controls if and how the server
/// authenticates clients via mutual TLS. It works in concert with the /// authenticates clients via mutual TLS. It works in concert with the
/// [`mtls`](crate::mtls) module. See [`MtlsConfig`] for configuration details. /// [`mtls`](crate::mtls) module. See [`MtlsConfig`](crate::mtls::MtlsConfig)
/// for configuration details.
/// ///
/// In `Rocket.toml`, configuration might look like: /// In `Rocket.toml`, configuration might look like:
/// ///
@ -78,7 +83,7 @@ use crate::tls::error::{Result, Error, KeyError};
/// # assert_eq!(tls_config.ciphers().count(), 9); /// # assert_eq!(tls_config.ciphers().count(), 9);
/// # assert!(!tls_config.prefer_server_cipher_order()); /// # assert!(!tls_config.prefer_server_cipher_order());
/// ``` /// ```
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] #[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
pub struct TlsConfig { pub struct TlsConfig {
/// Path to a PEM file with, or raw bytes for, a DER-encoded X.509 TLS /// Path to a PEM file with, or raw bytes for, a DER-encoded X.509 TLS
/// certificate chain. /// certificate chain.
@ -97,6 +102,8 @@ pub struct TlsConfig {
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub(crate) mutual: Option<crate::mtls::MtlsConfig>, pub(crate) mutual: Option<crate::mtls::MtlsConfig>,
#[serde(skip)]
pub(crate) resolver: Option<DynResolver>,
} }
/// A supported TLS cipher suite. /// A supported TLS cipher suite.
@ -134,6 +141,7 @@ impl Default for TlsConfig {
prefer_server_cipher_order: false, prefer_server_cipher_order: false,
#[cfg(feature = "mtls")] #[cfg(feature = "mtls")]
mutual: None, mutual: None,
resolver: None,
} }
} }
} }
@ -430,8 +438,57 @@ impl TlsConfig {
self.mutual.as_ref() self.mutual.as_ref()
} }
pub fn validate(&self) -> Result<(), crate::tls::Error> { /// Try to convert `self` into a [rustls] [`ServerConfig`].
self.server_config().map(|_| ()) ///
/// [`ServerConfig`]: rustls::server::ServerConfig
pub async fn server_config(&self) -> Result<rustls::server::ServerConfig> {
let this = self.clone();
tokio::task::spawn_blocking(move || this._server_config())
.map_err(io::Error::other)
.await?
}
/// Try to convert `self` into a [rustls] [`ServerConfig`].
///
/// [`ServerConfig`]: rustls::server::ServerConfig
pub(crate) fn _server_config(&self) -> Result<rustls::server::ServerConfig> {
let provider = Arc::new(self.default_crypto_provider());
#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca = Arc::new(mtls.load_ca_certs()?);
let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};
#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();
let mut tls_config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(self.load_certs()?, self.load_key()?)?;
tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}
Ok(tls_config)
}
/// NOTE: This is a blocking function.
pub fn validate(&self) -> Result<()> {
self._server_config().map(|_| ())
} }
} }

View File

@ -17,6 +17,7 @@ pub enum Error {
CertChain(std::io::Error), CertChain(std::io::Error),
PrivKey(KeyError), PrivKey(KeyError),
CertAuth(rustls::Error), CertAuth(rustls::Error),
Config(figment::Error),
} }
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
@ -31,6 +32,7 @@ impl std::fmt::Display for Error {
PrivKey(e) => write!(f, "failed to process private key: {e}"), PrivKey(e) => write!(f, "failed to process private key: {e}"),
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"), CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
Bind(e) => write!(f, "failed to bind to network interface: {e}"), Bind(e) => write!(f, "failed to bind to network interface: {e}"),
Config(e) => write!(f, "failed to read tls configuration: {e}"),
} }
} }
} }
@ -69,6 +71,7 @@ impl std::error::Error for Error {
Error::PrivKey(e) => Some(e), Error::PrivKey(e) => Some(e),
Error::CertAuth(e) => Some(e), Error::CertAuth(e) => Some(e),
Error::Bind(e) => Some(&**e), Error::Bind(e) => Some(&**e),
Error::Config(e) => Some(e),
} }
} }
} }
@ -102,3 +105,9 @@ impl From<std::convert::Infallible> for Error {
v.into() v.into()
} }
} }
impl From<figment::Error> for Error {
fn from(value: figment::Error) -> Self {
Error::Config(value)
}
}

View File

@ -0,0 +1,104 @@
use std::io;
use std::sync::Arc;
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::LazyConfigAcceptor;
use rustls::server::{Acceptor, ServerConfig};
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
use crate::tls::{Error, TlsConfig};
use super::resolver::DynResolver;
#[doc(inline)]
pub use tokio_rustls::server::TlsStream;
/// A TLS listener over some listener interface L.
pub struct TlsListener<L> {
listener: L,
config: TlsConfig,
default: Arc<ServerConfig>,
}
impl<T: Send, L: Bind<T>> Bind<(T, TlsConfig)> for TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>,
{
type Error = Error;
async fn bind((inner, config): (T, TlsConfig)) -> Result<Self, Self::Error> {
Ok(TlsListener {
default: Arc::new(config.server_config().await?),
listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?,
config,
})
}
fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result<Endpoint, Self::Error> {
L::bind_endpoint(inner)
.map(|e| e.with_tls(config))
.map_err(|e| Error::Bind(Box::new(e)))
}
}
impl<'r, L> Bind<&'r Rocket<Ignite>> for TlsListener<L>
where L: Bind<&'r Rocket<Ignite>> + Listener<Accept = <L as Listener>::Connection>
{
type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
config.resolver = DynResolver::extract(rocket);
<Self as Bind<_>>::bind((rocket, config)).await
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
<Self as Bind<_>>::bind_endpoint(&(*rocket, config))
}
}
impl<L> Listener for TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>,
L::Connection: AsyncRead + AsyncWrite
{
type Accept = L::Connection;
type Connection = TlsStream<L::Connection>;
async fn accept(&self) -> io::Result<Self::Accept> {
self.listener.accept().await
}
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
let handshake = acceptor.await?;
let hello = handshake.client_hello();
let config = match &self.config.resolver {
Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
None => self.default.clone(),
};
handshake.into_stream(config).await
}
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.listener.endpoint()?.with_tls(&self.config))
}
}
impl<C: Connection> Connection for TlsStream<C> {
fn endpoint(&self) -> io::Result<Endpoint> {
Ok(self.get_ref().0.endpoint()?.assume_tls())
}
fn certificates(&self) -> Option<Certificates<'_>> {
#[cfg(feature = "mtls")] {
let cert_chain = self.get_ref().1.peer_certificates()?;
Some(Certificates::from(cert_chain))
}
#[cfg(not(feature = "mtls"))]
None
}
}

View File

@ -1,6 +1,9 @@
mod error; mod error;
mod resolver;
mod listener;
pub(crate) mod config; pub(crate) mod config;
pub use error::Result; pub use error::{Error, Result};
pub use config::{TlsConfig, CipherSuite}; pub use config::{TlsConfig, CipherSuite};
pub use error::Error; pub use resolver::{Resolver, ClientHello, ServerConfig};
pub use listener::{TlsListener, TlsStream};

View File

@ -0,0 +1,115 @@
use std::fmt;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;
pub use rustls::server::{ClientHello, ServerConfig};
use crate::{Build, Ignite, Rocket};
use crate::fairing::{self, Info, Kind};
/// Proxy type to get PartialEq + Debug impls.
#[derive(Clone)]
pub(crate) struct DynResolver(Arc<dyn Resolver>);
pub struct Fairing<T: ?Sized>(PhantomData<T>);
/// A dynamic TLS configuration resolver.
///
/// # Example
///
/// This is an async trait. Implement it as follows:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use std::sync::Arc;
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
/// use rocket::{Rocket, Build};
///
/// struct MyResolver(Arc<ServerConfig>);
///
/// #[rocket::async_trait]
/// impl Resolver for MyResolver {
/// async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
/// // This is equivalent to what the default resolver would do.
/// let config: TlsConfig = rocket.figment().extract_inner("tls")?;
/// let server_config = config.server_config().await?;
/// Ok(MyResolver(Arc::new(server_config)))
/// }
///
/// async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
/// // return a `ServerConfig` based on `hello`; here we ignore it
/// Some(self.0.clone())
/// }
/// }
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().attach(MyResolver::fairing())
/// }
/// ```
#[crate::async_trait]
pub trait Resolver: Send + Sync + 'static {
async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
let _rocket = rocket;
let type_name = std::any::type_name::<Self>();
Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
}
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
fn fairing() -> Fairing<Self> where Self: Sized {
Fairing(PhantomData)
}
}
#[crate::async_trait]
impl<T: Resolver> fairing::Fairing for Fairing<T> {
fn info(&self) -> Info {
Info {
name: "Resolver Fairing",
kind: Kind::Ignite | Kind::Singleton
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
use yansi::Paint;
let result = T::init(&rocket).await;
match result {
Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
Err(e) => {
let name = std::any::type_name::<T>();
error!("TLS resolver {} failed to initialize.", name.primary().bold());
error_!("{e}");
Err(rocket)
}
}
}
}
impl DynResolver {
pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
}
}
impl fmt::Debug for DynResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Resolver").finish()
}
}
impl PartialEq for DynResolver {
fn eq(&self, _: &Self) -> bool {
false
}
}
impl Deref for DynResolver {
type Target = dyn Resolver;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@ -22,39 +22,39 @@ pub fn spawn_inspect<E, F, Fut>(or: F, future: Fut)
use std::io; use std::io;
use std::pin::pin; use std::pin::pin;
use std::future::Future; use std::future::Future;
use futures::future::{select, Either}; use either::Either;
use futures::future;
pub trait FutureExt: Future + Sized { pub trait FutureExt: Future + Sized {
/// Await `self` or `other`, whichever finishes first. /// Await `self` or `other`, whichever finishes first.
async fn or<B: Future>(self, other: B) -> Either<Self::Output, B::Output> { async fn race<B: Future>(self, other: B) -> Either<Self::Output, B::Output> {
match futures::future::select(pin!(self), pin!(other)).await { match future::select(pin!(self), pin!(other)).await {
Either::Left((v, _)) => Either::Left(v), future::Either::Left((v, _)) => Either::Left(v),
Either::Right((v, _)) => Either::Right(v), future::Either::Right((v, _)) => Either::Right(v),
} }
} }
/// Await `self` unless `trigger` completes. Returns `Ok(Some(T))` if `self` async fn race_io<T, K: Future>(self, trigger: K) -> io::Result<T>
/// completes successfully before `trigger`, `Err(E)` if `self` completes where Self: Future<Output = io::Result<T>>
/// unsuccessfully, and `Ok(None)` if `trigger` completes before `self`.
async fn unless<T, E, K: Future>(self, trigger: K) -> Result<Option<T>, E>
where Self: Future<Output = Result<T, E>>
{ {
match select(pin!(self), pin!(trigger)).await { match future::select(pin!(self), pin!(trigger)).await {
Either::Left((v, _)) => Ok(Some(v?)), future::Either::Left((v, _)) => v,
Either::Right((_, _)) => Ok(None), future::Either::Right((_, _)) => Err(io::Error::other("i/o terminated")),
}
}
/// Await `self` unless `trigger` completes. If `self` completes before
/// `trigger`, returns the result. Otherwise, always returns an `Err`.
async fn io_unless<T, K: Future>(self, trigger: K) -> std::io::Result<T>
where Self: Future<Output = std::io::Result<T>>
{
match select(pin!(self), pin!(trigger)).await {
Either::Left((v, _)) => v,
Either::Right((_, _)) => Err(io::Error::other("I/O terminated")),
} }
} }
} }
impl<F: Future + Sized> FutureExt for F { } impl<F: Future + Sized> FutureExt for F { }
#[doc(hidden)]
#[macro_export]
macro_rules! for_both {
($value:expr, $pattern:pat => $result:expr) => {
match $value {
tokio_util::either::Either::Left($pattern) => $result,
tokio_util::either::Either::Right($pattern) => $result,
}
};
}
pub use for_both;

View File

@ -3,6 +3,7 @@ use std::net::{SocketAddr, Ipv4Addr};
use rocket::config::Config; use rocket::config::Config;
use rocket::fairing::AdHoc; use rocket::fairing::AdHoc;
use rocket::futures::channel::oneshot; use rocket::futures::channel::oneshot;
use rocket::listener::tcp::TcpListener;
#[rocket::async_test] #[rocket::async_test]
async fn on_ignite_fairing_can_inspect_port() { async fn on_ignite_fairing_can_inspect_port() {
@ -15,6 +16,7 @@ async fn on_ignite_fairing_can_inspect_port() {
}) })
})); }));
rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))); let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr));
assert_ne!(rx.await.unwrap(), 0); assert_ne!(rx.await.unwrap(), 0);
} }

View File

@ -10,7 +10,7 @@ rocket = { path = "../../core/lib", features = ["secrets"] }
[dev-dependencies] [dev-dependencies]
rocket = { path = "../../core/lib", features = ["secrets", "json", "mtls"] } rocket = { path = "../../core/lib", features = ["secrets", "json", "mtls"] }
figment = { version = "0.10", features = ["toml", "env"] } figment = { version = "0.10.17", features = ["toml", "env"] }
tokio = { version = "1", features = ["macros", "io-std"] } tokio = { version = "1", features = ["macros", "io-std"] }
rand = "0.8" rand = "0.8"

View File

@ -0,0 +1,88 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCjle13u/R/0+zw
eycXhdF7ZNYQfqXfkMpw9GlerbqRrxSLEc/YXXBuIO5AZKkXYeP8iM9KbSBD4p8F
wZD7LL47601c5WwWpNfOravCaSjYgvaYyhnoNzmG8NYaVYKB9kup6lOyQmesNXEK
NGNSrKpsoaQ7jBk+l+VV1jNBjMhNVWuz4AdFMsVD09QyL1GvQ0OvT/BbUKypaFFw
YcHruYvHuKGnrlXkvw05aZmKtKiSE6UQoDKtZWfV8yV2M6Sr75i9GKaGMyUZIl88
MxVLGcGwO6To2wNFKfLkHLOGIWrKA7m/Bb2n1k2OT+6iOnDzU62BoAzG/j8dhNPL
mZ6a7cZfAgMBAAECggEANwiZe06gUuDZNY44+JDsiLbDzYjOBQiREq8nQ9LukVR1
dNPpOME2sdYiUUeMG3GzYaIlGsTbtfrnxOf5/oZu+XmP7VDBrFyIvd9viVgXhb+J
dp2HWbg6gktDvFhIL7DMg71xqubsOeNAxE4bnBS6wREgT2gylfxECzykwci7Gki4
AkeihvaxqdHk9WP8dtFOuCYhX5pyKd9veS1/L01dVMpoFrq72PHupplKYb3HIo+v
ga02DhNVcH3fomEbXzazC64k2h5Vz+8mgpu5/V1thKiB2izOwt/hv4tkf2iDNz43
xdSYUEFsk80M97VI1dM1+TBe/JO0auZvKLkuOWUjAQKBgQDlBMr+d+guajgQ863I
uEFK4veEXrD51L6AKT+cqFUi894fhOodnnmK8l3JBKzO0zjgsaez8exKZPRN8An8
4MejM+hMYciJsP7uDpPkhlI5zHd9CR7EFPWXXpt4PecQLvBbnJ/lDnWCrE4m5Zhs
9OR7izLMBAmaiPlTNAaXj22iqwKBgQC226wzXGr//lnTggZX+u9UdkZKewAYlgnB
Ywj3+JB6Q/kDDS8C6fdlAvWyHShxtO3gx2pJSI3hk7J8fZu/kbojlLF16ayO+tgg
t3EoTZxN5zncygPaULstdKHhnMp8a4AO8lLrHtackFbbX7fuUJft0w457FpARvM8
DONjWI8LHQKBgBBY5TyAxpv5jQL4weDf9hkoVk6mi69plieDyjyeb2VNTv+k9yki
FL7sSfF9WfBxd0/innvjuuAckKu3hJ7+VIG7xMse97eMYMYRWFEpnVju1WChdAa/
EEC7yhEtKf8nupRve6JYA99N+U4heV3dpSmEaB3T8/OJ73IW9pl+7W59AoGADxM/
OCDHZYF3sFtI4Jn8fy8dDmjjkiNUfJAInkDs0FeoQNsmZAwb7ET5Moz615z9+4kV
NyN3JwDBN0g3vexqtyI8Gyd/pW4CwXe+KX90gmustoolFSuQsueprOr7OpS2QwUx
Vtb9BH1V29IhXNFiJSZARwA4VJJE3U+Gs5sKd/UCgYEAoCPE3gVaa89nOqQtalhT
9SISOGQxxMknjNFrEuF3UaGuR0cxDRLX6lSEneAATEpho0QB2Fj4vO8PiyYOGvH+
5ouJD97rcU77OOixlLFt4+TAWI9AvT0mN7y+SHJ22RkwWGQyF4TIfkg0tQvu36D+
35W26Li1WteB2O4wV9qVReA=
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIEwDCCAqigAwIBAgIUay5Z8sVQUkSTFpacn6o4iq2ElGowDQYJKoZIhvcNAQEL
BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg
Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTI0MDQxNDA4MTU0MVoXDTI1
MDQxNDA4MTU0MVowgY4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh
MRcwFQYDVQQHDA5TaWxpY29uIFZhbGxleTEPMA0GA1UECgwGUm9ja2V0MRswGQYD
VQQDDBJSb2NrZXQgVExTIEV4YW1wbGUxIzAhBgkqhkiG9w0BCQEWFGV4YW1wbGVA
cm9ja2V0LmxvY2FsMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAo5Xt
d7v0f9Ps8HsnF4XRe2TWEH6l35DKcPRpXq26ka8UixHP2F1wbiDuQGSpF2Hj/IjP
Sm0gQ+KfBcGQ+yy+O+tNXOVsFqTXzq2rwmko2IL2mMoZ6Dc5hvDWGlWCgfZLqepT
skJnrDVxCjRjUqyqbKGkO4wZPpflVdYzQYzITVVrs+AHRTLFQ9PUMi9Rr0NDr0/w
W1CsqWhRcGHB67mLx7ihp65V5L8NOWmZirSokhOlEKAyrWVn1fMldjOkq++YvRim
hjMlGSJfPDMVSxnBsDuk6NsDRSny5ByzhiFqygO5vwW9p9ZNjk/uojpw81OtgaAM
xv4/HYTTy5memu3GXwIDAQABo1wwWjAYBgNVHREEETAPgg1ETlM6bG9jYWxob3N0
MB0GA1UdDgQWBBSowDBXM26C7VogwXNB1F0vLpYO7DAfBgNVHSMEGDAWgBREAyUj
0lTwopZ2B1VmnvMPfUtCkzANBgkqhkiG9w0BAQsFAAOCAgEAbjF11+t8qVEF72ey
19p1sRkG9ygb0gE2UpLzVpPilucioIOwQuT4rvsVYZQxK+smQZURDI4uNXODIeoS
r3maL82VryLSYbkQADyShYjF0uCX8AfCI0YtOKOschNZDcZEJ5mUpHjJE0lEZnkO
x8ZVXwWf4pv1/8DZoCkMN3gDHwhQGPtrls4q7O38rI7zK9DNrzu7R1ZdGjQSDasL
6DqHee90O2ejpELUxO6lRl2EUosfklRvjV7hfrDHlpN9EuweXt0JiaKw3WZzHSLa
dKS8wtTMq5aWzOWrew1ZEhRr+B3KS6BSC5o9xSQMfcDyS0KJcIJI9bNh3nElWFhM
IBVtGxM/EYAwNJ++jLD10WHvaqW0epMV2cUu+dGJX+TPuI0c/wNehisS4ahvR64m
UpjAwNUBlYpR/Gb15/i2fVk2BbUyU3AcpZfWFDopQ8UqC8ALVcNjbNHq+yVkuTpj
gn5iiTTcTqb6qNfie4oDX4KR6ZgpNiTl/PWZo58qxSwdGiJwrINACkPJ6Qg6Qrpd
hp3CanTWjioHfvTSdiubqw5/XRnqa2Iav0Sttc6TPnTimodmtWkaYA8mvjS+jq8N
f9l2UYQz8yLabMkn98BM+gRJYwrVt6sCbVuEaHgPwq/qX9mQFhUrfw3iEPKlmezt
T3AhgPhybUpMFpu+4Tp8JE2JlKQ=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIFbzCCA1egAwIBAgIURX345HUrWikAysSTFd8xoV5GSIYwDQYJKoZIhvcNAQEL
BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg
Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTIxMDcwOTIzMzMzM1oXDTMx
MDcwNzIzMzMzM1owRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQK
DAlSb2NrZXQgQ0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMIICIjANBgkqhkiG
9w0BAQEFAAOCAg8AMIICCgKCAgEAybxw0cVrq8yn9W9xRDHdci8rnA5CxPcxAyM5
y5LCFOV/tTY0IZgrmegRLo4blmz8QNiFta9Ydt7zsm7XUTm6BhJ7TkOUAcfP7eSv
3jNEIEJQLU+k5SepV7pwFPRjUr6+a7yypS2xXAkDEVoyvzsuKYwzj+x6HvDuVhOF
2zv4Kk0sLfS/3UelMdilKa5VBCL/WMEXaCpb7/BMUUwn868LVU8E9+1H6uDQMxKo
ZH1mH98yeFODjzM9Ko6n2ghXx8qbe+wab4mSHn/SPgFnDFU+ujyPXIQqrS4PSQW3
5lkCn70hOw2K+8LHDBmgxOLk2Jb8o8PJWX6v346dlRcZr9VzMqCyKvEf1i5oT2hg
NZrkDdUOgyMZeq6H7pQpSxSFSMtkaombSm816V0rg7/sXwS66KyaYJY7x8eYEpgd
GuQKXkyIwp687TGLul97amoy/J3jIDnQOuf/YEcdyHCKojh20E5AERC4sCg6l+qs
5Nbol7jZclzBFf+70JOsUFmCfVYd5e0LKWdYV9UhYABc3yQqJyzy/eyihWihUNZU
LXStjd+XIkhKs+b7uKaBp1poFfgjpdboxmREyppWexua1t0eAReBgMU43bEGoy+B
iWoTFjyeQijd6M++npzsqwknYyv+7VjX3EfijyTFgIpZUL196PTJ5SGJMf7eJmaG
BO0g2W0CAwEAAaNTMFEwHQYDVR0OBBYEFEQDJSPSVPCilnYHVWae8w99S0KTMB8G
A1UdIwQYMBaAFEQDJSPSVPCilnYHVWae8w99S0KTMA8GA1UdEwEB/wQFMAMBAf8w
DQYJKoZIhvcNAQELBQADggIBACCArR/ArOyoh97Pgie37miFJEJNtAe+ApzhDevh
11P0Vn5hbu+dR7vftCJ7e+0u0irbPgfdxuW3IpEDgL+fttqCIRdAT6MTKLiUmrVS
x0fQJqC4Hw4o+bIeRsaNAfB/SAEvOxBbyu5szOFak1r/sXVs4vzBINIF3NdtbNtj
Bhac0Fiy/+DlfTHJSRGvzYo+GljXHkrG02mF4aOWx9x97y/6UzbLqHJPINgyAIlN
ts29QIHVNtQQyUN292xC1F4TSrBNB+GziGt3XZ8YEASCkMEnIvs3Lpzsjjm9TrkE
W/b9ee3C6RWg+RW3pokORMM7Q/lSOMWUmPrzI7CBCKaQUNN9g+iimLkPyp386sCS
zXJDd0OKb0xkpxhrauEvzNfEJxGDQbxs8s598ZofhVo9ehdmmXcJAw/zUZjHSrI2
PW+vHJ4kslBmKtH1oyAW3zYiFyYYPu4ohkeSrq8z8351upxwJUm4m/ndByXTrPwz
Yj6dEHaysjoRl0wOJgQ7G2ikw1QtWja2apJN9Q66i98vEDmtoEyOqOLMSjKjFL7c
sSJ6vAittYtIziIeMK7E8lDc1rtzMT5MOAoTriVyIGBgHFs96YOoL0Vi5QmVtQtc
8dkFUapFAUj8pREVxnJoLGose/FxBvF2FQZ5Sb25pyTPAeXk7y56noF78nusiVSF
xRjI
-----END CERTIFICATE-----

View File

@ -9,6 +9,7 @@
# ecdsa_nistp256_sha256 # ecdsa_nistp256_sha256
# ecdsa_nistp384_sha384 # ecdsa_nistp384_sha384
# ecdsa_nistp521_sha512 # ecdsa_nistp521_sha512
# client
# #
# Generate a certificate of the [cert-kind] key type, or if no cert-kind is # Generate a certificate of the [cert-kind] key type, or if no cert-kind is
# specified, all of the certificates. # specified, all of the certificates.
@ -136,12 +137,23 @@ function gen_ecdsa_nistp521_sha512() {
rm ca_cert.srl server.csr ecdsa_nistp521_sha512_key.pem rm ca_cert.srl server.csr ecdsa_nistp521_sha512_key.pem
} }
function gen_client_cert() {
openssl req -newkey rsa:2048 -nodes -keyout client.key -out client.csr
openssl x509 -req -extfile <(printf "subjectAltName=DNS:${ALT}") -days 365 \
-in client.csr -CA ca_cert.pem -CAkey ca_key.pem -CAcreateserial \
-out client.crt
cat client.key client.crt ca_cert.pem > client.pem
rm client.key client.crt client.csr ca_cert.srl
}
case $1 in case $1 in
ed25519) gen_ed25519 ;; ed25519) gen_ed25519 ;;
rsa_sha256) gen_rsa_sha256 ;; rsa_sha256) gen_rsa_sha256 ;;
ecdsa_nistp256_sha256) gen_ecdsa_nistp256_sha256 ;; ecdsa_nistp256_sha256) gen_ecdsa_nistp256_sha256 ;;
ecdsa_nistp384_sha384) gen_ecdsa_nistp384_sha384 ;; ecdsa_nistp384_sha384) gen_ecdsa_nistp384_sha384 ;;
ecdsa_nistp521_sha512) gen_ecdsa_nistp521_sha512 ;; ecdsa_nistp521_sha512) gen_ecdsa_nistp521_sha512 ;;
client) gen_client_cert ;;
*) *)
gen_ed25519 gen_ed25519
gen_rsa_sha256 gen_rsa_sha256

View File

@ -7,6 +7,7 @@ use rocket::log::LogLevel;
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
use rocket::fairing::{Fairing, Info, Kind}; use rocket::fairing::{Fairing, Info, Kind};
use rocket::response::Redirect; use rocket::response::Redirect;
use rocket::listener::tcp::TcpListener;
use yansi::Paint; use yansi::Paint;
@ -59,7 +60,7 @@ impl Redirector {
rocket::custom(&config.server) rocket::custom(&config.server)
.manage(config) .manage(config)
.mount("/", redirects) .mount("/", redirects)
.launch_on(addr) .bind_launch::<_, TcpListener>(addr)
.await .await
} }
} }

View File

@ -66,8 +66,7 @@ fn insecure_cookies() {
} }
fn validate_profiles(profiles: &[&str]) { fn validate_profiles(profiles: &[&str]) {
use rocket::listener::DefaultListener; use rocket::config::{Config, TlsConfig, SecretKey};
use rocket::config::{Config, SecretKey};
for profile in profiles { for profile in profiles {
let config = Config { let config = Config {
@ -81,9 +80,8 @@ fn validate_profiles(profiles: &[&str]) {
assert_eq!(response.into_string().unwrap(), "Hello, world!"); assert_eq!(response.into_string().unwrap(), "Hello, world!");
let figment = client.rocket().figment(); let figment = client.rocket().figment();
let listener: DefaultListener = figment.extract().unwrap(); let config: TlsConfig = figment.extract_inner("tls").unwrap();
assert_eq!(figment.profile(), profile); config.validate().expect("valid TLS config");
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
} }
} }

View File

@ -171,6 +171,15 @@ function test_default() {
echo ":: Checking fuzzers..." echo ":: Checking fuzzers..."
indir "${FUZZ_ROOT}" $CARGO update indir "${FUZZ_ROOT}" $CARGO update
indir "${FUZZ_ROOT}" $CARGO check --all --all-features $@ indir "${FUZZ_ROOT}" $CARGO check --all --all-features $@
case "$OSTYPE" in
darwin* | linux*)
echo ":: Checking testbench..."
indir "${TESTBENCH_ROOT}" $CARGO update
indir "${TESTBENCH_ROOT}" $CARGO check $@
;;
*) echo ":: Skipping testbench [$OSTYPE]" ;;
esac
} }
function test_ui() { function test_ui() {

View File

@ -1,6 +1,6 @@
[package] [package]
name = "rocket-testbench" name = "testbench"
description = "end-to-end HTTP testbench for Rocket" description = "End-to-end HTTP Rocket testbench."
version = "0.0.0" version = "0.0.0"
edition = "2021" edition = "2021"
publish = false publish = false
@ -12,6 +12,7 @@ thiserror = "1.0"
procspawn = "1" procspawn = "1"
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
ipc-channel = "0.18" ipc-channel = "0.18"
rustls-pemfile = "2.1"
[dependencies.nix] [dependencies.nix]
version = "0.28" version = "0.28"

View File

@ -1,206 +1,64 @@
use std::time::Duration; use std::time::Duration;
use std::sync::Once;
use std::process::Stdio;
use std::io::{self, Read};
use rocket::fairing::AdHoc; use reqwest::blocking::{ClientBuilder, RequestBuilder};
use rocket::http::ext::IntoOwned; use rocket::http::{ext::IntoOwned, uri::{Absolute, Uri}};
use rocket::http::uri::{self, Absolute, Uri};
use rocket::serde::{Deserialize, Serialize};
use rocket::{Build, Rocket};
use procspawn::SpawnError; use crate::{Result, Error, Server};
use thiserror::Error;
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender};
static DEFAULT_CONFIG: &str = r#"
[default]
address = "tcp:127.0.0.1"
workers = 2
port = 0
cli_colors = false
secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg="
[default.shutdown]
grace = 1
mercy = 1
"#;
#[derive(Debug)] #[derive(Debug)]
#[allow(unused)]
pub struct Client { pub struct Client {
client: reqwest::blocking::Client, client: reqwest::blocking::Client,
server: procspawn::JoinHandle<()>,
tls: bool,
port: u16,
rx: IpcReceiver<Message>,
} }
#[derive(Error, Debug)] impl Client {
pub enum Error { pub fn default() -> Client {
#[error("join/kill failed: {0}")] Client::build()
JoinError(#[from] SpawnError), .try_into()
#[error("kill failed: {0}")] .expect("default builder ok")
TermFailure(#[from] nix::errno::Errno),
#[error("i/o error: {0}")]
Io(#[from] io::Error),
#[error("invalid URI: {0}")]
Uri(#[from] uri::Error<'static>),
#[error("the URI is invalid")]
InvalidUri,
#[error("bad request: {0}")]
Request(#[from] reqwest::Error),
#[error("IPC failure: {0}")]
Ipc(#[from] ipc_channel::ipc::IpcError),
#[error("liftoff failed")]
Liftoff(String, String),
} }
#[derive(Debug, Serialize, Deserialize)] pub fn build() -> ClientBuilder {
#[serde(crate = "rocket::serde")] reqwest::blocking::Client::builder()
pub enum Message {
Liftoff(bool, u16),
Failure,
}
#[derive(Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
#[must_use]
pub struct Token(String);
pub type Result<T, E = Error> = std::result::Result<T, E>;
impl Token {
fn configure(&self, toml: &str, rocket: Rocket<Build>) -> Rocket<Build> {
use rocket::figment::{Figment, providers::{Format, Toml}};
let toml = toml.replace("{CRATE}", env!("CARGO_MANIFEST_DIR"));
let config = Figment::from(rocket.figment())
.merge(Toml::string(DEFAULT_CONFIG).nested())
.merge(Toml::string(&toml).nested());
let server = self.0.clone();
rocket.configure(config)
.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap();
let tls = rocket.endpoints().any(|e| e.is_tls());
let sender = IpcSender::<Message>::connect(server).unwrap();
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
})))
}
pub fn rocket(&self, toml: &str) -> Rocket<Build> {
self.configure(toml, rocket::build())
}
pub fn configured_launch(self, toml: &str, rocket: Rocket<Build>) {
let rocket = self.configure(toml, rocket);
if let Err(e) = rocket::execute(rocket.launch()) {
let sender = IpcSender::<Message>::connect(self.0).unwrap();
let _ = sender.send(Message::Failure);
let _ = sender.send(Message::Failure);
e.pretty_print();
std::process::exit(1);
}
}
pub fn launch(self, rocket: Rocket<Build>) {
self.configured_launch(DEFAULT_CONFIG, rocket)
}
}
pub fn start(f: fn(Token)) -> Result<Client> {
static INIT: Once = Once::new();
INIT.call_once(procspawn::init);
let (ipc, server) = IpcOneShotServer::new()?;
let mut server = procspawn::Builder::new()
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn(Token(server), f);
let client = reqwest::blocking::Client::builder()
.danger_accept_invalid_certs(true) .danger_accept_invalid_certs(true)
.cookie_store(true) .cookie_store(true)
.tls_info(true) .tls_info(true)
.timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(5)) .connect_timeout(Duration::from_secs(5))
.build()?;
let (rx, _) = ipc.accept().unwrap();
match rx.recv() {
Ok(Message::Liftoff(tls, port)) => Ok(Client { client, server, tls, port, rx }),
Ok(Message::Failure) => {
let stdout = server.stdout().unwrap();
let mut out = String::new();
stdout.read_to_string(&mut out)?;
let stderr = server.stderr().unwrap();
let mut err = String::new();
stderr.read_to_string(&mut err)?;
Err(Error::Liftoff(out, err))
}
Err(e) => Err(e.into()),
} }
} pub fn get(&self, server: &Server, url: &str) -> Result<RequestBuilder> {
pub fn default() -> Result<Client> {
start(|token| token.launch(rocket::build()))
}
impl Client {
pub fn read_stdout(&mut self) -> Result<String> {
let Some(stdout) = self.server.stdout() else {
return Ok(String::new());
};
let mut string = String::new();
stdout.read_to_string(&mut string)?;
Ok(string)
}
pub fn read_stderr(&mut self) -> Result<String> {
let Some(stderr) = self.server.stderr() else {
return Ok(String::new());
};
let mut string = String::new();
stderr.read_to_string(&mut string)?;
Ok(string)
}
pub fn kill(&mut self) -> Result<()> {
Ok(self.server.kill()?)
}
pub fn terminate(&mut self) -> Result<()> {
use nix::{sys::signal, unistd::Pid};
let pid = Pid::from_raw(self.server.pid().unwrap() as i32);
Ok(signal::kill(pid, signal::SIGTERM)?)
}
pub fn wait(&mut self) -> Result<()> {
match self.server.join_timeout(Duration::from_secs(5)) {
Ok(_) => Ok(()),
Err(e) if e.is_remote_close() => Ok(()),
Err(e) => Err(e.into()),
}
}
pub fn get(&self, url: &str) -> Result<reqwest::blocking::RequestBuilder> {
let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? { let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? {
Uri::Origin(uri) => { Uri::Origin(uri) => {
let proto = if self.tls { "https" } else { "http" }; let proto = if server.tls { "https" } else { "http" };
let uri = format!("{proto}://127.0.0.1:{}{uri}", self.port); let uri = format!("{proto}://127.0.0.1:{}{uri}", server.port);
Absolute::parse_owned(uri)? Absolute::parse_owned(uri)?
} }
Uri::Absolute(uri) => uri, Uri::Absolute(mut uri) => {
_ => return Err(Error::InvalidUri), if let Some(auth) = uri.authority() {
let mut auth = auth.clone();
auth.set_port(server.port);
uri.set_authority(auth);
}
uri
}
uri => return Err(Error::InvalidUri(uri.into_owned())),
}; };
Ok(self.client.get(uri.to_string())) Ok(self.client.get(uri.to_string()))
} }
} }
impl From<reqwest::blocking::Client> for Client {
fn from(client: reqwest::blocking::Client) -> Self {
Client { client }
}
}
impl TryFrom<ClientBuilder> for Client {
type Error = Error;
fn try_from(builder: ClientBuilder) -> Result<Self, Self::Error> {
Ok(Client { client: builder.build()? })
}
}

View File

@ -1,3 +1,35 @@
pub mod client; // pub mod session;
mod client;
mod server;
pub use server::*;
pub use client::*; pub use client::*;
use std::io;
use thiserror::Error;
use procspawn::SpawnError;
use rocket::http::uri;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Error, Debug)]
pub enum Error {
#[error("join/kill failed: {0}")]
JoinError(#[from] SpawnError),
#[error("kill failed: {0}")]
TermFailure(#[from] nix::errno::Errno),
#[error("i/o error: {0}")]
Io(#[from] io::Error),
#[error("invalid URI: {0}")]
Uri(#[from] uri::Error<'static>),
#[error("invalid uri: {0}")]
InvalidUri(uri::Uri<'static>),
#[error("expected certificates are not present")]
MissingCertificate,
#[error("bad request: {0}")]
Request(#[from] reqwest::Error),
#[error("IPC failure: {0}")]
Ipc(#[from] ipc_channel::ipc::IpcError),
#[error("liftoff failed")]
Liftoff(String, String),
}

View File

@ -1,54 +1,93 @@
use rocket::{fairing::AdHoc, *}; use std::process::ExitCode;
use rocket_testbench::client::{self, Error}; use std::time::Duration;
use reqwest::tls::TlsInfo;
fn run() -> client::Result<()> { use rocket::listener::unix::UnixListener;
let mut client = client::start(|token| { use rocket::tokio::net::TcpListener;
#[get("/")] use rocket::yansi::Paint;
fn index() -> &'static str { use rocket::{get, routes, Build, Rocket, State};
"Hello, world!" use reqwest::{tls::TlsInfo, Identity};
} use testbench::*;
token.configured_launch(r#" static DEFAULT_CONFIG: &str = r#"
[default]
address = "tcp:127.0.0.1"
workers = 2
port = 0
cli_colors = false
secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg="
[default.shutdown]
grace = 1
mercy = 1
"#;
static TLS_CONFIG: &str = r#"
[default.tls] [default.tls]
certs = "{CRATE}/../examples/tls/private/rsa_sha256_cert.pem" certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem"
key = "{CRATE}/../examples/tls/private/rsa_sha256_key.pem" key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem"
"#, rocket::build().mount("/", routes![index])); "#;
})?;
let response = client.get("/")?.send()?; trait RocketExt {
let tls = response.extensions().get::<TlsInfo>().unwrap(); fn default() -> Self;
assert!(!tls.peer_certificate().unwrap().is_empty()); fn tls_default() -> Self;
assert_eq!(response.text()?, "Hello, world!"); fn configure_with_toml(self, toml: &str) -> Self;
client.terminate()?;
let stdout = client.read_stdout()?;
assert!(stdout.contains("Rocket has launched on https"));
assert!(stdout.contains("Graceful shutdown completed"));
assert!(stdout.contains("GET /"));
Ok(())
} }
fn run_fail() -> client::Result<()> { impl RocketExt for Rocket<Build> {
let client = client::start(|token| { fn default() -> Self {
let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) }); rocket::build().configure_with_toml(DEFAULT_CONFIG)
token.launch(rocket::build().attach(fail)); }
});
if let Err(Error::Liftoff(stdout, _)) = client { fn tls_default() -> Self {
rocket::build()
.configure_with_toml(DEFAULT_CONFIG)
.configure_with_toml(TLS_CONFIG)
}
fn configure_with_toml(self, toml: &str) -> Self {
use rocket::figment::{Figment, providers::{Format, Toml}};
let toml = toml.replace("{ROCKET}", rocket::fs::relative!("../"));
let config = Figment::from(self.figment())
.merge(Toml::string(&toml).nested());
self.configure(config)
}
}
fn read(path: &str) -> Result<Vec<u8>> {
let path = path.replace("{ROCKET}", rocket::fs::relative!("../"));
Ok(std::fs::read(path)?)
}
fn cert(path: &str) -> Result<Vec<u8>> {
let mut data = std::io::Cursor::new(read(path)?);
let cert = rustls_pemfile::certs(&mut data).last();
Ok(cert.ok_or(Error::MissingCertificate)??.to_vec())
}
fn run_fail() -> Result<()> {
use rocket::fairing::AdHoc;
let server = spawn! {
let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) });
Rocket::default().attach(fail)
};
if let Err(Error::Liftoff(stdout, _)) = server {
assert!(stdout.contains("Rocket failed to launch due to failing fairings")); assert!(stdout.contains("Rocket failed to launch due to failing fairings"));
assert!(stdout.contains("FailNow")); assert!(stdout.contains("FailNow"));
} else { } else {
panic!("unexpected result: {client:#?}"); panic!("unexpected result: {server:#?}");
} }
Ok(()) Ok(())
} }
fn infinite() -> client::Result<()> { fn infinite() -> Result<()> {
use rocket::response::stream::TextStream; use rocket::response::stream::TextStream;
let mut client = client::start(|token| { let mut server = spawn! {
#[get("/")] #[get("/")]
fn infinite() -> TextStream![&'static str] { fn infinite() -> TextStream![&'static str] {
TextStream! { TextStream! {
@ -58,37 +97,358 @@ fn infinite() -> client::Result<()> {
} }
} }
token.launch(rocket::build().mount("/", routes![infinite])); Rocket::default().mount("/", routes![infinite])
})?; }?;
client.get("/")?.send()?; let client = Client::default();
client.terminate()?; client.get(&server, "/")?.send()?;
let stdout = client.read_stdout()?; server.terminate()?;
let stdout = server.read_stdout()?;
assert!(stdout.contains("Rocket has launched on http")); assert!(stdout.contains("Rocket has launched on http"));
assert!(stdout.contains("GET /")); assert!(stdout.contains("GET /"));
assert!(stdout.contains("Graceful shutdown completed")); assert!(stdout.contains("Graceful shutdown completed"));
Ok(()) Ok(())
} }
fn main() { fn tls_info() -> Result<()> {
let names = ["run", "run_fail", "infinite"]; let mut server = spawn! {
let tests = [run, run_fail, infinite]; #[get("/")]
let handles = tests.into_iter() fn hello_world() -> &'static str {
.map(|test| std::thread::spawn(test)) "Hello, world!"
.collect::<Vec<_>>(); }
let mut failure = false; Rocket::tls_default().mount("/", routes![hello_world])
for (handle, name) in handles.into_iter().zip(names) { }?;
let result = handle.join();
failure = failure || matches!(result, Ok(Err(_)) | Err(_)); let client = Client::default();
match result { let response = client.get(&server, "/")?.send()?;
Ok(Ok(_)) => continue, let tls = response.extensions().get::<TlsInfo>().unwrap();
Ok(Err(e)) => eprintln!("{name} failed: {e}"), assert!(!tls.peer_certificate().unwrap().is_empty());
Err(_) => eprintln!("{name} failed (see panic above)"), assert_eq!(response.text()?, "Hello, world!");
server.terminate()?;
let stdout = server.read_stdout()?;
assert!(stdout.contains("Rocket has launched on https"));
assert!(stdout.contains("Graceful shutdown completed"));
assert!(stdout.contains("GET /"));
Ok(())
}
fn tls_resolver() -> Result<()> {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig};
struct CountingResolver {
config: Arc<ServerConfig>,
counter: Arc<AtomicUsize>,
}
#[rocket::async_trait]
impl Resolver for CountingResolver {
async fn init(rocket: &Rocket<Build>) -> rocket::tls::Result<Self> {
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
let config = Arc::new(config.server_config().await?);
let counter = rocket.state::<Arc<AtomicUsize>>().unwrap().clone();
Ok(Self { config, counter })
}
async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
self.counter.fetch_add(1, Ordering::Release);
Some(self.config.clone())
} }
} }
if failure { let server = spawn! {
std::process::exit(1); #[get("/count")]
fn count(counter: &State<Arc<AtomicUsize>>) -> String {
counter.load(Ordering::Acquire).to_string()
}
let counter = Arc::new(AtomicUsize::new(0));
Rocket::tls_default()
.manage(counter)
.mount("/", routes![count])
.attach(CountingResolver::fairing())
}?;
let client = Client::default();
let response = client.get(&server, "/count")?.send()?;
assert_eq!(response.text()?, "1");
// Use a new client so we get a new TLS session.
let client = Client::default();
let response = client.get(&server, "/count")?.send()?;
assert_eq!(response.text()?, "2");
Ok(())
}
fn test_mtls(mandatory: bool) -> Result<()> {
let server = spawn!(mandatory: bool => {
let mtls_config = format!(r#"
[default.tls.mutual]
ca_certs = "{{ROCKET}}/examples/tls/private/ca_cert.pem"
mandatory = {mandatory}
"#);
#[get("/")]
fn hello(cert: rocket::mtls::Certificate<'_>) -> String {
format!("{}:{}[{}] {}", cert.serial(), cert.version(), cert.issuer(), cert.subject())
}
#[get("/", rank = 2)]
fn hi() -> &'static str {
"Hello!"
}
Rocket::tls_default()
.configure_with_toml(&mtls_config)
.mount("/", routes![hello, hi])
})?;
let pem = read("{ROCKET}/examples/tls/private/client.pem")?;
let client: Client = Client::build()
.identity(Identity::from_pem(&pem)?)
.try_into()?;
let response = client.get(&server, "/")?.send()?;
assert_eq!(response.text()?,
"611895682361338926795452113263857440769284805738:2\
[C=US, ST=CA, O=Rocket CA, CN=Rocket Root CA] \
C=US, ST=California, L=Silicon Valley, O=Rocket, \
CN=Rocket TLS Example, Email=example@rocket.local");
let client = Client::default();
let response = client.get(&server, "/")?.send();
if mandatory {
assert!(response.unwrap_err().is_request());
} else {
assert_eq!(response?.text()?, "Hello!");
}
Ok(())
}
fn tls_mtls() -> Result<()> {
test_mtls(false)?;
test_mtls(true)
}
fn sni_resolver() -> Result<()> {
use std::sync::Arc;
use std::collections::HashMap;
use rocket::http::uri::Host;
use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig};
struct SniResolver {
default: Arc<ServerConfig>,
map: HashMap<Host<'static>, Arc<ServerConfig>>
}
#[rocket::async_trait]
impl Resolver for SniResolver {
async fn init(rocket: &Rocket<Build>) -> rocket::tls::Result<Self> {
let default: TlsConfig = rocket.figment().extract_inner("tls")?;
let sni: HashMap<Host<'_>, TlsConfig> = rocket.figment().extract_inner("tls.sni")?;
let default = Arc::new(default.server_config().await?);
let mut map = HashMap::new();
for (host, config) in sni {
let config = config.server_config().await?;
map.insert(host, Arc::new(config));
}
Ok(SniResolver { default, map })
}
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
if let Some(Ok(host)) = hello.server_name().map(Host::parse) {
if let Some(config) = self.map.get(&host) {
return Some(config.clone());
} }
} }
Some(self.default.clone())
}
}
static SNI_TLS_CONFIG: &str = r#"
[default.tls]
certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem"
key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem"
[default.tls.sni."sni1.dev"]
certs = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem"
key = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_key_pkcs8.pem"
[default.tls.sni."sni2.dev"]
certs = "{ROCKET}/examples/tls/private/ed25519_cert.pem"
key = "{ROCKET}/examples/tls/private/ed25519_key.pem"
"#;
let server = spawn! {
#[get("/")] fn index() { }
Rocket::default()
.configure_with_toml(SNI_TLS_CONFIG)
.mount("/", routes![index])
.attach(SniResolver::fairing())
}?;
let client: Client = Client::build()
.resolve("unknown.dev", server.socket_addr())
.resolve("sni1.dev", server.socket_addr())
.resolve("sni2.dev", server.socket_addr())
.try_into()?;
let response = client.get(&server, "https://unknown.dev")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
let expected = cert("{ROCKET}/examples/tls/private/rsa_sha256_cert.pem")?;
assert_eq!(tls.peer_certificate().unwrap(), expected);
let response = client.get(&server, "https://sni1.dev")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
let expected = cert("{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem")?;
assert_eq!(tls.peer_certificate().unwrap(), expected);
let response = client.get(&server, "https://sni2.dev")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
let expected = cert("{ROCKET}/examples/tls/private/ed25519_cert.pem")?;
assert_eq!(tls.peer_certificate().unwrap(), expected);
Ok(())
}
fn tcp_unix_listener_fail() -> Result<()> {
let server = spawn! {
Rocket::default().configure_with_toml("[default]\naddress = 123")
};
if let Err(Error::Liftoff(stdout, _)) = server {
assert!(stdout.contains("expected valid TCP (ip) or unix (path)"));
assert!(stdout.contains("default.address"));
} else {
panic!("unexpected result: {server:#?}");
}
let server = Server::spawn((), |(token, _)| {
let rocket = Rocket::default().configure_with_toml("[default]\naddress = \"unix:foo\"");
token.launch_with::<TcpListener>(rocket)
});
if let Err(Error::Liftoff(stdout, _)) = server {
assert!(stdout.contains("invalid tcp endpoint: unix:foo"));
} else {
panic!("unexpected result: {server:#?}");
}
let server = Server::spawn((), |(token, _)| {
token.launch_with::<UnixListener>(Rocket::default())
});
if let Err(Error::Liftoff(stdout, _)) = server {
assert!(stdout.contains("invalid unix endpoint: tcp:127.0.0.1:8000"));
} else {
panic!("unexpected result: {server:#?}");
}
Ok(())
}
macro_rules! tests {
($($f:ident),* $(,)?) => {[
$(Test {
name: stringify!($f),
run: |_: ()| $f().map_err(|e| e.to_string()),
}),*
]};
}
#[derive(Copy, Clone)]
struct Test {
name: &'static str,
run: fn(()) -> Result<(), String>,
}
static TESTS: &[Test] = &tests![
run_fail, infinite, tls_info, tls_resolver, tls_mtls, sni_resolver,
tcp_unix_listener_fail
];
fn main() -> ExitCode {
procspawn::init();
let filter = std::env::args().nth(1).unwrap_or_default();
let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter));
println!("running {}/{} tests", filtered.clone().count(), TESTS.len());
let handles = filtered.map(|test| (test, std::thread::spawn(|| {
let name = test.name;
let start = std::time::SystemTime::now();
let mut proc = procspawn::spawn((), test.run);
let result = loop {
match proc.join_timeout(Duration::from_secs(10)) {
Err(e) if e.is_timeout() => {
let elapsed = start.elapsed().unwrap().as_secs();
println!("{name} has been running for {elapsed} seconds...");
if elapsed >= 30 {
println!("{name} timeout");
break Err(e);
}
},
result => break result,
}
};
match result.as_ref().map_err(|e| e.panic_info()) {
Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()),
Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()),
Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()),
Err(None) => println!("test {name} ... {}", "error".magenta()),
}
matches!(result, Ok(Ok(())))
})));
let mut success = true;
for (_, handle) in handles {
success &= handle.join().unwrap_or(false);
}
match success {
true => ExitCode::SUCCESS,
false => {
println!("note: use `NOCAPTURE=1` to see test output");
ExitCode::FAILURE
}
}
}
// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and
// `UpdatingResolver` in a `contrib` library or as part of `rocket`.
//
// struct UpdatingResolver {
// timestamp: AtomicU64,
// config: ArcSwap<ServerConfig>
// }
//
// #[crate::async_trait]
// impl Resolver for UpdatingResolver {
// async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
// if let Either::Left(path) = self.tls_config.certs() {
// let metadata = tokio::fs::metadata(&path).await.ok()?;
// let modtime = metadata.modified().ok()?;
// let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs();
// let old_timestamp = self.timestamp.load(Ordering::Acquire);
// if timestamp > old_timestamp {
// let new_config = self.tls_config.to_server_config().await.ok()?;
// self.server_config.store(Arc::new(new_config));
// self.timestamp.store(timestamp, Ordering::Release);
// }
// }
//
// Some(self.server_config.load_full())
// }
// }

168
testbench/src/server.rs Normal file
View File

@ -0,0 +1,168 @@
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use std::sync::Once;
use std::process::Stdio;
use std::io::Read;
use rocket::fairing::AdHoc;
use rocket::listener::{Bind, DefaultListener};
use rocket::serde::{Deserialize, DeserializeOwned, Serialize};
use rocket::{Build, Ignite, Rocket};
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender};
use crate::{Result, Error};
#[derive(Debug)]
pub struct Server {
proc: procspawn::JoinHandle<Launched>,
pub tls: bool,
pub port: u16,
_rx: IpcReceiver<Message>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub enum Message {
Liftoff(bool, u16),
Failure,
}
#[derive(Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct Token(String);
#[derive(Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct Launched(());
fn stdio() -> Stdio {
std::env::var_os("NOCAPTURE")
.map(|_| Stdio::inherit())
.unwrap_or_else(Stdio::piped)
}
fn read<T: Read>(io: Option<T>) -> Result<String> {
if let Some(mut io) = io {
let mut string = String::new();
io.read_to_string(&mut string)?;
return Ok(string);
}
Ok(String::new())
}
impl Server {
pub fn spawn<T>(ctxt: T, f: fn((Token, T)) -> Launched) -> Result<Server>
where T: Serialize + DeserializeOwned
{
static INIT: Once = Once::new();
INIT.call_once(procspawn::init);
let (ipc, server) = IpcOneShotServer::new()?;
let mut proc = procspawn::Builder::new()
.stdin(Stdio::null())
.stdout(stdio())
.stderr(stdio())
.spawn((Token(server), ctxt), f);
let (rx, _) = ipc.accept().unwrap();
match rx.recv()? {
Message::Liftoff(tls, port) => {
Ok(Server { proc, tls, port, _rx: rx })
},
Message::Failure => {
Err(Error::Liftoff(read(proc.stdout())?, read(proc.stderr())?))
}
}
}
pub fn socket_addr(&self) -> SocketAddr {
let ip = Ipv4Addr::LOCALHOST;
SocketAddr::new(ip.into(), self.port)
}
pub fn read_stdout(&mut self) -> Result<String> {
read(self.proc.stdout())
}
pub fn read_stderr(&mut self) -> Result<String> {
read(self.proc.stderr())
}
pub fn kill(&mut self) -> Result<()> {
Ok(self.proc.kill()?)
}
pub fn terminate(&mut self) -> Result<()> {
use nix::{sys::signal, unistd::Pid};
let pid = Pid::from_raw(self.proc.pid().unwrap() as i32);
Ok(signal::kill(pid, signal::SIGTERM)?)
}
pub fn join(&mut self, duration: Duration) -> Result<()> {
match self.proc.join_timeout(duration) {
Ok(_) => Ok(()),
Err(e) if e.is_remote_close() => Ok(()),
Err(e) => Err(e.into()),
}
}
}
impl Token {
pub fn launch_with<B>(self, rocket: Rocket<Build>) -> Launched
where B: for<'r> Bind<&'r Rocket<Ignite>> + Sync + Send + 'static
{
let server = self.0.clone();
let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap();
let tls = rocket.endpoints().any(|e| e.is_tls());
let sender = IpcSender::<Message>::connect(server).unwrap();
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
})));
let server = self.0.clone();
if let Err(e) = rocket::execute(rocket.launch_with::<B>()) {
let sender = IpcSender::<Message>::connect(server).unwrap();
let _ = sender.send(Message::Failure);
let _ = sender.send(Message::Failure);
e.pretty_print();
std::process::exit(1);
}
Launched(())
}
pub fn launch(self, rocket: Rocket<Build>) -> Launched {
self.launch_with::<DefaultListener>(rocket)
}
}
impl Drop for Server {
fn drop(&mut self) {
let _ = self.terminate();
if self.join(Duration::from_secs(3)).is_err() {
let _ = self.kill();
}
}
}
#[macro_export]
macro_rules! spawn {
($($arg:ident : $t:ty),* => $rocket:block) => {{
#[allow(unused_parens)]
fn _server((token, $($arg),*): ($crate::Token, $($t),*)) -> $crate::Launched {
let rocket: rocket::Rocket<rocket::Build> = $rocket;
token.launch(rocket)
}
Server::spawn(($($arg),*), _server)
}};
($($token:tt)*) => {{
let _unit = ();
spawn!(_unit: () => { $($token)* } )
}};
}