Simplify 'Bind'. Allow try-launching on Futures.

This commit is contained in:
Sergio Benitez 2024-04-22 17:03:18 -07:00
parent 7cc818cd85
commit 3bfc4ca644
9 changed files with 99 additions and 75 deletions

View File

@ -1,10 +1,13 @@
use crate::listener::{Endpoint, Listener}; use std::error::Error;
pub trait Bind<T>: Listener + 'static { use crate::listener::{Endpoint, Listener};
type Error: std::error::Error + Send + 'static; use crate::{Rocket, Ignite};
pub trait Bind: Listener + 'static {
type Error: Error + Send + 'static;
#[crate::async_bound(Send)] #[crate::async_bound(Send)]
async fn bind(to: T) -> Result<Self, Self::Error>; async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error>;
fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>; fn bind_endpoint(to: &Rocket<Ignite>) -> Result<Endpoint, Self::Error>;
} }

View File

@ -74,10 +74,10 @@ pub use private::DefaultListener;
type Connection = crate::listener::tcp::TcpStream; type Connection = crate::listener::tcp::TcpStream;
#[cfg(doc)] #[cfg(doc)]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener { impl Bind for DefaultListener {
type Error = Error; type Error = Error;
async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() } async fn bind(_: &Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() } fn bind_endpoint(_: &Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
} }
#[cfg(doc)] #[cfg(doc)]
@ -96,36 +96,36 @@ impl super::Listener for DefaultListener {
pub type DefaultListener = private::Listener; pub type DefaultListener = private::Listener;
#[cfg(not(doc))] #[cfg(not(doc))]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener { impl Bind for DefaultListener {
type Error = Error; type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> { async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?; let config: Config = rocket.figment().extract()?;
match config.address { match config.address {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => { Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?; let listener = <TlsListener<TcpListener> as Bind>::bind(rocket).await?;
Ok(Left(Left(listener))) Ok(Left(Left(listener)))
} }
Endpoint::Tcp(_) => { Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket).await?; let listener = <TcpListener as Bind>::bind(rocket).await?;
Ok(Right(Left(listener))) Ok(Right(Left(listener)))
} }
#[cfg(all(unix, feature = "tls"))] #[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => { Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?; let listener = <TlsListener<UnixListener> as Bind>::bind(rocket).await?;
Ok(Left(Right(listener))) Ok(Left(Right(listener)))
} }
#[cfg(unix)] #[cfg(unix)]
Endpoint::Unix(_) => { Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket).await?; let listener = <UnixListener as Bind>::bind(rocket).await?;
Ok(Right(Right(listener))) Ok(Right(Right(listener)))
} }
endpoint => Err(Error::Unsupported(endpoint)), endpoint => Err(Error::Unsupported(endpoint)),
} }
} }
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> { fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: Config = rocket.figment().extract()?; let config: Config = rocket.figment().extract()?;
Ok(config.address) Ok(config.address)
} }

View File

@ -20,22 +20,10 @@ pub use tokio::net::{TcpListener, TcpStream};
use crate::{Ignite, Rocket}; use crate::{Ignite, Rocket};
use crate::listener::{Bind, Connection, Endpoint, Listener}; use crate::listener::{Bind, Connection, Endpoint, Listener};
impl Bind<SocketAddr> for TcpListener { impl Bind for TcpListener {
type Error = std::io::Error;
async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
Self::bind(addr).await
}
fn bind_endpoint(addr: &SocketAddr) -> Result<Endpoint, Self::Error> {
Ok(Endpoint::Tcp(*addr))
}
}
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
type Error = Either<figment::Error, io::Error>; type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> { async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?; let endpoint = Self::bind_endpoint(&rocket)?;
let addr = endpoint.tcp() let addr = endpoint.tcp()
.ok_or_else(|| io::Error::other("internal error: invalid endpoint")) .ok_or_else(|| io::Error::other("internal error: invalid endpoint"))
@ -44,7 +32,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
Self::bind(addr).await.map_err(Right) Self::bind(addr).await.map_err(Right)
} }
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> { fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let figment = rocket.figment(); let figment = rocket.figment();
let mut address = Endpoint::fetch(figment, "tcp", "address", |e| { let mut address = Endpoint::fetch(figment, "tcp", "address", |e| {
let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000); let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000);

View File

@ -71,10 +71,10 @@ impl UnixListener {
} }
} }
impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener { impl Bind for UnixListener {
type Error = Either<figment::Error, io::Error>; type Error = Either<figment::Error, io::Error>;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> { async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let endpoint = Self::bind_endpoint(&rocket)?; let endpoint = Self::bind_endpoint(&rocket)?;
let path = endpoint.unix() let path = endpoint.unix()
.ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?; .ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?;
@ -83,7 +83,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?) Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?)
} }
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> { fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf())); let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf()));
Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf) Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf)
.map(Endpoint::Unix) .map(Endpoint::Unix)

