From 3bfc4ca644630d43d73011d4bfd495e7de1b6512 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Mon, 22 Apr 2024 17:03:18 -0700 Subject: [PATCH] Simplify 'Bind'. Allow try-launching on Futures. --- core/lib/src/listener/bind.rs | 13 +++--- core/lib/src/listener/default.rs | 20 +++++----- core/lib/src/listener/tcp.rs | 18 ++------- core/lib/src/listener/unix.rs | 6 +-- core/lib/src/rocket.rs | 25 ++++++------ core/lib/src/tls/listener.rs | 31 ++++++-------- .../on_launch_fairing_can_inspect_port.rs | 2 +- testbench/src/main.rs | 40 +++++++++++++++---- testbench/src/server.rs | 19 +++++++-- 9 files changed, 99 insertions(+), 75 deletions(-) diff --git a/core/lib/src/listener/bind.rs b/core/lib/src/listener/bind.rs index 67e4cf7d..72193af6 100644 --- a/core/lib/src/listener/bind.rs +++ b/core/lib/src/listener/bind.rs @@ -1,10 +1,13 @@ -use crate::listener::{Endpoint, Listener}; +use std::error::Error; -pub trait Bind: Listener + 'static { - type Error: std::error::Error + Send + 'static; +use crate::listener::{Endpoint, Listener}; +use crate::{Rocket, Ignite}; + +pub trait Bind: Listener + 'static { + type Error: Error + Send + 'static; #[crate::async_bound(Send)] - async fn bind(to: T) -> Result; + async fn bind(rocket: &Rocket) -> Result; - fn bind_endpoint(to: &T) -> Result; + fn bind_endpoint(to: &Rocket) -> Result; } diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs index 24d39715..7d3a6111 100644 --- a/core/lib/src/listener/default.rs +++ b/core/lib/src/listener/default.rs @@ -74,10 +74,10 @@ pub use private::DefaultListener; type Connection = crate::listener::tcp::TcpStream; #[cfg(doc)] -impl<'r> Bind<&'r Rocket> for DefaultListener { +impl Bind for DefaultListener { type Error = Error; - async fn bind(_: &'r Rocket) -> Result { unreachable!() } - fn bind_endpoint(_: &&'r Rocket) -> Result { unreachable!() } + async fn bind(_: &Rocket) -> Result { unreachable!() } + fn bind_endpoint(_: &Rocket) -> Result { unreachable!() } } #[cfg(doc)] @@ -96,36 +96,36 @@ impl super::Listener for DefaultListener { pub type DefaultListener = private::Listener; #[cfg(not(doc))] -impl<'r> Bind<&'r Rocket> for DefaultListener { +impl Bind for DefaultListener { type Error = Error; - async fn bind(rocket: &'r Rocket) -> Result { + async fn bind(rocket: &Rocket) -> Result { let config: Config = rocket.figment().extract()?; match config.address { #[cfg(feature = "tls")] Endpoint::Tcp(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket).await?; + let listener = as Bind>::bind(rocket).await?; Ok(Left(Left(listener))) } Endpoint::Tcp(_) => { - let listener = >::bind(rocket).await?; + let listener = ::bind(rocket).await?; Ok(Right(Left(listener))) } #[cfg(all(unix, feature = "tls"))] Endpoint::Unix(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket).await?; + let listener = as Bind>::bind(rocket).await?; Ok(Left(Right(listener))) } #[cfg(unix)] Endpoint::Unix(_) => { - let listener = >::bind(rocket).await?; + let listener = ::bind(rocket).await?; Ok(Right(Right(listener))) } endpoint => Err(Error::Unsupported(endpoint)), } } - fn bind_endpoint(rocket: &&'r Rocket) -> Result { + fn bind_endpoint(rocket: &Rocket) -> Result { let config: Config = rocket.figment().extract()?; Ok(config.address) } diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs index 09348cba..ceeaa445 100644 --- a/core/lib/src/listener/tcp.rs +++ b/core/lib/src/listener/tcp.rs @@ -20,22 +20,10 @@ pub use tokio::net::{TcpListener, TcpStream}; use crate::{Ignite, Rocket}; use crate::listener::{Bind, Connection, Endpoint, Listener}; -impl Bind for TcpListener { - type Error = std::io::Error; - - async fn bind(addr: SocketAddr) -> Result { - Self::bind(addr).await - } - - fn bind_endpoint(addr: &SocketAddr) -> Result { - Ok(Endpoint::Tcp(*addr)) - } -} - -impl<'r> Bind<&'r Rocket> for TcpListener { +impl Bind for TcpListener { type Error = Either; - async fn bind(rocket: &'r Rocket) -> Result { + async fn bind(rocket: &Rocket) -> Result { let endpoint = Self::bind_endpoint(&rocket)?; let addr = endpoint.tcp() .ok_or_else(|| io::Error::other("internal error: invalid endpoint")) @@ -44,7 +32,7 @@ impl<'r> Bind<&'r Rocket> for TcpListener { Self::bind(addr).await.map_err(Right) } - fn bind_endpoint(rocket: &&'r Rocket) -> Result { + fn bind_endpoint(rocket: &Rocket) -> Result { let figment = rocket.figment(); let mut address = Endpoint::fetch(figment, "tcp", "address", |e| { let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000); diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs index dac1faec..a6992a7a 100644 --- a/core/lib/src/listener/unix.rs +++ b/core/lib/src/listener/unix.rs @@ -71,10 +71,10 @@ impl UnixListener { } } -impl<'r> Bind<&'r Rocket> for UnixListener { +impl Bind for UnixListener { type Error = Either; - async fn bind(rocket: &'r Rocket) -> Result { + async fn bind(rocket: &Rocket) -> Result { let endpoint = Self::bind_endpoint(&rocket)?; let path = endpoint.unix() .ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?; @@ -83,7 +83,7 @@ impl<'r> Bind<&'r Rocket> for UnixListener { Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?) } - fn bind_endpoint(rocket: &&'r Rocket) -> Result { + fn bind_endpoint(rocket: &Rocket) -> Result { let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf())); Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf) .map(Endpoint::Unix) diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index d7e0adda..fa192815 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::time::Duration; use std::any::Any; -use futures::TryFutureExt; +use futures::{Future, TryFutureExt}; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; @@ -682,9 +682,7 @@ impl Rocket { rocket } - async fn _launch_with(self) -> Result, Error> - where B: for<'r> Bind<&'r Rocket> - { + async fn _launch_with(self) -> Result, Error> { let bind_endpoint = B::bind_endpoint(&&self).ok(); let listener: B = B::bind(&self).await .map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?; @@ -1015,15 +1013,7 @@ impl Rocket

