Look up server IP address on every connect

This commit is contained in:
Dirkjan Ochtman 2022-12-12 13:56:52 +01:00
parent ed3bfdbcfa
commit eab64aa740
2 changed files with 28 additions and 23 deletions

View File

@ -1,10 +1,11 @@
use std::convert::TryInto; use std::convert::TryInto;
use std::io; use std::io;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
#[cfg(feature = "tokio-rustls")]
use tokio::net::lookup_host;
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
use tokio_rustls::client::TlsStream; use tokio_rustls::client::TlsStream;
@ -42,10 +43,8 @@ use crate::xml;
/// # #[tokio::main] /// # #[tokio::main]
/// # async fn main() { /// # async fn main() {
/// // Create an instance of EppClient /// // Create an instance of EppClient
/// let host = "example.com";
/// let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap();
/// let timeout = Duration::from_secs(5); /// let timeout = Duration::from_secs(5);
/// let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { /// let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
/// Ok(client) => client, /// Ok(client) => client,
/// Err(e) => panic!("Failed to create EppClient: {}", e) /// Err(e) => panic!("Failed to create EppClient: {}", e)
/// }; /// };
@ -77,22 +76,20 @@ pub struct EppClient<C: Connector> {
impl EppClient<RustlsConnector> { impl EppClient<RustlsConnector> {
/// Connect to the specified `addr` and `hostname` over TLS /// Connect to the specified `addr` and `hostname` over TLS
/// ///
/// The `registry` is used as a name in internal logging; `addr` provides the address to /// The `registry` is used as a name in internal logging; `host` provides the host name
/// connect to, `hostname` is sent as the TLS server name indication and `identity` provides /// and port to connect to), `hostname` is sent as the TLS server name indication and
/// optional TLS client authentication (using) rustls as the TLS implementation. /// `identity` provides optional TLS client authentication (using) rustls as the TLS
/// The `timeout` limits the time spent on any underlying network operations. /// implementation. The `timeout` limits the time spent on any underlying network operations.
/// ///
/// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin` /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin`
/// implementation. /// implementation.
pub async fn connect( pub async fn connect(
registry: String, registry: String,
addr: SocketAddr, server: (String, u16),
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>, identity: Option<(Vec<Certificate>, PrivateKey)>,
timeout: Duration, timeout: Duration,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
info!("Connecting to server: {:?}", addr); let connector = RustlsConnector::new(server, identity).await?;
let connector = RustlsConnector::new(addr, hostname, identity)?;
Self::new(connector, registry, timeout).await Self::new(connector, registry, timeout).await
} }
} }
@ -213,13 +210,12 @@ impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
pub struct RustlsConnector { pub struct RustlsConnector {
inner: TlsConnector, inner: TlsConnector,
domain: ServerName, domain: ServerName,
addr: SocketAddr, server: (String, u16),
} }
impl RustlsConnector { impl RustlsConnector {
pub fn new( pub async fn new(
addr: SocketAddr, server: (String, u16),
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>, identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut roots = RootCertStore::empty(); let mut roots = RootCertStore::empty();
@ -248,17 +244,17 @@ impl RustlsConnector {
None => builder.with_no_client_auth(), None => builder.with_no_client_auth(),
}; };
let domain = hostname.try_into().map_err(|_| { let domain = server.0.as_str().try_into().map_err(|_| {
io::Error::new( io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname), format!("Invalid domain: {}", server.0),
) )
})?; })?;
Ok(Self { Ok(Self {
inner: TlsConnector::from(Arc::new(config)), inner: TlsConnector::from(Arc::new(config)),
domain, domain,
addr, server,
}) })
} }
} }
@ -269,7 +265,18 @@ impl Connector for RustlsConnector {
type Connection = TlsStream<TcpStream>; type Connection = TlsStream<TcpStream>;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> { async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
let stream = TcpStream::connect(&self.addr).await?; info!("Connecting to server: {}:{}", self.server.0, self.server.1);
let addr = match lookup_host(&self.server).await?.next() {
Some(addr) => addr,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid host: {}", &self.server.0),
)))
}
};
let stream = TcpStream::connect(addr).await?;
let future = self.inner.connect(self.domain.clone(), stream); let future = self.inner.connect(self.domain.clone(), stream);
connection::timeout(timeout, future).await connection::timeout(timeout, future).await
} }

View File

@ -66,10 +66,8 @@
//! #[tokio::main] //! #[tokio::main]
//! async fn main() { //! async fn main() {
//! // Create an instance of EppClient //! // Create an instance of EppClient
//! let host = "example.com";
//! let addr = (host, 700).to_socket_addrs().unwrap().next().unwrap();
//! let timeout = Duration::from_secs(5); //! let timeout = Duration::from_secs(5);
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { //! let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
//! Ok(client) => client, //! Ok(client) => client,
//! Err(e) => panic!("Failed to create EppClient: {}", e) //! Err(e) => panic!("Failed to create EppClient: {}", e)
//! }; //! };