Wrap timeouts around network operations

This commit is contained in:
Dirkjan Ochtman 2022-02-01 11:31:08 +01:00 committed by masalachai
parent f3aab578e7
commit dc3f10cae3
5 changed files with 49 additions and 17 deletions

View File

@ -5,6 +5,7 @@
//! ```no_run //! ```no_run
//! use std::collections::HashMap; //! use std::collections::HashMap;
//! use std::net::ToSocketAddrs; //! use std::net::ToSocketAddrs;
//! use std::time::Duration;
//! //!
//! use epp_client::EppClient; //! use epp_client::EppClient;
//! use epp_client::domain::DomainCheck; //! use epp_client::domain::DomainCheck;
@ -16,7 +17,8 @@
//! // Create an instance of EppClient //! // Create an instance of EppClient
//! let host = "example.com"; //! let host = "example.com";
//! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap(); //! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap();
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None).await { //! let timeout = Duration::from_secs(5);
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await {
//! Ok(client) => client, //! Ok(client) => client,
//! Err(e) => panic!("Failed to create EppClient: {}", e) //! Err(e) => panic!("Failed to create EppClient: {}", e)
//! }; //! };
@ -37,6 +39,7 @@ use std::convert::TryInto;
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -49,7 +52,7 @@ use tokio_rustls::TlsConnector;
use tracing::info; use tracing::info;
use crate::common::{Certificate, NoExtension, PrivateKey}; use crate::common::{Certificate, NoExtension, PrivateKey};
use crate::connection::EppConnection; use crate::connection::{self, EppConnection};
use crate::error::Error; use crate::error::Error;
use crate::hello::{Greeting, GreetingDocument, HelloDocument}; use crate::hello::{Greeting, GreetingDocument, HelloDocument};
use crate::request::{Command, Extension, Transaction}; use crate::request::{Command, Extension, Transaction};
@ -69,7 +72,8 @@ impl EppClient<TlsStream<TcpStream>> {
/// ///
/// The `registry` is used as a name in internal logging; `addr` provides the address to /// The `registry` is used as a name in internal logging; `addr` provides the address to
/// connect to, `hostname` is sent as the TLS server name indication and `identity` provides /// connect to, `hostname` is sent as the TLS server name indication and `identity` provides
/// optional TLS client authentication. Uses rustls as the TLS implementation. /// optional TLS client authentication (using) rustls as the TLS implementation.
/// The `timeout` limits the time spent on any underlying network operations.
/// ///
/// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin` /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin`
/// implementation. /// implementation.
@ -78,6 +82,7 @@ impl EppClient<TlsStream<TcpStream>> {
addr: SocketAddr, addr: SocketAddr,
hostname: &str, hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>, identity: Option<(Vec<Certificate>, PrivateKey)>,
timeout: Duration,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
info!("Connecting to server: {:?}", addr); info!("Connecting to server: {:?}", addr);
@ -115,17 +120,17 @@ impl EppClient<TlsStream<TcpStream>> {
})?; })?;
let connector = TlsConnector::from(Arc::new(config)); let connector = TlsConnector::from(Arc::new(config));
let tcp = TcpStream::connect(&addr).await?; let future = connector.connect(domain, TcpStream::connect(&addr).await?);
let stream = connector.connect(domain, tcp).await?; let stream = connection::timeout(timeout, future).await?;
Self::new(registry, stream).await Self::new(registry, stream, timeout).await
} }
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> EppClient<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> EppClient<IO> {
/// Create an `EppClient` from an already established connection /// Create an `EppClient` from an already established connection
pub async fn new(registry: String, stream: IO) -> Result<Self, Error> { pub async fn new(registry: String, stream: IO, timeout: Duration) -> Result<Self, Error> {
Ok(Self { Ok(Self {
connection: EppConnection::new(registry, stream).await?, connection: EppConnection::new(registry, stream, timeout).await?,
}) })
} }

View File

@ -1,6 +1,8 @@
//! Manages registry connections and reading/writing to them //! Manages registry connections and reading/writing to them
use std::convert::TryInto; use std::convert::TryInto;
use std::future::Future;
use std::time::Duration;
use std::{io, str, u32}; use std::{io, str, u32};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
@ -13,14 +15,20 @@ pub(crate) struct EppConnection<IO> {
registry: String, registry: String,
stream: IO, stream: IO,
pub greeting: String, pub greeting: String,
timeout: Duration,
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> { impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
pub(crate) async fn new(registry: String, stream: IO) -> Result<Self, Error> { pub(crate) async fn new(
registry: String,
stream: IO,
timeout: Duration,
) -> Result<Self, Error> {
let mut this = Self { let mut this = Self {
registry, registry,
stream, stream,
greeting: String::new(), greeting: String::new(),
timeout,
}; };
this.greeting = this.get_epp_response().await?; this.greeting = this.get_epp_response().await?;
@ -40,7 +48,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
buf[..4].clone_from_slice(&len_u32); buf[..4].clone_from_slice(&len_u32);
buf[4..].clone_from_slice(content.as_bytes()); buf[4..].clone_from_slice(content.as_bytes());
let wrote = self.stream.write(&buf).await?; let wrote = timeout(self.timeout, self.stream.write(&buf)).await?;
debug!("{}: Wrote {} bytes", self.registry, wrote); debug!("{}: Wrote {} bytes", self.registry, wrote);
Ok(()) Ok(())
} }
@ -48,7 +56,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
/// Receives response from the socket and converts it into an EPP XML string /// Receives response from the socket and converts it into an EPP XML string
async fn get_epp_response(&mut self) -> Result<String, Error> { async fn get_epp_response(&mut self) -> Result<String, Error> {
let mut buf = [0u8; 4]; let mut buf = [0u8; 4];
self.stream.read_exact(&mut buf).await?; timeout(self.timeout, self.stream.read_exact(&mut buf)).await?;
let buf_size: usize = u32::from_be_bytes(buf).try_into()?; let buf_size: usize = u32::from_be_bytes(buf).try_into()?;
@ -59,7 +67,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
let mut read_size: usize = 0; let mut read_size: usize = 0;
loop { loop {
let read = self.stream.read(&mut buf[read_size..]).await?; let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?;
debug!("{}: Read: {} bytes", self.registry, read); debug!("{}: Read: {} bytes", self.registry, read);
read_size += read; read_size += read;
@ -95,7 +103,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
info!("{}: Closing connection", self.registry); info!("{}: Closing connection", self.registry);
self.stream.shutdown().await?; timeout(self.timeout, self.stream.shutdown()).await?;
Ok(()) Ok(())
} }
} }
pub(crate) async fn timeout<T, E: Into<Error>>(
timeout: Duration,
fut: impl Future<Output = Result<T, E>>,
) -> Result<T, Error> {
match tokio::time::timeout(timeout, fut).await {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(e.into()),
Err(_) => Err(Error::Timeout),
}
}

View File

@ -12,8 +12,9 @@ use crate::response::ResponseStatus;
/// Error enum holding the possible error types /// Error enum holding the possible error types
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
Io(std::io::Error),
Command(ResponseStatus), Command(ResponseStatus),
Io(std::io::Error),
Timeout,
Xml(Box<dyn StdError + Send + Sync>), Xml(Box<dyn StdError + Send + Sync>),
Other(Box<dyn StdError + Send + Sync>), Other(Box<dyn StdError + Send + Sync>),
} }
@ -23,10 +24,11 @@ impl StdError for Error {}
impl Display for Error { impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Error::Io(e) => write!(f, "I/O error: {}", e),
Error::Command(e) => { Error::Command(e) => {
write!(f, "command error: {}", e.result.message) write!(f, "command error: {}", e.result.message)
} }
Error::Io(e) => write!(f, "I/O error: {}", e),
Error::Timeout => write!(f, "timeout"),
Error::Xml(e) => write!(f, "(de)serialization error: {}", e), Error::Xml(e) => write!(f, "(de)serialization error: {}", e),
Error::Other(e) => write!(f, "error: {}", e), Error::Other(e) => write!(f, "error: {}", e),
} }