View File

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::any::Any; use std::any::Any;
use futures::TryFutureExt; use futures::{Future, TryFutureExt};
use yansi::Paint; use yansi::Paint;
use either::Either; use either::Either;
use figment::{Figment, Provider}; use figment::{Figment, Provider};
@ -682,9 +682,7 @@ impl Rocket<Ignite> {
rocket rocket
} }
async fn _launch_with<B>(self) -> Result<Rocket<Ignite>, Error> async fn _launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
where B: for<'r> Bind<&'r Rocket<Ignite>>
{
let bind_endpoint = B::bind_endpoint(&&self).ok(); let bind_endpoint = B::bind_endpoint(&&self).ok();
let listener: B = B::bind(&self).await let listener: B = B::bind(&self).await
.map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?; .map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?;
@ -1015,15 +1013,7 @@ impl<P: Phase> Rocket<P> {
self.launch_with::<DefaultListener>().await self.launch_with::<DefaultListener>().await
} }
pub async fn bind_launch<T, B: Bind<T>>(self, value: T) -> Result<Rocket<Ignite>, Error> { pub async fn launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
let endpoint = B::bind_endpoint(&value).ok();
let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e)));
self.launch_on(listener.await?).await
}
pub async fn launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
where B: for<'r> Bind<&'r Rocket<Ignite>>
{
match self.0.into_state() { match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await, State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await,
State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await, State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await,
@ -1031,6 +1021,15 @@ impl<P: Phase> Rocket<P> {
} }
} }
pub async fn try_launch_on<L, F, E>(self, listener: F) -> Result<Rocket<Ignite>, Error>
where L: Listener + 'static,
F: Future<Output = Result<L, E>>,
E: std::error::Error + Send + 'static
{
let listener = listener.map_err(|e| ErrorKind::Bind(None, Box::new(e))).await?;
self.launch_on(listener).await
}
pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error> pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
where L: Listener + 'static, where L: Listener + 'static,
{ {

View File

@ -8,7 +8,7 @@ use rustls::server::{Acceptor, ServerConfig};
use crate::{Ignite, Rocket}; use crate::{Ignite, Rocket};
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener}; use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
use crate::tls::{Error, TlsConfig}; use crate::tls::{TlsConfig, Result, Error};
use super::resolver::DynResolver; use super::resolver::DynResolver;
#[doc(inline)] #[doc(inline)]
@ -21,40 +21,35 @@ pub struct TlsListener<L> {
default: Arc<ServerConfig>, default: Arc<ServerConfig>,
} }
impl<T: Send, L: Bind<T>> Bind<(T, TlsConfig)> for TlsListener<L> impl<L> TlsListener<L>
where L: Listener<Accept = <L as Listener>::Connection>, where L: Listener<Accept = <L as Listener>::Connection>,
{ {
type Error = Error; pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
async fn bind((inner, config): (T, TlsConfig)) -> Result<Self, Self::Error> {
Ok(TlsListener { Ok(TlsListener {
default: Arc::new(config.server_config().await?), default: Arc::new(config.server_config().await?),
listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?, listener,
config, config,
}) })
} }
fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result<Endpoint, Self::Error> {
L::bind_endpoint(inner)
.map(|e| e.with_tls(config))
.map_err(|e| Error::Bind(Box::new(e)))
}
} }
impl<'r, L> Bind<&'r Rocket<Ignite>> for TlsListener<L> impl<L: Bind> Bind for TlsListener<L>
where L: Bind<&'r Rocket<Ignite>> + Listener<Accept = <L as Listener>::Connection> where L: Listener<Accept = <L as Listener>::Connection>
{ {
type Error = Error; type Error = Error;
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> { async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
let mut config: TlsConfig = rocket.figment().extract_inner("tls")?; let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
config.resolver = DynResolver::extract(rocket); config.resolver = DynResolver::extract(rocket);
<Self as Bind<_>>::bind((rocket, config)).await Self::from(listener, config).await
} }
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> { fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: TlsConfig = rocket.figment().extract_inner("tls")?; let config: TlsConfig = rocket.figment().extract_inner("tls")?;
<Self as Bind<_>>::bind_endpoint(&(*rocket, config)) L::bind_endpoint(rocket)
.map(|e| e.with_tls(&config))
.map_err(|e| Error::Bind(Box::new(e)))
} }
} }

View File

@ -17,6 +17,6 @@ async fn on_ignite_fairing_can_inspect_port() {
})); }));
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr)); rocket::tokio::spawn(rocket.try_launch_on(TcpListener::bind(addr)));
assert_ne!(rx.await.unwrap(), 0); assert_ne!(rx.await.unwrap(), 0);
} }

View File

