Pin I/O handlers. Allow 'FnOnce' in 'ws' handlers.

This modifies the 'IoHandler::io()' method so that it takes a
'Pin<Box<Self>>', allowing handlers to move internally and assume that
the data is pinned.

The change is then used in the 'ws' contrib crate to allow 'FnOnce'
handlers instead of 'FnMut'. The net effect is that streams, such as
those crated by 'Stream!', are now allowed to move internally.
This commit is contained in:
Sergio Benitez 2023-04-04 15:11:30 -07:00
parent 5e7a75e1a5
commit c3520fb4a1
4 changed files with 35 additions and 22 deletions

View File

@ -1,4 +1,5 @@
use std::io;
use std::pin::Pin;
use rocket::data::{IoHandler, IoStream};
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
@ -68,7 +69,7 @@ impl WebSocket {
/// Create a read/write channel to the client and call `handler` with it.
///
/// This method takes a `FnMut`, `handler`, that consumes a read/write
/// This method takes a `FnOnce`, `handler`, that consumes a read/write
/// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`]
/// for details on how to make use of the channel.
///
@ -113,14 +114,14 @@ impl WebSocket {
/// }
/// ```
pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
where F: FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r
{
Channel { ws: self, handler: Box::new(handler), }
}
/// Create a stream that consumes client [`Message`]s and emits its own.
///
/// This method takes a `FnMut` `stream` that consumes a read-only stream
/// This method takes a `FnOnce` `stream` that consumes a read-only stream
/// and returns a stream of [`Message`]s. While the returned stream can be
/// constructed in any manner, the [`Stream!`] macro is the preferred
/// method. In any case, the stream must be `Send`.
@ -153,7 +154,7 @@ impl WebSocket {
/// }
/// ```
pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
where F: FnMut(SplitStream<DuplexStream>) -> S + Send + 'r,
where F: FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r,
S: futures::Stream<Item = Result<Message>> + Send + 'r
{
MessageStream { ws: self, handler: Box::new(stream), }
@ -165,7 +166,7 @@ impl WebSocket {
/// `Channel` has no methods or functionality beyond its trait implementations.
pub struct Channel<'r> {
ws: WebSocket,
handler: Box<dyn FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
handler: Box<dyn FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + Send + 'r>,
}
/// A [`Stream`](futures::Stream) of [`Message`]s, returned by
@ -177,7 +178,7 @@ pub struct Channel<'r> {
// TODO: Get rid of this or `Channel` via a single `enum`.
pub struct MessageStream<'r, S> {
ws: WebSocket,
handler: Box<dyn FnMut(SplitStream<DuplexStream>) -> S + Send + 'r>
handler: Box<dyn FnOnce(SplitStream<DuplexStream>) -> S + Send + 'r>
}
#[rocket::async_trait]
@ -228,8 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
#[rocket::async_trait]
impl IoHandler for Channel<'_> {
async fn io(&mut self, io: IoStream) -> io::Result<()> {
let result = (self.handler)(DuplexStream::new(io, self.ws.config).await).await;
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
let channel = Pin::into_inner(self);
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
handle_result(result).map(|_| ())
}
}
@ -238,9 +240,10 @@ impl IoHandler for Channel<'_> {
impl<'r, S> IoHandler for MessageStream<'r, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'r
{
async fn io(&mut self, io: IoStream) -> io::Result<()> {
let (mut sink, stream) = DuplexStream::new(io, self.ws.config).await.split();
let mut stream = std::pin::pin!((self.handler)(stream));
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
let handler = Pin::into_inner(self).handler;
let mut stream = std::pin::pin!((handler)(source));
while let Some(msg) = stream.next().await {
let result = match msg {
Ok(msg) => sink.send(msg).await,

View File

@ -42,6 +42,8 @@ enum IoStreamKind {
/// to the client.
///
/// ```rust
/// use std::pin::Pin;
///
/// use rocket::tokio::io;
/// use rocket::data::{IoHandler, IoStream};
///
@ -49,7 +51,7 @@ enum IoStreamKind {
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
@ -66,7 +68,7 @@ enum IoStreamKind {
#[crate::async_trait]
pub trait IoHandler: Send {
/// Performs the raw I/O.
async fn io(&mut self, io: IoStream) -> io::Result<()>;
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()>;
}
#[doc(hidden)]

View File

@ -1,6 +1,7 @@
use std::{fmt, str};
use std::borrow::Cow;
use std::collections::HashMap;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncSeek};
@ -276,6 +277,8 @@ impl<'r> Builder<'r> {
/// # Example
///
/// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
@ -284,7 +287,7 @@ impl<'r> Builder<'r> {
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
@ -485,7 +488,7 @@ pub struct Response<'r> {
status: Option<Status>,
headers: HeaderMap<'r>,
body: Body<'r>,
upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>,
upgrade: HashMap<Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>>,
}
impl<'r> Response<'r> {
@ -801,7 +804,7 @@ impl<'r> Response<'r> {
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
&mut self,
protocols: I
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
) -> Result<Option<(Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>)>, ()> {
if self.upgrade.is_empty() {
return Ok(None);
}
@ -826,6 +829,8 @@ impl<'r> Response<'r> {
/// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`.
///
/// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
@ -834,7 +839,7 @@ impl<'r> Response<'r> {
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
@ -849,8 +854,8 @@ impl<'r> Response<'r> {
/// assert!(response.upgrade("raw-echo").is_some());
/// # })
/// ```
pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> {
self.upgrade.get_mut(proto.as_uncased()).map(|h| &mut **h)
pub fn upgrade(&mut self, proto: &str) -> Option<Pin<&mut (dyn IoHandler + 'r)>> {
self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
}
/// Returns a mutable borrow of the body of `self`, if there is one. A
@ -957,6 +962,8 @@ impl<'r> Response<'r> {
/// # Example
///
/// ```rust
/// use std::pin::Pin;
///
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
@ -965,7 +972,7 @@ impl<'r> Response<'r> {
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
@ -983,7 +990,7 @@ impl<'r> Response<'r> {
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
where N: Into<Uncased<'r>>, H: IoHandler + 'r
{
self.upgrade.insert(protocol.into(), Box::new(handler));
self.upgrade.insert(protocol.into(), Box::pin(handler));
}
/// Sets the body's maximum chunk size to `size` bytes.

View File

@ -1,6 +1,7 @@
use std::io;
use std::sync::Arc;
use std::time::Duration;
use std::pin::Pin;
use yansi::Paint;
use tokio::sync::oneshot;
@ -179,7 +180,7 @@ impl Rocket<Orbit> {
&self,
mut response: Response<'r>,
proto: uncased::Uncased<'r>,
mut io_handler: Box<dyn IoHandler + 'r>,
io_handler: Pin<Box<dyn IoHandler + 'r>>,
pending_upgrade: hyper::upgrade::OnUpgrade,
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
) {