Improve robustness against dropping futures
This commit is contained in:
parent
478e686f24
commit
ae4c9869ab
|
@ -109,7 +109,7 @@ impl<C: Connector> EppClient<C> {
|
|||
pub async fn hello(&mut self) -> Result<Greeting, Error> {
|
||||
let hello_xml = HelloDocument::default().serialize()?;
|
||||
|
||||
let response = self.connection.transact(&hello_xml).await?;
|
||||
let response = self.connection.transact(&hello_xml)?.await?;
|
||||
|
||||
Ok(GreetingDocument::deserialize(&response)?.data)
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ impl<C: Connector> EppClient<C> {
|
|||
let epp_xml =
|
||||
<Cmd as Transaction<Ext>>::serialize_request(data.command, data.extension, id)?;
|
||||
|
||||
let response = self.connection.transact(&epp_xml).await?;
|
||||
let response = self.connection.transact(&epp_xml)?.await?;
|
||||
|
||||
match Cmd::deserialize_response(&response) {
|
||||
Ok(response) => Ok(response),
|
||||
|
@ -141,7 +141,7 @@ impl<C: Connector> EppClient<C> {
|
|||
/// Accepts raw EPP XML and returns the raw EPP XML response to it.
|
||||
/// Not recommended for direct use but sometimes can be useful for debugging
|
||||
pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
|
||||
self.connection.transact(xml).await
|
||||
self.connection.transact(xml)?.await
|
||||
}
|
||||
|
||||
/// Returns the greeting received on establishment of the connection in raw xml form
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
|
||||
use std::convert::TryInto;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use std::{io, str, u32};
|
||||
use std::{io, mem, str, u32};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::error::Error;
|
||||
|
@ -18,8 +20,16 @@ pub(crate) struct EppConnection<C: Connector> {
|
|||
stream: C::Connection,
|
||||
pub greeting: String,
|
||||
timeout: Duration,
|
||||
// Whether the connection is in a good state to start sending a request
|
||||
ready: bool,
|
||||
// A request that is currently in flight
|
||||
//
|
||||
// Because the code here currently depends on only one request being in flight at a time,
|
||||
// this needs to be finished (written, and response read) before we start another one.
|
||||
current: Option<RequestState>,
|
||||
// The next request to be sent
|
||||
//
|
||||
// If we get a request while another request is in flight (because its future was dropped),
|
||||
// we will store it here until the current request is finished.
|
||||
next: Option<RequestState>,
|
||||
}
|
||||
|
||||
impl<C: Connector> EppConnection<C> {
|
||||
|
@ -34,17 +44,259 @@ impl<C: Connector> EppConnection<C> {
|
|||
connector,
|
||||
greeting: String::new(),
|
||||
timeout,
|
||||
ready: false,
|
||||
current: None,
|
||||
next: None,
|
||||
};
|
||||
|
||||
this.greeting = this.get_epp_response().await?;
|
||||
this.ready = true;
|
||||
this.read_greeting().await?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Constructs an EPP XML request in the required form and sends it to the server
|
||||
async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> {
|
||||
let len = content.len();
|
||||
async fn read_greeting(&mut self) -> Result<(), Error> {
|
||||
assert!(self.current.is_none());
|
||||
self.current = Some(RequestState::ReadLength {
|
||||
read: 0,
|
||||
buf: vec![0; 256],
|
||||
});
|
||||
|
||||
self.greeting = RequestFuture { conn: self }.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
|
||||
debug!("{}: reconnecting", self.registry);
|
||||
let _ = self.current.take();
|
||||
let _ = self.next.take();
|
||||
self.stream = self.connector.connect(self.timeout).await?;
|
||||
self.read_greeting().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends an EPP XML request to the registry and returns the response
|
||||
pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
|
||||
debug!("{}: request: {}", self.registry, command);
|
||||
let new = RequestState::new(command)?;
|
||||
|
||||
// If we have a request currently in flight, finish that first
|
||||
// If another request was queued up behind the one in flight, just replace it
|
||||
match self.current.is_some() {
|
||||
true => {
|
||||
debug!(
|
||||
"{}: Queueing up request in order to finish in-flight request",
|
||||
self.registry
|
||||
);
|
||||
self.next = Some(new);
|
||||
}
|
||||
false => self.current = Some(new),
|
||||
}
|
||||
|
||||
Ok(RequestFuture { conn: self })
|
||||
}
|
||||
|
||||
/// Closes the socket and shuts down the connection
|
||||
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
|
||||
info!("{}: Closing connection", self.registry);
|
||||
timeout(self.timeout, self.stream.shutdown()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle(
|
||||
&mut self,
|
||||
mut state: RequestState,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Result<Transition, Error> {
|
||||
match &mut state {
|
||||
RequestState::Writing { mut start, buf } => {
|
||||
let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
|
||||
Poll::Ready(Ok(wrote)) => wrote,
|
||||
Poll::Ready(Err(err)) => return Err(err.into()),
|
||||
Poll::Pending => return Ok(Transition::Pending(state)),
|
||||
};
|
||||
|
||||
if wrote == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!("{}: Unexpected EOF while writing", self.registry),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
start += wrote;
|
||||
debug!(
|
||||
"{}: Wrote {} bytes, {} out of {} done",
|
||||
self.registry,
|
||||
wrote,
|
||||
start,
|
||||
buf.len()
|
||||
);
|
||||
|
||||
// Transition to reading the response's frame header once
|
||||
// we've written the entire request
|
||||
if start < buf.len() {
|
||||
return Ok(Transition::Next(state));
|
||||
}
|
||||
|
||||
Ok(Transition::Next(RequestState::ReadLength {
|
||||
read: 0,
|
||||
buf: vec![0; 256],
|
||||
}))
|
||||
}
|
||||
RequestState::ReadLength { mut read, buf } => {
|
||||
let mut read_buf = ReadBuf::new(&mut buf[read..]);
|
||||
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Err(err.into()),
|
||||
Poll::Pending => return Ok(Transition::Pending(state)),
|
||||
};
|
||||
|
||||
let filled = read_buf.filled();
|
||||
if filled.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!("{}: Unexpected EOF while reading length", self.registry),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// We're looking for the frame header which tells us how long the response will be.
|
||||
// The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
|
||||
// have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.
|
||||
|
||||
read += filled.len();
|
||||
if read < 4 {
|
||||
return Ok(Transition::Next(state));
|
||||
}
|
||||
|
||||
let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
|
||||
debug!("{}: Expected response length: {}", self.registry, expected);
|
||||
buf.resize(expected, 0);
|
||||
Ok(Transition::Next(RequestState::Reading {
|
||||
read,
|
||||
buf: mem::take(buf),
|
||||
expected,
|
||||
}))
|
||||
}
|
||||
RequestState::Reading {
|
||||
mut read,
|
||||
buf,
|
||||
expected,
|
||||
} => {
|
||||
let mut read_buf = ReadBuf::new(&mut buf[read..]);
|
||||
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Err(err.into()),
|
||||
Poll::Pending => return Ok(Transition::Pending(state)),
|
||||
}
|
||||
|
||||
let filled = read_buf.filled();
|
||||
if filled.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!("{}: Unexpected EOF while reading", self.registry),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
read += filled.len();
|
||||
debug!(
|
||||
"{}: Read {} bytes, {} out of {} done",
|
||||
self.registry,
|
||||
filled.len(),
|
||||
read,
|
||||
expected
|
||||
);
|
||||
|
||||
//
|
||||
|
||||
Ok(if read < *expected {
|
||||
// If we haven't received the entire response yet, stick to the `Reading` state.
|
||||
Transition::Next(state)
|
||||
} else if let Some(next) = self.next.take() {
|
||||
// Otherwise, if we were just pushing through this request because it was already
|
||||
// in flight when we started a new one, ignore this response and move to the
|
||||
// next request (the one this `RequestFuture` is actually for).
|
||||
Transition::Next(next)
|
||||
} else {
|
||||
// Otherwise, drain off the frame header and convert the rest to a `String`.
|
||||
buf.drain(..4);
|
||||
Transition::Done(String::from_utf8(mem::take(buf))?)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RequestFuture<'a, C: Connector> {
|
||||
conn: &'a mut EppConnection<C>,
|
||||
}
|
||||
|
||||
impl<'a, C: Connector> Future for RequestFuture<'a, C> {
|
||||
type Output = Result<String, Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
loop {
|
||||
let state = this.conn.current.take().unwrap();
|
||||
match this.conn.handle(state, cx) {
|
||||
Ok(Transition::Next(next)) => {
|
||||
this.conn.current = Some(next);
|
||||
continue;
|
||||
}
|
||||
Ok(Transition::Pending(state)) => {
|
||||
this.conn.current = Some(state);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
|
||||
Err(err) => {
|
||||
// Assume the error means the connection can no longer be used
|
||||
this.conn.next = None;
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transitions between `RequestState`s
|
||||
enum Transition {
|
||||
Pending(RequestState),
|
||||
Next(RequestState),
|
||||
Done(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum RequestState {
|
||||
// Writing the request command out to the peer
|
||||
Writing {
|
||||
// The amount of bytes we've already written
|
||||
start: usize,
|
||||
// The full XML request
|
||||
buf: Vec<u8>,
|
||||
},
|
||||
// Reading the frame header (32-bit big-endian unsigned integer)
|
||||
ReadLength {
|
||||
// The amount of bytes we've already read
|
||||
read: usize,
|
||||
// The buffer we're using to read into
|
||||
buf: Vec<u8>,
|
||||
},
|
||||
// Reading the entire frame
|
||||
Reading {
|
||||
// The amount of bytes we've already read
|
||||
read: usize,
|
||||
// The buffer we're using to read into
|
||||
//
|
||||
// This will still have the frame header in it, needs to be cut off before
|
||||
// yielding the response to the caller.
|
||||
buf: Vec<u8>,
|
||||
// The expected length of the response according to the frame header
|
||||
expected: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl RequestState {
|
||||
fn new(command: &str) -> Result<Self, Error> {
|
||||
let len = command.len();
|
||||
|
||||
let buf_size = len + 4;
|
||||
let mut buf: Vec<u8> = vec![0u8; buf_size];
|
||||
|
@ -53,80 +305,8 @@ impl<C: Connector> EppConnection<C> {
|
|||
let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
|
||||
|
||||
buf[..4].clone_from_slice(&len_u32);
|
||||
buf[4..].clone_from_slice(content.as_bytes());
|
||||
|
||||
let wrote = timeout(self.timeout, self.stream.write(&buf)).await?;
|
||||
debug!("{}: Wrote {} bytes", self.registry, wrote);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receives response from the socket and converts it into an EPP XML string
|
||||
async fn get_epp_response(&mut self) -> Result<String, Error> {
|
||||
let mut buf = [0u8; 4];
|
||||
timeout(self.timeout, self.stream.read_exact(&mut buf)).await?;
|
||||
|
||||
let buf_size: usize = u32::from_be_bytes(buf).try_into()?;
|
||||
|
||||
let message_size = buf_size - 4;
|
||||
debug!("{}: Response buffer size: {}", self.registry, message_size);
|
||||
|
||||
let mut buf = vec![0; message_size];
|
||||
let mut read_size: usize = 0;
|
||||
|
||||
loop {
|
||||
let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?;
|
||||
debug!("{}: Read: {} bytes", self.registry, read);
|
||||
|
||||
read_size += read;
|
||||
debug!("{}: Total read: {} bytes", self.registry, read_size);
|
||||
|
||||
if read == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!("{}: unexpected eof", self.registry),
|
||||
)
|
||||
.into());
|
||||
} else if read_size >= message_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
self.ready = true;
|
||||
Ok(String::from_utf8(buf)?)
|
||||
}
|
||||
|
||||
pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
|
||||
debug!("{}: reconnecting", self.registry);
|
||||
self.ready = false;
|
||||
self.stream = self.connector.connect(self.timeout).await?;
|
||||
self.greeting = self.get_epp_response().await?;
|
||||
self.ready = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends an EPP XML request to the registry and return the response
|
||||
/// receieved to the request
|
||||
pub(crate) async fn transact(&mut self, content: &str) -> Result<String, Error> {
|
||||
if !self.ready {
|
||||
debug!("{}: connection not ready", self.registry);
|
||||
self.reconnect().await?;
|
||||
}
|
||||
|
||||
debug!("{}: request: {}", self.registry, content);
|
||||
self.send_epp_request(content).await?;
|
||||
|
||||
let response = self.get_epp_response().await?;
|
||||
debug!("{}: response: {}", self.registry, response);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Closes the socket and shuts the connection
|
||||
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
|
||||
info!("{}: Closing connection", self.registry);
|
||||
self.ready = false;
|
||||
timeout(self.timeout, self.stream.shutdown()).await?;
|
||||
Ok(())
|
||||
buf[4..].clone_from_slice(command.as_bytes());
|
||||
Ok(Self::Writing { start: 0, buf })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
//! Error types to wrap internal errors and make EPP errors easier to read
|
||||
|
||||
use std::array::TryFromSliceError;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt::{self, Display};
|
||||
use std::io;
|
||||
|
@ -70,3 +71,9 @@ impl From<Utf8Error> for Error {
|
|||
Self::Other(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TryFromSliceError> for Error {
|
||||
fn from(e: TryFromSliceError) -> Self {
|
||||
Self::Other(e.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,7 +103,7 @@ async fn client() {
|
|||
.unwrap();
|
||||
|
||||
assert_eq!(client.xml_greeting(), xml("response/greeting.xml"));
|
||||
client
|
||||
let rsp = client
|
||||
.transact(
|
||||
&Login::new(
|
||||
"username",
|
||||
|
@ -115,6 +115,8 @@ async fn client() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(rsp.result.code, ResultCode::CommandCompletedSuccessfully);
|
||||
|
||||
let rsp = client
|
||||
.transact(
|
||||
&DomainCheck {
|
||||
|
|
Loading…
Reference in New Issue