From c3520fb4a1f00d8705123a03e9188ec892b9153d Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 4 Apr 2023 15:11:30 -0700 Subject: [PATCH] Pin I/O handlers. Allow 'FnOnce' in 'ws' handlers. This modifies the 'IoHandler::io()' method so that it takes a 'Pin>', allowing handlers to move internally and assume that the data is pinned. The change is then used in the 'ws' contrib crate to allow 'FnOnce' handlers instead of 'FnMut'. The net effect is that streams, such as those crated by 'Stream!', are now allowed to move internally. --- contrib/ws/src/websocket.rs | 25 ++++++++++++++----------- core/lib/src/data/io_stream.rs | 6 ++++-- core/lib/src/response/response.rs | 23 +++++++++++++++-------- core/lib/src/server.rs | 3 ++- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/contrib/ws/src/websocket.rs b/contrib/ws/src/websocket.rs index fbf4ad11..c819f49f 100644 --- a/contrib/ws/src/websocket.rs +++ b/contrib/ws/src/websocket.rs @@ -1,4 +1,5 @@ use std::io; +use std::pin::Pin; use rocket::data::{IoHandler, IoStream}; use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream}; @@ -68,7 +69,7 @@ impl WebSocket { /// Create a read/write channel to the client and call `handler` with it. /// - /// This method takes a `FnMut`, `handler`, that consumes a read/write + /// This method takes a `FnOnce`, `handler`, that consumes a read/write /// WebSocket channel, [`DuplexStream`] to the client. See [`DuplexStream`] /// for details on how to make use of the channel. /// @@ -113,14 +114,14 @@ impl WebSocket { /// } /// ``` pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r> - where F: FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r + where F: FnOnce(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r { Channel { ws: self, handler: Box::new(handler), } } /// Create a stream that consumes client [`Message`]s and emits its own. /// - /// This method takes a `FnMut` `stream` that consumes a read-only stream + /// This method takes a `FnOnce` `stream` that consumes a read-only stream /// and returns a stream of [`Message`]s. While the returned stream can be /// constructed in any manner, the [`Stream!`] macro is the preferred /// method. In any case, the stream must be `Send`. @@ -153,7 +154,7 @@ impl WebSocket { /// } /// ``` pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S> - where F: FnMut(SplitStream) -> S + Send + 'r, + where F: FnOnce(SplitStream) -> S + Send + 'r, S: futures::Stream> + Send + 'r { MessageStream { ws: self, handler: Box::new(stream), } @@ -165,7 +166,7 @@ impl WebSocket { /// `Channel` has no methods or functionality beyond its trait implementations. pub struct Channel<'r> { ws: WebSocket, - handler: Box BoxFuture<'r, Result<()>> + Send + 'r>, + handler: Box BoxFuture<'r, Result<()>> + Send + 'r>, } /// A [`Stream`](futures::Stream) of [`Message`]s, returned by @@ -177,7 +178,7 @@ pub struct Channel<'r> { // TODO: Get rid of this or `Channel` via a single `enum`. pub struct MessageStream<'r, S> { ws: WebSocket, - handler: Box) -> S + Send + 'r> + handler: Box) -> S + Send + 'r> } #[rocket::async_trait] @@ -228,8 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> #[rocket::async_trait] impl IoHandler for Channel<'_> { - async fn io(&mut self, io: IoStream) -> io::Result<()> { - let result = (self.handler)(DuplexStream::new(io, self.ws.config).await).await; + async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + let channel = Pin::into_inner(self); + let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await; handle_result(result).map(|_| ()) } } @@ -238,9 +240,10 @@ impl IoHandler for Channel<'_> { impl<'r, S> IoHandler for MessageStream<'r, S> where S: futures::Stream> + Send + 'r { - async fn io(&mut self, io: IoStream) -> io::Result<()> { - let (mut sink, stream) = DuplexStream::new(io, self.ws.config).await.split(); - let mut stream = std::pin::pin!((self.handler)(stream)); + async fn io(self: Pin>, io: IoStream) -> io::Result<()> { + let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split(); + let handler = Pin::into_inner(self).handler; + let mut stream = std::pin::pin!((handler)(source)); while let Some(msg) = stream.next().await { let result = match msg { Ok(msg) => sink.send(msg).await, diff --git a/core/lib/src/data/io_stream.rs b/core/lib/src/data/io_stream.rs index d965b957..0945c5c0 100644 --- a/core/lib/src/data/io_stream.rs +++ b/core/lib/src/data/io_stream.rs @@ -42,6 +42,8 @@ enum IoStreamKind { /// to the client. /// /// ```rust +/// use std::pin::Pin; +/// /// use rocket::tokio::io; /// use rocket::data::{IoHandler, IoStream}; /// @@ -49,7 +51,7 @@ enum IoStreamKind { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { -/// async fn io(&mut self, io: IoStream) -> io::Result<()> { +/// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -66,7 +68,7 @@ enum IoStreamKind { #[crate::async_trait] pub trait IoHandler: Send { /// Performs the raw I/O. - async fn io(&mut self, io: IoStream) -> io::Result<()>; + async fn io(self: Pin>, io: IoStream) -> io::Result<()>; } #[doc(hidden)] diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 2e1fdfa8..588497e1 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -1,6 +1,7 @@ use std::{fmt, str}; use std::borrow::Cow; use std::collections::HashMap; +use std::pin::Pin; use tokio::io::{AsyncRead, AsyncSeek}; @@ -276,6 +277,8 @@ impl<'r> Builder<'r> { /// # Example /// /// ```rust + /// use std::pin::Pin; + /// /// use rocket::Response; /// use rocket::data::{IoHandler, IoStream}; /// use rocket::tokio::io; @@ -284,7 +287,7 @@ impl<'r> Builder<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -485,7 +488,7 @@ pub struct Response<'r> { status: Option, headers: HeaderMap<'r>, body: Body<'r>, - upgrade: HashMap, Box>, + upgrade: HashMap, Pin>>, } impl<'r> Response<'r> { @@ -801,7 +804,7 @@ impl<'r> Response<'r> { pub(crate) fn take_upgrade>( &mut self, protocols: I - ) -> Result, Box)>, ()> { + ) -> Result, Pin>)>, ()> { if self.upgrade.is_empty() { return Ok(None); } @@ -826,6 +829,8 @@ impl<'r> Response<'r> { /// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`. /// /// ```rust + /// use std::pin::Pin; + /// /// use rocket::Response; /// use rocket::data::{IoHandler, IoStream}; /// use rocket::tokio::io; @@ -834,7 +839,7 @@ impl<'r> Response<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -849,8 +854,8 @@ impl<'r> Response<'r> { /// assert!(response.upgrade("raw-echo").is_some()); /// # }) /// ``` - pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> { - self.upgrade.get_mut(proto.as_uncased()).map(|h| &mut **h) + pub fn upgrade(&mut self, proto: &str) -> Option> { + self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut()) } /// Returns a mutable borrow of the body of `self`, if there is one. A @@ -957,6 +962,8 @@ impl<'r> Response<'r> { /// # Example /// /// ```rust + /// use std::pin::Pin; + /// /// use rocket::Response; /// use rocket::data::{IoHandler, IoStream}; /// use rocket::tokio::io; @@ -965,7 +972,7 @@ impl<'r> Response<'r> { /// /// #[rocket::async_trait] /// impl IoHandler for EchoHandler { - /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// async fn io(self: Pin>, io: IoStream) -> io::Result<()> { /// let (mut reader, mut writer) = io::split(io); /// io::copy(&mut reader, &mut writer).await?; /// Ok(()) @@ -983,7 +990,7 @@ impl<'r> Response<'r> { pub fn add_upgrade(&mut self, protocol: N, handler: H) where N: Into>, H: IoHandler + 'r { - self.upgrade.insert(protocol.into(), Box::new(handler)); + self.upgrade.insert(protocol.into(), Box::pin(handler)); } /// Sets the body's maximum chunk size to `size` bytes. diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 6b315ed0..faae6ddc 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -1,6 +1,7 @@ use std::io; use std::sync::Arc; use std::time::Duration; +use std::pin::Pin; use yansi::Paint; use tokio::sync::oneshot; @@ -179,7 +180,7 @@ impl Rocket { &self, mut response: Response<'r>, proto: uncased::Uncased<'r>, - mut io_handler: Box, + io_handler: Pin>, pending_upgrade: hyper::upgrade::OnUpgrade, tx: oneshot::Sender>, ) {