From 001d3d91ea4a4c6f0568003a16da94889f484ab9 Mon Sep 17 00:00:00 2001 From: Rudi Floren Date: Mon, 16 Jan 2023 14:13:43 +0100 Subject: [PATCH] Revert "Improve robustness against dropping futures" This reverts commit ae4c9869abaf2ffad1d6e440a08061f09bddcf3f. --- src/client.rs | 6 +- src/connection.rs | 347 +++++++++++----------------------------------- src/error.rs | 7 - tests/basic.rs | 4 +- 4 files changed, 88 insertions(+), 276 deletions(-) diff --git a/src/client.rs b/src/client.rs index 676aaca..dae1f25 100644 --- a/src/client.rs +++ b/src/client.rs @@ -106,7 +106,7 @@ impl EppClient { let xml = xml::serialize(&HelloDocument::default())?; debug!("{}: hello: {}", self.connection.registry, &xml); - let response = self.connection.transact(&xml)?.await?; + let response = self.connection.transact(&xml).await?; debug!("{}: greeting: {}", self.connection.registry, &response); Ok(xml::deserialize::(&response)?.data) @@ -126,7 +126,7 @@ impl EppClient { let xml = xml::serialize(&document)?; debug!("{}: request: {}", self.connection.registry, &xml); - let response = self.connection.transact(&xml)?.await?; + let response = self.connection.transact(&xml).await?; debug!("{}: response: {}", self.connection.registry, &response); let rsp = xml::deserialize::>(&response)?; @@ -146,7 +146,7 @@ impl EppClient { /// Accepts raw EPP XML and returns the raw EPP XML response to it. /// Not recommended for direct use but sometimes can be useful for debugging pub async fn transact_xml(&mut self, xml: &str) -> Result { - self.connection.transact(xml)?.await + self.connection.transact(xml).await } /// Returns the greeting received on establishment of the connection in raw xml form diff --git a/src/connection.rs b/src/connection.rs index 4ceaa57..eab0769 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,13 +1,11 @@ //! Manages registry connections and reading/writing to them use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::time::Duration; -use std::{io, mem, str, u32}; +use std::{io, str, u32}; use async_trait::async_trait; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info}; use crate::error::Error; @@ -19,16 +17,8 @@ pub(crate) struct EppConnection { stream: C::Connection, pub greeting: String, timeout: Duration, - // A request that is currently in flight - // - // Because the code here currently depends on only one request being in flight at a time, - // this needs to be finished (written, and response read) before we start another one. - current: Option, - // The next request to be sent - // - // If we get a request while another request is in flight (because its future was dropped), - // we will store it here until the current request is finished. - next: Option, + // Whether the connection is in a good state to start sending a request + ready: bool, } impl EppConnection { @@ -43,258 +33,17 @@ impl EppConnection { connector, greeting: String::new(), timeout, - current: None, - next: None, + ready: false, }; - this.read_greeting().await?; + this.greeting = this.get_epp_response().await?; + this.ready = true; Ok(this) } - async fn read_greeting(&mut self) -> Result<(), Error> { - assert!(self.current.is_none()); - self.current = Some(RequestState::ReadLength { - read: 0, - buf: vec![0; 256], - }); - - self.greeting = RequestFuture { conn: self }.await?; - Ok(()) - } - - pub(crate) async fn reconnect(&mut self) -> Result<(), Error> { - debug!("{}: reconnecting", self.registry); - let _ = self.current.take(); - let _ = self.next.take(); - self.stream = self.connector.connect(self.timeout).await?; - self.read_greeting().await?; - Ok(()) - } - - /// Sends an EPP XML request to the registry and returns the response - pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result, Error> { - let new = RequestState::new(command)?; - - // If we have a request currently in flight, finish that first - // If another request was queued up behind the one in flight, just replace it - match self.current.is_some() { - true => { - debug!( - "{}: Queueing up request in order to finish in-flight request", - self.registry - ); - self.next = Some(new); - } - false => self.current = Some(new), - } - - Ok(RequestFuture { conn: self }) - } - - /// Closes the socket and shuts down the connection - pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { - info!("{}: Closing connection", self.registry); - timeout(self.timeout, self.stream.shutdown()).await?; - Ok(()) - } - - fn handle( - &mut self, - mut state: RequestState, - cx: &mut Context<'_>, - ) -> Result { - match &mut state { - RequestState::Writing { mut start, buf } => { - let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) { - Poll::Ready(Ok(wrote)) => wrote, - Poll::Ready(Err(err)) => return Err(err.into()), - Poll::Pending => return Ok(Transition::Pending(state)), - }; - - if wrote == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("{}: Unexpected EOF while writing", self.registry), - ) - .into()); - } - - start += wrote; - debug!( - "{}: Wrote {} bytes, {} out of {} done", - self.registry, - wrote, - start, - buf.len() - ); - - // Transition to reading the response's frame header once - // we've written the entire request - if start < buf.len() { - return Ok(Transition::Next(state)); - } - - Ok(Transition::Next(RequestState::ReadLength { - read: 0, - buf: vec![0; 256], - })) - } - RequestState::ReadLength { mut read, buf } => { - let mut read_buf = ReadBuf::new(&mut buf[read..]); - match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => return Err(err.into()), - Poll::Pending => return Ok(Transition::Pending(state)), - }; - - let filled = read_buf.filled(); - if filled.is_empty() { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("{}: Unexpected EOF while reading length", self.registry), - ) - .into()); - } - - // We're looking for the frame header which tells us how long the response will be. - // The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't - // have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`. - - read += filled.len(); - if read < 4 { - return Ok(Transition::Next(state)); - } - - let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize; - debug!("{}: Expected response length: {}", self.registry, expected); - buf.resize(expected, 0); - Ok(Transition::Next(RequestState::Reading { - read, - buf: mem::take(buf), - expected, - })) - } - RequestState::Reading { - mut read, - buf, - expected, - } => { - let mut read_buf = ReadBuf::new(&mut buf[read..]); - match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(err)) => return Err(err.into()), - Poll::Pending => return Ok(Transition::Pending(state)), - } - - let filled = read_buf.filled(); - if filled.is_empty() { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("{}: Unexpected EOF while reading", self.registry), - ) - .into()); - } - - read += filled.len(); - debug!( - "{}: Read {} bytes, {} out of {} done", - self.registry, - filled.len(), - read, - expected - ); - - // - - Ok(if read < *expected { - // If we haven't received the entire response yet, stick to the `Reading` state. - Transition::Next(state) - } else if let Some(next) = self.next.take() { - // Otherwise, if we were just pushing through this request because it was already - // in flight when we started a new one, ignore this response and move to the - // next request (the one this `RequestFuture` is actually for). - Transition::Next(next) - } else { - // Otherwise, drain off the frame header and convert the rest to a `String`. - buf.drain(..4); - Transition::Done(String::from_utf8(mem::take(buf))?) - }) - } - } - } -} - -pub(crate) struct RequestFuture<'a, C: Connector> { - conn: &'a mut EppConnection, -} - -impl<'a, C: Connector> Future for RequestFuture<'a, C> { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - loop { - let state = this.conn.current.take().unwrap(); - match this.conn.handle(state, cx) { - Ok(Transition::Next(next)) => { - this.conn.current = Some(next); - continue; - } - Ok(Transition::Pending(state)) => { - this.conn.current = Some(state); - return Poll::Pending; - } - Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)), - Err(err) => { - // Assume the error means the connection can no longer be used - this.conn.next = None; - return Poll::Ready(Err(err)); - } - } - } - } -} - -// Transitions between `RequestState`s -enum Transition { - Pending(RequestState), - Next(RequestState), - Done(String), -} - -#[derive(Debug)] -enum RequestState { - // Writing the request command out to the peer - Writing { - // The amount of bytes we've already written - start: usize, - // The full XML request - buf: Vec, - }, - // Reading the frame header (32-bit big-endian unsigned integer) - ReadLength { - // The amount of bytes we've already read - read: usize, - // The buffer we're using to read into - buf: Vec, - }, - // Reading the entire frame - Reading { - // The amount of bytes we've already read - read: usize, - // The buffer we're using to read into - // - // This will still have the frame header in it, needs to be cut off before - // yielding the response to the caller. - buf: Vec, - // The expected length of the response according to the frame header - expected: usize, - }, -} - -impl RequestState { - fn new(command: &str) -> Result { - let len = command.len(); + /// Constructs an EPP XML request in the required form and sends it to the server + async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> { + let len = content.len(); let buf_size = len + 4; let mut buf: Vec = vec![0u8; buf_size]; @@ -303,8 +52,80 @@ impl RequestState { let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?); buf[..4].clone_from_slice(&len_u32); - buf[4..].clone_from_slice(command.as_bytes()); - Ok(Self::Writing { start: 0, buf }) + buf[4..].clone_from_slice(content.as_bytes()); + + let wrote = timeout(self.timeout, self.stream.write(&buf)).await?; + debug!("{}: Wrote {} bytes", self.registry, wrote); + Ok(()) + } + + /// Receives response from the socket and converts it into an EPP XML string + async fn get_epp_response(&mut self) -> Result { + let mut buf = [0u8; 4]; + timeout(self.timeout, self.stream.read_exact(&mut buf)).await?; + + let buf_size: usize = u32::from_be_bytes(buf).try_into()?; + + let message_size = buf_size - 4; + debug!("{}: Response buffer size: {}", self.registry, message_size); + + let mut buf = vec![0; message_size]; + let mut read_size: usize = 0; + + loop { + let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?; + debug!("{}: Read: {} bytes", self.registry, read); + + read_size += read; + debug!("{}: Total read: {} bytes", self.registry, read_size); + + if read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("{}: unexpected eof", self.registry), + ) + .into()); + } else if read_size >= message_size { + break; + } + } + + self.ready = true; + Ok(String::from_utf8(buf)?) + } + + pub(crate) async fn reconnect(&mut self) -> Result<(), Error> { + debug!("{}: reconnecting", self.registry); + self.ready = false; + self.stream = self.connector.connect(self.timeout).await?; + self.greeting = self.get_epp_response().await?; + self.ready = true; + Ok(()) + } + + /// Sends an EPP XML request to the registry and return the response + /// receieved to the request + pub(crate) async fn transact(&mut self, content: &str) -> Result { + if !self.ready { + debug!("{}: connection not ready", self.registry); + self.reconnect().await?; + } + + debug!("{}: request: {}", self.registry, content); + self.send_epp_request(content).await?; + + let response = self.get_epp_response().await?; + debug!("{}: response: {}", self.registry, response); + + Ok(response) + } + + /// Closes the socket and shuts the connection + pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { + info!("{}: Closing connection", self.registry); + self.ready = false; + timeout(self.timeout, self.stream.shutdown()).await?; + Ok(()) } } diff --git a/src/error.rs b/src/error.rs index 5a2e3f0..806e8c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,5 @@ //! Error types to wrap internal errors and make EPP errors easier to read -use std::array::TryFromSliceError; use std::error::Error as StdError; use std::fmt::{self, Display}; use std::io; @@ -71,9 +70,3 @@ impl From for Error { Self::Other(e.into()) } } - -impl From for Error { - fn from(e: TryFromSliceError) -> Self { - Self::Other(e.into()) - } -} diff --git a/tests/basic.rs b/tests/basic.rs index 9bc1bd7..bc7a546 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -104,7 +104,7 @@ async fn client() { .unwrap(); assert_eq!(client.xml_greeting(), xml("response/greeting.xml")); - let rsp = client + client .transact( &Login::new( "username", @@ -117,8 +117,6 @@ async fn client() { .await .unwrap(); - assert_eq!(rsp.result.code, ResultCode::CommandCompletedSuccessfully); - let rsp = client .transact( &DomainCheck {