From ed73ff546fb81cbdd61b5cdc4dd6a1e099e08c2e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 2 Aug 2023 15:30:52 +0200 Subject: [PATCH] Fix compilation with --no-default-features --- .github/workflows/rust.yml | 4 + src/client.rs | 178 ++++++++++++++++++++----------------- 2 files changed, 100 insertions(+), 82 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9918ed7..5d3dfb8 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,6 +34,10 @@ jobs: with: command: test args: --all-features + - uses: actions-rs/cargo@v1 + with: + command: test + args: --no-default-features lint: runs-on: ubuntu-latest diff --git a/src/client.rs b/src/client.rs index 94a2176..82e6e4d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,22 +1,12 @@ -use std::io; -use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; -#[cfg(feature = "tokio-rustls")] -use tokio::net::lookup_host; -use tokio::net::TcpStream; -#[cfg(feature = "tokio-rustls")] -use tokio_rustls::client::TlsStream; -#[cfg(feature = "tokio-rustls")] -use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; -#[cfg(feature = "tokio-rustls")] -use tokio_rustls::TlsConnector; -use tracing::{debug, error, info}; +use tracing::{debug, error}; -use crate::common::{Certificate, NoExtension, PrivateKey}; +use crate::common::NoExtension; +#[cfg(feature = "tokio-rustls")] +use crate::common::{Certificate, PrivateKey}; pub use crate::connection::Connector; -use crate::connection::{self, EppConnection}; +use crate::connection::EppConnection; use crate::error::Error; use crate::hello::{Greeting, Hello}; use crate::request::{Command, CommandWrapper, Extension, Transaction}; @@ -39,6 +29,7 @@ use crate::xml; /// use instant_epp::domain::DomainCheck; /// use instant_epp::common::NoExtension; /// +/// # #[cfg(feature = "tokio-rustls")] /// # #[tokio::main] /// # async fn main() { /// // Create an instance of EppClient @@ -62,6 +53,9 @@ use crate::xml; /// .iter() /// .for_each(|chk| println!("Domain: {}, Available: {}", chk.inner.id, chk.inner.available)); /// # } +/// # +/// # #[cfg(not(feature = "tokio-rustls"))] +/// # fn main() {} /// ``` /// /// The output would look like this: @@ -215,77 +209,97 @@ impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> { impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {} #[cfg(feature = "tokio-rustls")] -pub struct RustlsConnector { - inner: TlsConnector, - domain: ServerName, - server: (String, u16), -} - -impl RustlsConnector { - pub async fn new( - server: (String, u16), - identity: Option<(Vec, PrivateKey)>, - ) -> Result { - let mut roots = RootCertStore::empty(); - roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - let builder = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots); - - let config = match identity { - Some((certs, key)) => { - let certs = certs - .into_iter() - .map(|cert| tokio_rustls::rustls::Certificate(cert.0)) - .collect(); - builder - .with_client_auth_cert(certs, tokio_rustls::rustls::PrivateKey(key.0)) - .map_err(|e| Error::Other(e.into()))? - } - None => builder.with_no_client_auth(), - }; - - let domain = server.0.as_str().try_into().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - format!("Invalid domain: {}", server.0), - ) - })?; - - Ok(Self { - inner: TlsConnector::from(Arc::new(config)), - domain, - server, - }) - } -} +use rustls_connector::RustlsConnector; #[cfg(feature = "tokio-rustls")] -#[async_trait] -impl Connector for RustlsConnector { - type Connection = TlsStream; +mod rustls_connector { + use std::io; + use std::sync::Arc; + use std::time::Duration; - async fn connect(&self, timeout: Duration) -> Result { - info!("Connecting to server: {}:{}", self.server.0, self.server.1); - let addr = match lookup_host(&self.server).await?.next() { - Some(addr) => addr, - None => { - return Err(Error::Io(io::Error::new( + use async_trait::async_trait; + use tokio::net::lookup_host; + use tokio::net::TcpStream; + use tokio_rustls::client::TlsStream; + use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; + use tokio_rustls::TlsConnector; + use tracing::info; + + use crate::common::{Certificate, PrivateKey}; + use crate::connection::{self, Connector}; + use crate::error::Error; + + pub struct RustlsConnector { + inner: TlsConnector, + domain: ServerName, + server: (String, u16), + } + + impl RustlsConnector { + pub async fn new( + server: (String, u16), + identity: Option<(Vec, PrivateKey)>, + ) -> Result { + let mut roots = RootCertStore::empty(); + roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + let builder = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots); + + let config = match identity { + Some((certs, key)) => { + let certs = certs + .into_iter() + .map(|cert| tokio_rustls::rustls::Certificate(cert.0)) + .collect(); + builder + .with_client_auth_cert(certs, tokio_rustls::rustls::PrivateKey(key.0)) + .map_err(|e| Error::Other(e.into()))? + } + None => builder.with_no_client_auth(), + }; + + let domain = server.0.as_str().try_into().map_err(|_| { + io::Error::new( io::ErrorKind::InvalidInput, - format!("Invalid host: {}", &self.server.0), - ))) - } - }; + format!("Invalid domain: {}", server.0), + ) + })?; - let stream = TcpStream::connect(addr).await?; - let future = self.inner.connect(self.domain.clone(), stream); - connection::timeout(timeout, future).await + Ok(Self { + inner: TlsConnector::from(Arc::new(config)), + domain, + server, + }) + } + } + + #[async_trait] + impl Connector for RustlsConnector { + type Connection = TlsStream; + + async fn connect(&self, timeout: Duration) -> Result { + info!("Connecting to server: {}:{}", self.server.0, self.server.1); + let addr = match lookup_host(&self.server).await?.next() { + Some(addr) => addr, + None => { + return Err(Error::Io(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid host: {}", &self.server.0), + ))) + } + }; + + let stream = TcpStream::connect(addr).await?; + let future = self.inner.connect(self.domain.clone(), stream); + connection::timeout(timeout, future).await + } } }