Allow the Connection to connect itself

This commit is contained in:
Dirkjan Ochtman 2022-02-04 22:38:45 +01:00 committed by masalachai
parent 95b62891c9
commit 36558c429c
4 changed files with 119 additions and 67 deletions

View File

@ -11,6 +11,7 @@ repository = "https://github.com/masalachai/epp-client"
default = ["tokio-rustls"] default = ["tokio-rustls"]
[dependencies] [dependencies]
async-trait = "0.1.52"
celes = "2.1" celes = "2.1"
chrono = "0.4" chrono = "0.4"
quick-xml = { version = "0.22", features = [ "serialize" ] } quick-xml = { version = "0.22", features = [ "serialize" ] }

View File

@ -4,17 +4,18 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite}; use async_trait::async_trait;
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
use tokio_rustls::client::TlsStream; use tokio_rustls::client::TlsStream;
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tracing::info; use tracing::info;
use crate::common::{Certificate, NoExtension, PrivateKey}; use crate::common::{Certificate, NoExtension, PrivateKey};
pub use crate::connection::Connector;
use crate::connection::{self, EppConnection}; use crate::connection::{self, EppConnection};
use crate::error::Error; use crate::error::Error;
use crate::hello::{Greeting, GreetingDocument, HelloDocument}; use crate::hello::{Greeting, GreetingDocument, HelloDocument};
@ -68,12 +69,12 @@ use crate::xml::EppXml;
/// Domain: eppdev.com, Available: 1 /// Domain: eppdev.com, Available: 1
/// Domain: eppdev.net, Available: 1 /// Domain: eppdev.net, Available: 1
/// ``` /// ```
pub struct EppClient<IO> { pub struct EppClient<C: Connector> {
connection: EppConnection<IO>, connection: EppConnection<C>,
} }
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
impl EppClient<TlsStream<TcpStream>> { impl EppClient<RustlsConnector> {
/// Connect to the specified `addr` and `hostname` over TLS /// Connect to the specified `addr` and `hostname` over TLS
/// ///
/// The `registry` is used as a name in internal logging; `addr` provides the address to /// The `registry` is used as a name in internal logging; `addr` provides the address to
@ -91,52 +92,16 @@ impl EppClient<TlsStream<TcpStream>> {
timeout: Duration, timeout: Duration,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
info!("Connecting to server: {:?}", addr); info!("Connecting to server: {:?}", addr);
let connector = RustlsConnector::new(addr, hostname, identity)?;
let mut roots = RootCertStore::empty(); Self::new(connector, registry, timeout).await
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| tokio_rustls::rustls::Certificate(cert.0))
.collect();
builder
.with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let domain = hostname.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname),
)
})?;
let connector = TlsConnector::from(Arc::new(config));
let future = connector.connect(domain, TcpStream::connect(&addr).await?);
let stream = connection::timeout(timeout, future).await?;
Self::new(registry, stream, timeout).await
} }
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> EppClient<IO> { impl<C: Connector> EppClient<C> {
/// Create an `EppClient` from an already established connection /// Create an `EppClient` from an already established connection
pub async fn new(registry: String, stream: IO, timeout: Duration) -> Result<Self, Error> { pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
Ok(Self { Ok(Self {
connection: EppConnection::new(registry, stream, timeout).await?, connection: EppConnection::new(connector, registry, timeout).await?,
}) })
} }
@ -149,21 +114,22 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppClient<IO> {
Ok(GreetingDocument::deserialize(&response)?.data) Ok(GreetingDocument::deserialize(&response)?.data)
} }
pub async fn transact<'c, 'e, C, E>( pub async fn transact<'c, 'e, Cmd, Ext>(
&mut self, &mut self,
data: impl Into<RequestData<'c, 'e, C, E>>, data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
id: &str, id: &str,
) -> Result<Response<C::Response, E::Response>, Error> ) -> Result<Response<Cmd::Response, Ext::Response>, Error>
where where
C: Transaction<E> + Command + 'c, Cmd: Transaction<Ext> + Command + 'c,
E: Extension + 'e, Ext: Extension + 'e,
{ {
let data = data.into(); let data = data.into();
let epp_xml = <C as Transaction<E>>::serialize_request(data.command, data.extension, id)?; let epp_xml =
<Cmd as Transaction<Ext>>::serialize_request(data.command, data.extension, id)?;
let response = self.connection.transact(&epp_xml).await?; let response = self.connection.transact(&epp_xml).await?;
C::deserialize_response(&response) Cmd::deserialize_response(&response)
} }
/// Accepts raw EPP XML and returns the raw EPP XML response to it. /// Accepts raw EPP XML and returns the raw EPP XML response to it.
@ -209,3 +175,69 @@ impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c,
} }
} }
} }
#[cfg(feature = "tokio-rustls")]
pub struct RustlsConnector {
inner: TlsConnector,
domain: ServerName,
addr: SocketAddr,
}
impl RustlsConnector {
pub fn new(
addr: SocketAddr,
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut roots = RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| tokio_rustls::rustls::Certificate(cert.0))
.collect();
builder
.with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let domain = hostname.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname),
)
})?;
Ok(Self {
inner: TlsConnector::from(Arc::new(config)),
domain,
addr,
})
}
}
#[cfg(feature = "tokio-rustls")]
#[async_trait]
impl Connector for RustlsConnector {
type Connection = TlsStream<TcpStream>;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
let stream = TcpStream::connect(&self.addr).await?;
let future = self.inner.connect(self.domain.clone(), stream);
connection::timeout(timeout, future).await
}
}

