From eab64aa740fe12b14a543ec9872684e3264a8188 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 12 Dec 2022 13:56:52 +0100 Subject: [PATCH] Look up server IP address on every connect --- src/client.rs | 47 +++++++++++++++++++++++++++-------------------- src/lib.rs | 4 +--- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/client.rs b/src/client.rs index e9fa9d3..a63ce6a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,11 @@ use std::convert::TryInto; use std::io; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; +#[cfg(feature = "tokio-rustls")] +use tokio::net::lookup_host; use tokio::net::TcpStream; #[cfg(feature = "tokio-rustls")] use tokio_rustls::client::TlsStream; @@ -42,10 +43,8 @@ use crate::xml; /// # #[tokio::main] /// # async fn main() { /// // Create an instance of EppClient -/// let host = "example.com"; -/// let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap(); /// let timeout = Duration::from_secs(5); -/// let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { +/// let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await { /// Ok(client) => client, /// Err(e) => panic!("Failed to create EppClient: {}", e) /// }; @@ -77,22 +76,20 @@ pub struct EppClient { impl EppClient { /// Connect to the specified `addr` and `hostname` over TLS /// - /// The `registry` is used as a name in internal logging; `addr` provides the address to - /// connect to, `hostname` is sent as the TLS server name indication and `identity` provides - /// optional TLS client authentication (using) rustls as the TLS implementation. - /// The `timeout` limits the time spent on any underlying network operations. + /// The `registry` is used as a name in internal logging; `host` provides the host name + /// and port to connect to), `hostname` is sent as the TLS server name indication and + /// `identity` provides optional TLS client authentication (using) rustls as the TLS + /// implementation. The `timeout` limits the time spent on any underlying network operations. /// /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin` /// implementation. pub async fn connect( registry: String, - addr: SocketAddr, - hostname: &str, + server: (String, u16), identity: Option<(Vec, PrivateKey)>, timeout: Duration, ) -> Result { - info!("Connecting to server: {:?}", addr); - let connector = RustlsConnector::new(addr, hostname, identity)?; + let connector = RustlsConnector::new(server, identity).await?; Self::new(connector, registry, timeout).await } } @@ -213,13 +210,12 @@ impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {} pub struct RustlsConnector { inner: TlsConnector, domain: ServerName, - addr: SocketAddr, + server: (String, u16), } impl RustlsConnector { - pub fn new( - addr: SocketAddr, - hostname: &str, + pub async fn new( + server: (String, u16), identity: Option<(Vec, PrivateKey)>, ) -> Result { let mut roots = RootCertStore::empty(); @@ -248,17 +244,17 @@ impl RustlsConnector { None => builder.with_no_client_auth(), }; - let domain = hostname.try_into().map_err(|_| { + let domain = server.0.as_str().try_into().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, - format!("Invalid domain: {}", hostname), + format!("Invalid domain: {}", server.0), ) })?; Ok(Self { inner: TlsConnector::from(Arc::new(config)), domain, - addr, + server, }) } } @@ -269,7 +265,18 @@ impl Connector for RustlsConnector { type Connection = TlsStream; async fn connect(&self, timeout: Duration) -> Result { - let stream = TcpStream::connect(&self.addr).await?; + info!("Connecting to server: {}:{}", self.server.0, self.server.1); + let addr = match lookup_host(&self.server).await?.next() { + Some(addr) => addr, + None => { + return Err(Error::Io(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid host: {}", &self.server.0), + ))) + } + }; + + let stream = TcpStream::connect(addr).await?; let future = self.inner.connect(self.domain.clone(), stream); connection::timeout(timeout, future).await } diff --git a/src/lib.rs b/src/lib.rs index f49d3a9..88304a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,10 +66,8 @@ //! #[tokio::main] //! async fn main() { //! // Create an instance of EppClient -//! let host = "example.com"; -//! let addr = (host, 700).to_socket_addrs().unwrap().next().unwrap(); //! let timeout = Duration::from_secs(5); -//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { +//! let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await { //! Ok(client) => client, //! Err(e) => panic!("Failed to create EppClient: {}", e) //! };