Pin I/O handlers. Allow 'FnOnce' in 'ws' handlers.

This modifies the 'IoHandler::io()' method so that it takes a
'Pin<Box<Self>>', allowing handlers to move internally and assume that
the data is pinned.

The change is then used in the 'ws' contrib crate to allow 'FnOnce'
handlers instead of 'FnMut'. The net effect is that streams, such as
those crated by 'Stream!', are now allowed to move internally.
This commit is contained in:
Sergio Benitez 2023-04-04 15:11:30 -07:00
parent 5e7a75e1a5
commit c3520fb4a1
4 changed files with 35 additions and 22 deletions

View File

@ -1,4 +1,5 @@
use std::io; use std::io;
use std::pin::Pin;
use rocket::data::{IoHandler, IoStream}; use rocket::data::{IoHandler, IoStream};
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream}; use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
@ -68,7 +69,7 @@ impl WebSocket {
/// Create a read/write channel to the client and call `handler` with it. /// Create a read/write channel to the client and call `handler` with it.
/// ///
/// This method takes a `FnMut`, `handler`, that consumes a read/write /// This method takes a `FnOnce`, `handler`, that consumes a read/write
/// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`] /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
/// for details on how to make use of the channel. /// for details on how to make use of the channel.
/// ///
@ -113,14 +114,14 @@ impl WebSocket {
/// } /// }
/// ``` /// ```
pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r> pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
where F: FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
{ {
Channel { ws: self, handler: Box::new(handler), } Channel { ws: self, handler: Box::new(handler), }
} }
/// Create a stream that consumes client [`Message`]s and emits its own. /// Create a stream that consumes client [`Message`]s and emits its own.
/// ///
/// This method takes a `FnMut` `stream` that consumes a read-only stream /// This method takes a `FnOnce` `stream` that consumes a read-only stream
/// and returns a stream of [`Message`]s. While the returned stream can be /// and returns a stream of [`Message`]s. While the returned stream can be
/// constructed in any manner, the [`Stream!`] macro is the preferred /// constructed in any manner, the [`Stream!`] macro is the preferred
/// method. In any case, the stream must be `Send`. /// method. In any case, the stream must be `Send`.
@ -153,7 +154,7 @@ impl WebSocket {
/// } /// }
/// ``` /// ```
pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S> pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
where F: FnMut(SplitStream<DuplexStream>) -> S + Send + 'r, where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
S: futures::Stream<Item = Result<Message>> + Send + 'r S: futures::Stream<Item = Result<Message>> + Send + 'r
{ {
MessageStream { ws: self, handler: Box::new(stream), } MessageStream { ws: self, handler: Box::new(stream), }
@ -165,7 +166,7 @@ impl WebSocket {
/// `Channel` has no methods or functionality beyond its trait implementations. /// `Channel` has no methods or functionality beyond its trait implementations.
pub struct Channel<'r> { pub struct Channel<'r> {
ws: WebSocket, ws: WebSocket,
handler: Box<dyn FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>, handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
} }
/// A [`Stream`](futures::Stream) of [`Message`]s, returned by /// A [`Stream`](futures::Stream) of [`Message`]s, returned by
@ -177,7 +178,7 @@ pub struct Channel<'r> {
// TODO: Get rid of this or `Channel` via a single `enum`. // TODO: Get rid of this or `Channel` via a single `enum`.
pub struct MessageStream<'r, S> { pub struct MessageStream<'r, S> {
ws: WebSocket, ws: WebSocket,
handler: Box<dyn FnMut(SplitStream<DuplexStream>) -> S + Send + 'r> handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
} }
#[rocket::async_trait] #[rocket::async_trait]
@ -228,8 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
#[rocket::async_trait] #[rocket::async_trait]
impl IoHandler for Channel<'_> { impl IoHandler for Channel<'_> {
async fn io(&mut self, io: IoStream) -> io::Result<()> { async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
let result = (self.handler)(DuplexStream::new(io, self.ws.config).await).await; let channel = Pin::into_inner(self);
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
handle_result(result).map(|_| ()) handle_result(result).map(|_| ())
} }
} }
@ -238,9 +240,10 @@ impl IoHandler for Channel<'_> {
impl<'r, S> IoHandler for MessageStream<'r, S> impl<'r, S> IoHandler for MessageStream<'r, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'r where S: futures::Stream<Item = Result<Message>> + Send + 'r
{ {
async fn io(&mut self, io: IoStream) -> io::Result<()> { async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
let (mut sink, stream) = DuplexStream::new(io, self.ws.config).await.split(); let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
let mut stream = std::pin::pin!((self.handler)(stream)); let handler = Pin::into_inner(self).handler;
let mut stream = std::pin::pin!((handler)(source));
while let Some(msg) = stream.next().await { while let Some(msg) = stream.next().await {
let result = match msg { let result = match msg {
Ok(msg) => sink.send(msg).await, Ok(msg) => sink.send(msg).await,

View File

@ -42,6 +42,8 @@ enum IoStreamKind {
/// to the client. /// to the client.
/// ///
/// ```rust /// ```rust
/// use std::pin::Pin;
///
/// use rocket::tokio::io; /// use rocket::tokio::io;
/// use rocket::data::{IoHandler, IoStream}; /// use rocket::data::{IoHandler, IoStream};
/// ///
@ -49,7 +51,7 @@ enum IoStreamKind {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> { /// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -66,7 +68,7 @@ enum IoStreamKind {
#[crate::async_trait] #[crate::async_trait]
pub trait IoHandler: Send { pub trait IoHandler: Send {
/// Performs the raw I/O. /// Performs the raw I/O.
async fn io(&mut self, io: IoStream) -> io::Result<()>; async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()>;
} }
#[doc(hidden)] #[doc(hidden)]

View File

@ -1,6 +1,7 @@
use std::{fmt, str}; use std::{fmt, str};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncSeek}; use tokio::io::{AsyncRead, AsyncSeek};
@ -276,6 +277,8 @@ impl<'r> Builder<'r> {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response; /// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream}; /// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io; /// use rocket::tokio::io;
@ -284,7 +287,7 @@ impl<'r> Builder<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> { /// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -485,7 +488,7 @@ pub struct Response<'r> {
status: Option<Status>, status: Option<Status>,
headers: HeaderMap<'r>, headers: HeaderMap<'r>,
body: Body<'r>, body: Body<'r>,
upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>, upgrade: HashMap<Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>>,
} }
impl<'r> Response<'r> { impl<'r> Response<'r> {
@ -801,7 +804,7 @@ impl<'r> Response<'r> {
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>( pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
&mut self, &mut self,
protocols: I protocols: I
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> { ) -> Result<Option<(Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>)>, ()> {
if self.upgrade.is_empty() { if self.upgrade.is_empty() {
return Ok(None); return Ok(None);
} }
@ -826,6 +829,8 @@ impl<'r> Response<'r> {
/// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`. /// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`.
/// ///
/// ```rust /// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response; /// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream}; /// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io; /// use rocket::tokio::io;
@ -834,7 +839,7 @@ impl<'r> Response<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> { /// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -849,8 +854,8 @@ impl<'r> Response<'r> {
/// assert!(response.upgrade("raw-echo").is_some()); /// assert!(response.upgrade("raw-echo").is_some());
/// # }) /// # })
/// ``` /// ```
pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> { pub fn upgrade(&mut self, proto: &str) -> Option<Pin<&mut (dyn IoHandler + 'r)>> {
self.upgrade.get_mut(proto.as_uncased()).map(|h| &mut **h) self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
} }
/// Returns a mutable borrow of the body of `self`, if there is one. A /// Returns a mutable borrow of the body of `self`, if there is one. A
@ -957,6 +962,8 @@ impl<'r> Response<'r> {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response; /// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream}; /// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io; /// use rocket::tokio::io;
@ -965,7 +972,7 @@ impl<'r> Response<'r> {
/// ///
/// #[rocket::async_trait] /// #[rocket::async_trait]
/// impl IoHandler for EchoHandler { /// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> { /// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io); /// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?; /// io::copy(&mut reader, &mut writer).await?;
/// Ok(()) /// Ok(())
@ -983,7 +990,7 @@ impl<'r> Response<'r> {
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H) pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
where N: Into<Uncased<'r>>, H: IoHandler + 'r where N: Into<Uncased<'r>>, H: IoHandler + 'r
{ {
self.upgrade.insert(protocol.into(), Box::new(handler)); self.upgrade.insert(protocol.into(), Box::pin(handler));
} }
/// Sets the body's maximum chunk size to `size` bytes. /// Sets the body's maximum chunk size to `size` bytes.

View File

@ -1,6 +1,7 @@
use std::io; use std::io;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::pin::Pin;
use yansi::Paint; use yansi::Paint;
use tokio::sync::oneshot; use tokio::sync::oneshot;
@ -179,7 +180,7 @@ impl Rocket<Orbit> {
&self, &self,
mut response: Response<'r>, mut response: Response<'r>,
proto: uncased::Uncased<'r>, proto: uncased::Uncased<'r>,
mut io_handler: Box<dyn IoHandler + 'r>, io_handler: Pin<Box<dyn IoHandler + 'r>>,
pending_upgrade: hyper::upgrade::OnUpgrade, pending_upgrade: hyper::upgrade::OnUpgrade,
tx: oneshot::Sender<hyper::Response<hyper::Body>>, tx: oneshot::Sender<hyper::Response<hyper::Body>>,
) { ) {