View File

@ -41,6 +41,7 @@
//! ```no_run //! ```no_run
//! use std::collections::HashMap; //! use std::collections::HashMap;
//! use std::net::ToSocketAddrs; //! use std::net::ToSocketAddrs;
//! use std::time::Duration;
//! //!
//! use epp_client::EppClient; //! use epp_client::EppClient;
//! use epp_client::domain::DomainCheck; //! use epp_client::domain::DomainCheck;
@ -54,7 +55,8 @@
//! // Create an instance of EppClient //! // Create an instance of EppClient
//! let host = "example.com"; //! let host = "example.com";
//! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap(); //! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap();
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None).await { //! let timeout = Duration::from_secs(5);
//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await {
//! Ok(client) => client, //! Ok(client) => client,
//! Err(e) => panic!("Failed to create EppClient: {}", e) //! Err(e) => panic!("Failed to create EppClient: {}", e)
//! }; //! };

View File

@ -1,6 +1,7 @@
use std::fs::File; use std::fs::File;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::str; use std::str;
use std::time::Duration;
use regex::Regex; use regex::Regex;
use tokio_test::io::Builder; use tokio_test::io::Builder;
@ -86,7 +87,10 @@ async fn client() {
]) ])
.build(); .build();
let mut client = EppClient::new("test".into(), stream).await.unwrap(); let mut client = EppClient::new("test".into(), stream, Duration::from_secs(5))
.await
.unwrap();
assert_eq!(client.xml_greeting(), xml("response/greeting.xml")); assert_eq!(client.xml_greeting(), xml("response/greeting.xml"));
client client
.transact( .transact(