Simplify 'Bind'. Allow try-launching on Futures.

This commit is contained in:
Sergio Benitez 2024-04-22 17:03:18 -07:00
parent 7cc818cd85
commit 3bfc4ca644
9 changed files with 99 additions and 75 deletions

View File

@ -1,10 +1,13 @@
use crate::listener::{Endpoint, Listener};
use std::error::Error;
pub trait Bind<T>: Listener + 'static {
type Error: std::error::Error + Send + 'static;
use crate::listener::{Endpoint, Listener};
use crate::{Rocket, Ignite};
pub trait Bind: Listener + 'static {
type Error: Error + Send + 'static;
#[crate::async_bound(Send)]
async fn bind(to: T) -> Result<Self, Self::Error>;
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error>;
fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
fn bind_endpoint(to: &Rocket<Ignite>) -> Result<Endpoint, Self::Error>;
}

View File

@ -74,10 +74,10 @@ pub use private::DefaultListener;
type Connection = crate::listener::tcp::TcpStream;
#[cfg(doc)]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
impl Bind for DefaultListener {
type Error = Error;
async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
async fn bind(_: &Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
fn bind_endpoint(_: &Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
}
#[cfg(doc)]
@ -96,36 +96,36 @@ impl super::Listener for DefaultListener {
pub type DefaultListener = private::Listener;
#[cfg(not(doc))]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
impl Bind for DefaultListener {
type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
let listener = <TlsListener<TcpListener> as Bind>::bind(rocket).await?;
Ok(Left(Left(listener)))
}
Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
let listener = <TcpListener as Bind>::bind(rocket).await?;
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
let listener = <TlsListener<UnixListener> as Bind>::bind(rocket).await?;
Ok(Left(Right(listener)))
}
#[cfg(unix)]
Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
let listener = <UnixListener as Bind>::bind(rocket).await?;
Ok(Right(Right(listener)))
}
endpoint => Err(Error::Unsupported(endpoint)),
}
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: Config = rocket.figment().extract()?;
Ok(config.address)
}

View File

@ -20,22 +20,10 @@ pub use tokio::net::{TcpListener, TcpStream};
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Connection, Endpoint, Listener};
impl Bind<SocketAddr> for TcpListener {
type Error = std::io::Error;
async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
Self::bind(addr).await
}
fn bind_endpoint(addr: &SocketAddr) -> Result<Endpoint, Self::Error> {
Ok(Endpoint::Tcp(*addr))
}
}
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
impl Bind for TcpListener {
type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?;
let addr = endpoint.tcp()
.ok_or_else(|| io::Error::other("internal error: invalid endpoint"))
@ -44,7 +32,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
Self::bind(addr).await.map_err(Right)
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let figment = rocket.figment();
let mut address = Endpoint::fetch(figment, "tcp", "address", |e| {
let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000);

View File

@ -71,10 +71,10 @@ impl UnixListener {
}
}
impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
impl Bind for UnixListener {
type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?;
let path = endpoint.unix()
.ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?;
@ -83,7 +83,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?)
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf()));
Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf)
.map(Endpoint::Unix)

View File

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::any::Any;
use futures::TryFutureExt;
use futures::{Future, TryFutureExt};
use yansi::Paint;
use either::Either;
use figment::{Figment, Provider};
@ -682,9 +682,7 @@ impl Rocket<Ignite> {
rocket
}
async fn _launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
where B: for<'r> Bind<&'r Rocket<Ignite>>
{
async fn _launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
let bind_endpoint = B::bind_endpoint(&&self).ok();
let listener: B = B::bind(&self).await
.map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?;
@ -1015,15 +1013,7 @@ impl<P: Phase> Rocket<P> {
self.launch_with::<DefaultListener>().await
}
pub async fn bind_launch<T, B: Bind<T>>(self, value: T) -> Result<Rocket<Ignite>, Error> {
let endpoint = B::bind_endpoint(&value).ok();
let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e)));
self.launch_on(listener.await?).await
}
pub async fn launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
where B: for<'r> Bind<&'r Rocket<Ignite>>
{
pub async fn launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await,
State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await,
@ -1031,6 +1021,15 @@ impl<P: Phase> Rocket<P> {
}
}
pub async fn try_launch_on<L, F, E>(self, listener: F) -> Result<Rocket<Ignite>, Error>
where L: Listener + 'static,
F: Future<Output = Result<L, E>>,
E: std::error::Error + Send + 'static
{
let listener = listener.map_err(|e| ErrorKind::Bind(None, Box::new(e))).await?;
self.launch_on(listener).await
}
pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
where L: Listener + 'static,
{

View File

@ -8,7 +8,7 @@ use rustls::server::{Acceptor, ServerConfig};
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
use crate::tls::{Error, TlsConfig};
use crate::tls::{TlsConfig, Result, Error};
use super::resolver::DynResolver;
#[doc(inline)]
@ -21,40 +21,35 @@ pub struct TlsListener<L> {
default: Arc<ServerConfig>,
}
impl<T: Send, L: Bind<T>> Bind<(T, TlsConfig)> for TlsListener<L>
impl<L> TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>,
{
type Error = Error;
async fn bind((inner, config): (T, TlsConfig)) -> Result<Self, Self::Error> {
pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
Ok(TlsListener {
default: Arc::new(config.server_config().await?),
listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?,
listener,
config,
})
}
fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result<Endpoint, Self::Error> {
L::bind_endpoint(inner)
.map(|e| e.with_tls(config))
.map_err(|e| Error::Bind(Box::new(e)))
}
}
impl<'r, L> Bind<&'r Rocket<Ignite>> for TlsListener<L>
where L: Bind<&'r Rocket<Ignite>> + Listener<Accept = <L as Listener>::Connection>
impl<L: Bind> Bind for TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>
{
type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
config.resolver = DynResolver::extract(rocket);
<Self as Bind<_>>::bind((rocket, config)).await
Self::from(listener, config).await
}
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
<Self as Bind<_>>::bind_endpoint(&(*rocket, config))
L::bind_endpoint(rocket)
.map(|e| e.with_tls(&config))
.map_err(|e| Error::Bind(Box::new(e)))
}
}

