diff --git a/Cargo.toml b/Cargo.toml index 9e144c6..64ff86f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,17 +15,20 @@ categories = ["web-programming", "api-bindings"] default = ["hyper-rustls", "ring"] aws-lc-rs = ["dep:aws-lc-rs", "hyper-rustls?/aws-lc-rs", "rcgen/aws_lc_rs"] fips = ["aws-lc-rs", "aws-lc-rs?/fips"] +hyper-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:hyper-util"] ring = ["dep:ring", "hyper-rustls?/ring", "rcgen/ring"] [dependencies] +async-trait = "0.1" aws-lc-rs = { version = "1.8.0", optional = true } base64 = "0.21.0" bytes = "1" http = "1" +http-body = "1" http-body-util = "0.1.2" -hyper = { version = "1.3.1", features = ["client", "http1", "http2"] } +hyper = { version = "1.3.1", features = ["client", "http1", "http2"], optional = true } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "tls12", "rustls-native-certs"], optional = true } -hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1", "http2", "tokio"] } +hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1", "http2", "tokio"], optional = true } ring = { version = "0.17", features = ["std"], optional = true } rustls-pki-types = "1.1.0" serde = { version = "1.0.104", features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs index 79b9ac0..37b4b84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,17 +3,21 @@ #![warn(unreachable_pub)] #![warn(missing_docs)] +use std::error::Error as StdError; use std::fmt; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use async_trait::async_trait; use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD}; +use bytes::Bytes; use http::header::{CONTENT_TYPE, LOCATION}; use http::{Method, Request, Response, StatusCode}; use http_body_util::{BodyExt, Full}; -use hyper::body::{Bytes, Incoming}; +#[cfg(feature = "hyper-rustls")] use hyper_util::client::legacy::connect::Connect; +#[cfg(feature = "hyper-rustls")] use hyper_util::client::legacy::Client as HyperClient; #[cfg(feature = "hyper-rustls")] use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; @@ -310,7 +314,8 @@ impl Account { .await?; let account_url = rsp - .headers() + .parts + .headers .get(LOCATION) .and_then(|hv| hv.to_str().ok()) .map(|s| s.to_owned()); @@ -352,7 +357,8 @@ impl Account { let nonce = nonce_from_response(&rsp); let order_url = rsp - .headers() + .parts + .headers .get(LOCATION) .and_then(|hv| hv.to_str().ok()) .map(|s| s.to_owned()); @@ -441,7 +447,7 @@ impl AccountInner { payload: Option<&impl Serialize>, nonce: Option, url: &str, - ) -> Result, Error> { + ) -> Result { self.client.post(payload, nonce, self, url).await } } @@ -476,7 +482,7 @@ impl Client { .body(Full::default()) .expect("infallible error should not occur"); let rsp = http.request(req).await?; - let body = rsp.into_body().collect().await?.to_bytes(); + let body = rsp.body().await.map_err(Error::Other)?; Ok(Client { http, urls: serde_json::from_slice(&body)?, @@ -489,7 +495,7 @@ impl Client { nonce: Option, signer: &impl Signer, url: &str, - ) -> Result, Error> { + ) -> Result { let nonce = self.nonce(nonce).await?; let body = JoseJson::new(payload, signer.header(Some(&nonce), url), signer)?; let request = Request::builder() @@ -516,7 +522,7 @@ impl Client { // https://datatracker.ietf.org/doc/html/rfc8555#section-7.2 // "The server's response MUST include a Replay-Nonce header field containing a fresh // nonce and SHOULD have status code 200 (OK)." - if rsp.status() != StatusCode::OK { + if rsp.parts.status != StatusCode::OK { return Err("error response from newNonce resource".into()); } @@ -663,8 +669,9 @@ impl Signer for ExternalAccountKey { } } -fn nonce_from_response(rsp: &Response) -> Option { - rsp.headers() +fn nonce_from_response(rsp: &BytesResponse) -> Option { + rsp.parts + .headers .get(REPLAY_NONCE) .and_then(|hv| String::from_utf8(hv.as_ref().to_vec()).ok()) } @@ -694,34 +701,104 @@ impl HttpClient for DefaultClient { fn request( &self, req: Request>, - ) -> Pin, Error>> + Send>> { + ) -> Pin> + Send>> { let fut = self.0.request(req); - Box::pin(async move { fut.await.map_err(Error::from) }) + Box::pin(async move { + match fut.await { + Ok(rsp) => Ok(BytesResponse::from(rsp)), + Err(e) => Err(e.into()), + } + }) } } -/// A HTTP client based on [`hyper::Client`] +/// A HTTP client abstraction pub trait HttpClient: Send + Sync + 'static { /// Send the given request and return the response fn request( &self, req: Request>, - ) -> Pin, Error>> + Send>>; + ) -> Pin> + Send>>; } -impl HttpClient for HyperClient> -where - C: Connect + Clone + Send + Sync + 'static, -{ +#[cfg(feature = "hyper-rustls")] +impl HttpClient for HyperClient> { fn request( &self, req: Request>, - ) -> Pin, Error>> + Send>> { - let fut = >>::request(self, req); - Box::pin(async move { fut.await.map_err(Error::from) }) + ) -> Pin> + Send>> { + let fut = self.request(req); + Box::pin(async move { + match fut.await { + Ok(rsp) => Ok(BytesResponse::from(rsp)), + Err(e) => Err(e.into()), + } + }) } } +/// Response with object safe body type +pub struct BytesResponse { + /// Response status and header + pub parts: http::response::Parts, + /// Response body + pub body: Box, +} + +impl BytesResponse { + pub(crate) async fn body(mut self) -> Result> { + self.body.into_bytes().await + } +} + +impl From> for BytesResponse +where + B: http_body::Body + Send + Unpin + 'static, + B::Data: Send, + B::Error: Into>, +{ + fn from(rsp: Response) -> Self { + let (parts, body) = rsp.into_parts(); + Self { + parts, + body: Box::new(BodyWrapper { inner: Some(body) }), + } + } +} + +struct BodyWrapper { + inner: Option, +} + +#[async_trait] +impl BytesBody for BodyWrapper +where + B: http_body::Body + Send + Unpin + 'static, + B::Data: Send, + B::Error: Into>, +{ + async fn into_bytes(&mut self) -> Result> { + let Some(body) = self.inner.take() else { + return Ok(Bytes::new()); + }; + + match body.collect().await { + Ok(body) => Ok(body.to_bytes()), + Err(e) => Err(e.into()), + } + } +} + +/// Object safe body trait +#[async_trait] +pub trait BytesBody { + /// Convert the body into [`Bytes`] + /// + /// This consumes the body. The behavior for calling this method multiple times is undefined. + #[allow(clippy::wrong_self_convention)] // async_trait doesn't support taking `self` + async fn into_bytes(&mut self) -> Result>; +} + mod crypto { #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] pub(crate) use aws_lc_rs as ring_like; diff --git a/src/types.rs b/src/types.rs index 2f7d16f..0f6ee7b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -2,9 +2,6 @@ use std::fmt; use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD}; use bytes::Bytes; -use http_body_util::BodyExt; -use hyper::body::Incoming; -use hyper::Response; use rustls_pki_types::CertificateDer; use serde::de::DeserializeOwned; use serde::ser::SerializeMap; @@ -12,6 +9,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::crypto::{self, KeyPair}; +use crate::BytesResponse; /// Error type for instant-acme #[derive(Debug, Error)] @@ -34,6 +32,7 @@ pub enum Error { #[error("HTTP request failure: {0}")] Http(#[from] http::Error), /// Hyper request failure + #[cfg(feature = "hyper-rustls")] #[error("HTTP request failure: {0}")] Hyper(#[from] hyper::Error), /// Invalid ACME server URL @@ -56,6 +55,7 @@ impl From<&'static str> for Error { } } +#[cfg(feature = "hyper-rustls")] impl From for Error { fn from(value: hyper_util::client::legacy::Error) -> Self { Self::Other(Box::new(value)) @@ -134,13 +134,13 @@ pub struct Problem { } impl Problem { - pub(crate) async fn check(rsp: Response) -> Result { + pub(crate) async fn check(rsp: BytesResponse) -> Result { Ok(serde_json::from_slice(&Self::from_response(rsp).await?)?) } - pub(crate) async fn from_response(rsp: Response) -> Result { - let status = rsp.status(); - let body = rsp.into_body().collect().await?.to_bytes(); + pub(crate) async fn from_response(rsp: BytesResponse) -> Result { + let status = rsp.parts.status; + let body = rsp.body().await.map_err(Error::Other)?; match status.is_informational() || status.is_success() || status.is_redirection() { true => Ok(body), false => Err(serde_json::from_slice::(&body)?.into()),