From dc3f10cae308b9b6aa58cce55f1c08b2040b6308 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 1 Feb 2022 11:31:08 +0100 Subject: [PATCH] Wrap timeouts around network operations --- src/client.rs | 21 +++++++++++++-------- src/connection.rs | 29 ++++++++++++++++++++++++----- src/error.rs | 6 ++++-- src/lib.rs | 4 +++- tests/basic.rs | 6 +++++- 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/client.rs b/src/client.rs index 51264cf..65757df 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,6 +5,7 @@ //! ```no_run //! use std::collections::HashMap; //! use std::net::ToSocketAddrs; +//! use std::time::Duration; //! //! use epp_client::EppClient; //! use epp_client::domain::DomainCheck; @@ -16,7 +17,8 @@ //! // Create an instance of EppClient //! let host = "example.com"; //! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap(); -//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None).await { +//! let timeout = Duration::from_secs(5); +//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { //! Ok(client) => client, //! Err(e) => panic!("Failed to create EppClient: {}", e) //! }; @@ -37,6 +39,7 @@ use std::convert::TryInto; use std::io; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; @@ -49,7 +52,7 @@ use tokio_rustls::TlsConnector; use tracing::info; use crate::common::{Certificate, NoExtension, PrivateKey}; -use crate::connection::EppConnection; +use crate::connection::{self, EppConnection}; use crate::error::Error; use crate::hello::{Greeting, GreetingDocument, HelloDocument}; use crate::request::{Command, Extension, Transaction}; @@ -69,7 +72,8 @@ impl EppClient> { /// /// The `registry` is used as a name in internal logging; `addr` provides the address to /// connect to, `hostname` is sent as the TLS server name indication and `identity` provides - /// optional TLS client authentication. Uses rustls as the TLS implementation. + /// optional TLS client authentication (using) rustls as the TLS implementation. + /// The `timeout` limits the time spent on any underlying network operations. /// /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin` /// implementation. @@ -78,6 +82,7 @@ impl EppClient> { addr: SocketAddr, hostname: &str, identity: Option<(Vec, PrivateKey)>, + timeout: Duration, ) -> Result { info!("Connecting to server: {:?}", addr); @@ -115,17 +120,17 @@ impl EppClient> { })?; let connector = TlsConnector::from(Arc::new(config)); - let tcp = TcpStream::connect(&addr).await?; - let stream = connector.connect(domain, tcp).await?; - Self::new(registry, stream).await + let future = connector.connect(domain, TcpStream::connect(&addr).await?); + let stream = connection::timeout(timeout, future).await?; + Self::new(registry, stream, timeout).await } } impl EppClient { /// Create an `EppClient` from an already established connection - pub async fn new(registry: String, stream: IO) -> Result { + pub async fn new(registry: String, stream: IO, timeout: Duration) -> Result { Ok(Self { - connection: EppConnection::new(registry, stream).await?, + connection: EppConnection::new(registry, stream, timeout).await?, }) } diff --git a/src/connection.rs b/src/connection.rs index 406499b..088152c 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,8 @@ //! Manages registry connections and reading/writing to them use std::convert::TryInto; +use std::future::Future; +use std::time::Duration; use std::{io, str, u32}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -13,14 +15,20 @@ pub(crate) struct EppConnection { registry: String, stream: IO, pub greeting: String, + timeout: Duration, } impl EppConnection { - pub(crate) async fn new(registry: String, stream: IO) -> Result { + pub(crate) async fn new( + registry: String, + stream: IO, + timeout: Duration, + ) -> Result { let mut this = Self { registry, stream, greeting: String::new(), + timeout, }; this.greeting = this.get_epp_response().await?; @@ -40,7 +48,7 @@ impl EppConnection { buf[..4].clone_from_slice(&len_u32); buf[4..].clone_from_slice(content.as_bytes()); - let wrote = self.stream.write(&buf).await?; + let wrote = timeout(self.timeout, self.stream.write(&buf)).await?; debug!("{}: Wrote {} bytes", self.registry, wrote); Ok(()) } @@ -48,7 +56,7 @@ impl EppConnection { /// Receives response from the socket and converts it into an EPP XML string async fn get_epp_response(&mut self) -> Result { let mut buf = [0u8; 4]; - self.stream.read_exact(&mut buf).await?; + timeout(self.timeout, self.stream.read_exact(&mut buf)).await?; let buf_size: usize = u32::from_be_bytes(buf).try_into()?; @@ -59,7 +67,7 @@ impl EppConnection { let mut read_size: usize = 0; loop { - let read = self.stream.read(&mut buf[read_size..]).await?; + let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?; debug!("{}: Read: {} bytes", self.registry, read); read_size += read; @@ -95,7 +103,18 @@ impl EppConnection { pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { info!("{}: Closing connection", self.registry); - self.stream.shutdown().await?; + timeout(self.timeout, self.stream.shutdown()).await?; Ok(()) } } + +pub(crate) async fn timeout>( + timeout: Duration, + fut: impl Future>, +) -> Result { + match tokio::time::timeout(timeout, fut).await { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(e.into()), + Err(_) => Err(Error::Timeout), + } +} diff --git a/src/error.rs b/src/error.rs index fff0b0d..30ddb3f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,8 +12,9 @@ use crate::response::ResponseStatus; /// Error enum holding the possible error types #[derive(Debug)] pub enum Error { - Io(std::io::Error), Command(ResponseStatus), + Io(std::io::Error), + Timeout, Xml(Box), Other(Box), } @@ -23,10 +24,11 @@ impl StdError for Error {} impl Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Error::Io(e) => write!(f, "I/O error: {}", e), Error::Command(e) => { write!(f, "command error: {}", e.result.message) } + Error::Io(e) => write!(f, "I/O error: {}", e), + Error::Timeout => write!(f, "timeout"), Error::Xml(e) => write!(f, "(de)serialization error: {}", e), Error::Other(e) => write!(f, "error: {}", e), } diff --git a/src/lib.rs b/src/lib.rs index 859e166..be9a816 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ //! ```no_run //! use std::collections::HashMap; //! use std::net::ToSocketAddrs; +//! use std::time::Duration; //! //! use epp_client::EppClient; //! use epp_client::domain::DomainCheck; @@ -54,7 +55,8 @@ //! // Create an instance of EppClient //! let host = "example.com"; //! let addr = (host, 7000).to_socket_addrs().unwrap().next().unwrap(); -//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None).await { +//! let timeout = Duration::from_secs(5); +//! let mut client = match EppClient::connect("registry_name".to_string(), addr, host, None, timeout).await { //! Ok(client) => client, //! Err(e) => panic!("Failed to create EppClient: {}", e) //! }; diff --git a/tests/basic.rs b/tests/basic.rs index 11e17c8..c18f2b6 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,6 +1,7 @@ use std::fs::File; use std::io::{self, Read, Write}; use std::str; +use std::time::Duration; use regex::Regex; use tokio_test::io::Builder; @@ -86,7 +87,10 @@ async fn client() { ]) .build(); - let mut client = EppClient::new("test".into(), stream).await.unwrap(); + let mut client = EppClient::new("test".into(), stream, Duration::from_secs(5)) + .await + .unwrap(); + assert_eq!(client.xml_greeting(), xml("response/greeting.xml")); client .transact(