diff --git a/src/client.rs b/src/client.rs index 986a08c..99c64a3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -109,7 +109,7 @@ impl EppClient { pub async fn hello(&mut self) -> Result { let hello_xml = HelloDocument::default().serialize()?; - let response = self.connection.transact(&hello_xml).await?; + let response = self.connection.transact(&hello_xml)?.await?; Ok(GreetingDocument::deserialize(&response)?.data) } @@ -127,7 +127,7 @@ impl EppClient { let epp_xml = >::serialize_request(data.command, data.extension, id)?; - let response = self.connection.transact(&epp_xml).await?; + let response = self.connection.transact(&epp_xml)?.await?; match Cmd::deserialize_response(&response) { Ok(response) => Ok(response), @@ -141,7 +141,7 @@ impl EppClient { /// Accepts raw EPP XML and returns the raw EPP XML response to it. /// Not recommended for direct use but sometimes can be useful for debugging pub async fn transact_xml(&mut self, xml: &str) -> Result { - self.connection.transact(xml).await + self.connection.transact(xml)?.await } /// Returns the greeting received on establishment of the connection in raw xml form diff --git a/src/connection.rs b/src/connection.rs index 139870a..ae0eb6a 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -2,11 +2,13 @@ use std::convert::TryInto; use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; -use std::{io, str, u32}; +use std::{io, mem, str, u32}; use async_trait::async_trait; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tracing::{debug, info}; use crate::error::Error; @@ -18,8 +20,16 @@ pub(crate) struct EppConnection { stream: C::Connection, pub greeting: String, timeout: Duration, - // Whether the connection is in a good state to start sending a request - ready: bool, + // A request that is currently in flight + // + // Because the code here currently depends on only one request being in flight at a time, + // this needs to be finished (written, and response read) before we start another one. + current: Option, + // The next request to be sent + // + // If we get a request while another request is in flight (because its future was dropped), + // we will store it here until the current request is finished. + next: Option, } impl EppConnection { @@ -34,17 +44,259 @@ impl EppConnection { connector, greeting: String::new(), timeout, - ready: false, + current: None, + next: None, }; - this.greeting = this.get_epp_response().await?; - this.ready = true; + this.read_greeting().await?; Ok(this) } - /// Constructs an EPP XML request in the required form and sends it to the server - async fn send_epp_request(&mut self, content: &str) -> Result<(), Error> { - let len = content.len(); + async fn read_greeting(&mut self) -> Result<(), Error> { + assert!(self.current.is_none()); + self.current = Some(RequestState::ReadLength { + read: 0, + buf: vec![0; 256], + }); + + self.greeting = RequestFuture { conn: self }.await?; + Ok(()) + } + + pub(crate) async fn reconnect(&mut self) -> Result<(), Error> { + debug!("{}: reconnecting", self.registry); + let _ = self.current.take(); + let _ = self.next.take(); + self.stream = self.connector.connect(self.timeout).await?; + self.read_greeting().await?; + Ok(()) + } + + /// Sends an EPP XML request to the registry and returns the response + pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result, Error> { + debug!("{}: request: {}", self.registry, command); + let new = RequestState::new(command)?; + + // If we have a request currently in flight, finish that first + // If another request was queued up behind the one in flight, just replace it + match self.current.is_some() { + true => { + debug!( + "{}: Queueing up request in order to finish in-flight request", + self.registry + ); + self.next = Some(new); + } + false => self.current = Some(new), + } + + Ok(RequestFuture { conn: self }) + } + + /// Closes the socket and shuts down the connection + pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { + info!("{}: Closing connection", self.registry); + timeout(self.timeout, self.stream.shutdown()).await?; + Ok(()) + } + + fn handle( + &mut self, + mut state: RequestState, + cx: &mut Context<'_>, + ) -> Result { + match &mut state { + RequestState::Writing { mut start, buf } => { + let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) { + Poll::Ready(Ok(wrote)) => wrote, + Poll::Ready(Err(err)) => return Err(err.into()), + Poll::Pending => return Ok(Transition::Pending(state)), + }; + + if wrote == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("{}: Unexpected EOF while writing", self.registry), + ) + .into()); + } + + start += wrote; + debug!( + "{}: Wrote {} bytes, {} out of {} done", + self.registry, + wrote, + start, + buf.len() + ); + + // Transition to reading the response's frame header once + // we've written the entire request + if start < buf.len() { + return Ok(Transition::Next(state)); + } + + Ok(Transition::Next(RequestState::ReadLength { + read: 0, + buf: vec![0; 256], + })) + } + RequestState::ReadLength { mut read, buf } => { + let mut read_buf = ReadBuf::new(&mut buf[read..]); + match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Err(err.into()), + Poll::Pending => return Ok(Transition::Pending(state)), + }; + + let filled = read_buf.filled(); + if filled.is_empty() { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("{}: Unexpected EOF while reading length", self.registry), + ) + .into()); + } + + // We're looking for the frame header which tells us how long the response will be. + // The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't + // have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`. + + read += filled.len(); + if read < 4 { + return Ok(Transition::Next(state)); + } + + let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize; + debug!("{}: Expected response length: {}", self.registry, expected); + buf.resize(expected, 0); + Ok(Transition::Next(RequestState::Reading { + read, + buf: mem::take(buf), + expected, + })) + } + RequestState::Reading { + mut read, + buf, + expected, + } => { + let mut read_buf = ReadBuf::new(&mut buf[read..]); + match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Err(err.into()), + Poll::Pending => return Ok(Transition::Pending(state)), + } + + let filled = read_buf.filled(); + if filled.is_empty() { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("{}: Unexpected EOF while reading", self.registry), + ) + .into()); + } + + read += filled.len(); + debug!( + "{}: Read {} bytes, {} out of {} done", + self.registry, + filled.len(), + read, + expected + ); + + // + + Ok(if read < *expected { + // If we haven't received the entire response yet, stick to the `Reading` state. + Transition::Next(state) + } else if let Some(next) = self.next.take() { + // Otherwise, if we were just pushing through this request because it was already + // in flight when we started a new one, ignore this response and move to the + // next request (the one this `RequestFuture` is actually for). + Transition::Next(next) + } else { + // Otherwise, drain off the frame header and convert the rest to a `String`. + buf.drain(..4); + Transition::Done(String::from_utf8(mem::take(buf))?) + }) + } + } + } +} + +pub(crate) struct RequestFuture<'a, C: Connector> { + conn: &'a mut EppConnection, +} + +impl<'a, C: Connector> Future for RequestFuture<'a, C> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + let state = this.conn.current.take().unwrap(); + match this.conn.handle(state, cx) { + Ok(Transition::Next(next)) => { + this.conn.current = Some(next); + continue; + } + Ok(Transition::Pending(state)) => { + this.conn.current = Some(state); + return Poll::Pending; + } + Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)), + Err(err) => { + // Assume the error means the connection can no longer be used + this.conn.next = None; + return Poll::Ready(Err(err)); + } + } + } + } +} + +// Transitions between `RequestState`s +enum Transition { + Pending(RequestState), + Next(RequestState), + Done(String), +} + +#[derive(Debug)] +enum RequestState { + // Writing the request command out to the peer + Writing { + // The amount of bytes we've already written + start: usize, + // The full XML request + buf: Vec, + }, + // Reading the frame header (32-bit big-endian unsigned integer) + ReadLength { + // The amount of bytes we've already read + read: usize, + // The buffer we're using to read into + buf: Vec, + }, + // Reading the entire frame + Reading { + // The amount of bytes we've already read + read: usize, + // The buffer we're using to read into + // + // This will still have the frame header in it, needs to be cut off before + // yielding the response to the caller. + buf: Vec, + // The expected length of the response according to the frame header + expected: usize, + }, +} + +impl RequestState { + fn new(command: &str) -> Result { + let len = command.len(); let buf_size = len + 4; let mut buf: Vec = vec![0u8; buf_size]; @@ -53,80 +305,8 @@ impl EppConnection { let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?); buf[..4].clone_from_slice(&len_u32); - buf[4..].clone_from_slice(content.as_bytes()); - - let wrote = timeout(self.timeout, self.stream.write(&buf)).await?; - debug!("{}: Wrote {} bytes", self.registry, wrote); - Ok(()) - } - - /// Receives response from the socket and converts it into an EPP XML string - async fn get_epp_response(&mut self) -> Result { - let mut buf = [0u8; 4]; - timeout(self.timeout, self.stream.read_exact(&mut buf)).await?; - - let buf_size: usize = u32::from_be_bytes(buf).try_into()?; - - let message_size = buf_size - 4; - debug!("{}: Response buffer size: {}", self.registry, message_size); - - let mut buf = vec![0; message_size]; - let mut read_size: usize = 0; - - loop { - let read = timeout(self.timeout, self.stream.read(&mut buf[read_size..])).await?; - debug!("{}: Read: {} bytes", self.registry, read); - - read_size += read; - debug!("{}: Total read: {} bytes", self.registry, read_size); - - if read == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("{}: unexpected eof", self.registry), - ) - .into()); - } else if read_size >= message_size { - break; - } - } - - self.ready = true; - Ok(String::from_utf8(buf)?) - } - - pub(crate) async fn reconnect(&mut self) -> Result<(), Error> { - debug!("{}: reconnecting", self.registry); - self.ready = false; - self.stream = self.connector.connect(self.timeout).await?; - self.greeting = self.get_epp_response().await?; - self.ready = true; - Ok(()) - } - - /// Sends an EPP XML request to the registry and return the response - /// receieved to the request - pub(crate) async fn transact(&mut self, content: &str) -> Result { - if !self.ready { - debug!("{}: connection not ready", self.registry); - self.reconnect().await?; - } - - debug!("{}: request: {}", self.registry, content); - self.send_epp_request(content).await?; - - let response = self.get_epp_response().await?; - debug!("{}: response: {}", self.registry, response); - - Ok(response) - } - - /// Closes the socket and shuts the connection - pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { - info!("{}: Closing connection", self.registry); - self.ready = false; - timeout(self.timeout, self.stream.shutdown()).await?; - Ok(()) + buf[4..].clone_from_slice(command.as_bytes()); + Ok(Self::Writing { start: 0, buf }) } } diff --git a/src/error.rs b/src/error.rs index 30ddb3f..d170579 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,6 @@ //! Error types to wrap internal errors and make EPP errors easier to read +use std::array::TryFromSliceError; use std::error::Error as StdError; use std::fmt::{self, Display}; use std::io; @@ -70,3 +71,9 @@ impl From for Error { Self::Other(e.into()) } } + +impl From for Error { + fn from(e: TryFromSliceError) -> Self { + Self::Other(e.into()) + } +} diff --git a/tests/basic.rs b/tests/basic.rs index 58d56fb..9a063c1 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -103,7 +103,7 @@ async fn client() { .unwrap(); assert_eq!(client.xml_greeting(), xml("response/greeting.xml")); - client + let rsp = client .transact( &Login::new( "username", @@ -115,6 +115,8 @@ async fn client() { .await .unwrap(); + assert_eq!(rsp.result.code, ResultCode::CommandCompletedSuccessfully); + let rsp = client .transact( &DomainCheck {