mirror of https://github.com/rwf2/Rocket.git
Simplify 'Bind'. Allow try-launching on Futures.
This commit is contained in:
parent
7cc818cd85
commit
3bfc4ca644
|
@ -1,10 +1,13 @@
|
|||
use crate::listener::{Endpoint, Listener};
|
||||
use std::error::Error;
|
||||
|
||||
pub trait Bind<T>: Listener + 'static {
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
use crate::listener::{Endpoint, Listener};
|
||||
use crate::{Rocket, Ignite};
|
||||
|
||||
pub trait Bind: Listener + 'static {
|
||||
type Error: Error + Send + 'static;
|
||||
|
||||
#[crate::async_bound(Send)]
|
||||
async fn bind(to: T) -> Result<Self, Self::Error>;
|
||||
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error>;
|
||||
|
||||
fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
|
||||
fn bind_endpoint(to: &Rocket<Ignite>) -> Result<Endpoint, Self::Error>;
|
||||
}
|
||||
|
|
|
@ -74,10 +74,10 @@ pub use private::DefaultListener;
|
|||
type Connection = crate::listener::tcp::TcpStream;
|
||||
|
||||
#[cfg(doc)]
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
|
||||
impl Bind for DefaultListener {
|
||||
type Error = Error;
|
||||
async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
|
||||
fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
|
||||
async fn bind(_: &Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
|
||||
fn bind_endpoint(_: &Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
|
||||
}
|
||||
|
||||
#[cfg(doc)]
|
||||
|
@ -96,36 +96,36 @@ impl super::Listener for DefaultListener {
|
|||
pub type DefaultListener = private::Listener;
|
||||
|
||||
#[cfg(not(doc))]
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
|
||||
impl Bind for DefaultListener {
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let config: Config = rocket.figment().extract()?;
|
||||
match config.address {
|
||||
#[cfg(feature = "tls")]
|
||||
Endpoint::Tcp(_) if config.tls.is_some() => {
|
||||
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
|
||||
let listener = <TlsListener<TcpListener> as Bind>::bind(rocket).await?;
|
||||
Ok(Left(Left(listener)))
|
||||
}
|
||||
Endpoint::Tcp(_) => {
|
||||
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
|
||||
let listener = <TcpListener as Bind>::bind(rocket).await?;
|
||||
Ok(Right(Left(listener)))
|
||||
}
|
||||
#[cfg(all(unix, feature = "tls"))]
|
||||
Endpoint::Unix(_) if config.tls.is_some() => {
|
||||
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
|
||||
let listener = <TlsListener<UnixListener> as Bind>::bind(rocket).await?;
|
||||
Ok(Left(Right(listener)))
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Endpoint::Unix(_) => {
|
||||
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
|
||||
let listener = <UnixListener as Bind>::bind(rocket).await?;
|
||||
Ok(Right(Right(listener)))
|
||||
}
|
||||
endpoint => Err(Error::Unsupported(endpoint)),
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let config: Config = rocket.figment().extract()?;
|
||||
Ok(config.address)
|
||||
}
|
||||
|
|
|
@ -20,22 +20,10 @@ pub use tokio::net::{TcpListener, TcpStream};
|
|||
use crate::{Ignite, Rocket};
|
||||
use crate::listener::{Bind, Connection, Endpoint, Listener};
|
||||
|
||||
impl Bind<SocketAddr> for TcpListener {
|
||||
type Error = std::io::Error;
|
||||
|
||||
async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
|
||||
Self::bind(addr).await
|
||||
}
|
||||
|
||||
fn bind_endpoint(addr: &SocketAddr) -> Result<Endpoint, Self::Error> {
|
||||
Ok(Endpoint::Tcp(*addr))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
|
||||
impl Bind for TcpListener {
|
||||
type Error = Either<figment::Error, io::Error>;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let endpoint = Self::bind_endpoint(&rocket)?;
|
||||
let addr = endpoint.tcp()
|
||||
.ok_or_else(|| io::Error::other("internal error: invalid endpoint"))
|
||||
|
@ -44,7 +32,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
|
|||
Self::bind(addr).await.map_err(Right)
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let figment = rocket.figment();
|
||||
let mut address = Endpoint::fetch(figment, "tcp", "address", |e| {
|
||||
let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000);
|
||||
|
|
|
@ -71,10 +71,10 @@ impl UnixListener {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
|
||||
impl Bind for UnixListener {
|
||||
type Error = Either<figment::Error, io::Error>;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let endpoint = Self::bind_endpoint(&rocket)?;
|
||||
let path = endpoint.unix()
|
||||
.ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?;
|
||||
|
@ -83,7 +83,7 @@ impl<'r> Bind<&'r Rocket<Ignite>> for UnixListener {
|
|||
Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?)
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf()));
|
||||
Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf)
|
||||
.map(Endpoint::Unix)
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use std::any::Any;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use futures::{Future, TryFutureExt};
|
||||
use yansi::Paint;
|
||||
use either::Either;
|
||||
use figment::{Figment, Provider};
|
||||
|
@ -682,9 +682,7 @@ impl Rocket<Ignite> {
|
|||
rocket
|
||||
}
|
||||
|
||||
async fn _launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>>
|
||||
{
|
||||
async fn _launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
|
||||
let bind_endpoint = B::bind_endpoint(&&self).ok();
|
||||
let listener: B = B::bind(&self).await
|
||||
.map_err(|e| ErrorKind::Bind(bind_endpoint, Box::new(e)))?;
|
||||
|
@ -1015,15 +1013,7 @@ impl<P: Phase> Rocket<P> {
|
|||
self.launch_with::<DefaultListener>().await
|
||||
}
|
||||
|
||||
pub async fn bind_launch<T, B: Bind<T>>(self, value: T) -> Result<Rocket<Ignite>, Error> {
|
||||
let endpoint = B::bind_endpoint(&value).ok();
|
||||
let listener = B::bind(value).map_err(|e| ErrorKind::Bind(endpoint, Box::new(e)));
|
||||
self.launch_on(listener.await?).await
|
||||
}
|
||||
|
||||
pub async fn launch_with<B>(self) -> Result<Rocket<Ignite>, Error>
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>>
|
||||
{
|
||||
pub async fn launch_with<B: Bind>(self) -> Result<Rocket<Ignite>, Error> {
|
||||
match self.0.into_state() {
|
||||
State::Build(s) => Rocket::from(s).ignite().await?._launch_with::<B>().await,
|
||||
State::Ignite(s) => Rocket::from(s)._launch_with::<B>().await,
|
||||
|
@ -1031,6 +1021,15 @@ impl<P: Phase> Rocket<P> {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn try_launch_on<L, F, E>(self, listener: F) -> Result<Rocket<Ignite>, Error>
|
||||
where L: Listener + 'static,
|
||||
F: Future<Output = Result<L, E>>,
|
||||
E: std::error::Error + Send + 'static
|
||||
{
|
||||
let listener = listener.map_err(|e| ErrorKind::Bind(None, Box::new(e))).await?;
|
||||
self.launch_on(listener).await
|
||||
}
|
||||
|
||||
pub async fn launch_on<L>(self, listener: L) -> Result<Rocket<Ignite>, Error>
|
||||
where L: Listener + 'static,
|
||||
{
|
||||
|
|
|
@ -8,7 +8,7 @@ use rustls::server::{Acceptor, ServerConfig};
|
|||
|
||||
use crate::{Ignite, Rocket};
|
||||
use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener};
|
||||
use crate::tls::{Error, TlsConfig};
|
||||
use crate::tls::{TlsConfig, Result, Error};
|
||||
use super::resolver::DynResolver;
|
||||
|
||||
#[doc(inline)]
|
||||
|
@ -21,40 +21,35 @@ pub struct TlsListener<L> {
|
|||
default: Arc<ServerConfig>,
|
||||
}
|
||||
|
||||
impl<T: Send, L: Bind<T>> Bind<(T, TlsConfig)> for TlsListener<L>
|
||||
impl<L> TlsListener<L>
|
||||
where L: Listener<Accept = <L as Listener>::Connection>,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
async fn bind((inner, config): (T, TlsConfig)) -> Result<Self, Self::Error> {
|
||||
pub async fn from(listener: L, config: TlsConfig) -> Result<TlsListener<L>> {
|
||||
Ok(TlsListener {
|
||||
default: Arc::new(config.server_config().await?),
|
||||
listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?,
|
||||
listener,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result<Endpoint, Self::Error> {
|
||||
L::bind_endpoint(inner)
|
||||
.map(|e| e.with_tls(config))
|
||||
.map_err(|e| Error::Bind(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, L> Bind<&'r Rocket<Ignite>> for TlsListener<L>
|
||||
where L: Bind<&'r Rocket<Ignite>> + Listener<Accept = <L as Listener>::Connection>
|
||||
impl<L: Bind> Bind for TlsListener<L>
|
||||
where L: Listener<Accept = <L as Listener>::Connection>
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
async fn bind(rocket: &Rocket<Ignite>) -> Result<Self, Self::Error> {
|
||||
let listener = L::bind(rocket).map_err(|e| Error::Bind(Box::new(e))).await?;
|
||||
let mut config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
config.resolver = DynResolver::extract(rocket);
|
||||
<Self as Bind<_>>::bind((rocket, config)).await
|
||||
Self::from(listener, config).await
|
||||
}
|
||||
|
||||
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
fn bind_endpoint(rocket: &Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
|
||||
let config: TlsConfig = rocket.figment().extract_inner("tls")?;
|
||||
<Self as Bind<_>>::bind_endpoint(&(*rocket, config))
|
||||
L::bind_endpoint(rocket)
|
||||
.map(|e| e.with_tls(&config))
|
||||
.map_err(|e| Error::Bind(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,6 @@ async fn on_ignite_fairing_can_inspect_port() {
|
|||
}));
|
||||
|
||||
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
|
||||
rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr));
|
||||
rocket::tokio::spawn(rocket.try_launch_on(TcpListener::bind(addr)));
|
||||
assert_ne!(rx.await.unwrap(), 0);
|
||||
}
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::process::ExitCode;
|
||||
use std::time::Duration;
|
||||
|
||||
use rocket::listener::unix::UnixListener;
|
||||
use rocket::tokio::net::TcpListener;
|
||||
use rocket::yansi::Paint;
|
||||
use rocket::{get, routes, Build, Rocket, State};
|
||||
use rocket::listener::{unix::UnixListener, Endpoint};
|
||||
use rocket::tls::TlsListener;
|
||||
|
||||
use reqwest::{tls::TlsInfo, Identity};
|
||||
|
||||
use testbench::*;
|
||||
|
||||
static DEFAULT_CONFIG: &str = r#"
|
||||
|
@ -112,12 +116,12 @@ fn infinite() -> Result<()> {
|
|||
}
|
||||
|
||||
fn tls_info() -> Result<()> {
|
||||
let mut server = spawn! {
|
||||
#[get("/")]
|
||||
fn hello_world() -> &'static str {
|
||||
"Hello, world!"
|
||||
fn hello_world(endpoint: &Endpoint) -> String {
|
||||
format!("Hello, {endpoint}!")
|
||||
}
|
||||
|
||||
let mut server = spawn! {
|
||||
Rocket::tls_default().mount("/", routes![hello_world])
|
||||
}?;
|
||||
|
||||
|
@ -125,13 +129,35 @@ fn tls_info() -> Result<()> {
|
|||
let response = client.get(&server, "/")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
assert!(!tls.peer_certificate().unwrap().is_empty());
|
||||
assert_eq!(response.text()?, "Hello, world!");
|
||||
assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
|
||||
|
||||
server.terminate()?;
|
||||
let stdout = server.read_stdout()?;
|
||||
assert!(stdout.contains("Rocket has launched on https"));
|
||||
assert!(stdout.contains("Graceful shutdown completed"));
|
||||
assert!(stdout.contains("GET /"));
|
||||
|
||||
let server = Server::spawn((), |(token, _)| {
|
||||
let rocket = rocket::build()
|
||||
.configure_with_toml(TLS_CONFIG)
|
||||
.mount("/", routes![hello_world]);
|
||||
|
||||
token.with_launch(rocket, |rocket| {
|
||||
let config = rocket.figment().extract_inner("tls");
|
||||
rocket.try_launch_on(async move {
|
||||
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
TlsListener::from(listener, config?).await
|
||||
})
|
||||
})
|
||||
}).unwrap();
|
||||
|
||||
let client = Client::default();
|
||||
let response = client.get(&server, "/")?.send()?;
|
||||
let tls = response.extensions().get::<TlsInfo>().unwrap();
|
||||
assert!(!tls.peer_certificate().unwrap().is_empty());
|
||||
assert!(response.text()?.starts_with("Hello, https://127.0.0.1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::future::Future;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use std::sync::Once;
|
||||
|
@ -111,8 +112,9 @@ impl Server {
|
|||
}
|
||||
|
||||
impl Token {
|
||||
pub fn launch_with<B>(self, rocket: Rocket<Build>) -> Launched
|
||||
where B: for<'r> Bind<&'r Rocket<Ignite>> + Sync + Send + 'static
|
||||
pub fn with_launch<F, Fut>(self, rocket: Rocket<Build>, launch: F) -> Launched
|
||||
where F: FnOnce(Rocket<Ignite>) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<Rocket<Ignite>, rocket::Error>> + Send
|
||||
{
|
||||
let server = self.0.clone();
|
||||
let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move {
|
||||
|
@ -124,7 +126,12 @@ impl Token {
|
|||
})));
|
||||
|
||||
let server = self.0.clone();
|
||||
if let Err(e) = rocket::execute(rocket.launch_with::<B>()) {
|
||||
let launch = async move {
|
||||
let rocket = rocket.ignite().await?;
|
||||
launch(rocket).await
|
||||
};
|
||||
|
||||
if let Err(e) = rocket::execute(launch) {
|
||||
let sender = IpcSender::<Message>::connect(server).unwrap();
|
||||
let _ = sender.send(Message::Failure);
|
||||
let _ = sender.send(Message::Failure);
|
||||
|
@ -135,6 +142,12 @@ impl Token {
|
|||
Launched(())
|
||||
}
|
||||
|
||||
pub fn launch_with<B: Bind>(self, rocket: Rocket<Build>) -> Launched
|
||||
where B: Send + Sync + 'static
|
||||
{
|
||||
self.with_launch(rocket, |rocket| rocket.launch_with::<B>())
|
||||
}
|
||||
|
||||
pub fn launch(self, rocket: Rocket<Build>) -> Launched {
|
||||
self.launch_with::<DefaultListener>(rocket)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue