From 8cad81ab3da4aeda53d7fc35484f4de9e7e59be9 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Feb 2022 22:38:45 +0100 Subject: [PATCH] Allow the Connection to connect itself --- Cargo.toml | 1 + src/client.rs | 138 ++++++++++++++++++++++++++++------------------ src/connection.rs | 18 ++++-- tests/basic.rs | 29 +++++++--- 4 files changed, 119 insertions(+), 67 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ee9029..5cf6429 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/masalachai/epp-client" default = ["tokio-rustls"] [dependencies] +async-trait = "0.1.52" celes = "2.1" chrono = "0.4" quick-xml = { version = "0.22", features = [ "serialize" ] } diff --git a/src/client.rs b/src/client.rs index 99d363f..d5ca481 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,17 +4,18 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; +use async_trait::async_trait; use tokio::net::TcpStream; #[cfg(feature = "tokio-rustls")] use tokio_rustls::client::TlsStream; #[cfg(feature = "tokio-rustls")] -use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; #[cfg(feature = "tokio-rustls")] use tokio_rustls::TlsConnector; use tracing::info; use crate::common::{Certificate, NoExtension, PrivateKey}; +pub use crate::connection::Connector; use crate::connection::{self, EppConnection}; use crate::error::Error; use crate::hello::{Greeting, GreetingDocument, HelloDocument}; @@ -68,12 +69,12 @@ use crate::xml::EppXml; /// Domain: eppdev.com, Available: 1 /// Domain: eppdev.net, Available: 1 /// ``` -pub struct EppClient { - connection: EppConnection, +pub struct EppClient { + connection: EppConnection, } #[cfg(feature = "tokio-rustls")] -impl EppClient> { +impl EppClient { /// Connect to the specified `addr` and `hostname` over TLS /// /// The `registry` is used as a name in internal logging; `addr` provides the address to @@ -91,52 +92,16 @@ impl EppClient> { timeout: Duration, ) -> Result { info!("Connecting to server: {:?}", addr); - - let mut roots = RootCertStore::empty(); - roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - let builder = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots); - - let config = match identity { - Some((certs, key)) => { - let certs = certs - .into_iter() - .map(|cert| tokio_rustls::rustls::Certificate(cert.0)) - .collect(); - builder - .with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0)) - .map_err(|e| Error::Other(e.into()))? - } - None => builder.with_no_client_auth(), - }; - - let domain = hostname.try_into().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("Invalid domain: {}", hostname), - ) - })?; - - let connector = TlsConnector::from(Arc::new(config)); - let future = connector.connect(domain, TcpStream::connect(&addr).await?); - let stream = connection::timeout(timeout, future).await?; - Self::new(registry, stream, timeout).await + let connector = RustlsConnector::new(addr, hostname, identity)?; + Self::new(connector, registry, timeout).await } } -impl EppClient { +impl EppClient { /// Create an `EppClient` from an already established connection - pub async fn new(registry: String, stream: IO, timeout: Duration) -> Result { + pub async fn new(connector: C, registry: String, timeout: Duration) -> Result { Ok(Self { - connection: EppConnection::new(registry, stream, timeout).await?, + connection: EppConnection::new(connector, registry, timeout).await?, }) } @@ -149,21 +114,22 @@ impl EppClient { Ok(GreetingDocument::deserialize(&response)?.data) } - pub async fn transact<'c, 'e, C, E>( + pub async fn transact<'c, 'e, Cmd, Ext>( &mut self, - data: impl Into>, + data: impl Into>, id: &str, - ) -> Result, Error> + ) -> Result, Error> where - C: Transaction + Command + 'c, - E: Extension + 'e, + Cmd: Transaction + Command + 'c, + Ext: Extension + 'e, { let data = data.into(); - let epp_xml = >::serialize_request(data.command, data.extension, id)?; + let epp_xml = + >::serialize_request(data.command, data.extension, id)?; let response = self.connection.transact(&epp_xml).await?; - C::deserialize_response(&response) + Cmd::deserialize_response(&response) } /// Accepts raw EPP XML and returns the raw EPP XML response to it. @@ -209,3 +175,69 @@ impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, } } } + +#[cfg(feature = "tokio-rustls")] +pub struct RustlsConnector { + inner: TlsConnector, + domain: ServerName, + addr: SocketAddr, +} + +impl RustlsConnector { + pub fn new( + addr: SocketAddr, + hostname: &str, + identity: Option<(Vec, PrivateKey)>, + ) -> Result { + let mut roots = RootCertStore::empty(); + roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + let builder = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots); + + let config = match identity { + Some((certs, key)) => { + let certs = certs + .into_iter() + .map(|cert| tokio_rustls::rustls::Certificate(cert.0)) + .collect(); + builder + .with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0)) + .map_err(|e| Error::Other(e.into()))? + } + None => builder.with_no_client_auth(), + }; + + let domain = hostname.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid domain: {}", hostname), + ) + })?; + + Ok(Self { + inner: TlsConnector::from(Arc::new(config)), + domain, + addr, + }) + } +} + +#[cfg(feature = "tokio-rustls")] +#[async_trait] +impl Connector for RustlsConnector { + type Connection = TlsStream; + + async fn connect(&self, timeout: Duration) -> Result { + let stream = TcpStream::connect(&self.addr).await?; + let future = self.inner.connect(self.domain.clone(), stream); + connection::timeout(timeout, future).await + } +} diff --git a/src/connection.rs b/src/connection.rs index 088152c..490e5a8 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -5,28 +5,29 @@ use std::future::Future; use std::time::Duration; use std::{io, str, u32}; +use async_trait::async_trait; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info}; use crate::error::Error; /// EPP Connection struct with some metadata for the connection -pub(crate) struct EppConnection { +pub(crate) struct EppConnection { registry: String, - stream: IO, + stream: C::Connection, pub greeting: String, timeout: Duration, } -impl EppConnection { +impl EppConnection { pub(crate) async fn new( + connector: C, registry: String, - stream: IO, timeout: Duration, ) -> Result { let mut this = Self { registry, - stream, + stream: connector.connect(timeout).await?, greeting: String::new(), timeout, }; @@ -118,3 +119,10 @@ pub(crate) async fn timeout>( Err(_) => Err(Error::Timeout), } } + +#[async_trait] +pub trait Connector { + type Connection: AsyncRead + AsyncWrite + Unpin; + + async fn connect(&self, timeout: Duration) -> Result; +} diff --git a/tests/basic.rs b/tests/basic.rs index c18f2b6..58d56fb 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -3,6 +3,7 @@ use std::io::{self, Read, Write}; use std::str; use std::time::Duration; +use async_trait::async_trait; use regex::Regex; use tokio_test::io::Builder; @@ -78,16 +79,26 @@ fn build_stream(units: &[&str]) -> Builder { #[tokio::test] async fn client() { let _guard = log_to_stdout(); - let stream = build_stream(&[ - "response/greeting.xml", - "request/login.xml", - "response/login.xml", - "request/domain/check.xml", - "response/domain/check.xml", - ]) - .build(); - let mut client = EppClient::new("test".into(), stream, Duration::from_secs(5)) + struct FakeConnector; + + #[async_trait] + impl epp_client::client::Connector for FakeConnector { + type Connection = tokio_test::io::Mock; + + async fn connect(&self, _: Duration) -> Result { + Ok(build_stream(&[ + "response/greeting.xml", + "request/login.xml", + "response/login.xml", + "request/domain/check.xml", + "response/domain/check.xml", + ]) + .build()) + } + } + + let mut client = EppClient::new(FakeConnector, "test".into(), Duration::from_secs(5)) .await .unwrap();