Improve robustness against dropping futures

This commit is contained in:
Dirkjan Ochtman 2022-03-09 17:49:35 +01:00 committed by masalachai
parent 478e686f24
commit ae4c9869ab
4 changed files with 277 additions and 88 deletions

View File

@ -109,7 +109,7 @@ impl<C: Connector> EppClient<C> {
pub async fn hello(&mut self) -> Result<Greeting, Error> {
let hello_xml = HelloDocument::default().serialize()?;
let response = self.connection.transact(&hello_xml).await?;
let response = self.connection.transact(&hello_xml)?.await?;
Ok(GreetingDocument::deserialize(&response)?.data)
}
@ -127,7 +127,7 @@ impl<C: Connector> EppClient<C> {
let epp_xml =
<Cmd as Transaction<Ext>>::serialize_request(data.command, data.extension, id)?;
let response = self.connection.transact(&epp_xml).await?;
let response = self.connection.transact(&epp_xml)?.await?;
match Cmd::deserialize_response(&response) {
Ok(response) => Ok(response),
@ -141,7 +141,7 @@ impl<C: Connector> EppClient<C> {
/// Accepts raw EPP XML and returns the raw EPP XML response to it.
/// Not recommended for direct use but sometimes can be useful for debugging
pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
self.connection.transact(xml).await
self.connection.transact(xml)?.await
}
/// Returns the greeting received on establishment of the connection in raw xml form

View File

@ -2,11 +2,13 @@
use std::convert::TryInto;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, str, u32};
use std::{io, mem, str, u32};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tracing::{debug, info};
use crate::error::Error;
@ -18,8 +20,16 @@ pub(crate) struct EppConnection<C: Connector> {
stream: C::Connection,
pub greeting: String,
timeout: Duration,
// Whether the connection is in a good state to start sending a request
ready: bool,
// A request that is currently in flight
//
// Because the code here currently depends on only one request being in flight at a time,
// this needs to be finished (written, and response read) before we start another one.
current: Option<RequestState>,
// The next request to be sent
//
// If we get a request while another request is in flight (because its future was dropped),
// we will store it here until the current request is finished.
next: Option<RequestState>,
}
impl<C: Connector> EppConnection<C> {
@ -34,17 +44,259 @@ impl<C: Connector> EppConnection<C> {
connector,
greeting: String::new(),
timeout,
ready: false,
current: None,
next: None,
};
this.greeting = this.get_epp_response().await?;
this.ready = true;
this.read_greeting().await?;
Ok(this)
}
/// Constructs an EPP XML request in the required form and sends it to the server
async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> {
let len = content.len();
async fn read_greeting(&mut self) -> Result<(), Error> {
assert!(self.current.is_none());
self.current = Some(RequestState::ReadLength {
read: 0,
buf: vec![0; 256],
});
self.greeting = RequestFuture { conn: self }.await?;
Ok(())
}
pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
debug!("{}: reconnecting", self.registry);
let _ = self.current.take();
let _ = self.next.take();
self.stream = self.connector.connect(self.timeout).await?;
self.read_greeting().await?;
Ok(())
}
/// Sends an EPP XML request to the registry and returns the response
pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
debug!("{}: request: {}", self.registry, command);
let new = RequestState::new(command)?;
// If we have a request currently in flight, finish that first
// If another request was queued up behind the one in flight, just replace it
match self.current.is_some() {
true => {
debug!(
"{}: Queueing up request in order to finish in-flight request",
self.registry
);
self.next = Some(new);
}
false => self.current = Some(new),
}
Ok(RequestFuture { conn: self })
}
/// Closes the socket and shuts down the connection
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
info!("{}: Closing connection", self.registry);
timeout(self.timeout, self.stream.shutdown()).await?;
Ok(())
}
fn handle(
&mut self,
mut state: RequestState,
cx: &mut Context<'_>,
) -> Result<Transition, Error> {
match &mut state {
RequestState::Writing { mut start, buf } => {
let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
Poll::Ready(Ok(wrote)) => wrote,
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
};
if wrote == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while writing", self.registry),
)
.into());
}
start += wrote;
debug!(
"{}: Wrote {} bytes, {} out of {} done",
self.registry,
wrote,
start,
buf.len()
);
// Transition to reading the response's frame header once
// we've written the entire request
if start < buf.len() {
return Ok(Transition::Next(state));
}
Ok(Transition::Next(RequestState::ReadLength {
read: 0,
buf: vec![0; 256],
}))
}
RequestState::ReadLength { mut read, buf } => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
};
let filled = read_buf.filled();
if filled.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while reading length", self.registry),
)
.into());
}
// We're looking for the frame header which tells us how long the response will be.
// The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
// have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.
read += filled.len();
if read < 4 {
return Ok(Transition::Next(state));
}
let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
debug!("{}: Expected response length: {}", self.registry, expected);
buf.resize(expected, 0);
Ok(Transition::Next(RequestState::Reading {
read,
buf: mem::take(buf),
expected,
}))
}
RequestState::Reading {
mut read,
buf,
expected,
} => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
}
let filled = read_buf.filled();
if filled.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while reading", self.registry),
)
.into());
}
read += filled.len();
debug!(
"{}: Read {} bytes, {} out of {} done",
self.registry,
filled.len(),
read,
expected
);
//
Ok(if read < *expected {
// If we haven't received the entire response yet, stick to the `Reading` state.
Transition::Next(state)
} else if let Some(next) = self.next.take() {
// Otherwise, if we were just pushing through this request because it was already
// in flight when we started a new one, ignore this response and move to the
// next request (the one this `RequestFuture` is actually for).
Transition::Next(next)
} else {
// Otherwise, drain off the frame header and convert the rest to a `String`.
buf.drain(..4);
Transition::Done(String::from_utf8(mem::take(buf))?)
})
}
}
}
}
pub(crate) struct RequestFuture<'a, C: Connector> {
conn: &'a mut EppConnection<C>,
}
impl<'a, C: Connector> Future for RequestFuture<'a, C> {
type Output = Result<String, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
let state = this.conn.current.take().unwrap();
match this.conn.handle(state, cx) {
Ok(Transition::Next(next)) => {
this.conn.current = Some(next);
continue;
}
Ok(Transition::Pending(state)) => {
this.conn.current = Some(state);
return Poll::Pending;
}
Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
Err(err) => {
// Assume the error means the connection can no longer be used
this.conn.next = None;
return Poll::Ready(Err(err));
}
}
}
}
}
// Transitions between `RequestState`s
enum Transition {
Pending(RequestState),
Next(RequestState),
Done(String),
}
#[derive(Debug)]
enum RequestState {
// Writing the request command out to the peer
Writing {
// The amount of bytes we've already written
start: usize,
// The full XML request
buf: Vec<u8>,
},
// Reading the frame header (32-bit big-endian unsigned integer)
ReadLength {
// The amount of bytes we've already read
read: usize,
// The buffer we're using to read into
buf: Vec<u8>,
},
// Reading the entire frame
Reading {
// The amount of bytes we've already read
read: usize,
// The buffer we're using to read into
//
// This will still have the frame header in it, needs to be cut off before
// yielding the response to the caller.
buf: Vec<u8>,
// The expected length of the response according to the frame header
expected: usize,
},
}
impl RequestState {
fn new(command: &str) -> Result<Self, Error> {
let len = command.len();
let buf_size = len + 4;
let mut buf: Vec<u8> = vec![0u8; buf_size];
@ -53,80 +305,8 @@ impl<C: Connector> EppConnection<C> {
let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
buf[..4].clone_from_slice(&len_u32);
buf[4..].clone_from_slice(content.as_bytes());
let wrote = timeout(self.timeout, self.stream.write(&buf)).await?;
debug!("{}: Wrote {} bytes", self.registry, wrote);
Ok(())
}
/// Receives response from the socket and converts it into an EPP XML string
async fn get_epp_response(&mut self) -> Result<String, Error> {
let mut buf = [0u8; 4];
timeout(self.timeout, self.stream.read_exact(&mut buf)).await?;
let buf_size: usize = u32::from_be_bytes(buf).try_into()?;
let message_size = buf_size - 4;
debug!("{}: Response buffer size: {}", self.registry, message_size);
let mut buf = vec![0; message_size];
let mut read_size: usize = 0;
loop {
let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?;
debug!("{}: Read: {} bytes", self.registry, read);
read_size += read;
debug!("{}: Total read: {} bytes", self.registry, read_size);
if read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: unexpected eof", self.registry),
)
.into());
} else if read_size >= message_size {
break;
}
}
self.ready = true;
Ok(String::from_utf8(buf)?)
}
pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
debug!("{}: reconnecting", self.registry);
self.ready = false;
self.stream = self.connector.connect(self.timeout).await?;
self.greeting = self.get_epp_response().await?;
self.ready = true;
Ok(())
}
/// Sends an EPP XML request to the registry and return the response
/// receieved to the request
pub(crate) async fn transact(&mut self, content: &str) -> Result<String, Error> {
if !self.ready {
debug!("{}: connection not ready", self.registry);
self.reconnect().await?;
}
debug!("{}: request: {}", self.registry, content);
self.send_epp_request(content).await?;
let response = self.get_epp_response().await?;
debug!("{}: response: {}", self.registry, response);
Ok(response)
}
/// Closes the socket and shuts the connection
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
info!("{}: Closing connection", self.registry);
self.ready = false;
timeout(self.timeout, self.stream.shutdown()).await?;
Ok(())
buf[4..].clone_from_slice(command.as_bytes());
Ok(Self::Writing { start: 0, buf })
}
}

View File

@ -1,5 +1,6 @@
//! Error types to wrap internal errors and make EPP errors easier to read
use std::array::TryFromSliceError;
use std::error::Error as StdError;
use std::fmt::{self, Display};
use std::io;
@ -70,3 +71,9 @@ impl From<Utf8Error> for Error {
Self::Other(e.into())
}
}
impl From<TryFromSliceError> for Error {
fn from(e: TryFromSliceError) -> Self {
Self::Other(e.into())
}
}

View File

@ -103,7 +103,7 @@ async fn client() {
.unwrap();
assert_eq!(client.xml_greeting(), xml("response/greeting.xml"));
client
let rsp = client
.transact(
&Login::new(
"username",
@ -115,6 +115,8 @@ async fn client() {
.await
.unwrap();
assert_eq!(rsp.result.code, ResultCode::CommandCompletedSuccessfully);
let rsp = client
.transact(
&DomainCheck {