@ -1,11 +1,15 @@
use std::net::{Ipv4Addr, SocketAddr};
use std::process::ExitCode; use std::process::ExitCode;
use std::time::Duration; use std::time::Duration;
use rocket::listener::unix::UnixListener;
use rocket::tokio::net::TcpListener; use rocket::tokio::net::TcpListener;
use rocket::yansi::Paint; use rocket::yansi::Paint;
use rocket::{get, routes, Build, Rocket, State}; use rocket::{get, routes, Build, Rocket, State};
use rocket::listener::{unix::UnixListener, Endpoint};
use rocket::tls::TlsListener;
use reqwest::{tls::TlsInfo, Identity}; use reqwest::{tls::TlsInfo, Identity};
use testbench::*; use testbench::*;
static DEFAULT_CONFIG: &str = r#" static DEFAULT_CONFIG: &str = r#"
@ -112,12 +116,12 @@ fn infinite() -> Result<()> {
} }
fn tls_info() -> Result<()> { fn tls_info() -> Result<()> {
let mut server = spawn! {
#[get("/")] #[get("/")]
fn hello_world() -> &'static str { fn hello_world(endpoint: &Endpoint) -> String {
"Hello, world!" format!("Hello, {endpoint}!")
} }
let mut server = spawn! {
Rocket::tls_default().mount("/", routes![hello_world]) Rocket::tls_default().mount("/", routes![hello_world])
}?; }?;
@ -125,13 +129,35 @@ fn tls_info() -> Result<()> {
let response = client.get(&server, "/")?.send()?; let response = client.get(&server, "/")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap(); let tls = response.extensions().get::<TlsInfo>().unwrap();
assert!(!tls.peer_certificate().unwrap().is_empty()); assert!(!tls.peer_certificate().unwrap().is_empty());
assert_eq!(response.text()?, "Hello, world!"); assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
server.terminate()?; server.terminate()?;
let stdout = server.read_stdout()?; let stdout = server.read_stdout()?;
assert!(stdout.contains("Rocket has launched on https")); assert!(stdout.contains("Rocket has launched on https"));
assert!(stdout.contains("Graceful shutdown completed")); assert!(stdout.contains("Graceful shutdown completed"));
assert!(stdout.contains("GET /")); assert!(stdout.contains("GET /"));
let server = Server::spawn((), |(token, _)| {
let rocket = rocket::build()
.configure_with_toml(TLS_CONFIG)
.mount("/", routes![hello_world]);
token.with_launch(rocket, |rocket| {
let config = rocket.figment().extract_inner("tls");
rocket.try_launch_on(async move {
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = TcpListener::bind(addr).await?;
TlsListener::from(listener, config?).await
})
})
}).unwrap();
let client = Client::default();
let response = client.get(&server, "/")?.send()?;
let tls = response.extensions().get::<TlsInfo>().unwrap();
assert!(!tls.peer_certificate().unwrap().is_empty());
assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
Ok(()) Ok(())
} }

View File

@ -1,3 +1,4 @@
use std::future::Future;
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration; use std::time::Duration;
use std::sync::Once; use std::sync::Once;
@ -111,8 +112,9 @@ impl Server {
} }
impl Token { impl Token {
pub fn launch_with<B>(self, rocket: Rocket<Build>) -> Launched pub fn with_launch<F, Fut>(self, rocket: Rocket<Build>, launch: F) -> Launched
where B: for<'r> Bind<&'r Rocket<Ignite>> + Sync + Send + 'static where F: FnOnce(Rocket<Ignite>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Rocket<Ignite>, rocket::Error>> + Send
{ {
let server = self.0.clone(); let server = self.0.clone();
let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move { let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
@ -124,7 +126,12 @@ impl Token {
}))); })));
let server = self.0.clone(); let server = self.0.clone();
if let Err(e) = rocket::execute(rocket.launch_with::<B>()) { let launch = async move {
let rocket = rocket.ignite().await?;
launch(rocket).await
};
if let Err(e) = rocket::execute(launch) {
let sender = IpcSender::<Message>::connect(server).unwrap(); let sender = IpcSender::<Message>::connect(server).unwrap();
let _ = sender.send(Message::Failure); let _ = sender.send(Message::Failure);
let _ = sender.send(Message::Failure); let _ = sender.send(Message::Failure);
@ -135,6 +142,12 @@ impl Token {
Launched(()) Launched(())
} }
pub fn launch_with<B: Bind>(self, rocket: Rocket<Build>) -> Launched
where B: Send + Sync + 'static
{
self.with_launch(rocket, |rocket| rocket.launch_with::<B>())
}
pub fn launch(self, rocket: Rocket<Build>) -> Launched { pub fn launch(self, rocket: Rocket<Build>) -> Launched {
self.launch_with::<DefaultListener>(rocket) self.launch_with::<DefaultListener>(rocket)
} }