Fix compilation with --no-default-features

This commit is contained in:
Dirkjan Ochtman 2023-08-02 15:30:52 +02:00
parent c1b118f378
commit ed73ff546f
2 changed files with 100 additions and 82 deletions

View File

@ -34,6 +34,10 @@ jobs:
with: with:
command: test command: test
args: --all-features args: --all-features
- uses: actions-rs/cargo@v1
with:
command: test
args: --no-default-features
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -1,22 +1,12 @@
use std::io;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; use tracing::{debug, error};
#[cfg(feature = "tokio-rustls")]
use tokio::net::lookup_host;
use tokio::net::TcpStream;
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::client::TlsStream;
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::TlsConnector;
use tracing::{debug, error, info};
use crate::common::{Certificate, NoExtension, PrivateKey}; use crate::common::NoExtension;
#[cfg(feature = "tokio-rustls")]
use crate::common::{Certificate, PrivateKey};
pub use crate::connection::Connector; pub use crate::connection::Connector;
use crate::connection::{self, EppConnection}; use crate::connection::EppConnection;
use crate::error::Error; use crate::error::Error;
use crate::hello::{Greeting, Hello}; use crate::hello::{Greeting, Hello};
use crate::request::{Command, CommandWrapper, Extension, Transaction}; use crate::request::{Command, CommandWrapper, Extension, Transaction};
@ -39,6 +29,7 @@ use crate::xml;
/// use instant_epp::domain::DomainCheck; /// use instant_epp::domain::DomainCheck;
/// use instant_epp::common::NoExtension; /// use instant_epp::common::NoExtension;
/// ///
/// # #[cfg(feature = "tokio-rustls")]
/// # #[tokio::main] /// # #[tokio::main]
/// # async fn main() { /// # async fn main() {
/// // Create an instance of EppClient /// // Create an instance of EppClient
@ -62,6 +53,9 @@ use crate::xml;
/// .iter() /// .iter()
/// .for_each(|chk| println!("Domain: {}, Available: {}", chk.inner.id, chk.inner.available)); /// .for_each(|chk| println!("Domain: {}, Available: {}", chk.inner.id, chk.inner.available));
/// # } /// # }
/// #
/// # #[cfg(not(feature = "tokio-rustls"))]
/// # fn main() {}
/// ``` /// ```
/// ///
/// The output would look like this: /// The output would look like this:
@ -215,77 +209,97 @@ impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {} impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
pub struct RustlsConnector { use rustls_connector::RustlsConnector;
inner: TlsConnector,
domain: ServerName,
server: (String, u16),
}
impl RustlsConnector {
pub async fn new(
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut roots = RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| tokio_rustls::rustls::Certificate(cert.0))
.collect();
builder
.with_client_auth_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let domain = server.0.as_str().try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", server.0),
)
})?;
Ok(Self {
inner: TlsConnector::from(Arc::new(config)),
domain,
server,
})
}
}
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
#[async_trait] mod rustls_connector {
impl Connector for RustlsConnector { use std::io;
type Connection = TlsStream<TcpStream>; use std::sync::Arc;
use std::time::Duration;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> { use async_trait::async_trait;
info!("Connecting to server: {}:{}", self.server.0, self.server.1); use tokio::net::lookup_host;
let addr = match lookup_host(&self.server).await?.next() { use tokio::net::TcpStream;
Some(addr) => addr, use tokio_rustls::client::TlsStream;
None => { use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
return Err(Error::Io(io::Error::new( use tokio_rustls::TlsConnector;
use tracing::info;
use crate::common::{Certificate, PrivateKey};
use crate::connection::{self, Connector};
use crate::error::Error;
pub struct RustlsConnector {
inner: TlsConnector,
domain: ServerName,
server: (String, u16),
}
impl RustlsConnector {
pub async fn new(
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut roots = RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| tokio_rustls::rustls::Certificate(cert.0))
.collect();
builder
.with_client_auth_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let domain = server.0.as_str().try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("Invalid host: {}", &self.server.0), format!("Invalid domain: {}", server.0),
))) )
} })?;
};
let stream = TcpStream::connect(addr).await?; Ok(Self {
let future = self.inner.connect(self.domain.clone(), stream); inner: TlsConnector::from(Arc::new(config)),
connection::timeout(timeout, future).await domain,
server,
})
}
}
#[async_trait]
impl Connector for RustlsConnector {
type Connection = TlsStream<TcpStream>;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
info!("Connecting to server: {}:{}", self.server.0, self.server.1);
let addr = match lookup_host(&self.server).await?.next() {
Some(addr) => addr,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid host: {}", &self.server.0),
)))
}
};
let stream = TcpStream::connect(addr).await?;
let future = self.inner.connect(self.domain.clone(), stream);
connection::timeout(timeout, future).await
}
} }
} }