Look up server IP address on every connect

This commit is contained in:
Dirkjan Ochtman 2022-12-12 13:56:52 +01:00
parent ed3bfdbcfa
commit eab64aa740
2 changed files with 28 additions and 23 deletions

View File

@ -1,10 +1,11 @@
use std::convert::TryInto;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
#[cfg(feature = "tokio-rustls")]
use tokio::net::lookup_host;
use tokio::net::TcpStream;
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::client::TlsStream;
@ -42,10 +43,8 @@ use crate::xml;
/// # #[tokio::main]
/// # async fn main() {
/// // Create an instance of EppClient
/// let host = "example.com";
/// let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap();
/// let timeout = Duration::from_secs(5);
/// let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await {
/// let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
/// Ok(client) => client,
/// Err(e) => panic!("Failed to create EppClient: {}", e)
/// };
@ -77,22 +76,20 @@ pub struct EppClient<C: Connector> {
impl EppClient<RustlsConnector> {
/// Connect to the specified `addr` and `hostname` over TLS
///
/// The `registry` is used as a name in internal logging; `addr` provides the address to
/// connect to, `hostname` is sent as the TLS server name indication and `identity` provides
/// optional TLS client authentication (using) rustls as the TLS implementation.
/// The `timeout` limits the time spent on any underlying network operations.
/// The `registry` is used as a name in internal logging; `host` provides the host name
/// and port to connect to), `hostname` is sent as the TLS server name indication and
/// `identity` provides optional TLS client authentication (using) rustls as the TLS
/// implementation. The `timeout` limits the time spent on any underlying network operations.
///
/// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin`
/// implementation.
pub async fn connect(
registry: String,
addr: SocketAddr,
hostname: &str,
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
timeout: Duration,
) -> Result<Self, Error> {
info!("Connecting to server: {:?}", addr);
let connector = RustlsConnector::new(addr, hostname, identity)?;
let connector = RustlsConnector::new(server, identity).await?;
Self::new(connector, registry, timeout).await
}
}
@ -213,13 +210,12 @@ impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
pub struct RustlsConnector {
inner: TlsConnector,
domain: ServerName,
addr: SocketAddr,
server: (String, u16),
}
impl RustlsConnector {
pub fn new(
addr: SocketAddr,
hostname: &str,
pub async fn new(
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut roots = RootCertStore::empty();
@ -248,17 +244,17 @@ impl RustlsConnector {
None => builder.with_no_client_auth(),
};
let domain = hostname.try_into().map_err(|_| {
let domain = server.0.as_str().try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname),
format!("Invalid domain: {}", server.0),
)
})?;
Ok(Self {
inner: TlsConnector::from(Arc::new(config)),
domain,
addr,
server,
})
}
}
@ -269,7 +265,18 @@ impl Connector for RustlsConnector {
type Connection = TlsStream<TcpStream>;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
let stream = TcpStream::connect(&self.addr).await?;
info!("Connecting to server: {}:{}", self.server.0, self.server.1);
let addr = match lookup_host(&self.server).await?.next() {
Some(addr) => addr,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid host: {}", &self.server.0),
)))
}
};
let stream = TcpStream::connect(addr).await?;
let future = self.inner.connect(self.domain.clone(), stream);
connection::timeout(timeout, future).await
}

View File

@ -66,10 +66,8 @@
//! #[tokio::main]
//! async fn main() {
//! // Create an instance of EppClient
//! let host = "example.com";
//! let addr = (host, 700).to_socket_addrs().unwrap().next().unwrap();
//! let timeout = Duration::from_secs(5);
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await {
//! let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
//! Ok(client) => client,
//! Err(e) => panic!("Failed to create EppClient: {}", e)
//! };