View File

@ -5,28 +5,29 @@ use std::future::Future;
use std::time::Duration; use std::time::Duration;
use std::{io, str, u32}; use std::{io, str, u32};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::error::Error; use crate::error::Error;
/// EPP Connection struct with some metadata for the connection /// EPP Connection struct with some metadata for the connection
pub(crate) struct EppConnection<IO> { pub(crate) struct EppConnection<C: Connector> {
registry: String, registry: String,
stream: IO, stream: C::Connection,
pub greeting: String, pub greeting: String,
timeout: Duration, timeout: Duration,
} }
impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> { impl<C: Connector> EppConnection<C> {
pub(crate) async fn new( pub(crate) async fn new(
connector: C,
registry: String, registry: String,
stream: IO,
timeout: Duration, timeout: Duration,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut this = Self { let mut this = Self {
registry, registry,
stream, stream: connector.connect(timeout).await?,
greeting: String::new(), greeting: String::new(),
timeout, timeout,
}; };
@ -118,3 +119,10 @@ pub(crate) async fn timeout<T, E: Into<Error>>(
Err(_) => Err(Error::Timeout), Err(_) => Err(Error::Timeout),
} }
} }
#[async_trait]
pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Unpin;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
}

View File

@ -3,6 +3,7 @@ use std::io::{self, Read, Write};
use std::str; use std::str;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use tokio_test::io::Builder; use tokio_test::io::Builder;
@ -78,16 +79,26 @@ fn build_stream(units: &[&str]) -> Builder {
#[tokio::test] #[tokio::test]
async fn client() { async fn client() {
let _guard = log_to_stdout(); let _guard = log_to_stdout();
let stream = build_stream(&[
struct FakeConnector;
#[async_trait]
impl epp_client::client::Connector for FakeConnector {
type Connection = tokio_test::io::Mock;
async fn connect(&self, _: Duration) -> Result<Self::Connection, epp_client::Error> {
Ok(build_stream(&[
"response/greeting.xml", "response/greeting.xml",
"request/login.xml", "request/login.xml",
"response/login.xml", "response/login.xml",
"request/domain/check.xml", "request/domain/check.xml",
"response/domain/check.xml", "response/domain/check.xml",
]) ])
.build(); .build())
}
}
let mut client = EppClient::new("test".into(), stream, Duration::from_secs(5)) let mut client = EppClient::new(FakeConnector, "test".into(), Duration::from_secs(5))
.await .await
.unwrap(); .unwrap();