Move responsibility for stream setup into client

This commit is contained in:
Dirkjan Ochtman 2021-12-22 11:41:34 +01:00 committed by masalachai
parent d69439ff24
commit 44f3fbef53
2 changed files with 59 additions and 69 deletions

View File

@ -33,10 +33,16 @@
//! } //! }
//! ``` //! ```
use std::convert::TryInto;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::io;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream; use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::TlsConnector;
use tracing::info;
use crate::common::{Certificate, NoExtension, PrivateKey}; use crate::common::{Certificate, NoExtension, PrivateKey};
use crate::connection::EppConnection; use crate::connection::EppConnection;
@ -62,8 +68,9 @@ impl EppClient {
hostname: &str, hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>, identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let stream = epp_connect(addr, hostname, identity).await?;
Ok(Self { Ok(Self {
connection: EppConnection::connect(registry, addr, hostname, identity).await?, connection: EppConnection::new(registry, stream).await?,
}) })
} }
@ -114,6 +121,54 @@ impl EppClient {
} }
} }
/// Establishes a TLS connection to a registry and returns a ConnectionStream instance containing the
/// socket stream to read/write to the connection
async fn epp_connect(
addr: SocketAddr,
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<TlsStream<TcpStream>, Error> {
info!("Connecting to server: {:?}", addr,);
let mut roots = RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| rustls::Certificate(cert.0))
.collect();
builder
.with_single_cert(certs, rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(&addr).await?;
let domain = hostname.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname),
)
})?;
Ok(connector.connect(domain, stream).await?)
}
pub struct RequestData<'a, C, E> { pub struct RequestData<'a, C, E> {
command: &'a C, command: &'a C,
extension: Option<&'a E>, extension: Option<&'a E>,

View File

@ -1,17 +1,11 @@
//! Manages registry connections and reading/writing to them //! Manages registry connections and reading/writing to them
use std::convert::TryInto; use std::convert::TryInto;
use std::net::SocketAddr;
use std::sync::Arc;
use std::{io, str, u32}; use std::{io, str, u32};
use rustls::{OwnedTrustAnchor, RootCertStore};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::common::{Certificate, PrivateKey};
use crate::error::Error; use crate::error::Error;
/// EPP Connection struct with some metadata for the connection /// EPP Connection struct with some metadata for the connection
@ -21,31 +15,20 @@ pub(crate) struct EppConnection<IO> {
pub greeting: String, pub greeting: String,
} }
impl EppConnection<TlsStream<TcpStream>> { impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
/// Create an EppConnection instance with the stream to the registry pub(crate) async fn new(registry: String, mut stream: IO) -> Result<Self, Error> {
pub(crate) async fn connect(
registry: String,
addr: SocketAddr,
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut stream = epp_connect(addr, hostname, identity).await?;
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
stream.read(&mut buf).await?; stream.read(&mut buf).await?;
let greeting = str::from_utf8(&buf[4..])?.to_string(); let greeting = str::from_utf8(&buf[4..])?.to_string();
debug!("{}: greeting: {}", registry, greeting); debug!("{}: greeting: {}", registry, greeting);
Ok(EppConnection { Ok(Self {
registry, registry,
stream, stream,
greeting, greeting,
}) })
} }
}
impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
/// Constructs an EPP XML request in the required form and sends it to the server /// Constructs an EPP XML request in the required form and sends it to the server
async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> { async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> {
let len = content.len(); let len = content.len();
@ -118,51 +101,3 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> EppConnection<IO> {
Ok(()) Ok(())
} }
} }
/// Establishes a TLS connection to a registry and returns a ConnectionStream instance containing the
/// socket stream to read/write to the connection
async fn epp_connect(
addr: SocketAddr,
hostname: &str,
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<TlsStream<TcpStream>, Error> {
info!("Connecting to server: {:?}", addr,);
let mut roots = RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| rustls::Certificate(cert.0))
.collect();
builder
.with_single_cert(certs, rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(&addr).await?;
let domain = hostname.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", hostname),
)
})?;
Ok(connector.connect(domain, stream).await?)
}