View File

@ -17,6 +17,6 @@ async fn on_ignite_fairing_can_inspect_port() {
}));
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr));
rocket::tokio::spawn(rocket.try_launch_on(TcpListener::bind(addr)));
assert_ne!(rx.await.unwrap(), 0);
}

View File

@ -1,11 +1,15 @@
use std::net::{Ipv4Addr, SocketAddr};
use std::process::ExitCode;
use std::time::Duration;
use rocket::listener::unix::UnixListener;
use rocket::tokio::net::TcpListener;
use rocket::yansi::Paint;
use rocket::{get, routes, Build, Rocket, State};
use rocket::listener::{unix::UnixListener, Endpoint};
use rocket::tls::TlsListener;
use reqwest::{tls::TlsInfo, Identity};
use testbench::*;
static DEFAULT_CONFIG: &str = r#"
@ -112,12 +116,12 @@ fn infinite() -> Result<()> {
}
fn tls_info() -> Result<()> {
let mut server = spawn! {
#[get("/")]
fn hello_world() -> &'static str {
"Hello, world!"
}
#[get("/")]
fn hello_world(endpoint: &Endpoint) -> String {
format!("Hello, {endpoint}!")
}
let mut server = spawn! {
Rocket::tls_default().mount("/", routes![hello_world])
}?;
@ -125,13 +129,35 @@ fn tls_info() -> Result<()> {
let response = client.get(&server, "/")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
assert!(!tls.peer_certificate().unwrap().is_empty());
assert_eq!(response.text()?, "Hello, world!");
assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
server.terminate()?;
let stdout = server.read_stdout()?;
assert!(stdout.contains("Rocket has launched on https"));
assert!(stdout.contains("Graceful shutdown completed"));
assert!(stdout.contains("GET /"));
let server = Server::spawn((), |(token, _)| {
let rocket = rocket::build()
.configure_with_toml(TLS_CONFIG)
.mount("/", routes![hello_world]);
token.with_launch(rocket, |rocket| {
let config = rocket.figment().extract_inner("tls");
rocket.try_launch_on(async move {
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = TcpListener::bind(addr).await?;
TlsListener::from(listener, config?).await
})
})
}).unwrap();
let client = Client::default();
let response = client.get(&server, "/")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
assert!(!tls.peer_certificate().unwrap().is_empty());
assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
Ok(())
}

View File

@ -1,3 +1,4 @@
use std::future::Future;
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use std::sync::Once;
@ -111,8 +112,9 @@ impl Server {
}
impl Token {
pub fn launch_with<B>(self, rocket: Rocket<Build>) -> Launched
where B: for<'r> Bind<&'r Rocket<Ignite>> + Sync + Send + 'static
pub fn with_launch<F, Fut>(self, rocket: Rocket<Build>, launch: F) -> Launched
where F: FnOnce(Rocket<Ignite>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Rocket<Ignite>, rocket::Error>> + Send
{
let server = self.0.clone();
let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
@ -124,7 +126,12 @@ impl Token {
})));
let server = self.0.clone();
if let Err(e) = rocket::execute(rocket.launch_with::<B>()) {
let launch = async move {
let rocket = rocket.ignite().await?;
launch(rocket).await
};
if let Err(e) = rocket::execute(launch) {
let sender = IpcSender::<Message>::connect(server).unwrap();
let _ = sender.send(Message::Failure);
let _ = sender.send(Message::Failure);
@ -135,6 +142,12 @@ impl Token {
Launched(())
}
pub fn launch_with<B: Bind>(self, rocket: Rocket<Build>) -> Launched
where B: Send + Sync + 'static
{
self.with_launch(rocket, |rocket| rocket.launch_with::<B>())
}
pub fn launch(self, rocket: Rocket<Build>) -> Launched {
self.launch_with::<DefaultListener>(rocket)
}