mirror of https://github.com/rwf2/Rocket.git
Introduce dynamic TLS resolvers.
This commit introduces the ability to dynamically select a TLS configuration based on the client's TLS hello via the new `Resolver` trait. In support of this, it also makes the following changes: * Added `Authority::set_port()`. * `UdsListener` is now `UnixListener`. * `Bindable` removed in favor of new `Bind`. * All built-in listeners now implement `Bind<&Rocket>`. * `Connection` requires `AsyncRead + AsyncWrite`. * The `Debug` impl for `Endpoint` displays the underlying address. * `Listener` must be `Sized`. * The TLS listener was moved to `tls::TlsListener`. * The preview `quic` listener no longer implements `Listener`. * Added `TlsConfig::server_config()`. * Added `race` future helpers. * Added `Rocket::launch_with()`, `Rocket::bind_launch()`. * Added a default `client.pem` to the TLS example. * Various unnecessary listener `Config` structures removed. In addition, the testbench was revamped to support more scenarios. This resulted in the following issues being found and fixed: * Fix an issue where the logger would ignore color requests. * Clarified docs for `mtls::Certificate` guard. * Improved error messages on listener misconfiguration. Resolves #2730. Resolves #2363. Closes #2748. Closes #2683. Closes #2577.
This commit is contained in:
parent
60f3cd57b0
commit
7cc818cd85
|
@ -117,6 +117,8 @@
|
|||
//! to an `Object` (a dictionary) value. The [`context!`] macro can be used to
|
||||
//! create inline `Serialize`-able context objects.
|
||||
//!
|
||||
//! [`Serialize`]: rocket::serde::Serialize
|
||||
//!
|
||||
//! ```rust
|
||||
//! # #[macro_use] extern crate rocket;
|
||||
//! use rocket::serde::Serialize;
|
||||
|
@ -165,7 +167,7 @@
|
|||
//! builds, template reloading is disabled to improve performance and cannot be
|
||||
//! enabled.
|
||||
//!
|
||||
//! [attached]: Rocket::attach()
|
||||
//! [attached]: rocket::Rocket::attach()
|
||||
//!
|
||||
//! ### Metadata and Rendering to `String`
|
||||
//!
|
||||
|
|
|
@ -140,11 +140,12 @@ impl Template {
|
|||
}
|
||||
|
||||
/// Render the template named `name` with the context `context`. The
|
||||
/// `context` is typically created using the [`context!`] macro, but it can
|
||||
/// be of any type that implements `Serialize`, such as `HashMap` or a
|
||||
/// custom `struct`.
|
||||
/// `context` is typically created using the [`context!()`](crate::context!)
|
||||
/// macro, but it can be of any type that implements `Serialize`, such as
|
||||
/// `HashMap` or a custom `struct`.
|
||||
///
|
||||
/// To render a template directly into a string, use [`Metadata::render()`].
|
||||
/// To render a template directly into a string, use
|
||||
/// [`Metadata::render()`](crate::Metadata::render()).
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
|
@ -291,8 +292,8 @@ impl Sentinel for Template {
|
|||
/// A macro to easily create a template rendering context.
|
||||
///
|
||||
/// Invocations of this macro expand to a value of an anonymous type which
|
||||
/// implements [`serde::Serialize`]. Fields can be literal expressions or
|
||||
/// variables captured from a surrounding scope, as long as all fields implement
|
||||
/// implements [`Serialize`]. Fields can be literal expressions or variables
|
||||
/// captured from a surrounding scope, as long as all fields implement
|
||||
/// `Serialize`.
|
||||
///
|
||||
/// # Examples
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#[cfg(all(feature = "diesel_sqlite_pool"))]
|
||||
#[cfg(test)]
|
||||
#[cfg(all(feature = "diesel_sqlite_pool"))]
|
||||
mod sqlite_shutdown_test {
|
||||
use rocket::{async_test, Build, Rocket};
|
||||
use rocket_sync_db_pools::database;
|
||||
|
|
|
@ -185,7 +185,7 @@ impl<'a> Authority<'a> {
|
|||
self.host.from_cow_source(&self.source)
|
||||
}
|
||||
|
||||
/// Returns the port part of the authority URI, if there is one.
|
||||
/// Returns the `port` part of the authority URI, if there is one.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -206,6 +206,28 @@ impl<'a> Authority<'a> {
|
|||
pub fn port(&self) -> Option<u16> {
|
||||
self.port
|
||||
}
|
||||
|
||||
/// Set the `port` of the authority URI.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// let mut uri = uri!("username:password@host:123");
|
||||
/// assert_eq!(uri.port(), Some(123));
|
||||
///
|
||||
/// uri.set_port(1024);
|
||||
/// assert_eq!(uri.port(), Some(1024));
|
||||
/// assert_eq!(uri, "username:password@host:1024");
|
||||
///
|
||||
/// uri.set_port(None);
|
||||
/// assert_eq!(uri.port(), None);
|
||||
/// assert_eq!(uri, "username:password@host");
|
||||
/// ```
|
||||
#[inline(always)]
|
||||
pub fn set_port<T: Into<Option<u16>>>(&mut self, port: T) {
|
||||
self.port = port.into();
|
||||
}
|
||||
}
|
||||
|
||||
impl_serde!(Authority<'a>, "an authority-form URI");
|
||||
|
|
|
@ -69,7 +69,7 @@ ref-swap = "0.1.2"
|
|||
parking_lot = "0.12"
|
||||
ubyte = {version = "0.10.2", features = ["serde"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
figment = { version = "0.10.13", features = ["toml", "env"] }
|
||||
figment = { version = "0.10.17", features = ["toml", "env"] }
|
||||
rand = "0.8"
|
||||
either = "1"
|
||||
pin-project-lite = "0.2"
|
||||
|
@ -140,5 +140,5 @@ version_check = "0.9.1"
|
|||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["macros", "io-std"] }
|
||||
figment = { version = "0.10", features = ["test"] }
|
||||
figment = { version = "0.10.17", features = ["test"] }
|
||||
pretty_assertions = "1"
|
||||
|
|
|
@ -137,9 +137,6 @@ mod secret_key;
|
|||
#[cfg(unix)]
|
||||
pub use crate::shutdown::Sig;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use crate::listener::unix::UdsConfig;
|
||||
|
||||
#[cfg(feature = "secrets")]
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
|
|
|
@ -178,13 +178,17 @@ impl Error {
|
|||
self.mark_handled();
|
||||
match self.kind() {
|
||||
ErrorKind::Bind(ref a, ref e) => {
|
||||
match a {
|
||||
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
|
||||
None => error!("Binding to network interface failed."),
|
||||
}
|
||||
if let Some(e) = e.downcast_ref::<Self>() {
|
||||
e.pretty_print()
|
||||
} else {
|
||||
match a {
|
||||
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
|
||||
None => error!("Binding to network interface failed."),
|
||||
}
|
||||
|
||||
info_!("{}", e);
|
||||
"aborting due to bind error"
|
||||
info_!("{}", e);
|
||||
"aborting due to bind error"
|
||||
}
|
||||
}
|
||||
ErrorKind::Io(ref e) => {
|
||||
error!("Rocket failed to launch due to an I/O error.");
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
use crate::listener::{Endpoint, Listener};
|
||||
|
||||
pub trait Bind<T>: Listener + 'static {
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
|
||||
#[crate::async_bound(Send)]
|
||||
async fn bind(to: T) -> Result<Self, Self::Error>;
|
||||
|
||||
fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
use std::io;
|
||||
use futures::TryFutureExt;
|
||||
|
||||
use crate::listener::{Listener, Endpoint};
|
||||
|
||||
pub trait Bindable: Sized {
|
||||
type Listener: Listener + 'static;
|
||||
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error>;
|
||||
|
||||
/// The endpoint that `self` binds on.
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint>;
|
||||
}
|
||||
|
||||
impl<L: Listener + 'static> Bindable for L {
|
||||
type Listener = L;
|
||||
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint> {
|
||||
L::endpoint(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
|
||||
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;
|
||||
|
||||
type Error = either::Either<A::Error, B::Error>;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
match self {
|
||||
either::Either::Left(a) => a.bind()
|
||||
.map_ok(tokio_util::either::Either::Left)
|
||||
.map_err(either::Either::Left)
|
||||
.await,
|
||||
either::Either::Right(b) => b.bind()
|
||||
.map_ok(tokio_util::either::Either::Right)
|
||||
.map_err(either::Either::Right)
|
||||
.await,
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint> {
|
||||
either::for_both!(self, a => a.bind_endpoint())
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
use std::io;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::either::Either;
|
||||
|
||||
use super::Endpoint;
|
||||
|
@ -9,7 +10,7 @@ use super::Endpoint;
|
|||
#[derive(Clone)]
|
||||
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
|
||||
|
||||
pub trait Connection: Send + Unpin {
|
||||
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
|
||||
fn endpoint(&self) -> io::Result<Endpoint>;
|
||||
|
||||
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
||||
|
|
|
@ -1,64 +1,190 @@
|
|||
use core::fmt;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tokio_util::either::Either::{Left, Right};
|
||||
use either::Either;
|
||||
|
||||
use crate::listener::{Bindable, Endpoint};
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::{Ignite, Rocket};
|
||||
use crate::listener::{Bind, Endpoint, tcp::TcpListener};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct DefaultListener {
|
||||
#[cfg(unix)] use crate::listener::unix::UnixListener;
|
||||
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};
|
||||
|
||||
mod private {
|
||||
use super::*;
|
||||
use tokio_util::either::Either;
|
||||
|
||||
#[cfg(feature = "tls")] type TlsListener<T> = super::TlsListener<T>;
|
||||
#[cfg(not(feature = "tls"))] type TlsListener<T> = T;
|
||||
#[cfg(unix)] type UnixListener = super::UnixListener;
|
||||
#[cfg(not(unix))] type UnixListener = TcpListener;
|
||||
|
||||
pub type Listener = Either<
|
||||
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
|
||||
Either<TcpListener, UnixListener>,
|
||||
>;
|
||||
|
||||
/// The default connection listener.
|
||||
///
|
||||
/// # Configuration
|
||||
///
|
||||
/// Reads the following optional configuration parameters:
|
||||
///
|
||||
/// | parameter | type | default |
|
||||
/// | ----------- | ----------------- | --------------------- |
|
||||
/// | `address` | [`Endpoint`] | `tcp:127.0.0.1:8000` |
|
||||
/// | `tls` | [`TlsConfig`] | None |
|
||||
/// | `reuse` | boolean | `true` |
|
||||
///
|
||||
/// # Listener
|
||||
///
|
||||
/// Based on the above configuration, this listener defers to one of the
|
||||
/// following existing listeners:
|
||||
///
|
||||
/// | listener | `address` type | `tls` enabled |
|
||||
/// |-------------------------------|--------------------|---------------|
|
||||
/// | [`TcpListener`] | [`Endpoint::Tcp`] | no |
|
||||
/// | [`UnixListener`] | [`Endpoint::Unix`] | no |
|
||||
/// | [`TlsListener<TcpListener>`] | [`Endpoint::Tcp`] | yes |
|
||||
/// | [`TlsListener<UnixListener>`] | [`Endpoint::Unix`] | yes |
|
||||
///
|
||||
/// [`UnixListener`]: crate::listener::unix::UnixListener
|
||||
/// [`TlsListener<TcpListener>`]: crate::tls::TlsListener
|
||||
/// [`TlsListener<UnixListener>`]: crate::tls::TlsListener
|
||||
///
|
||||
/// * **address type** is the variant the `address` parameter parses as.
|
||||
/// * **`tls` enabled** is `yes` when the `tls` feature is enabled _and_ a
|
||||
/// `tls` configuration is provided.
|
||||
#[cfg(doc)]
|
||||
pub struct DefaultListener(());
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
#[serde(default)]
|
||||
pub address: Endpoint,
|
||||
pub port: Option<u16>,
|
||||
pub reuse: Option<bool>,
|
||||
address: Endpoint,
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls: Option<crate::tls::TlsConfig>,
|
||||
tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
|
||||
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
|
||||
#[cfg(doc)]
|
||||
pub use private::DefaultListener;
|
||||
|
||||
#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
|
||||
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
|
||||
#[cfg(doc)]
|
||||
type Connection = crate::listener::tcp::TcpStream;
|
||||
|
||||
impl DefaultListener {
|
||||
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
|
||||
match &self.address {
|
||||
Endpoint::Tcp(mut address) => {
|
||||
if let Some(port) = self.port {
|
||||
address.set_port(port);
|
||||
}
|
||||
#[cfg(doc)]
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
|
||||
type Error = Error;
|
||||
async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
|
||||
fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
|
||||
}
|
||||
|
||||
Ok(BaseBindable::Left(address))
|
||||
},
|
||||
#[cfg(unix)]
|
||||
Endpoint::Unix(path) => {
|
||||
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
|
||||
Ok(BaseBindable::Right(uds))
|
||||
},
|
||||
#[cfg(not(unix))]
|
||||
e@Endpoint::Unix(_) => {
|
||||
let msg = "Unix domain sockets unavailable on non-unix platforms.";
|
||||
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
|
||||
},
|
||||
other => {
|
||||
let msg = format!("unsupported default listener address: {other}");
|
||||
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
|
||||
#[cfg(doc)]
|
||||
impl super::Listener for DefaultListener {
|
||||
#[doc(hidden)] type Accept = Connection;
|
||||
#[doc(hidden)] type Connection = Connection;
|
||||
#[doc(hidden)]
|
||||
async fn accept(&self) -> std::io::Result<Connection> { unreachable!() }
|
||||
#[doc(hidden)]
|
||||
async fn connect(&self, _: Self::Accept) -> std::io::Result<Connection> { unreachable!() }
|
||||
#[doc(hidden)]
|
||||
fn endpoint(&self) -> std::io::Result<Endpoint> { unreachable!() }
|
||||
}
|
||||
|
||||
#[cfg(not(doc))]
|
||||
pub type DefaultListener = private::Listener;
|
||||
|
||||
#[cfg(not(doc))]
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let config: Config = rocket.figment().extract()?;
|
||||
match config.address {
|
||||
#[cfg(feature = "tls")]
|
||||
Endpoint::Tcp(_) if config.tls.is_some() => {
|
||||
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
|
||||
Ok(Left(Left(listener)))
|
||||
}
|
||||
Endpoint::Tcp(_) => {
|
||||
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
|
||||
Ok(Right(Left(listener)))
|
||||
}
|
||||
#[cfg(all(unix, feature = "tls"))]
|
||||
Endpoint::Unix(_) if config.tls.is_some() => {
|
||||
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
|
||||
Ok(Left(Right(listener)))
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Endpoint::Unix(_) => {
|
||||
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
|
||||
Ok(Right(Right(listener)))
|
||||
}
|
||||
endpoint => Err(Error::Unsupported(endpoint)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
|
||||
#[cfg(feature = "tls")]
|
||||
if let Some(tls) = self.tls.clone() {
|
||||
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
|
||||
}
|
||||
|
||||
TlsBindable::Right(inner)
|
||||
}
|
||||
|
||||
pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
|
||||
self.base_bindable()
|
||||
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let config: Config = rocket.figment().extract()?;
|
||||
Ok(config.address)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Config(figment::Error),
|
||||
Io(std::io::Error),
|
||||
Unsupported(Endpoint),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(crate::tls::Error),
|
||||
}
|
||||
|
||||
impl From<figment::Error> for Error {
|
||||
fn from(value: figment::Error) -> Self {
|
||||
Error::Config(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Error::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
impl From<crate::tls::Error> for Error {
|
||||
fn from(value: crate::tls::Error) -> Self {
|
||||
Error::Tls(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Either<figment::Error, std::io::Error>> for Error {
|
||||
fn from(value: Either<figment::Error, std::io::Error>) -> Self {
|
||||
value.either(Error::Config, Error::Io)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Error::Config(e) => e.fmt(f),
|
||||
Error::Io(e) => e.fmt(f),
|
||||
Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
|
||||
#[cfg(feature = "tls")]
|
||||
Error::Tls(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Error::Config(e) => Some(e),
|
||||
Error::Io(e) => Some(e),
|
||||
Error::Unsupported(_) => None,
|
||||
#[cfg(feature = "tls")]
|
||||
Error::Tls(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::path::{Path, PathBuf};
|
|||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use figment::Figment;
|
||||
use serde::de;
|
||||
|
||||
use crate::http::uncased::AsUncased;
|
||||
|
@ -12,27 +13,43 @@ use crate::http::uncased::AsUncased;
|
|||
#[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
|
||||
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
|
||||
|
||||
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
|
||||
pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { }
|
||||
|
||||
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
|
||||
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> CustomEndpoint for T {}
|
||||
|
||||
/// # Conversions
|
||||
///
|
||||
/// * [`&str`] - parse with [`FromStr`]
|
||||
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
|
||||
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
|
||||
#[derive(Debug, Clone)]
|
||||
///
|
||||
/// # Syntax
|
||||
///
|
||||
/// The string syntax is:
|
||||
///
|
||||
/// ```text
|
||||
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
|
||||
/// socket := IP_ADDR | SOCKET_ADDR
|
||||
/// path := PATH
|
||||
///
|
||||
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
|
||||
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
|
||||
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
|
||||
/// ```
|
||||
///
|
||||
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
|
||||
#[derive(Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Endpoint {
|
||||
Tcp(net::SocketAddr),
|
||||
Quic(net::SocketAddr),
|
||||
Unix(PathBuf),
|
||||
Tls(Arc<Endpoint>, TlsInfo),
|
||||
Custom(Arc<dyn EndpointAddr>),
|
||||
Custom(Arc<dyn CustomEndpoint>),
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub fn new<T: EndpointAddr>(value: T) -> Endpoint {
|
||||
pub fn new<T: CustomEndpoint>(value: T) -> Endpoint {
|
||||
Endpoint::Custom(Arc::new(value))
|
||||
}
|
||||
|
||||
|
@ -152,6 +169,29 @@ impl Endpoint {
|
|||
|
||||
Self::Tls(Arc::new(self), None)
|
||||
}
|
||||
|
||||
/// Fetch the endpoint at `path` in `figment` of kind `kind` (e.g, "tcp")
|
||||
/// then map the value using `f(Some(value))` if present and `f(None)` if
|
||||
/// missing into a different value of typr `T`.
|
||||
///
|
||||
/// If the conversion succeeds, returns `Ok(value)`. If the conversion fails
|
||||
/// and `Some` value was passed in, returns an error indicating the endpoint
|
||||
/// was an invalid `kind` and otherwise returns a "missing field" error.
|
||||
pub(crate) fn fetch<T, F>(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result<T>
|
||||
where F: FnOnce(Option<&Endpoint>) -> Option<T>
|
||||
{
|
||||
match figment.extract_inner::<Endpoint>(path) {
|
||||
Ok(endpoint) => f(Some(&endpoint)).ok_or_else(|| {
|
||||
let msg = format!("invalid {kind} endpoint: {endpoint:?}");
|
||||
let mut error = figment::Error::from(msg).with_path(path);
|
||||
error.profile = Some(figment.profile().clone());
|
||||
error.metadata = figment.find_metadata(path).cloned();
|
||||
error
|
||||
}),
|
||||
Err(e) if e.missing() => f(None).ok_or(e),
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Endpoint {
|
||||
|
@ -180,28 +220,15 @@ impl fmt::Display for Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for Endpoint {
|
||||
fn from(value: PathBuf) -> Self {
|
||||
Self::Unix(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
|
||||
v.as_pathname()
|
||||
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
|
||||
.map(|path| Endpoint::Unix(path.to_path_buf()))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Endpoint {
|
||||
type Error = AddrParseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
value.parse()
|
||||
impl fmt::Debug for Endpoint {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Tcp(a) => write!(f, "tcp:{a}"),
|
||||
Self::Quic(a) => write!(f, "quic:{a}]"),
|
||||
Self::Unix(a) => write!(f, "unix:{}", a.display()),
|
||||
Self::Tls(e, _) => write!(f, "unix:{:?}", &**e),
|
||||
Self::Custom(e) => e.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,21 +238,6 @@ impl Default for Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parses an address into a `Endpoint`.
|
||||
///
|
||||
/// The syntax is:
|
||||
///
|
||||
/// ```text
|
||||
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
|
||||
/// socket := IP_ADDR | SOCKET_ADDR
|
||||
/// path := PATH
|
||||
///
|
||||
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
|
||||
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
|
||||
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
|
||||
/// ```
|
||||
///
|
||||
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
|
||||
impl FromStr for Endpoint {
|
||||
type Err = AddrParseError;
|
||||
|
||||
|
@ -237,8 +249,6 @@ impl FromStr for Endpoint {
|
|||
if let Some((proto, string)) = string.split_once(':') {
|
||||
if proto.trim().as_uncased() == "tcp" {
|
||||
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
|
||||
} else if proto.trim().as_uncased() == "quic" {
|
||||
return parse_tcp(string.trim(), 8000).map(Self::Quic);
|
||||
} else if proto.trim().as_uncased() == "unix" {
|
||||
return Ok(Self::Unix(PathBuf::from(string.trim())));
|
||||
}
|
||||
|
@ -256,7 +266,7 @@ impl<'de> de::Deserialize<'de> for Endpoint {
|
|||
type Value = Endpoint;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter.write_str("TCP or Unix address")
|
||||
formatter.write_str("valid TCP (ip) or unix (path) endpoint")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||
|
@ -294,3 +304,37 @@ impl PartialEq<Path> for Endpoint {
|
|||
self.unix() == Some(other)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
|
||||
v.as_pathname()
|
||||
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
|
||||
.map(|path| Endpoint::Unix(path.to_path_buf()))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Endpoint {
|
||||
type Error = AddrParseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
value.parse()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from {
|
||||
($T:ty => $V:ident) => {
|
||||
impl From<$T> for Endpoint {
|
||||
fn from(value: $T) -> Self {
|
||||
Self::$V(value.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_from!(std::net::SocketAddr => Tcp);
|
||||
impl_from!(std::net::SocketAddrV4 => Tcp);
|
||||
impl_from!(std::net::SocketAddrV6 => Tcp);
|
||||
impl_from!(PathBuf => Unix);
|
||||
|
|
|
@ -5,7 +5,7 @@ use tokio_util::either::Either;
|
|||
|
||||
use crate::listener::{Connection, Endpoint};
|
||||
|
||||
pub trait Listener: Send + Sync {
|
||||
pub trait Listener: Sized + Send + Sync {
|
||||
type Accept: Send;
|
||||
|
||||
type Connection: Connection;
|
||||
|
|
|
@ -3,15 +3,12 @@ mod bounced;
|
|||
mod listener;
|
||||
mod endpoint;
|
||||
mod connection;
|
||||
mod bindable;
|
||||
mod bind;
|
||||
mod default;
|
||||
|
||||
#[cfg(unix)]
|
||||
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||
pub mod unix;
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||
pub mod tls;
|
||||
pub mod tcp;
|
||||
#[cfg(feature = "http3-preview")]
|
||||
pub mod quic;
|
||||
|
@ -19,7 +16,7 @@ pub mod quic;
|
|||
pub use endpoint::*;
|
||||
pub use listener::*;
|
||||
pub use connection::*;
|
||||
pub use bindable::*;
|
||||
pub use bind::*;
|
||||
pub use default::*;
|
||||
|
||||
pub(crate) use cancellable::*;
|
||||
|
|
|
@ -38,7 +38,7 @@ use tokio::sync::Mutex;
|
|||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::tls::{TlsConfig, Error};
|
||||
use crate::listener::{Listener, Connection, Endpoint};
|
||||
use crate::listener::Endpoint;
|
||||
|
||||
type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;
|
||||
|
||||
|
@ -51,14 +51,16 @@ pub struct QuicListener {
|
|||
pub struct H3Stream(H3Conn);
|
||||
|
||||
pub struct H3Connection {
|
||||
pub handle: quic::connection::Handle,
|
||||
pub parts: http::request::Parts,
|
||||
pub tx: QuicTx,
|
||||
pub rx: QuicRx,
|
||||
pub(crate) handle: quic::connection::Handle,
|
||||
pub(crate) parts: http::request::Parts,
|
||||
pub(crate) tx: QuicTx,
|
||||
pub(crate) rx: QuicRx,
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct QuicRx(h3::server::RequestStream<quic_h3::RecvStream, Bytes>);
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct QuicTx(h3::server::RequestStream<quic_h3::SendStream<Bytes>, Bytes>);
|
||||
|
||||
impl QuicListener {
|
||||
|
@ -94,25 +96,20 @@ impl QuicListener {
|
|||
}
|
||||
}
|
||||
|
||||
impl Listener for QuicListener {
|
||||
type Accept = quic::Connection;
|
||||
|
||||
type Connection = H3Stream;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
impl QuicListener {
|
||||
pub async fn accept(&self) -> Option<quic::Connection> {
|
||||
self.listener
|
||||
.lock().await
|
||||
.accept().await
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed"))
|
||||
}
|
||||
|
||||
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||
pub async fn connect(&self, accept: quic::Connection) -> io::Result<H3Stream> {
|
||||
let quic_conn = quic_h3::Connection::new(accept);
|
||||
let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
|
||||
Ok(H3Stream(conn))
|
||||
}
|
||||
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
pub fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
|
||||
}
|
||||
}
|
||||
|
@ -159,16 +156,8 @@ impl QuicTx {
|
|||
}
|
||||
|
||||
// FIXME: Expose certificates when possible.
|
||||
impl Connection for H3Stream {
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
let addr = self.0.inner.conn.handle().remote_addr()?;
|
||||
Ok(Endpoint::Quic(addr).assume_tls())
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: Expose certificates when possible.
|
||||
impl Connection for H3Connection {
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
impl H3Connection {
|
||||
pub fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
let addr = self.handle.remote_addr()?;
|
||||
Ok(Endpoint::Quic(addr).assume_tls())
|
||||
}
|
||||
|
|
|
@ -1,21 +1,61 @@
|
|||
//! TCP listener.
|
||||
//!
|
||||
//! # Configuration
|
||||
//!
|
||||
//! Reads the following configuration parameters:
|
||||
//!
|
||||
//! | parameter | type | default | note |
|
||||
//! |-----------|--------------|-------------|---------------------------------|
|
||||
//! | `address` | [`Endpoint`] | `127.0.0.1` | must be `tcp:ip` |
|
||||
//! | `port` | `u16` | `8000` | replaces the port in `address ` |
|
||||
|
||||
use std::io;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
|
||||
use either::{Either, Left, Right};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||
use crate::{Ignite, Rocket};
|
||||
use crate::listener::{Bind, Connection, Endpoint, Listener};
|
||||
|
||||
impl Bindable for std::net::SocketAddr {
|
||||
type Listener = TcpListener;
|
||||
impl Bind<SocketAddr> for TcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
TcpListener::bind(self).await
|
||||
async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
|
||||
Self::bind(addr).await
|
||||
}
|
||||
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(Endpoint::Tcp(*self))
|
||||
fn bind_endpoint(addr: &SocketAddr) -> Result<Endpoint, Self::Error> {
|
||||
Ok(Endpoint::Tcp(*addr))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
|
||||
type Error = Either<figment::Error, io::Error>;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let endpoint = Self::bind_endpoint(&rocket)?;
|
||||
let addr = endpoint.tcp()
|
||||
.ok_or_else(|| io::Error::other("internal error: invalid endpoint"))
|
||||
.map_err(Right)?;
|
||||
|
||||
Self::bind(addr).await.map_err(Right)
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let figment = rocket.figment();
|
||||
let mut address = Endpoint::fetch(figment, "tcp", "address", |e| {
|
||||
let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000);
|
||||
e.map(|e| e.tcp()).unwrap_or(Some(default))
|
||||
}).map_err(Left)?;
|
||||
|
||||
if let Some(port) = figment.extract_inner("port").map_err(Left)? {
|
||||
address.set_port(port);
|
||||
}
|
||||
|
||||
Ok(Endpoint::Tcp(address))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::Deserialize;
|
||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::tls::{TlsConfig, Error};
|
||||
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// A TLS listener over some listener interface L.
|
||||
pub struct TlsListener<L> {
|
||||
listener: L,
|
||||
acceptor: TlsAcceptor,
|
||||
config: TlsConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct TlsBindable<I> {
|
||||
#[serde(flatten)]
|
||||
pub inner: I,
|
||||
pub tls: TlsConfig,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
|
||||
let provider = Arc::new(self.default_crypto_provider());
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
let verifier = match self.mutual {
|
||||
Some(ref mtls) => {
|
||||
let ca = Arc::new(mtls.load_ca_certs()?);
|
||||
let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
|
||||
match mtls.mandatory {
|
||||
true => verifier.build()?,
|
||||
false => verifier.allow_unauthenticated().build()?,
|
||||
}
|
||||
},
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
let verifier = WebPkiClientVerifier::no_client_auth();
|
||||
|
||||
let mut tls_config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(self.load_certs()?, self.load_key()?)?;
|
||||
|
||||
tls_config.ignore_client_order = self.prefer_server_cipher_order;
|
||||
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
if cfg!(feature = "http2") {
|
||||
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
}
|
||||
|
||||
Ok(tls_config)
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Bindable> Bindable for TlsBindable<I>
|
||||
where I::Listener: Listener<Accept = <I::Listener as Listener>::Connection>,
|
||||
<I::Listener as Listener>::Connection: AsyncRead + AsyncWrite
|
||||
{
|
||||
type Listener = TlsListener<I::Listener>;
|
||||
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
Ok(TlsListener {
|
||||
acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)),
|
||||
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
|
||||
config: self.tls,
|
||||
})
|
||||
}
|
||||
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint> {
|
||||
let inner = self.inner.bind_endpoint()?;
|
||||
Ok(inner.with_tls(&self.tls))
|
||||
}
|
||||
}
|
||||
|
||||
impl<L> Listener for TlsListener<L>
|
||||
where L: Listener<Accept = <L as Listener>::Connection>,
|
||||
L::Connection: AsyncRead + AsyncWrite
|
||||
{
|
||||
type Accept = L::Connection;
|
||||
|
||||
type Connection = TlsStream<L::Connection>;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
self.listener.accept().await
|
||||
}
|
||||
|
||||
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
|
||||
self.acceptor.accept(conn).await
|
||||
}
|
||||
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.listener.endpoint()?.with_tls(&self.config))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Connection> Connection for TlsStream<C> {
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.get_ref().0.endpoint()?.assume_tls())
|
||||
}
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
fn certificates(&self) -> Option<Certificates<'_>> {
|
||||
let cert_chain = self.get_ref().1.peer_certificates()?;
|
||||
Some(Certificates::from(cert_chain))
|
||||
}
|
||||
}
|
|
@ -1,48 +1,49 @@
|
|||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use either::{Either, Left, Right};
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use crate::fs::NamedFile;
|
||||
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||
use crate::listener::{Listener, Bind, Connection, Endpoint};
|
||||
use crate::util::unix;
|
||||
use crate::{Ignite, Rocket};
|
||||
|
||||
pub use tokio::net::UnixStream;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UdsConfig {
|
||||
/// Socket address.
|
||||
pub path: PathBuf,
|
||||
/// Recreate a socket that already exists.
|
||||
pub reuse: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct UdsListener {
|
||||
/// Unix domain sockets listener.
|
||||
///
|
||||
/// # Configuration
|
||||
///
|
||||
/// Reads the following configuration parameters:
|
||||
///
|
||||
/// | parameter | type | default | note |
|
||||
/// |-----------|--------------|---------|-------------------------------------------|
|
||||
/// | `address` | [`Endpoint`] | | required: must be `unix:path` |
|
||||
/// | `reuse` | boolean | `true` | whether to create/reuse/delete the socket |
|
||||
pub struct UnixListener {
|
||||
path: PathBuf,
|
||||
lock: Option<NamedFile>,
|
||||
listener: tokio::net::UnixListener,
|
||||
}
|
||||
|
||||
impl Bindable for UdsConfig {
|
||||
type Listener = UdsListener;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||
let lock = if self.reuse.unwrap_or(true) {
|
||||
let lock_ext = match self.path.extension().and_then(|s| s.to_str()) {
|
||||
impl UnixListener {
|
||||
pub async fn bind<P: AsRef<Path>>(path: P, reuse: bool) -> io::Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let lock = if reuse {
|
||||
let lock_ext = match path.extension().and_then(|s| s.to_str()) {
|
||||
Some(ext) if !ext.is_empty() => format!("{}.lock", ext),
|
||||
_ => "lock".to_string()
|
||||
};
|
||||
|
||||
let mut opts = tokio::fs::File::options();
|
||||
opts.create(true).write(true);
|
||||
let lock_path = self.path.with_extension(lock_ext);
|
||||
let lock_path = path.with_extension(lock_ext);
|
||||
let lock_file = NamedFile::open_with(lock_path, &opts).await?;
|
||||
|
||||
unix::lock_exclusive_nonblocking(lock_file.file())?;
|
||||
if self.path.exists() {
|
||||
tokio::fs::remove_file(&self.path).await?;
|
||||
if path.exists() {
|
||||
tokio::fs::remove_file(&path).await?;
|
||||
}
|
||||
|
||||
Some(lock_file)
|
||||
|
@ -55,9 +56,9 @@ impl Bindable for UdsConfig {
|
|||
// and this will succeed. So let's try a few times.
|
||||
let mut retries = 5;
|
||||
let listener = loop {
|
||||
match tokio::net::UnixListener::bind(&self.path) {
|
||||
match tokio::net::UnixListener::bind(&path) {
|
||||
Ok(listener) => break listener,
|
||||
Err(e) if self.path.exists() && lock.is_none() => return Err(e),
|
||||
Err(e) if path.exists() && lock.is_none() => return Err(e),
|
||||
Err(_) if retries > 0 => {
|
||||
retries -= 1;
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
@ -66,15 +67,31 @@ impl Bindable for UdsConfig {
|
|||
}
|
||||
};
|
||||
|
||||
Ok(UdsListener { lock, listener, path: self.path, })
|
||||
}
|
||||
|
||||
fn bind_endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(Endpoint::Unix(self.path.clone()))
|
||||
Ok(UnixListener { lock, listener, path: path.into() })
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for UdsListener {
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
|
||||
type Error = Either<figment::Error, io::Error>;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let endpoint = Self::bind_endpoint(&rocket)?;
|
||||
let path = endpoint.unix()
|
||||
.ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?;
|
||||
|
||||
let reuse: Option<bool> = rocket.figment().extract_inner("reuse").map_err(Left)?;
|
||||
Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?)
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf()));
|
||||
Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf)
|
||||
.map(Endpoint::Unix)
|
||||
.map_err(Left)
|
||||
}
|
||||
}
|
||||
|
||||
impl Listener for UnixListener {
|
||||
type Accept = UnixStream;
|
||||
|
||||
type Connection = Self::Accept;
|
||||
|
@ -98,7 +115,7 @@ impl Connection for UnixStream {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for UdsListener {
|
||||
impl Drop for UnixListener {
|
||||
fn drop(&mut self) {
|
||||
if let Some(lock) = &self.lock {
|
||||
let _ = std::fs::remove_file(&self.path);
|
||||
|
|
|
@ -154,13 +154,15 @@ impl log::Log for RocketLogger {
|
|||
}
|
||||
}
|
||||
|
||||
static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub(crate) fn init_default() {
|
||||
crate::log::init(&crate::Config::debug_default())
|
||||
if !ROCKET_LOGGER_SET.load(Ordering::Acquire) {
|
||||
crate::log::init(&crate::Config::debug_default())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn init(config: &crate::Config) {
|
||||
static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
// Try to initialize Rocket's logger, recording if we succeeded.
|
||||
if log::set_boxed_logger(Box::new(RocketLogger)).is_ok() {
|
||||
ROCKET_LOGGER_SET.store(true, Ordering::Release);
|
||||
|
|
|
@ -14,8 +14,9 @@ use crate::http::Status;
|
|||
///
|
||||
/// The request guard implementation succeeds if:
|
||||
///
|
||||
/// * MTLS is [configured](crate::mtls).
|
||||
/// * The client presents certificates.
|
||||
/// * The certificates are active and not yet expired.
|
||||
/// * The certificates are valid and not expired.
|
||||
/// * The client's certificate chain was signed by the CA identified by the
|
||||
/// configured `ca_certs` and with respect to SNI, if any. See [module level
|
||||
/// docs](crate::mtls) for configuration details.
|
||||
|
@ -24,7 +25,7 @@ use crate::http::Status;
|
|||
/// status of 401 Unauthorized.
|
||||
///
|
||||
/// If the certificate chain fails to validate or verify, the guard _fails_ with
|
||||
/// the respective [`Error`].
|
||||
/// the respective [`Error`] a status of 401 Unauthorized.
|
||||
///
|
||||
/// # Wrapping
|
||||
///
|
||||
|
|
|
@ -79,8 +79,8 @@ pub struct MtlsConfig {
|
|||
impl MtlsConfig {
|
||||
/// Constructs a `MtlsConfig` from a path to a PEM file with a certificate
|
||||
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
|
||||
/// method does no validation; it simply creates a structure suitable for
|
||||
/// passing into a [`TlsConfig`].
|
||||
/// method does no validation; it simply creates an [`MtlsConfig`] for later
|
||||
/// use.
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
|
@ -101,8 +101,7 @@ impl MtlsConfig {
|
|||
|
||||
/// Constructs a `MtlsConfig` from a byte buffer to a certificate authority
|
||||
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
|
||||
/// validation; it simply creates a structure suitable for passing into a
|
||||
/// [`TlsConfig`].
|
||||
/// validation; it simply creates an [`MtlsConfig`] for later use.
|
||||
///
|
||||
/// These certificates will be used to verify client-presented certificates
|
||||
/// in TLS connections.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::convert::Infallible;
|
||||
use std::fmt::Debug;
|
||||
use std::net::IpAddr;
|
||||
use std::convert::Infallible;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use crate::{Request, Route};
|
||||
use crate::outcome::{self, IntoOutcome, Outcome::*};
|
||||
|
@ -496,7 +496,7 @@ impl<'r> FromRequest<'r> for &'r Endpoint {
|
|||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<'r> FromRequest<'r> for std::net::SocketAddr {
|
||||
impl<'r> FromRequest<'r> for SocketAddr {
|
||||
type Error = Infallible;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::fmt;
|
||||
use std::{io, fmt};
|
||||
use std::ops::RangeFrom;
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
use std::borrow::Cow;
|
||||
|
@ -18,7 +18,7 @@ use crate::data::Limits;
|
|||
use crate::http::ProxyProto;
|
||||
use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie};
|
||||
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
||||
use crate::listener::{Certificates, Endpoint, Connection};
|
||||
use crate::listener::{Certificates, Endpoint};
|
||||
|
||||
/// The type of an incoming web request.
|
||||
///
|
||||
|
@ -44,11 +44,11 @@ pub(crate) struct ConnectionMeta {
|
|||
pub peer_certs: Option<Arc<Certificates<'static>>>,
|
||||
}
|
||||
|
||||
impl<C: Connection> From<&C> for ConnectionMeta {
|
||||
fn from(conn: &C) -> Self {
|
||||
impl ConnectionMeta {
|
||||
pub fn new(endpoint: io::Result<Endpoint>, certs: Option<Certificates<'_>>) -> Self {
|
||||
ConnectionMeta {
|
||||
peer_endpoint: conn.endpoint().ok(),
|
||||
peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new),
|
||||
peer_endpoint: endpoint.ok(),
|
||||
peer_certs: certs.map(|c| c.into_owned()).map(Arc::new),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,9 +114,8 @@ impl<'r> Builder<'r> {
|
|||
/// the same name exist, they are all removed, and only the new header and
|
||||
/// value will remain.
|
||||
///
|
||||
/// The type of `header` can be any type that implements `Into<Header>`.
|
||||
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and
|
||||
/// [hyper::header types](crate::http::hyper::header).
|
||||
/// The type of `header` can be any type that implements `Into<Header>`. See
|
||||
/// [trait implementations](Header#trait-implementations).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -144,9 +143,8 @@ impl<'r> Builder<'r> {
|
|||
/// `Response`. This allows for multiple headers with the same name and
|
||||
/// potentially different values to be present in the `Response`.
|
||||
///
|
||||
/// The type of `header` can be any type that implements `Into<Header>`.
|
||||
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType)
|
||||
/// and [`Accept`](crate::http::Accept).
|
||||
/// The type of `header` can be any type that implements `Into<Header>`. See
|
||||
/// [trait implementations](Header#trait-implementations).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -641,9 +639,8 @@ impl<'r> Response<'r> {
|
|||
|
||||
/// Sets the header `header` in `self`. Any existing headers with the name
|
||||
/// `header.name` will be lost, and only `header` will remain. The type of
|
||||
/// `header` can be any type that implements `Into<Header>`. This includes
|
||||
/// `Header` itself, [`ContentType`](crate::http::ContentType) and
|
||||
/// [`hyper::header` types](crate::http::hyper::header).
|
||||
/// `header` can be any type that implements `Into<Header>`. See [trait
|
||||
/// implementations](Header#trait-implementations).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -723,10 +720,7 @@ impl<'r> Response<'r> {
|
|||
|
||||
/// Adds a custom header with name `name` and value `value` to `self`. If
|
||||
/// `self` already contains headers with the name `name`, another header
|
||||
/// with the same `name` and `value` is added. The type of `header` can be
|
||||
/// any type implements `Into<Header>`. This includes `Header` itself,
|
||||
/// [`ContentType`](crate::http::ContentType) and [`hyper::header`
|
||||
/// types](crate::http::hyper::header).
|
||||
/// with the same `name` and `value` is added.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
|
|
@ -2,15 +2,16 @@ use std::fmt;
|
|||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::any::Any;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use yansi::Paint;
|
||||
use either::Either;
|
||||
use figment::{Figment, Provider};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::shutdown::{Stages, Shutdown};
|
||||
use crate::{sentinel, shield::Shield, Catcher, Config, Route};
|
||||
use crate::listener::{Bindable, DefaultListener, Endpoint, Listener};
|
||||
use crate::listener::{Bind, DefaultListener, Endpoint, Listener};
|
||||
use crate::router::Router;
|
||||
use crate::fairing::{Fairing, Fairings};
|
||||
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
|
||||
|
@ -681,19 +682,34 @@ impl Rocket<Ignite> {
|
|||
rocket
|
||||
}
|
||||
|
||||
async fn _launch(self) -> Result<Rocket<Ignite>, Error> {
|
||||
let config = self.figment().extract::<DefaultListener>()?;
|
||||
either::for_both!(config.base_bindable()?, base => {
|
||||
either::for_both!(config.tls_bindable(base), bindable => {
|
||||
self._launch_on(bindable).await
|
||||
})
|
||||
})
|
||||
async fn _launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>>
|
||||
{
|
||||
let bind_endpoint = B::bind_endpoint(&&self).ok();
|
||||
let listener: B = B::bind(&self).await
|
||||
.map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?;
|
||||
|
||||
let any: Box<dyn Any + Send + Sync> = Box::new(listener);
|
||||
match any.downcast::<DefaultListener>() {
|
||||
Ok(listener) => {
|
||||
let listener = *listener;
|
||||
crate::util::for_both!(listener, listener => {
|
||||
crate::util::for_both!(listener, listener => {
|
||||
self._launch_on(listener).await
|
||||
})
|
||||
})
|
||||
}
|
||||
Err(any) => {
|
||||
let listener = *any.downcast::<B>().unwrap();
|
||||
self._launch_on(listener).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error>
|
||||
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite
|
||||
async fn _launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
|
||||
where L: Listener + 'static,
|
||||
{
|
||||
let rocket = self.bind_and_serve(bindable, |rocket| async move {
|
||||
let rocket = self.listen_and_serve(listener, |rocket| async move {
|
||||
let rocket = Arc::new(rocket);
|
||||
|
||||
rocket.shutdown.spawn_listener(&rocket.config.shutdown);
|
||||
|
@ -996,19 +1012,31 @@ impl<P: Phase> Rocket<P> {
|
|||
/// }
|
||||
/// ```
|
||||
pub async fn launch(self) -> Result<Rocket<Ignite>, Error> {
|
||||
self.launch_with::<DefaultListener>().await
|
||||
}
|
||||
|
||||
pub async fn bind_launch<T, B: Bind<T>>(self, value: T) -> Result<Rocket<Ignite>, Error> {
|
||||
let endpoint = B::bind_endpoint(&value).ok();
|
||||
let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e)));
|
||||
self.launch_on(listener.await?).await
|
||||
}
|
||||
|
||||
pub async fn launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>>
|
||||
{
|
||||
match self.0.into_state() {
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch().await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch().await,
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await,
|
||||
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error>
|
||||
where <B::Listener as Listener>::Connection: AsyncRead + AsyncWrite
|
||||
pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
|
||||
where L: Listener + 'static,
|
||||
{
|
||||
match self.0.into_state() {
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await,
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(listener).await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch_on(listener).await,
|
||||
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,14 +6,14 @@ use std::time::Duration;
|
|||
use hyper::service::service_fn;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use futures::{Future, TryFutureExt, future::Either::*};
|
||||
use futures::{Future, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::{Ignite, Orbit, Request, Rocket};
|
||||
use crate::request::ConnectionMeta;
|
||||
use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
|
||||
use crate::listener::{Bindable, BouncedExt, CancellableExt, Listener};
|
||||
use crate::error::{log_server_error, ErrorKind};
|
||||
use crate::listener::{Listener, Connection, BouncedExt, CancellableExt};
|
||||
use crate::error::log_server_error;
|
||||
use crate::data::{IoStream, RawStream};
|
||||
use crate::util::{spawn_inspect, FutureExt, ReaderStream};
|
||||
use crate::http::Status;
|
||||
|
@ -91,31 +91,28 @@ async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
|
|||
}
|
||||
|
||||
impl Rocket<Ignite> {
|
||||
pub(crate) async fn bind_and_serve<B, R>(
|
||||
pub(crate) async fn listen_and_serve<L, R>(
|
||||
self,
|
||||
bindable: B,
|
||||
post_bind_callback: impl FnOnce(Rocket<Orbit>) -> R,
|
||||
listener: L,
|
||||
orbit_callback: impl FnOnce(Rocket<Orbit>) -> R,
|
||||
) -> Result<Arc<Rocket<Orbit>>>
|
||||
where B: Bindable,
|
||||
<B::Listener as Listener>::Connection: AsyncRead + AsyncWrite,
|
||||
where L: Listener + 'static,
|
||||
R: Future<Output = Result<Arc<Rocket<Orbit>>>>
|
||||
{
|
||||
let binding_endpoint = bindable.bind_endpoint().ok();
|
||||
let h12listener = bindable.bind()
|
||||
.map_err(|e| ErrorKind::Bind(binding_endpoint, Box::new(e)))
|
||||
.await?;
|
||||
let endpoint = listener.endpoint()?;
|
||||
|
||||
let endpoint = h12listener.endpoint()?;
|
||||
#[cfg(feature = "http3-preview")]
|
||||
if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) {
|
||||
use crate::error::ErrorKind;
|
||||
|
||||
let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone())
|
||||
.map_err(|e| ErrorKind::Bind(Some(endpoint.clone()), Box::new(e)))
|
||||
.await?;
|
||||
|
||||
let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]);
|
||||
let rocket = post_bind_callback(rocket).await?;
|
||||
let rocket = orbit_callback(rocket).await?;
|
||||
|
||||
let http12 = tokio::task::spawn(rocket.clone().serve12(h12listener));
|
||||
let http12 = tokio::task::spawn(rocket.clone().serve12(listener));
|
||||
let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener));
|
||||
let (r1, r2) = tokio::join!(http12, http3);
|
||||
r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??;
|
||||
|
@ -129,8 +126,8 @@ impl Rocket<Ignite> {
|
|||
}
|
||||
|
||||
let rocket = self.into_orbit(vec![endpoint]);
|
||||
let rocket = post_bind_callback(rocket).await?;
|
||||
rocket.clone().serve12(h12listener).await?;
|
||||
let rocket = orbit_callback(rocket).await?;
|
||||
rocket.clone().serve12(listener).await?;
|
||||
Ok(rocket)
|
||||
}
|
||||
}
|
||||
|
@ -160,11 +157,11 @@ impl Rocket<Orbit> {
|
|||
}
|
||||
|
||||
let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder));
|
||||
while let Some(accept) = listener.accept().unless(self.shutdown()).await? {
|
||||
while let Some(accept) = listener.accept().race(self.shutdown()).await.left().transpose()? {
|
||||
let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone());
|
||||
spawn_inspect(|e| log_server_error(&**e), async move {
|
||||
let conn = listener.connect(accept).io_unless(rocket.shutdown()).await?;
|
||||
let meta = ConnectionMeta::from(&conn);
|
||||
let conn = listener.connect(accept).race_io(rocket.shutdown()).await?;
|
||||
let meta = ConnectionMeta::new(conn.endpoint(), conn.certificates());
|
||||
let service = service_fn(|mut req| {
|
||||
let upgrade = hyper::upgrade::on(&mut req);
|
||||
let (parts, incoming) = req.into_parts();
|
||||
|
@ -173,9 +170,9 @@ impl Rocket<Orbit> {
|
|||
|
||||
let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone()));
|
||||
let mut server = pin!(server.serve_connection_with_upgrades(io, service));
|
||||
match server.as_mut().or(rocket.shutdown()).await {
|
||||
Left(result) => result,
|
||||
Right(()) => {
|
||||
match server.as_mut().race(rocket.shutdown()).await.left() {
|
||||
Some(result) => result,
|
||||
None => {
|
||||
server.as_mut().graceful_shutdown();
|
||||
server.await
|
||||
},
|
||||
|
@ -189,26 +186,26 @@ impl Rocket<Orbit> {
|
|||
#[cfg(feature = "http3-preview")]
|
||||
async fn serve3(self: Arc<Self>, listener: crate::listener::quic::QuicListener) -> Result<()> {
|
||||
let rocket = self.clone();
|
||||
let listener = Arc::new(listener.bounced());
|
||||
while let Some(accept) = listener.accept().unless(rocket.shutdown()).await? {
|
||||
let listener = Arc::new(listener);
|
||||
while let Some(Some(accept)) = listener.accept().race(rocket.shutdown()).await.left() {
|
||||
let (listener, rocket) = (listener.clone(), rocket.clone());
|
||||
spawn_inspect(|e: &io::Error| log_server_error(e), async move {
|
||||
let mut stream = listener.connect(accept).io_unless(rocket.shutdown()).await?;
|
||||
while let Some(mut conn) = stream.accept().io_unless(rocket.shutdown()).await? {
|
||||
let mut stream = listener.connect(accept).race_io(rocket.shutdown()).await?;
|
||||
while let Some(mut conn) = stream.accept().race_io(rocket.shutdown()).await? {
|
||||
let rocket = rocket.clone();
|
||||
spawn_inspect(|e: &io::Error| log_server_error(e), async move {
|
||||
let meta = ConnectionMeta::from(&conn);
|
||||
let meta = ConnectionMeta::new(conn.endpoint(), None);
|
||||
let rx = conn.rx.cancellable(rocket.shutdown.clone());
|
||||
let response = rocket.clone()
|
||||
.service(conn.parts, rx, None, meta)
|
||||
.map_err(io::Error::other)
|
||||
.io_unless(rocket.shutdown.mercy.clone())
|
||||
.race_io(rocket.shutdown.mercy.clone())
|
||||
.await?;
|
||||
|
||||
let grace = rocket.shutdown.grace.clone();
|
||||
match conn.tx.send_response(response).or(grace).await {
|
||||
Left(result) => result,
|
||||
Right(_) => Ok(conn.tx.cancel()),
|
||||
match conn.tx.send_response(response).race(grace).await.left() {
|
||||
Some(result) => result,
|
||||
None => Ok(conn.tx.cancel()),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ impl Shutdown {
|
|||
/// This function returns immediately; pending requests will continue to run
|
||||
/// until completion or expiration of the grace period, which ever comes
|
||||
/// first, before the actual shutdown occurs. The grace period can be
|
||||
/// configured via [`Shutdown::grace`](crate::config::ShutdownConfig::grace).
|
||||
/// configured via [`ShutdownConfig`]'s `grace` field.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use rocket::*;
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::crypto::{ring, CryptoProvider};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use futures::TryFutureExt;
|
||||
use figment::value::magic::{Either, RelativePathBuf};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use indexmap::IndexSet;
|
||||
use rustls::crypto::{ring, CryptoProvider};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||
|
||||
use crate::tls::resolver::DynResolver;
|
||||
use crate::tls::error::{Result, Error, KeyError};
|
||||
|
||||
/// TLS configuration: certificate chain, key, and ciphersuites.
|
||||
|
@ -35,7 +39,8 @@ use crate::tls::error::{Result, Error, KeyError};
|
|||
///
|
||||
/// Additionally, the `mutual` parameter controls if and how the server
|
||||
/// authenticates clients via mutual TLS. It works in concert with the
|
||||
/// [`mtls`](crate::mtls) module. See [`MtlsConfig`] for configuration details.
|
||||
/// [`mtls`](crate::mtls) module. See [`MtlsConfig`](crate::mtls::MtlsConfig)
|
||||
/// for configuration details.
|
||||
///
|
||||
/// In `Rocket.toml`, configuration might look like:
|
||||
///
|
||||
|
@ -78,7 +83,7 @@ use crate::tls::error::{Result, Error, KeyError};
|
|||
/// # assert_eq!(tls_config.ciphers().count(), 9);
|
||||
/// # assert!(!tls_config.prefer_server_cipher_order());
|
||||
/// ```
|
||||
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)]
|
||||
pub struct TlsConfig {
|
||||
/// Path to a PEM file with, or raw bytes for, a DER-encoded X.509 TLS
|
||||
/// certificate chain.
|
||||
|
@ -97,6 +102,8 @@ pub struct TlsConfig {
|
|||
#[cfg(feature = "mtls")]
|
||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||
pub(crate) mutual: Option<crate::mtls::MtlsConfig>,
|
||||
#[serde(skip)]
|
||||
pub(crate) resolver: Option<DynResolver>,
|
||||
}
|
||||
|
||||
/// A supported TLS cipher suite.
|
||||
|
@ -134,6 +141,7 @@ impl Default for TlsConfig {
|
|||
prefer_server_cipher_order: false,
|
||||
#[cfg(feature = "mtls")]
|
||||
mutual: None,
|
||||
resolver: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -430,8 +438,57 @@ impl TlsConfig {
|
|||
self.mutual.as_ref()
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<(), crate::tls::Error> {
|
||||
self.server_config().map(|_| ())
|
||||
/// Try to convert `self` into a [rustls] [`ServerConfig`].
|
||||
///
|
||||
/// [`ServerConfig`]: rustls::server::ServerConfig
|
||||
pub async fn server_config(&self) -> Result<rustls::server::ServerConfig> {
|
||||
let this = self.clone();
|
||||
tokio::task::spawn_blocking(move || this._server_config())
|
||||
.map_err(io::Error::other)
|
||||
.await?
|
||||
}
|
||||
|
||||
/// Try to convert `self` into a [rustls] [`ServerConfig`].
|
||||
///
|
||||
/// [`ServerConfig`]: rustls::server::ServerConfig
|
||||
pub(crate) fn _server_config(&self) -> Result<rustls::server::ServerConfig> {
|
||||
let provider = Arc::new(self.default_crypto_provider());
|
||||
|
||||
#[cfg(feature = "mtls")]
|
||||
let verifier = match self.mutual {
|
||||
Some(ref mtls) => {
|
||||
let ca = Arc::new(mtls.load_ca_certs()?);
|
||||
let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
|
||||
match mtls.mandatory {
|
||||
true => verifier.build()?,
|
||||
false => verifier.allow_unauthenticated().build()?,
|
||||
}
|
||||
},
|
||||
None => WebPkiClientVerifier::no_client_auth(),
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
let verifier = WebPkiClientVerifier::no_client_auth();
|
||||
|
||||
let mut tls_config = ServerConfig::builder_with_provider(provider)
|
||||
.with_safe_default_protocol_versions()?
|
||||
.with_client_cert_verifier(verifier)
|
||||
.with_single_cert(self.load_certs()?, self.load_key()?)?;
|
||||
|
||||
tls_config.ignore_client_order = self.prefer_server_cipher_order;
|
||||
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||
if cfg!(feature = "http2") {
|
||||
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||
}
|
||||
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
/// NOTE: This is a blocking function.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
self._server_config().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ pub enum Error {
|
|||
CertChain(std::io::Error),
|
||||
PrivKey(KeyError),
|
||||
CertAuth(rustls::Error),
|
||||
Config(figment::Error),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
|
@ -31,6 +32,7 @@ impl std::fmt::Display for Error {
|
|||
PrivKey(e) => write!(f, "failed to process private key: {e}"),
|
||||
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
|
||||
Bind(e) => write!(f, "failed to bind to network interface: {e}"),
|
||||
Config(e) => write!(f, "failed to read tls configuration: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -69,6 +71,7 @@ impl std::error::Error for Error {
|
|||
Error::PrivKey(e) => Some(e),
|
||||
Error::CertAuth(e) => Some(e),
|
||||
Error::Bind(e) => Some(&**e),
|
||||
Error::Config(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,3 +105,9 @@ impl From<std::convert::Infallible> for Error {
|
|||
v.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<figment::Error> for Error {
|
||||
fn from(value: figment::Error) -> Self {
|
||||
Error::Config(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::LazyConfigAcceptor;
|
||||
use rustls::server::{Acceptor, ServerConfig};
|
||||
|
||||
use crate::{Ignite, Rocket};
|
||||
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
|
||||
use crate::tls::{Error, TlsConfig};
|
||||
use super::resolver::DynResolver;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// A TLS listener over some listener interface L.
|
||||
pub struct TlsListener<L> {
|
||||
listener: L,
|
||||
config: TlsConfig,
|
||||
default: Arc<ServerConfig>,
|
||||
}
|
||||
|
||||
impl<T: Send, L: Bind<T>> Bind<(T, TlsConfig)> for TlsListener<L>
|
||||
where L: Listener<Accept = <L as Listener>::Connection>,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
async fn bind((inner, config): (T, TlsConfig)) -> Result<Self, Self::Error> {
|
||||
Ok(TlsListener {
|
||||
default: Arc::new(config.server_config().await?),
|
||||
listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result<Endpoint, Self::Error> {
|
||||
L::bind_endpoint(inner)
|
||||
.map(|e| e.with_tls(config))
|
||||
.map_err(|e| Error::Bind(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, L> Bind<&'r Rocket<Ignite>> for TlsListener<L>
|
||||
where L: Bind<&'r Rocket<Ignite>> + Listener<Accept = <L as Listener>::Connection>
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
config.resolver = DynResolver::extract(rocket);
|
||||
<Self as Bind<_>>::bind((rocket, config)).await
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
<Self as Bind<_>>::bind_endpoint(&(*rocket, config))
|
||||
}
|
||||
}
|
||||
|
||||
impl<L> Listener for TlsListener<L>
|
||||
where L: Listener<Accept = <L as Listener>::Connection>,
|
||||
L::Connection: AsyncRead + AsyncWrite
|
||||
{
|
||||
type Accept = L::Connection;
|
||||
|
||||
type Connection = TlsStream<L::Connection>;
|
||||
|
||||
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||
self.listener.accept().await
|
||||
}
|
||||
|
||||
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
|
||||
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
|
||||
let handshake = acceptor.await?;
|
||||
let hello = handshake.client_hello();
|
||||
let config = match &self.config.resolver {
|
||||
Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
|
||||
None => self.default.clone(),
|
||||
};
|
||||
|
||||
handshake.into_stream(config).await
|
||||
}
|
||||
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.listener.endpoint()?.with_tls(&self.config))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Connection> Connection for TlsStream<C> {
|
||||
fn endpoint(&self) -> io::Result<Endpoint> {
|
||||
Ok(self.get_ref().0.endpoint()?.assume_tls())
|
||||
}
|
||||
|
||||
fn certificates(&self) -> Option<Certificates<'_>> {
|
||||
#[cfg(feature = "mtls")] {
|
||||
let cert_chain = self.get_ref().1.peer_certificates()?;
|
||||
Some(Certificates::from(cert_chain))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "mtls"))]
|
||||
None
|
||||
}
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
mod error;
|
||||
mod resolver;
|
||||
mod listener;
|
||||
pub(crate) mod config;
|
||||
|
||||
pub use error::Result;
|
||||
pub use error::{Error, Result};
|
||||
pub use config::{TlsConfig, CipherSuite};
|
||||
pub use error::Error;
|
||||
pub use resolver::{Resolver, ClientHello, ServerConfig};
|
||||
pub use listener::{TlsListener, TlsStream};
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use rustls::server::{ClientHello, ServerConfig};
|
||||
|
||||
use crate::{Build, Ignite, Rocket};
|
||||
use crate::fairing::{self, Info, Kind};
|
||||
|
||||
/// Proxy type to get PartialEq + Debug impls.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct DynResolver(Arc<dyn Resolver>);
|
||||
|
||||
pub struct Fairing<T: ?Sized>(PhantomData<T>);
|
||||
|
||||
/// A dynamic TLS configuration resolver.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// This is an async trait. Implement it as follows:
|
||||
///
|
||||
/// ```rust
|
||||
/// # #[macro_use] extern crate rocket;
|
||||
/// use std::sync::Arc;
|
||||
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
|
||||
/// use rocket::{Rocket, Build};
|
||||
///
|
||||
/// struct MyResolver(Arc<ServerConfig>);
|
||||
///
|
||||
/// #[rocket::async_trait]
|
||||
/// impl Resolver for MyResolver {
|
||||
/// async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
|
||||
/// // This is equivalent to what the default resolver would do.
|
||||
/// let config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
/// let server_config = config.server_config().await?;
|
||||
/// Ok(MyResolver(Arc::new(server_config)))
|
||||
/// }
|
||||
///
|
||||
/// async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
|
||||
/// // return a `ServerConfig` based on `hello`; here we ignore it
|
||||
/// Some(self.0.clone())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// #[launch]
|
||||
/// fn rocket() -> _ {
|
||||
/// rocket::build().attach(MyResolver::fairing())
|
||||
/// }
|
||||
/// ```
|
||||
#[crate::async_trait]
|
||||
pub trait Resolver: Send + Sync + 'static {
|
||||
async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
|
||||
let _rocket = rocket;
|
||||
let type_name = std::any::type_name::<Self>();
|
||||
Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into())
|
||||
}
|
||||
|
||||
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
|
||||
|
||||
fn fairing() -> Fairing<Self> where Self: Sized {
|
||||
Fairing(PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
#[crate::async_trait]
|
||||
impl<T: Resolver> fairing::Fairing for Fairing<T> {
|
||||
fn info(&self) -> Info {
|
||||
Info {
|
||||
name: "Resolver Fairing",
|
||||
kind: Kind::Ignite | Kind::Singleton
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
|
||||
use yansi::Paint;
|
||||
|
||||
let result = T::init(&rocket).await;
|
||||
match result {
|
||||
Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc<dyn Resolver>)),
|
||||
Err(e) => {
|
||||
let name = std::any::type_name::<T>();
|
||||
error!("TLS resolver {} failed to initialize.", name.primary().bold());
|
||||
error_!("{e}");
|
||||
Err(rocket)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DynResolver {
|
||||
pub fn extract(rocket: &Rocket<Ignite>) -> Option<Self> {
|
||||
rocket.state::<Arc<dyn Resolver>>().map(|r| Self(r.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for DynResolver {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("Resolver").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for DynResolver {
|
||||
fn eq(&self, _: &Self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for DynResolver {
|
||||
type Target = dyn Resolver;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self.0
|
||||
}
|
||||
}
|
|
@ -22,39 +22,39 @@ pub fn spawn_inspect<E, F, Fut>(or: F, future: Fut)
|
|||
use std::io;
|
||||
use std::pin::pin;
|
||||
use std::future::Future;
|
||||
use futures::future::{select, Either};
|
||||
use either::Either;
|
||||
use futures::future;
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// Await `self` or `other`, whichever finishes first.
|
||||
async fn or<B: Future>(self, other: B) -> Either<Self::Output, B::Output> {
|
||||
match futures::future::select(pin!(self), pin!(other)).await {
|
||||
Either::Left((v, _)) => Either::Left(v),
|
||||
Either::Right((v, _)) => Either::Right(v),
|
||||
async fn race<B: Future>(self, other: B) -> Either<Self::Output, B::Output> {
|
||||
match future::select(pin!(self), pin!(other)).await {
|
||||
future::Either::Left((v, _)) => Either::Left(v),
|
||||
future::Either::Right((v, _)) => Either::Right(v),
|
||||
}
|
||||
}
|
||||
|
||||
/// Await `self` unless `trigger` completes. Returns `Ok(Some(T))` if `self`
|
||||
/// completes successfully before `trigger`, `Err(E)` if `self` completes
|
||||
/// unsuccessfully, and `Ok(None)` if `trigger` completes before `self`.
|
||||
async fn unless<T, E, K: Future>(self, trigger: K) -> Result<Option<T>, E>
|
||||
where Self: Future<Output = Result<T, E>>
|
||||
async fn race_io<T, K: Future>(self, trigger: K) -> io::Result<T>
|
||||
where Self: Future<Output = io::Result<T>>
|
||||
{
|
||||
match select(pin!(self), pin!(trigger)).await {
|
||||
Either::Left((v, _)) => Ok(Some(v?)),
|
||||
Either::Right((_, _)) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Await `self` unless `trigger` completes. If `self` completes before
|
||||
/// `trigger`, returns the result. Otherwise, always returns an `Err`.
|
||||
async fn io_unless<T, K: Future>(self, trigger: K) -> std::io::Result<T>
|
||||
where Self: Future<Output = std::io::Result<T>>
|
||||
{
|
||||
match select(pin!(self), pin!(trigger)).await {
|
||||
Either::Left((v, _)) => v,
|
||||
Either::Right((_, _)) => Err(io::Error::other("I/O terminated")),
|
||||
match future::select(pin!(self), pin!(trigger)).await {
|
||||
future::Either::Left((v, _)) => v,
|
||||
future::Either::Right((_, _)) => Err(io::Error::other("i/o terminated")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future + Sized> FutureExt for F { }
|
||||
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! for_both {
|
||||
($value:expr, $pattern:pat => $result:expr) => {
|
||||
match $value {
|
||||
tokio_util::either::Either::Left($pattern) => $result,
|
||||
tokio_util::either::Either::Right($pattern) => $result,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub use for_both;
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::net::{SocketAddr, Ipv4Addr};
|
|||
use rocket::config::Config;
|
||||
use rocket::fairing::AdHoc;
|
||||
use rocket::futures::channel::oneshot;
|
||||
use rocket::listener::tcp::TcpListener;
|
||||
|
||||
#[rocket::async_test]
|
||||
async fn on_ignite_fairing_can_inspect_port() {
|
||||
|
@ -15,6 +16,7 @@ async fn on_ignite_fairing_can_inspect_port() {
|
|||
})
|
||||
}));
|
||||
|
||||
rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))));
|
||||
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
|
||||
rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr));
|
||||
assert_ne!(rx.await.unwrap(), 0);
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ rocket = { path = "../../core/lib", features = ["secrets"] }
|
|||
|
||||
[dev-dependencies]
|
||||
rocket = { path = "../../core/lib", features = ["secrets", "json", "mtls"] }
|
||||
figment = { version = "0.10", features = ["toml", "env"] }
|
||||
figment = { version = "0.10.17", features = ["toml", "env"] }
|
||||
tokio = { version = "1", features = ["macros", "io-std"] }
|
||||
rand = "0.8"
|
||||
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCjle13u/R/0+zw
|
||||
eycXhdF7ZNYQfqXfkMpw9GlerbqRrxSLEc/YXXBuIO5AZKkXYeP8iM9KbSBD4p8F
|
||||
wZD7LL47601c5WwWpNfOravCaSjYgvaYyhnoNzmG8NYaVYKB9kup6lOyQmesNXEK
|
||||
NGNSrKpsoaQ7jBk+l+VV1jNBjMhNVWuz4AdFMsVD09QyL1GvQ0OvT/BbUKypaFFw
|
||||
YcHruYvHuKGnrlXkvw05aZmKtKiSE6UQoDKtZWfV8yV2M6Sr75i9GKaGMyUZIl88
|
||||
MxVLGcGwO6To2wNFKfLkHLOGIWrKA7m/Bb2n1k2OT+6iOnDzU62BoAzG/j8dhNPL
|
||||
mZ6a7cZfAgMBAAECggEANwiZe06gUuDZNY44+JDsiLbDzYjOBQiREq8nQ9LukVR1
|
||||
dNPpOME2sdYiUUeMG3GzYaIlGsTbtfrnxOf5/oZu+XmP7VDBrFyIvd9viVgXhb+J
|
||||
dp2HWbg6gktDvFhIL7DMg71xqubsOeNAxE4bnBS6wREgT2gylfxECzykwci7Gki4
|
||||
AkeihvaxqdHk9WP8dtFOuCYhX5pyKd9veS1/L01dVMpoFrq72PHupplKYb3HIo+v
|
||||
ga02DhNVcH3fomEbXzazC64k2h5Vz+8mgpu5/V1thKiB2izOwt/hv4tkf2iDNz43
|
||||
xdSYUEFsk80M97VI1dM1+TBe/JO0auZvKLkuOWUjAQKBgQDlBMr+d+guajgQ863I
|
||||
uEFK4veEXrD51L6AKT+cqFUi894fhOodnnmK8l3JBKzO0zjgsaez8exKZPRN8An8
|
||||
4MejM+hMYciJsP7uDpPkhlI5zHd9CR7EFPWXXpt4PecQLvBbnJ/lDnWCrE4m5Zhs
|
||||
9OR7izLMBAmaiPlTNAaXj22iqwKBgQC226wzXGr//lnTggZX+u9UdkZKewAYlgnB
|
||||
Ywj3+JB6Q/kDDS8C6fdlAvWyHShxtO3gx2pJSI3hk7J8fZu/kbojlLF16ayO+tgg
|
||||
t3EoTZxN5zncygPaULstdKHhnMp8a4AO8lLrHtackFbbX7fuUJft0w457FpARvM8
|
||||
DONjWI8LHQKBgBBY5TyAxpv5jQL4weDf9hkoVk6mi69plieDyjyeb2VNTv+k9yki
|
||||
FL7sSfF9WfBxd0/innvjuuAckKu3hJ7+VIG7xMse97eMYMYRWFEpnVju1WChdAa/
|
||||
EEC7yhEtKf8nupRve6JYA99N+U4heV3dpSmEaB3T8/OJ73IW9pl+7W59AoGADxM/
|
||||
OCDHZYF3sFtI4Jn8fy8dDmjjkiNUfJAInkDs0FeoQNsmZAwb7ET5Moz615z9+4kV
|
||||
NyN3JwDBN0g3vexqtyI8Gyd/pW4CwXe+KX90gmustoolFSuQsueprOr7OpS2QwUx
|
||||
Vtb9BH1V29IhXNFiJSZARwA4VJJE3U+Gs5sKd/UCgYEAoCPE3gVaa89nOqQtalhT
|
||||
9SISOGQxxMknjNFrEuF3UaGuR0cxDRLX6lSEneAATEpho0QB2Fj4vO8PiyYOGvH+
|
||||
5ouJD97rcU77OOixlLFt4+TAWI9AvT0mN7y+SHJ22RkwWGQyF4TIfkg0tQvu36D+
|
||||
35W26Li1WteB2O4wV9qVReA=
|
||||
-----END PRIVATE KEY-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEwDCCAqigAwIBAgIUay5Z8sVQUkSTFpacn6o4iq2ElGowDQYJKoZIhvcNAQEL
|
||||
BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg
|
||||
Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTI0MDQxNDA4MTU0MVoXDTI1
|
||||
MDQxNDA4MTU0MVowgY4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh
|
||||
MRcwFQYDVQQHDA5TaWxpY29uIFZhbGxleTEPMA0GA1UECgwGUm9ja2V0MRswGQYD
|
||||
VQQDDBJSb2NrZXQgVExTIEV4YW1wbGUxIzAhBgkqhkiG9w0BCQEWFGV4YW1wbGVA
|
||||
cm9ja2V0LmxvY2FsMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAo5Xt
|
||||
d7v0f9Ps8HsnF4XRe2TWEH6l35DKcPRpXq26ka8UixHP2F1wbiDuQGSpF2Hj/IjP
|
||||
Sm0gQ+KfBcGQ+yy+O+tNXOVsFqTXzq2rwmko2IL2mMoZ6Dc5hvDWGlWCgfZLqepT
|
||||
skJnrDVxCjRjUqyqbKGkO4wZPpflVdYzQYzITVVrs+AHRTLFQ9PUMi9Rr0NDr0/w
|
||||
W1CsqWhRcGHB67mLx7ihp65V5L8NOWmZirSokhOlEKAyrWVn1fMldjOkq++YvRim
|
||||
hjMlGSJfPDMVSxnBsDuk6NsDRSny5ByzhiFqygO5vwW9p9ZNjk/uojpw81OtgaAM
|
||||
xv4/HYTTy5memu3GXwIDAQABo1wwWjAYBgNVHREEETAPgg1ETlM6bG9jYWxob3N0
|
||||
MB0GA1UdDgQWBBSowDBXM26C7VogwXNB1F0vLpYO7DAfBgNVHSMEGDAWgBREAyUj
|
||||
0lTwopZ2B1VmnvMPfUtCkzANBgkqhkiG9w0BAQsFAAOCAgEAbjF11+t8qVEF72ey
|
||||
19p1sRkG9ygb0gE2UpLzVpPilucioIOwQuT4rvsVYZQxK+smQZURDI4uNXODIeoS
|
||||
r3maL82VryLSYbkQADyShYjF0uCX8AfCI0YtOKOschNZDcZEJ5mUpHjJE0lEZnkO
|
||||
x8ZVXwWf4pv1/8DZoCkMN3gDHwhQGPtrls4q7O38rI7zK9DNrzu7R1ZdGjQSDasL
|
||||
6DqHee90O2ejpELUxO6lRl2EUosfklRvjV7hfrDHlpN9EuweXt0JiaKw3WZzHSLa
|
||||
dKS8wtTMq5aWzOWrew1ZEhRr+B3KS6BSC5o9xSQMfcDyS0KJcIJI9bNh3nElWFhM
|
||||
IBVtGxM/EYAwNJ++jLD10WHvaqW0epMV2cUu+dGJX+TPuI0c/wNehisS4ahvR64m
|
||||
UpjAwNUBlYpR/Gb15/i2fVk2BbUyU3AcpZfWFDopQ8UqC8ALVcNjbNHq+yVkuTpj
|
||||
gn5iiTTcTqb6qNfie4oDX4KR6ZgpNiTl/PWZo58qxSwdGiJwrINACkPJ6Qg6Qrpd
|
||||
hp3CanTWjioHfvTSdiubqw5/XRnqa2Iav0Sttc6TPnTimodmtWkaYA8mvjS+jq8N
|
||||
f9l2UYQz8yLabMkn98BM+gRJYwrVt6sCbVuEaHgPwq/qX9mQFhUrfw3iEPKlmezt
|
||||
T3AhgPhybUpMFpu+4Tp8JE2JlKQ=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFbzCCA1egAwIBAgIURX345HUrWikAysSTFd8xoV5GSIYwDQYJKoZIhvcNAQEL
|
||||
BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg
|
||||
Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTIxMDcwOTIzMzMzM1oXDTMx
|
||||
MDcwNzIzMzMzM1owRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQK
|
||||
DAlSb2NrZXQgQ0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMIICIjANBgkqhkiG
|
||||
9w0BAQEFAAOCAg8AMIICCgKCAgEAybxw0cVrq8yn9W9xRDHdci8rnA5CxPcxAyM5
|
||||
y5LCFOV/tTY0IZgrmegRLo4blmz8QNiFta9Ydt7zsm7XUTm6BhJ7TkOUAcfP7eSv
|
||||
3jNEIEJQLU+k5SepV7pwFPRjUr6+a7yypS2xXAkDEVoyvzsuKYwzj+x6HvDuVhOF
|
||||
2zv4Kk0sLfS/3UelMdilKa5VBCL/WMEXaCpb7/BMUUwn868LVU8E9+1H6uDQMxKo
|
||||
ZH1mH98yeFODjzM9Ko6n2ghXx8qbe+wab4mSHn/SPgFnDFU+ujyPXIQqrS4PSQW3
|
||||
5lkCn70hOw2K+8LHDBmgxOLk2Jb8o8PJWX6v346dlRcZr9VzMqCyKvEf1i5oT2hg
|
||||
NZrkDdUOgyMZeq6H7pQpSxSFSMtkaombSm816V0rg7/sXwS66KyaYJY7x8eYEpgd
|
||||
GuQKXkyIwp687TGLul97amoy/J3jIDnQOuf/YEcdyHCKojh20E5AERC4sCg6l+qs
|
||||
5Nbol7jZclzBFf+70JOsUFmCfVYd5e0LKWdYV9UhYABc3yQqJyzy/eyihWihUNZU
|
||||
LXStjd+XIkhKs+b7uKaBp1poFfgjpdboxmREyppWexua1t0eAReBgMU43bEGoy+B
|
||||
iWoTFjyeQijd6M++npzsqwknYyv+7VjX3EfijyTFgIpZUL196PTJ5SGJMf7eJmaG
|
||||
BO0g2W0CAwEAAaNTMFEwHQYDVR0OBBYEFEQDJSPSVPCilnYHVWae8w99S0KTMB8G
|
||||
A1UdIwQYMBaAFEQDJSPSVPCilnYHVWae8w99S0KTMA8GA1UdEwEB/wQFMAMBAf8w
|
||||
DQYJKoZIhvcNAQELBQADggIBACCArR/ArOyoh97Pgie37miFJEJNtAe+ApzhDevh
|
||||
11P0Vn5hbu+dR7vftCJ7e+0u0irbPgfdxuW3IpEDgL+fttqCIRdAT6MTKLiUmrVS
|
||||
x0fQJqC4Hw4o+bIeRsaNAfB/SAEvOxBbyu5szOFak1r/sXVs4vzBINIF3NdtbNtj
|
||||
Bhac0Fiy/+DlfTHJSRGvzYo+GljXHkrG02mF4aOWx9x97y/6UzbLqHJPINgyAIlN
|
||||
ts29QIHVNtQQyUN292xC1F4TSrBNB+GziGt3XZ8YEASCkMEnIvs3Lpzsjjm9TrkE
|
||||
W/b9ee3C6RWg+RW3pokORMM7Q/lSOMWUmPrzI7CBCKaQUNN9g+iimLkPyp386sCS
|
||||
zXJDd0OKb0xkpxhrauEvzNfEJxGDQbxs8s598ZofhVo9ehdmmXcJAw/zUZjHSrI2
|
||||
PW+vHJ4kslBmKtH1oyAW3zYiFyYYPu4ohkeSrq8z8351upxwJUm4m/ndByXTrPwz
|
||||
Yj6dEHaysjoRl0wOJgQ7G2ikw1QtWja2apJN9Q66i98vEDmtoEyOqOLMSjKjFL7c
|
||||
sSJ6vAittYtIziIeMK7E8lDc1rtzMT5MOAoTriVyIGBgHFs96YOoL0Vi5QmVtQtc
|
||||
8dkFUapFAUj8pREVxnJoLGose/FxBvF2FQZ5Sb25pyTPAeXk7y56noF78nusiVSF
|
||||
xRjI
|
||||
-----END CERTIFICATE-----
|
|
@ -9,6 +9,7 @@
|
|||
# ecdsa_nistp256_sha256
|
||||
# ecdsa_nistp384_sha384
|
||||
# ecdsa_nistp521_sha512
|
||||
# client
|
||||
#
|
||||
# Generate a certificate of the [cert-kind] key type, or if no cert-kind is
|
||||
# specified, all of the certificates.
|
||||
|
@ -136,12 +137,23 @@ function gen_ecdsa_nistp521_sha512() {
|
|||
rm ca_cert.srl server.csr ecdsa_nistp521_sha512_key.pem
|
||||
}
|
||||
|
||||
function gen_client_cert() {
|
||||
openssl req -newkey rsa:2048 -nodes -keyout client.key -out client.csr
|
||||
openssl x509 -req -extfile <(printf "subjectAltName=DNS:${ALT}") -days 365 \
|
||||
-in client.csr -CA ca_cert.pem -CAkey ca_key.pem -CAcreateserial \
|
||||
-out client.crt
|
||||
|
||||
cat client.key client.crt ca_cert.pem > client.pem
|
||||
rm client.key client.crt client.csr ca_cert.srl
|
||||
}
|
||||
|
||||
case $1 in
|
||||
ed25519) gen_ed25519 ;;
|
||||
rsa_sha256) gen_rsa_sha256 ;;
|
||||
ecdsa_nistp256_sha256) gen_ecdsa_nistp256_sha256 ;;
|
||||
ecdsa_nistp384_sha384) gen_ecdsa_nistp384_sha384 ;;
|
||||
ecdsa_nistp521_sha512) gen_ecdsa_nistp521_sha512 ;;
|
||||
client) gen_client_cert ;;
|
||||
*)
|
||||
gen_ed25519
|
||||
gen_rsa_sha256
|
||||
|
|
|
@ -7,6 +7,7 @@ use rocket::log::LogLevel;
|
|||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
|
||||
use rocket::fairing::{Fairing, Info, Kind};
|
||||
use rocket::response::Redirect;
|
||||
use rocket::listener::tcp::TcpListener;
|
||||
|
||||
use yansi::Paint;
|
||||
|
||||
|
@ -59,7 +60,7 @@ impl Redirector {
|
|||
rocket::custom(&config.server)
|
||||
.manage(config)
|
||||
.mount("/", redirects)
|
||||
.launch_on(addr)
|
||||
.bind_launch::<_, TcpListener>(addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,8 +66,7 @@ fn insecure_cookies() {
|
|||
}
|
||||
|
||||
fn validate_profiles(profiles: &[&str]) {
|
||||
use rocket::listener::DefaultListener;
|
||||
use rocket::config::{Config, SecretKey};
|
||||
use rocket::config::{Config, TlsConfig, SecretKey};
|
||||
|
||||
for profile in profiles {
|
||||
let config = Config {
|
||||
|
@ -81,9 +80,8 @@ fn validate_profiles(profiles: &[&str]) {
|
|||
assert_eq!(response.into_string().unwrap(), "Hello, world!");
|
||||
|
||||
let figment = client.rocket().figment();
|
||||
let listener: DefaultListener = figment.extract().unwrap();
|
||||
assert_eq!(figment.profile(), profile);
|
||||
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
|
||||
let config: TlsConfig = figment.extract_inner("tls").unwrap();
|
||||
config.validate().expect("valid TLS config");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -171,6 +171,15 @@ function test_default() {
|
|||
echo ":: Checking fuzzers..."
|
||||
indir "${FUZZ_ROOT}" $CARGO update
|
||||
indir "${FUZZ_ROOT}" $CARGO check --all --all-features $@
|
||||
|
||||
case "$OSTYPE" in
|
||||
darwin* | linux*)
|
||||
echo ":: Checking testbench..."
|
||||
indir "${TESTBENCH_ROOT}" $CARGO update
|
||||
indir "${TESTBENCH_ROOT}" $CARGO check $@
|
||||
;;
|
||||
*) echo ":: Skipping testbench [$OSTYPE]" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
function test_ui() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "rocket-testbench"
|
||||
description = "end-to-end HTTP testbench for Rocket"
|
||||
name = "testbench"
|
||||
description = "End-to-end HTTP Rocket testbench."
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
@ -12,6 +12,7 @@ thiserror = "1.0"
|
|||
procspawn = "1"
|
||||
pretty_assertions = "1.4.0"
|
||||
ipc-channel = "0.18"
|
||||
rustls-pemfile = "2.1"
|
||||
|
||||
[dependencies.nix]
|
||||
version = "0.28"
|
||||
|
|
|
@ -1,206 +1,64 @@
|
|||
use std::time::Duration;
|
||||
use std::sync::Once;
|
||||
use std::process::Stdio;
|
||||
use std::io::{self, Read};
|
||||
|
||||
use rocket::fairing::AdHoc;
|
||||
use rocket::http::ext::IntoOwned;
|
||||
use rocket::http::uri::{self, Absolute, Uri};
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use rocket::{Build, Rocket};
|
||||
use reqwest::blocking::{ClientBuilder, RequestBuilder};
|
||||
use rocket::http::{ext::IntoOwned, uri::{Absolute, Uri}};
|
||||
|
||||
use procspawn::SpawnError;
|
||||
use thiserror::Error;
|
||||
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender};
|
||||
|
||||
static DEFAULT_CONFIG: &str = r#"
|
||||
[default]
|
||||
address = "tcp:127.0.0.1"
|
||||
workers = 2
|
||||
port = 0
|
||||
cli_colors = false
|
||||
secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg="
|
||||
|
||||
[default.shutdown]
|
||||
grace = 1
|
||||
mercy = 1
|
||||
"#;
|
||||
use crate::{Result, Error, Server};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct Client {
|
||||
client: reqwest::blocking::Client,
|
||||
server: procspawn::JoinHandle<()>,
|
||||
tls: bool,
|
||||
port: u16,
|
||||
rx: IpcReceiver<Message>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("join/kill failed: {0}")]
|
||||
JoinError(#[from] SpawnError),
|
||||
#[error("kill failed: {0}")]
|
||||
TermFailure(#[from] nix::errno::Errno),
|
||||
#[error("i/o error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("invalid URI: {0}")]
|
||||
Uri(#[from] uri::Error<'static>),
|
||||
#[error("the URI is invalid")]
|
||||
InvalidUri,
|
||||
#[error("bad request: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
#[error("IPC failure: {0}")]
|
||||
Ipc(#[from] ipc_channel::ipc::IpcError),
|
||||
#[error("liftoff failed")]
|
||||
Liftoff(String, String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub enum Message {
|
||||
Liftoff(bool, u16),
|
||||
Failure,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
#[must_use]
|
||||
pub struct Token(String);
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
impl Token {
|
||||
fn configure(&self, toml: &str, rocket: Rocket<Build>) -> Rocket<Build> {
|
||||
use rocket::figment::{Figment, providers::{Format, Toml}};
|
||||
|
||||
let toml = toml.replace("{CRATE}", env!("CARGO_MANIFEST_DIR"));
|
||||
let config = Figment::from(rocket.figment())
|
||||
.merge(Toml::string(DEFAULT_CONFIG).nested())
|
||||
.merge(Toml::string(&toml).nested());
|
||||
|
||||
let server = self.0.clone();
|
||||
rocket.configure(config)
|
||||
.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
|
||||
let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap();
|
||||
let tls = rocket.endpoints().any(|e| e.is_tls());
|
||||
let sender = IpcSender::<Message>::connect(server).unwrap();
|
||||
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
|
||||
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
|
||||
})))
|
||||
}
|
||||
|
||||
pub fn rocket(&self, toml: &str) -> Rocket<Build> {
|
||||
self.configure(toml, rocket::build())
|
||||
}
|
||||
|
||||
pub fn configured_launch(self, toml: &str, rocket: Rocket<Build>) {
|
||||
let rocket = self.configure(toml, rocket);
|
||||
if let Err(e) = rocket::execute(rocket.launch()) {
|
||||
let sender = IpcSender::<Message>::connect(self.0).unwrap();
|
||||
let _ = sender.send(Message::Failure);
|
||||
let _ = sender.send(Message::Failure);
|
||||
e.pretty_print();
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn launch(self, rocket: Rocket<Build>) {
|
||||
self.configured_launch(DEFAULT_CONFIG, rocket)
|
||||
}
|
||||
}
|
||||
pub fn start(f: fn(Token)) -> Result<Client> {
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(procspawn::init);
|
||||
|
||||
let (ipc, server) = IpcOneShotServer::new()?;
|
||||
let mut server = procspawn::Builder::new()
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn(Token(server), f);
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.cookie_store(true)
|
||||
.tls_info(true)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.build()?;
|
||||
|
||||
let (rx, _) = ipc.accept().unwrap();
|
||||
match rx.recv() {
|
||||
Ok(Message::Liftoff(tls, port)) => Ok(Client { client, server, tls, port, rx }),
|
||||
Ok(Message::Failure) => {
|
||||
let stdout = server.stdout().unwrap();
|
||||
let mut out = String::new();
|
||||
stdout.read_to_string(&mut out)?;
|
||||
|
||||
let stderr = server.stderr().unwrap();
|
||||
let mut err = String::new();
|
||||
stderr.read_to_string(&mut err)?;
|
||||
Err(Error::Liftoff(out, err))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub fn default() -> Result<Client> {
|
||||
start(|token| token.launch(rocket::build()))
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn read_stdout(&mut self) -> Result<String> {
|
||||
let Some(stdout) = self.server.stdout() else {
|
||||
return Ok(String::new());
|
||||
};
|
||||
|
||||
let mut string = String::new();
|
||||
stdout.read_to_string(&mut string)?;
|
||||
Ok(string)
|
||||
pub fn default() -> Client {
|
||||
Client::build()
|
||||
.try_into()
|
||||
.expect("default builder ok")
|
||||
}
|
||||
|
||||
pub fn read_stderr(&mut self) -> Result<String> {
|
||||
let Some(stderr) = self.server.stderr() else {
|
||||
return Ok(String::new());
|
||||
};
|
||||
|
||||
let mut string = String::new();
|
||||
stderr.read_to_string(&mut string)?;
|
||||
Ok(string)
|
||||
pub fn build() -> ClientBuilder {
|
||||
reqwest::blocking::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.cookie_store(true)
|
||||
.tls_info(true)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
}
|
||||
|
||||
pub fn kill(&mut self) -> Result<()> {
|
||||
Ok(self.server.kill()?)
|
||||
}
|
||||
|
||||
pub fn terminate(&mut self) -> Result<()> {
|
||||
use nix::{sys::signal, unistd::Pid};
|
||||
|
||||
let pid = Pid::from_raw(self.server.pid().unwrap() as i32);
|
||||
Ok(signal::kill(pid, signal::SIGTERM)?)
|
||||
}
|
||||
|
||||
pub fn wait(&mut self) -> Result<()> {
|
||||
match self.server.join_timeout(Duration::from_secs(5)) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) if e.is_remote_close() => Ok(()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, url: &str) -> Result<reqwest::blocking::RequestBuilder> {
|
||||
pub fn get(&self, server: &Server, url: &str) -> Result<RequestBuilder> {
|
||||
let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? {
|
||||
Uri::Origin(uri) => {
|
||||
let proto = if self.tls { "https" } else { "http" };
|
||||
let uri = format!("{proto}://127.0.0.1:{}{uri}", self.port);
|
||||
let proto = if server.tls { "https" } else { "http" };
|
||||
let uri = format!("{proto}://127.0.0.1:{}{uri}", server.port);
|
||||
Absolute::parse_owned(uri)?
|
||||
}
|
||||
Uri::Absolute(uri) => uri,
|
||||
_ => return Err(Error::InvalidUri),
|
||||
Uri::Absolute(mut uri) => {
|
||||
if let Some(auth) = uri.authority() {
|
||||
let mut auth = auth.clone();
|
||||
auth.set_port(server.port);
|
||||
uri.set_authority(auth);
|
||||
}
|
||||
|
||||
uri
|
||||
}
|
||||
uri => return Err(Error::InvalidUri(uri.into_owned())),
|
||||
};
|
||||
|
||||
Ok(self.client.get(uri.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::blocking::Client> for Client {
|
||||
fn from(client: reqwest::blocking::Client) -> Self {
|
||||
Client { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ClientBuilder> for Client {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(builder: ClientBuilder) -> Result<Self, Self::Error> {
|
||||
Ok(Client { client: builder.build()? })
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,35 @@
|
|||
pub mod client;
|
||||
// pub mod session;
|
||||
mod client;
|
||||
mod server;
|
||||
|
||||
pub use server::*;
|
||||
pub use client::*;
|
||||
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use procspawn::SpawnError;
|
||||
use rocket::http::uri;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("join/kill failed: {0}")]
|
||||
JoinError(#[from] SpawnError),
|
||||
#[error("kill failed: {0}")]
|
||||
TermFailure(#[from] nix::errno::Errno),
|
||||
#[error("i/o error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
#[error("invalid URI: {0}")]
|
||||
Uri(#[from] uri::Error<'static>),
|
||||
#[error("invalid uri: {0}")]
|
||||
InvalidUri(uri::Uri<'static>),
|
||||
#[error("expected certificates are not present")]
|
||||
MissingCertificate,
|
||||
#[error("bad request: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
#[error("IPC failure: {0}")]
|
||||
Ipc(#[from] ipc_channel::ipc::IpcError),
|
||||
#[error("liftoff failed")]
|
||||
Liftoff(String, String),
|
||||
}
|
||||
|
|
|
@ -1,54 +1,93 @@
|
|||
use rocket::{fairing::AdHoc, *};
|
||||
use rocket_testbench::client::{self, Error};
|
||||
use reqwest::tls::TlsInfo;
|
||||
use std::process::ExitCode;
|
||||
use std::time::Duration;
|
||||
|
||||
fn run() -> client::Result<()> {
|
||||
let mut client = client::start(|token| {
|
||||
#[get("/")]
|
||||
fn index() -> &'static str {
|
||||
"Hello, world!"
|
||||
}
|
||||
use rocket::listener::unix::UnixListener;
|
||||
use rocket::tokio::net::TcpListener;
|
||||
use rocket::yansi::Paint;
|
||||
use rocket::{get, routes, Build, Rocket, State};
|
||||
use reqwest::{tls::TlsInfo, Identity};
|
||||
use testbench::*;
|
||||
|
||||
token.configured_launch(r#"
|
||||
[default.tls]
|
||||
certs = "{CRATE}/../examples/tls/private/rsa_sha256_cert.pem"
|
||||
key = "{CRATE}/../examples/tls/private/rsa_sha256_key.pem"
|
||||
"#, rocket::build().mount("/", routes![index]));
|
||||
})?;
|
||||
static DEFAULT_CONFIG: &str = r#"
|
||||
[default]
|
||||
address = "tcp:127.0.0.1"
|
||||
workers = 2
|
||||
port = 0
|
||||
cli_colors = false
|
||||
secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg="
|
||||
|
||||
let response = client.get("/")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
assert!(!tls.peer_certificate().unwrap().is_empty());
|
||||
assert_eq!(response.text()?, "Hello, world!");
|
||||
[default.shutdown]
|
||||
grace = 1
|
||||
mercy = 1
|
||||
"#;
|
||||
|
||||
client.terminate()?;
|
||||
let stdout = client.read_stdout()?;
|
||||
assert!(stdout.contains("Rocket has launched on https"));
|
||||
assert!(stdout.contains("Graceful shutdown completed"));
|
||||
assert!(stdout.contains("GET /"));
|
||||
Ok(())
|
||||
static TLS_CONFIG: &str = r#"
|
||||
[default.tls]
|
||||
certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem"
|
||||
key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem"
|
||||
"#;
|
||||
|
||||
trait RocketExt {
|
||||
fn default() -> Self;
|
||||
fn tls_default() -> Self;
|
||||
fn configure_with_toml(self, toml: &str) -> Self;
|
||||
}
|
||||
|
||||
fn run_fail() -> client::Result<()> {
|
||||
let client = client::start(|token| {
|
||||
let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) });
|
||||
token.launch(rocket::build().attach(fail));
|
||||
});
|
||||
impl RocketExt for Rocket<Build> {
|
||||
fn default() -> Self {
|
||||
rocket::build().configure_with_toml(DEFAULT_CONFIG)
|
||||
}
|
||||
|
||||
if let Err(Error::Liftoff(stdout, _)) = client {
|
||||
fn tls_default() -> Self {
|
||||
rocket::build()
|
||||
.configure_with_toml(DEFAULT_CONFIG)
|
||||
.configure_with_toml(TLS_CONFIG)
|
||||
}
|
||||
|
||||
fn configure_with_toml(self, toml: &str) -> Self {
|
||||
use rocket::figment::{Figment, providers::{Format, Toml}};
|
||||
|
||||
let toml = toml.replace("{ROCKET}", rocket::fs::relative!("../"));
|
||||
let config = Figment::from(self.figment())
|
||||
.merge(Toml::string(&toml).nested());
|
||||
|
||||
self.configure(config)
|
||||
}
|
||||
}
|
||||
|
||||
fn read(path: &str) -> Result<Vec<u8>> {
|
||||
let path = path.replace("{ROCKET}", rocket::fs::relative!("../"));
|
||||
Ok(std::fs::read(path)?)
|
||||
}
|
||||
|
||||
fn cert(path: &str) -> Result<Vec<u8>> {
|
||||
let mut data = std::io::Cursor::new(read(path)?);
|
||||
let cert = rustls_pemfile::certs(&mut data).last();
|
||||
Ok(cert.ok_or(Error::MissingCertificate)??.to_vec())
|
||||
}
|
||||
|
||||
fn run_fail() -> Result<()> {
|
||||
use rocket::fairing::AdHoc;
|
||||
|
||||
let server = spawn! {
|
||||
let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) });
|
||||
Rocket::default().attach(fail)
|
||||
};
|
||||
|
||||
if let Err(Error::Liftoff(stdout, _)) = server {
|
||||
assert!(stdout.contains("Rocket failed to launch due to failing fairings"));
|
||||
assert!(stdout.contains("FailNow"));
|
||||
} else {
|
||||
panic!("unexpected result: {client:#?}");
|
||||
panic!("unexpected result: {server:#?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn infinite() -> client::Result<()> {
|
||||
fn infinite() -> Result<()> {
|
||||
use rocket::response::stream::TextStream;
|
||||
|
||||
let mut client = client::start(|token| {
|
||||
let mut server = spawn! {
|
||||
#[get("/")]
|
||||
fn infinite() -> TextStream![&'static str] {
|
||||
TextStream! {
|
||||
|
@ -58,37 +97,358 @@ fn infinite() -> client::Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
token.launch(rocket::build().mount("/", routes![infinite]));
|
||||
})?;
|
||||
Rocket::default().mount("/", routes![infinite])
|
||||
}?;
|
||||
|
||||
client.get("/")?.send()?;
|
||||
client.terminate()?;
|
||||
let stdout = client.read_stdout()?;
|
||||
let client = Client::default();
|
||||
client.get(&server, "/")?.send()?;
|
||||
server.terminate()?;
|
||||
|
||||
let stdout = server.read_stdout()?;
|
||||
assert!(stdout.contains("Rocket has launched on http"));
|
||||
assert!(stdout.contains("GET /"));
|
||||
assert!(stdout.contains("Graceful shutdown completed"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let names = ["run", "run_fail", "infinite"];
|
||||
let tests = [run, run_fail, infinite];
|
||||
let handles = tests.into_iter()
|
||||
.map(|test| std::thread::spawn(test))
|
||||
.collect::<Vec<_>>();
|
||||
fn tls_info() -> Result<()> {
|
||||
let mut server = spawn! {
|
||||
#[get("/")]
|
||||
fn hello_world() -> &'static str {
|
||||
"Hello, world!"
|
||||
}
|
||||
|
||||
let mut failure = false;
|
||||
for (handle, name) in handles.into_iter().zip(names) {
|
||||
let result = handle.join();
|
||||
failure = failure || matches!(result, Ok(Err(_)) | Err(_));
|
||||
match result {
|
||||
Ok(Ok(_)) => continue,
|
||||
Ok(Err(e)) => eprintln!("{name} failed: {e}"),
|
||||
Err(_) => eprintln!("{name} failed (see panic above)"),
|
||||
Rocket::tls_default().mount("/", routes![hello_world])
|
||||
}?;
|
||||
|
||||
let client = Client::default();
|
||||
let response = client.get(&server, "/")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
assert!(!tls.peer_certificate().unwrap().is_empty());
|
||||
assert_eq!(response.text()?, "Hello, world!");
|
||||
|
||||
server.terminate()?;
|
||||
let stdout = server.read_stdout()?;
|
||||
assert!(stdout.contains("Rocket has launched on https"));
|
||||
assert!(stdout.contains("Graceful shutdown completed"));
|
||||
assert!(stdout.contains("GET /"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tls_resolver() -> Result<()> {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig};
|
||||
|
||||
struct CountingResolver {
|
||||
config: Arc<ServerConfig>,
|
||||
counter: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Resolver for CountingResolver {
|
||||
async fn init(rocket: &Rocket<Build>) -> rocket::tls::Result<Self> {
|
||||
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
let config = Arc::new(config.server_config().await?);
|
||||
let counter = rocket.state::<Arc<AtomicUsize>>().unwrap().clone();
|
||||
Ok(Self { config, counter })
|
||||
}
|
||||
|
||||
async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
|
||||
self.counter.fetch_add(1, Ordering::Release);
|
||||
Some(self.config.clone())
|
||||
}
|
||||
}
|
||||
|
||||
if failure {
|
||||
std::process::exit(1);
|
||||
let server = spawn! {
|
||||
#[get("/count")]
|
||||
fn count(counter: &State<Arc<AtomicUsize>>) -> String {
|
||||
counter.load(Ordering::Acquire).to_string()
|
||||
}
|
||||
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
Rocket::tls_default()
|
||||
.manage(counter)
|
||||
.mount("/", routes![count])
|
||||
.attach(CountingResolver::fairing())
|
||||
}?;
|
||||
|
||||
let client = Client::default();
|
||||
let response = client.get(&server, "/count")?.send()?;
|
||||
assert_eq!(response.text()?, "1");
|
||||
|
||||
// Use a new client so we get a new TLS session.
|
||||
let client = Client::default();
|
||||
let response = client.get(&server, "/count")?.send()?;
|
||||
assert_eq!(response.text()?, "2");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_mtls(mandatory: bool) -> Result<()> {
|
||||
let server = spawn!(mandatory: bool => {
|
||||
let mtls_config = format!(r#"
|
||||
[default.tls.mutual]
|
||||
ca_certs = "{{ROCKET}}/examples/tls/private/ca_cert.pem"
|
||||
mandatory = {mandatory}
|
||||
"#);
|
||||
|
||||
#[get("/")]
|
||||
fn hello(cert: rocket::mtls::Certificate<'_>) -> String {
|
||||
format!("{}:{}[{}] {}", cert.serial(), cert.version(), cert.issuer(), cert.subject())
|
||||
}
|
||||
|
||||
#[get("/", rank = 2)]
|
||||
fn hi() -> &'static str {
|
||||
"Hello!"
|
||||
}
|
||||
|
||||
Rocket::tls_default()
|
||||
.configure_with_toml(&mtls_config)
|
||||
.mount("/", routes![hello, hi])
|
||||
})?;
|
||||
|
||||
let pem = read("{ROCKET}/examples/tls/private/client.pem")?;
|
||||
let client: Client = Client::build()
|
||||
.identity(Identity::from_pem(&pem)?)
|
||||
.try_into()?;
|
||||
|
||||
let response = client.get(&server, "/")?.send()?;
|
||||
assert_eq!(response.text()?,
|
||||
"611895682361338926795452113263857440769284805738:2\
|
||||
[C=US, ST=CA, O=Rocket CA, CN=Rocket Root CA] \
|
||||
C=US, ST=California, L=Silicon Valley, O=Rocket, \
|
||||
CN=Rocket TLS Example, Email=example@rocket.local");
|
||||
|
||||
let client = Client::default();
|
||||
let response = client.get(&server, "/")?.send();
|
||||
if mandatory {
|
||||
assert!(response.unwrap_err().is_request());
|
||||
} else {
|
||||
assert_eq!(response?.text()?, "Hello!");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tls_mtls() -> Result<()> {
|
||||
test_mtls(false)?;
|
||||
test_mtls(true)
|
||||
}
|
||||
|
||||
fn sni_resolver() -> Result<()> {
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rocket::http::uri::Host;
|
||||
use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig};
|
||||
|
||||
struct SniResolver {
|
||||
default: Arc<ServerConfig>,
|
||||
map: HashMap<Host<'static>, Arc<ServerConfig>>
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Resolver for SniResolver {
|
||||
async fn init(rocket: &Rocket<Build>) -> rocket::tls::Result<Self> {
|
||||
let default: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
let sni: HashMap<Host<'_>, TlsConfig> = rocket.figment().extract_inner("tls.sni")?;
|
||||
|
||||
let default = Arc::new(default.server_config().await?);
|
||||
let mut map = HashMap::new();
|
||||
for (host, config) in sni {
|
||||
let config = config.server_config().await?;
|
||||
map.insert(host, Arc::new(config));
|
||||
}
|
||||
|
||||
Ok(SniResolver { default, map })
|
||||
}
|
||||
|
||||
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
|
||||
if let Some(Ok(host)) = hello.server_name().map(Host::parse) {
|
||||
if let Some(config) = self.map.get(&host) {
|
||||
return Some(config.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Some(self.default.clone())
|
||||
}
|
||||
}
|
||||
|
||||
static SNI_TLS_CONFIG: &str = r#"
|
||||
[default.tls]
|
||||
certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem"
|
||||
key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem"
|
||||
|
||||
[default.tls.sni."sni1.dev"]
|
||||
certs = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem"
|
||||
key = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_key_pkcs8.pem"
|
||||
|
||||
[default.tls.sni."sni2.dev"]
|
||||
certs = "{ROCKET}/examples/tls/private/ed25519_cert.pem"
|
||||
key = "{ROCKET}/examples/tls/private/ed25519_key.pem"
|
||||
"#;
|
||||
|
||||
let server = spawn! {
|
||||
#[get("/")] fn index() { }
|
||||
|
||||
Rocket::default()
|
||||
.configure_with_toml(SNI_TLS_CONFIG)
|
||||
.mount("/", routes![index])
|
||||
.attach(SniResolver::fairing())
|
||||
}?;
|
||||
|
||||
let client: Client = Client::build()
|
||||
.resolve("unknown.dev", server.socket_addr())
|
||||
.resolve("sni1.dev", server.socket_addr())
|
||||
.resolve("sni2.dev", server.socket_addr())
|
||||
.try_into()?;
|
||||
|
||||
let response = client.get(&server, "https://unknown.dev")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
let expected = cert("{ROCKET}/examples/tls/private/rsa_sha256_cert.pem")?;
|
||||
assert_eq!(tls.peer_certificate().unwrap(), expected);
|
||||
|
||||
let response = client.get(&server, "https://sni1.dev")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
let expected = cert("{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem")?;
|
||||
assert_eq!(tls.peer_certificate().unwrap(), expected);
|
||||
|
||||
let response = client.get(&server, "https://sni2.dev")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
let expected = cert("{ROCKET}/examples/tls/private/ed25519_cert.pem")?;
|
||||
assert_eq!(tls.peer_certificate().unwrap(), expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tcp_unix_listener_fail() -> Result<()> {
|
||||
let server = spawn! {
|
||||
Rocket::default().configure_with_toml("[default]\naddress = 123")
|
||||
};
|
||||
|
||||
if let Err(Error::Liftoff(stdout, _)) = server {
|
||||
assert!(stdout.contains("expected valid TCP (ip) or unix (path)"));
|
||||
assert!(stdout.contains("default.address"));
|
||||
} else {
|
||||
panic!("unexpected result: {server:#?}");
|
||||
}
|
||||
|
||||
let server = Server::spawn((), |(token, _)| {
|
||||
let rocket = Rocket::default().configure_with_toml("[default]\naddress = \"unix:foo\"");
|
||||
token.launch_with::<TcpListener>(rocket)
|
||||
});
|
||||
|
||||
if let Err(Error::Liftoff(stdout, _)) = server {
|
||||
assert!(stdout.contains("invalid tcp endpoint: unix:foo"));
|
||||
} else {
|
||||
panic!("unexpected result: {server:#?}");
|
||||
}
|
||||
|
||||
let server = Server::spawn((), |(token, _)| {
|
||||
token.launch_with::<UnixListener>(Rocket::default())
|
||||
});
|
||||
|
||||
if let Err(Error::Liftoff(stdout, _)) = server {
|
||||
assert!(stdout.contains("invalid unix endpoint: tcp:127.0.0.1:8000"));
|
||||
} else {
|
||||
panic!("unexpected result: {server:#?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
macro_rules! tests {
|
||||
($($f:ident),* $(,)?) => {[
|
||||
$(Test {
|
||||
name: stringify!($f),
|
||||
run: |_: ()| $f().map_err(|e| e.to_string()),
|
||||
}),*
|
||||
]};
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct Test {
|
||||
name: &'static str,
|
||||
run: fn(()) -> Result<(), String>,
|
||||
}
|
||||
|
||||
static TESTS: &[Test] = &tests![
|
||||
run_fail, infinite, tls_info, tls_resolver, tls_mtls, sni_resolver,
|
||||
tcp_unix_listener_fail
|
||||
];
|
||||
|
||||
fn main() -> ExitCode {
|
||||
procspawn::init();
|
||||
|
||||
let filter = std::env::args().nth(1).unwrap_or_default();
|
||||
let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter));
|
||||
|
||||
println!("running {}/{} tests", filtered.clone().count(), TESTS.len());
|
||||
let handles = filtered.map(|test| (test, std::thread::spawn(|| {
|
||||
let name = test.name;
|
||||
let start = std::time::SystemTime::now();
|
||||
let mut proc = procspawn::spawn((), test.run);
|
||||
let result = loop {
|
||||
match proc.join_timeout(Duration::from_secs(10)) {
|
||||
Err(e) if e.is_timeout() => {
|
||||
let elapsed = start.elapsed().unwrap().as_secs();
|
||||
println!("{name} has been running for {elapsed} seconds...");
|
||||
|
||||
if elapsed >= 30 {
|
||||
println!("{name} timeout");
|
||||
break Err(e);
|
||||
}
|
||||
},
|
||||
result => break result,
|
||||
}
|
||||
};
|
||||
|
||||
match result.as_ref().map_err(|e| e.panic_info()) {
|
||||
Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()),
|
||||
Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()),
|
||||
Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()),
|
||||
Err(None) => println!("test {name} ... {}", "error".magenta()),
|
||||
}
|
||||
|
||||
matches!(result, Ok(Ok(())))
|
||||
})));
|
||||
|
||||
let mut success = true;
|
||||
for (_, handle) in handles {
|
||||
success &= handle.join().unwrap_or(false);
|
||||
}
|
||||
|
||||
match success {
|
||||
true => ExitCode::SUCCESS,
|
||||
false => {
|
||||
println!("note: use `NOCAPTURE=1` to see test output");
|
||||
ExitCode::FAILURE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and
|
||||
// `UpdatingResolver` in a `contrib` library or as part of `rocket`.
|
||||
//
|
||||
// struct UpdatingResolver {
|
||||
// timestamp: AtomicU64,
|
||||
// config: ArcSwap<ServerConfig>
|
||||
// }
|
||||
//
|
||||
// #[crate::async_trait]
|
||||
// impl Resolver for UpdatingResolver {
|
||||
// async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
|
||||
// if let Either::Left(path) = self.tls_config.certs() {
|
||||
// let metadata = tokio::fs::metadata(&path).await.ok()?;
|
||||
// let modtime = metadata.modified().ok()?;
|
||||
// let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs();
|
||||
// let old_timestamp = self.timestamp.load(Ordering::Acquire);
|
||||
// if timestamp > old_timestamp {
|
||||
// let new_config = self.tls_config.to_server_config().await.ok()?;
|
||||
// self.server_config.store(Arc::new(new_config));
|
||||
// self.timestamp.store(timestamp, Ordering::Release);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Some(self.server_config.load_full())
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use std::sync::Once;
|
||||
use std::process::Stdio;
|
||||
use std::io::Read;
|
||||
|
||||
use rocket::fairing::AdHoc;
|
||||
use rocket::listener::{Bind, DefaultListener};
|
||||
use rocket::serde::{Deserialize, DeserializeOwned, Serialize};
|
||||
use rocket::{Build, Ignite, Rocket};
|
||||
|
||||
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender};
|
||||
|
||||
use crate::{Result, Error};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Server {
|
||||
proc: procspawn::JoinHandle<Launched>,
|
||||
pub tls: bool,
|
||||
pub port: u16,
|
||||
_rx: IpcReceiver<Message>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub enum Message {
|
||||
Liftoff(bool, u16),
|
||||
Failure,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct Token(String);
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct Launched(());
|
||||
|
||||
fn stdio() -> Stdio {
|
||||
std::env::var_os("NOCAPTURE")
|
||||
.map(|_| Stdio::inherit())
|
||||
.unwrap_or_else(Stdio::piped)
|
||||
}
|
||||
|
||||
fn read<T: Read>(io: Option<T>) -> Result<String> {
|
||||
if let Some(mut io) = io {
|
||||
let mut string = String::new();
|
||||
io.read_to_string(&mut string)?;
|
||||
return Ok(string);
|
||||
}
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn spawn<T>(ctxt: T, f: fn((Token, T)) -> Launched) -> Result<Server>
|
||||
where T: Serialize + DeserializeOwned
|
||||
{
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(procspawn::init);
|
||||
|
||||
let (ipc, server) = IpcOneShotServer::new()?;
|
||||
let mut proc = procspawn::Builder::new()
|
||||
.stdin(Stdio::null())
|
||||
.stdout(stdio())
|
||||
.stderr(stdio())
|
||||
.spawn((Token(server), ctxt), f);
|
||||
|
||||
let (rx, _) = ipc.accept().unwrap();
|
||||
match rx.recv()? {
|
||||
Message::Liftoff(tls, port) => {
|
||||
Ok(Server { proc, tls, port, _rx: rx })
|
||||
},
|
||||
Message::Failure => {
|
||||
Err(Error::Liftoff(read(proc.stdout())?, read(proc.stderr())?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn socket_addr(&self) -> SocketAddr {
|
||||
let ip = Ipv4Addr::LOCALHOST;
|
||||
SocketAddr::new(ip.into(), self.port)
|
||||
}
|
||||
|
||||
pub fn read_stdout(&mut self) -> Result<String> {
|
||||
read(self.proc.stdout())
|
||||
}
|
||||
|
||||
pub fn read_stderr(&mut self) -> Result<String> {
|
||||
read(self.proc.stderr())
|
||||
}
|
||||
|
||||
pub fn kill(&mut self) -> Result<()> {
|
||||
Ok(self.proc.kill()?)
|
||||
}
|
||||
|
||||
pub fn terminate(&mut self) -> Result<()> {
|
||||
use nix::{sys::signal, unistd::Pid};
|
||||
|
||||
let pid = Pid::from_raw(self.proc.pid().unwrap() as i32);
|
||||
Ok(signal::kill(pid, signal::SIGTERM)?)
|
||||
}
|
||||
|
||||
pub fn join(&mut self, duration: Duration) -> Result<()> {
|
||||
match self.proc.join_timeout(duration) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) if e.is_remote_close() => Ok(()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Token {
|
||||
pub fn launch_with<B>(self, rocket: Rocket<Build>) -> Launched
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>> + Sync + Send + 'static
|
||||
{
|
||||
let server = self.0.clone();
|
||||
let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
|
||||
let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap();
|
||||
let tls = rocket.endpoints().any(|e| e.is_tls());
|
||||
let sender = IpcSender::<Message>::connect(server).unwrap();
|
||||
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
|
||||
let _ = sender.send(Message::Liftoff(tls, tcp.port()));
|
||||
})));
|
||||
|
||||
let server = self.0.clone();
|
||||
if let Err(e) = rocket::execute(rocket.launch_with::<B>()) {
|
||||
let sender = IpcSender::<Message>::connect(server).unwrap();
|
||||
let _ = sender.send(Message::Failure);
|
||||
let _ = sender.send(Message::Failure);
|
||||
e.pretty_print();
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Launched(())
|
||||
}
|
||||
|
||||
pub fn launch(self, rocket: Rocket<Build>) -> Launched {
|
||||
self.launch_with::<DefaultListener>(rocket)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Server {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.terminate();
|
||||
if self.join(Duration::from_secs(3)).is_err() {
|
||||
let _ = self.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! spawn {
|
||||
($($arg:ident : $t:ty),* => $rocket:block) => {{
|
||||
#[allow(unused_parens)]
|
||||
fn _server((token, $($arg),*): ($crate::Token, $($t),*)) -> $crate::Launched {
|
||||
let rocket: rocket::Rocket<rocket::Build> = $rocket;
|
||||
token.launch(rocket)
|
||||
}
|
||||
|
||||
Server::spawn(($($arg),*), _server)
|
||||
}};
|
||||
|
||||
($($token:tt)*) => {{
|
||||
let _unit = ();
|
||||
spawn!(_unit: () => { $($token)* } )
|
||||
}};
|
||||
}
|
Loading…
Reference in New Issue