{ self.launch_with::().await } - pub async fn bind_launch>(self, value: T) -> Result, Error> { - let endpoint = B::bind_endpoint(&value).ok(); - let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e))); - self.launch_on(listener.await?).await - } - - pub async fn launch_with(self) -> Result, Error> - where B: for<'r> Bind<&'r Rocket> - { + pub async fn launch_with(self) -> Result, Error> { match self.0.into_state() { State::Build(s) => Rocket::from(s).ignite().await?._launch_with::().await, State::Ignite(s) => Rocket::from(s)._launch_with::().await, @@ -1031,6 +1021,15 @@ impl Rocket

{ } } + pub async fn try_launch_on(self, listener: F) -> Result, Error> + where L: Listener + 'static, + F: Future>, + E: std::error::Error + Send + 'static + { + let listener = listener.map_err(|e| ErrorKind::Bind(None, Box::new(e))).await?; + self.launch_on(listener).await + } + pub async fn launch_on(self, listener: L) -> Result, Error> where L: Listener + 'static, { diff --git a/core/lib/src/tls/listener.rs b/core/lib/src/tls/listener.rs index aeda6afe..3bdf3f56 100644 --- a/core/lib/src/tls/listener.rs +++ b/core/lib/src/tls/listener.rs @@ -8,7 +8,7 @@ use rustls::server::{Acceptor, ServerConfig}; use crate::{Ignite, Rocket}; use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener}; -use crate::tls::{Error, TlsConfig}; +use crate::tls::{TlsConfig, Result, Error}; use super::resolver::DynResolver; #[doc(inline)] @@ -21,40 +21,35 @@ pub struct TlsListener { default: Arc, } -impl> Bind<(T, TlsConfig)> for TlsListener +impl TlsListener where L: Listener::Connection>, { - type Error = Error; - - async fn bind((inner, config): (T, TlsConfig)) -> Result { + pub async fn from(listener: L, config: TlsConfig) -> Result> { Ok(TlsListener { default: Arc::new(config.server_config().await?), - listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?, + listener, config, }) } - - fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result { - L::bind_endpoint(inner) - .map(|e| e.with_tls(config)) - .map_err(|e| Error::Bind(Box::new(e))) - } } -impl<'r, L> Bind<&'r Rocket> for TlsListener - where L: Bind<&'r Rocket> + Listener::Connection> +impl Bind for TlsListener + where L: Listener::Connection> { type Error = Error; - async fn bind(rocket: &'r Rocket) -> Result { + async fn bind(rocket: &Rocket) -> Result { + let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?; let mut config: TlsConfig = rocket.figment().extract_inner("tls")?; config.resolver = DynResolver::extract(rocket); - >::bind((rocket, config)).await + Self::from(listener, config).await } - fn bind_endpoint(rocket: &&'r Rocket) -> Result { + fn bind_endpoint(rocket: &Rocket) -> Result { let config: TlsConfig = rocket.figment().extract_inner("tls")?; - >::bind_endpoint(&(*rocket, config)) + L::bind_endpoint(rocket) + .map(|e| e.with_tls(&config)) + .map_err(|e| Error::Bind(Box::new(e))) } } diff --git a/core/lib/tests/on_launch_fairing_can_inspect_port.rs b/core/lib/tests/on_launch_fairing_can_inspect_port.rs index a9c42cdc..d5ccc571 100644 --- a/core/lib/tests/on_launch_fairing_can_inspect_port.rs +++ b/core/lib/tests/on_launch_fairing_can_inspect_port.rs @@ -17,6 +17,6 @@ async fn on_ignite_fairing_can_inspect_port() { })); let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); - rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr)); + rocket::tokio::spawn(rocket.try_launch_on(TcpListener::bind(addr))); assert_ne!(rx.await.unwrap(), 0); } diff --git a/testbench/src/main.rs b/testbench/src/main.rs index 093099fe..da6fc3eb 100644 --- a/testbench/src/main.rs +++ b/testbench/src/main.rs @@ -1,11 +1,15 @@ +use std::net::{Ipv4Addr, SocketAddr}; use std::process::ExitCode; use std::time::Duration; -use rocket::listener::unix::UnixListener; use rocket::tokio::net::TcpListener; use rocket::yansi::Paint; use rocket::{get, routes, Build, Rocket, State}; +use rocket::listener::{unix::UnixListener, Endpoint}; +use rocket::tls::TlsListener; + use reqwest::{tls::TlsInfo, Identity}; + use testbench::*; static DEFAULT_CONFIG: &str = r#" @@ -112,12 +116,12 @@ fn infinite() -> Result<()> { } fn tls_info() -> Result<()> { - let mut server = spawn! { - #[get("/")] - fn hello_world() -> &'static str { - "Hello, world!" - } + #[get("/")] + fn hello_world(endpoint: &Endpoint) -> String { + format!("Hello, {endpoint}!") + } + let mut server = spawn! { Rocket::tls_default().mount("/", routes![hello_world]) }?; @@ -125,13 +129,35 @@ fn tls_info() -> Result<()> { let response = client.get(&server, "/")?.send()?; let tls = response.extensions().get::().unwrap(); assert!(!tls.peer_certificate().unwrap().is_empty()); - assert_eq!(response.text()?, "Hello, world!"); + assert!(response.text()?.starts_with("Hello, https://127.0.0.1")); server.terminate()?; let stdout = server.read_stdout()?; assert!(stdout.contains("Rocket has launched on https")); assert!(stdout.contains("Graceful shutdown completed")); assert!(stdout.contains("GET /")); + + let server = Server::spawn((), |(token, _)| { + let rocket = rocket::build() + .configure_with_toml(TLS_CONFIG) + .mount("/", routes![hello_world]); + + token.with_launch(rocket, |rocket| { + let config = rocket.figment().extract_inner("tls"); + rocket.try_launch_on(async move { + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0); + let listener = TcpListener::bind(addr).await?; + TlsListener::from(listener, config?).await + }) + }) + }).unwrap(); + + let client = Client::default(); + let response = client.get(&server, "/")?.send()?; + let tls = response.extensions().get::().unwrap(); + assert!(!tls.peer_certificate().unwrap().is_empty()); + assert!(response.text()?.starts_with("Hello, https://127.0.0.1")); + Ok(()) } diff --git a/testbench/src/server.rs b/testbench/src/server.rs index 13b40c3e..fb15589d 100644 --- a/testbench/src/server.rs +++ b/testbench/src/server.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::net::{Ipv4Addr, SocketAddr}; use std::time::Duration; use std::sync::Once; @@ -111,8 +112,9 @@ impl Server { } impl Token { - pub fn launch_with(self, rocket: Rocket) -> Launched - where B: for<'r> Bind<&'r Rocket> + Sync + Send + 'static + pub fn with_launch(self, rocket: Rocket, launch: F) -> Launched + where F: FnOnce(Rocket) -> Fut + Send + Sync + 'static, + Fut: Future, rocket::Error>> + Send { let server = self.0.clone(); let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move { @@ -124,7 +126,12 @@ impl Token { }))); let server = self.0.clone(); - if let Err(e) = rocket::execute(rocket.launch_with::()) { + let launch = async move { + let rocket = rocket.ignite().await?; + launch(rocket).await + }; + + if let Err(e) = rocket::execute(launch) { let sender = IpcSender::::connect(server).unwrap(); let _ = sender.send(Message::Failure); let _ = sender.send(Message::Failure); @@ -135,6 +142,12 @@ impl Token { Launched(()) } + pub fn launch_with(self, rocket: Rocket) -> Launched + where B: Send + Sync + 'static + { + self.with_launch(rocket, |rocket| rocket.launch_with::()) + } + pub fn launch(self, rocket: Rocket) -> Launched { self.launch_with::(rocket) }