From d97c83d7e0989f4a11a026048392de7ead776a67 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Wed, 29 Mar 2023 16:59:57 -0700 Subject: [PATCH] Finalize support for external connection upgrades. --- core/lib/src/data/io_stream.rs | 150 +++++++++++++++++++++++++ core/lib/src/data/mod.rs | 2 + core/lib/src/lib.rs | 2 - core/lib/src/response/response.rs | 179 ++++++++++++++++++++++++++---- core/lib/src/server.rs | 70 ++++++------ core/lib/src/upgrade/mod.rs | 14 --- 6 files changed, 344 insertions(+), 73 deletions(-) create mode 100644 core/lib/src/data/io_stream.rs delete mode 100644 core/lib/src/upgrade/mod.rs diff --git a/core/lib/src/data/io_stream.rs b/core/lib/src/data/io_stream.rs new file mode 100644 index 00000000..d965b957 --- /dev/null +++ b/core/lib/src/data/io_stream.rs @@ -0,0 +1,150 @@ +use std::io; +use std::task::{Context, Poll}; +use std::pin::Pin; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::http::hyper::upgrade::Upgraded; + +/// A bidirectional, raw stream to the client. +/// +/// An instance of `IoStream` is passed to an [`IoHandler`] in response to a +/// successful upgrade request initiated by responders via +/// [`Response::add_upgrade()`] or the equivalent builder method +/// [`Builder::upgrade()`]. For details on upgrade connections, see +/// [`Response`#upgrading]. +/// +/// An `IoStream` is guaranteed to be [`AsyncRead`], [`AsyncWrite`], and +/// `Unpin`. Bytes written to the stream are sent directly to the client. Bytes +/// read from the stream are those sent directly _by_ the client. See +/// [`IoHandler`] for one example of how values of this type are used. +/// +/// [`Response::add_upgrade()`]: crate::Response::add_upgrade() +/// [`Builder::upgrade()`]: crate::response::Builder::upgrade() +/// [`Response`#upgrading]: crate::response::Response#upgrading +pub struct IoStream { + kind: IoStreamKind, +} + +/// Just in case we want to add stream kinds in the future. +enum IoStreamKind { + Upgraded(Upgraded) +} + +/// An upgraded connection I/O handler. +/// +/// An I/O handler performs raw I/O via the passed in [`IoStream`], which is +/// [`AsyncRead`], [`AsyncWrite`], and `Unpin`. +/// +/// # Example +/// +/// The example below implements an `EchoHandler` that echos the raw bytes back +/// to the client. +/// +/// ```rust +/// use rocket::tokio::io; +/// use rocket::data::{IoHandler, IoStream}; +/// +/// struct EchoHandler; +/// +/// #[rocket::async_trait] +/// impl IoHandler for EchoHandler { +/// async fn io(&mut self, io: IoStream) -> io::Result<()> { +/// let (mut reader, mut writer) = io::split(io); +/// io::copy(&mut reader, &mut writer).await?; +/// Ok(()) +/// } +/// } +/// +/// # use rocket::Response; +/// # rocket::async_test(async { +/// # let mut response = Response::new(); +/// # response.add_upgrade("raw-echo", EchoHandler); +/// # assert!(response.upgrade("raw-echo").is_some()); +/// # }) +/// ``` +#[crate::async_trait] +pub trait IoHandler: Send { + /// Performs the raw I/O. + async fn io(&mut self, io: IoStream) -> io::Result<()>; +} + +#[doc(hidden)] +impl From for IoStream { + fn from(io: Upgraded) -> Self { + IoStream { kind: IoStreamKind::Upgraded(io) } + } +} + +/// A "trait alias" of sorts so we can use `AsyncRead + AsyncWrite + Unpin` in `dyn`. +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin { } + +/// Implemented for all `AsyncRead + AsyncWrite + Unpin`, of course. +impl AsyncReadWrite for T { } + +impl IoStream { + /// Returns the internal I/O stream. + fn inner_mut(&mut self) -> Pin<&mut dyn AsyncReadWrite> { + match self.kind { + IoStreamKind::Upgraded(ref mut io) => Pin::new(io), + } + } + + /// Returns `true` if the inner I/O stream is write vectored. + fn inner_is_write_vectored(&self) -> bool { + match self.kind { + IoStreamKind::Upgraded(ref io) => io.is_write_vectored(), + } + } +} + +impl AsyncRead for IoStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.get_mut().inner_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for IoStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().inner_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner_mut().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner_mut().poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.get_mut().inner_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner_is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn is_unpin() { + fn check_traits() {} + check_traits::(); + } +} diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index 2d6691e5..9c7a3314 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -6,12 +6,14 @@ mod data; mod data_stream; mod from_data; mod limits; +mod io_stream; pub use self::data::Data; pub use self::data_stream::DataStream; pub use self::from_data::{FromData, Outcome}; pub use self::limits::Limits; pub use self::capped::{N, Capped}; +pub use self::io_stream::{IoHandler, IoStream}; pub use ubyte::{ByteUnit, ToByteUnit}; pub(crate) use self::data_stream::StreamReader; diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index 823a6e0f..e1b8a260 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -123,7 +123,6 @@ pub use time; #[doc(hidden)] pub mod sentinel; pub mod local; pub mod request; -pub mod upgrade; pub mod response; pub mod config; pub mod form; @@ -176,7 +175,6 @@ mod rocket; mod router; mod phase; -#[doc(inline)] pub use crate::upgrade::Upgrade; #[doc(inline)] pub use crate::response::Response; #[doc(inline)] pub use crate::data::Data; #[doc(inline)] pub use crate::config::Config; diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 5ef2aeea..d01bc004 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -1,11 +1,13 @@ use std::{fmt, str}; use std::borrow::Cow; +use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncSeek}; use crate::http::{Header, HeaderMap, Status, ContentType, Cookie}; +use crate::http::uncased::{Uncased, AsUncased}; +use crate::data::IoHandler; use crate::response::Body; -use crate::upgrade::Upgrade; /// Builder for the [`Response`] type. /// @@ -262,10 +264,43 @@ impl<'r> Builder<'r> { self } - /// Sets the upgrade of the `Response`. + /// Registers `handler` as the I/O handler for upgrade protocol `protocol`. + /// + /// This is equivalent to [`Response::add_upgrade()`]. + /// + /// **NOTE**: Responses registering I/O handlers for upgraded protocols + /// **should not** set the response status to `101 Switching Protocols`, nor set the + /// `Connection` or `Upgrade` headers. Rocket automatically sets these + /// headers as needed. See [`Response`#upgrading] for details. + /// + /// # Example + /// + /// ```rust + /// use rocket::Response; + /// use rocket::data::{IoHandler, IoStream}; + /// use rocket::tokio::io; + /// + /// struct EchoHandler; + /// + /// #[rocket::async_trait] + /// impl IoHandler for EchoHandler { + /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// let (mut reader, mut writer) = io::split(io); + /// io::copy(&mut reader, &mut writer).await?; + /// Ok(()) + /// } + /// } + /// + /// let response = Response::build() + /// .upgrade("raw-echo", EchoHandler) + /// .streamed_body(std::io::Cursor::new("We didn't upgrade!")) + /// .finalize(); + /// ``` #[inline(always)] - pub fn upgrade(&mut self, upgrade: Option + Send>>) -> &mut Builder<'r> { - self.response.set_upgrade(upgrade); + pub fn upgrade(&mut self, protocol: P, handler: H) -> &mut Builder<'r> + where P: Into>, H: IoHandler + 'r + { + self.response.add_upgrade(protocol.into(), handler); self } @@ -415,13 +450,42 @@ impl<'r> Builder<'r> { /// A response, as returned by types implementing /// [`Responder`](crate::response::Responder). /// -/// See [`Builder`] for docs on how a `Response` is typically created. +/// See [`Builder`] for docs on how a `Response` is typically created and the +/// [module docs](crate::response) for notes on composing responses +/// +/// ## Upgrading +/// +/// A response may optionally register [`IoHandler`]s for upgraded requests via +/// [`Response::add_upgrade()`] or the corresponding builder method +/// [`Builder::upgrade()`]. If the incoming request 1) requests an upgrade via a +/// `Connection: Upgrade` header _and_ 2) includes a protocol in its `Upgrade` +/// header that is registered by the returned `Response`, the connection will be +/// upgraded. An upgrade response is sent to the client, and the registered +/// `IoHandler` for the client's preferred protocol is invoked with an +/// [`IoStream`](crate::data::IoStream) representing a raw byte stream to the +/// client. Note that protocol names are treated case-insensitively during +/// matching. +/// +/// If a connection is upgraded, Rocket automatically set the following in the +/// upgrade response: +/// * The response status to `101 Switching Protocols`. +/// * The `Connection: Upgrade` header. +/// * The `Upgrade` header's value to the selected protocol. +/// +/// As such, a response **should never** set a `101` status nor the `Connection` +/// or `Upgrade` headers: Rocket handles this automatically. Instead, it should +/// set a status and headers to use in case the connection is not upgraded, +/// either due to an error or because the client did not request an upgrade. +/// +/// If a connection _is not_ upgraded due to an error, even though there was a +/// matching, registered protocol, the `IoHandler` is not invoked, and the +/// original response is sent to the client without alteration. #[derive(Default)] pub struct Response<'r> { status: Option, headers: HeaderMap<'r>, body: Body<'r>, - upgrade: Option + Send>>, + upgrade: HashMap, Box>, } impl<'r> Response<'r> { @@ -730,6 +794,47 @@ impl<'r> Response<'r> { &self.body } + pub(crate) fn take_upgrade>( + &mut self, + mut protocols: I + ) -> Option<(Uncased<'r>, Box)> { + protocols.find_map(|p| self.upgrade.remove_entry(p.as_uncased())) + } + + /// Returns the [`IoHandler`] for the protocol `proto`. + /// + /// Returns `Some` if such a handler was registered via + /// [`Response::add_upgrade()`] or the corresponding builder method + /// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`. + /// + /// ```rust + /// use rocket::Response; + /// use rocket::data::{IoHandler, IoStream}; + /// use rocket::tokio::io; + /// + /// struct EchoHandler; + /// + /// #[rocket::async_trait] + /// impl IoHandler for EchoHandler { + /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// let (mut reader, mut writer) = io::split(io); + /// io::copy(&mut reader, &mut writer).await?; + /// Ok(()) + /// } + /// } + /// + /// # rocket::async_test(async { + /// let mut response = Response::new(); + /// assert!(response.upgrade("raw-echo").is_none()); + /// + /// response.add_upgrade("raw-echo", EchoHandler); + /// assert!(response.upgrade("raw-echo").is_some()); + /// # }) + /// ``` + pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> { + self.upgrade.get_mut(proto.as_uncased()).map(|h| &mut **h) + } + /// Returns a mutable borrow of the body of `self`, if there is one. A /// mutable borrow allows for reading the body. /// @@ -816,25 +921,51 @@ impl<'r> Response<'r> { self.body = Body::with_unsized(body); } - /// Returns a instance of the `Upgrade`-trait when the `Response` is upgradeable - #[inline(always)] - pub fn upgrade(&self) -> Option<&Box + Send>> { - self.upgrade.as_ref() - } - - /// Takes the upgrade out of the response, leaving a [`None`] in it's place. - /// With this, the caller takes ownership about the `Upgrade`-trait. - #[inline(always)] - pub fn take_upgrade(&mut self) -> Option + Send>> { - self.upgrade.take() - } - - /// Sets the upgrade contained in this `Response` + /// Registers `handler` as the I/O handler for upgrade protocol `protocol`. /// - /// While the response also need to have status 101 SwitchingProtocols in order to be a valid upgrade, - /// this method doesn't set this, and it's expected that the caller sets this. - pub fn set_upgrade(&mut self, upgrade: Option + Send>>) { - self.upgrade = upgrade; + /// Responses registering I/O handlers for upgraded protocols **should not** + /// set the response status to `101`, nor set the `Connection` or `Upgrade` + /// headers. Rocket automatically sets these headers as needed. See + /// [`Response`#upgrading] for details. + /// + /// If a handler was previously registered for `protocol`, this `handler` + /// replaces it. If the connection is upgraded to `protocol`, the last + /// `handler` registered for the protocol is used to handle the connection. + /// See [`IoHandler`] for details on implementing an I/O handler. For + /// details on connection upgrading, see [`Response`#upgrading]. + /// + /// [`Response`#upgrading]: Response#upgrading + /// + /// # Example + /// + /// ```rust + /// use rocket::Response; + /// use rocket::data::{IoHandler, IoStream}; + /// use rocket::tokio::io; + /// + /// struct EchoHandler; + /// + /// #[rocket::async_trait] + /// impl IoHandler for EchoHandler { + /// async fn io(&mut self, io: IoStream) -> io::Result<()> { + /// let (mut reader, mut writer) = io::split(io); + /// io::copy(&mut reader, &mut writer).await?; + /// Ok(()) + /// } + /// } + /// + /// # rocket::async_test(async { + /// let mut response = Response::new(); + /// assert!(response.upgrade("raw-echo").is_none()); + /// + /// response.add_upgrade("raw-echo", EchoHandler); + /// assert!(response.upgrade("raw-echo").is_some()); + /// # }) + /// ``` + pub fn add_upgrade(&mut self, protocol: N, handler: H) + where N: Into>, H: IoHandler + 'r + { + self.upgrade.insert(protocol.into(), Box::new(handler)); } /// Sets the body's maximum chunk size to `size` bytes. diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index d0414c0f..9a473a2c 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -14,8 +14,9 @@ use crate::outcome::Outcome; use crate::error::{Error, ErrorKind}; use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo}; use crate::request::ConnectionMeta; +use crate::data::IoHandler; -use crate::http::{hyper, Method, Status, Header}; +use crate::http::{hyper, uncased, Method, Status, Header}; use crate::http::private::{TcpListener, Listener, Connection, Incoming}; // A token returned to force the execution of one method before another. @@ -71,9 +72,10 @@ async fn hyper_service_fn( // sends the response metadata (and a body channel) prior. let (tx, rx) = oneshot::channel(); + debug!("Received request: {:#?}", hyp_req); tokio::spawn(async move { - // Upgrade before do any other; we handle errors below - let hyp_upgraded = hyper::upgrade::on(&mut hyp_req); + // We move the request next, so get the upgrade future now. + let pending_upgrade = hyper::upgrade::on(&mut hyp_req); // Convert a Hyper request into a Rocket request. let (h_parts, mut h_body) = hyp_req.into_parts(); @@ -83,36 +85,9 @@ async fn hyper_service_fn( let mut data = Data::from(&mut h_body); let token = rocket.preprocess_request(&mut req, &mut data).await; let mut response = rocket.dispatch(token, &req, data).await; - - if response.status() == Status::SwitchingProtocols { - let may_upgrade = response.take_upgrade(); - match may_upgrade { - Some(upgrade) => { - - // send the finishing response; needed so that hyper can upgrade the request - rocket.send_response(response, tx).await; - - match hyp_upgraded.await { - Ok(hyp_upgraded) => { - // let the upgrade take the upgraded hyper request - let fu = upgrade.start(hyp_upgraded); - fu.await; - } - Err(e) => { - error_!("Failed to upgrade request: {e}"); - // NOTE: we *should* send a response here but since we send one earlier AND upgraded the request, - // this cannot be done easily at this point... - // let response = rocket.handle_error(Status::InternalServerError, &req).await; - // rocket.send_response(response, tx).await; - } - } - } - None => { - error_!("Status is 101 switching protocols, but response dosn't hold a upgrade"); - let response = rocket.handle_error(Status::InternalServerError, &req).await; - rocket.send_response(response, tx).await; - } - } + let upgrade = response.take_upgrade(req.headers().get("upgrade")); + if let Some((proto, handler)) = upgrade { + rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await; } else { rocket.send_response(response, tx).await; } @@ -179,6 +154,7 @@ impl Rocket { let hyp_response = hyp_res.body(hyp_body) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + debug!("sending response: {:#?}", hyp_response); tx.send(hyp_response).map_err(|_| { let msg = "client disconnect before response started"; io::Error::new(io::ErrorKind::BrokenPipe, msg) @@ -194,6 +170,34 @@ impl Rocket { Ok(()) } + async fn handle_upgrade<'r>( + &self, + mut response: Response<'r>, + protocol: uncased::Uncased<'r>, + mut io_handler: Box, + pending_upgrade: hyper::upgrade::OnUpgrade, + tx: oneshot::Sender>, + ) { + info_!("Upgrading connection to {}.", Paint::white(&protocol)); + response.set_status(Status::SwitchingProtocols); + response.set_raw_header("Connection", "Upgrade"); + response.set_raw_header("Upgrade", protocol.into_cow()); + self.send_response(response, tx).await; + + match pending_upgrade.await { + Ok(io_stream) => { + info_!("Upgrade successful."); + if let Err(e) = io_handler.io(io_stream.into()).await { + error!("Upgraded I/O handler failed: {}", e); + } + }, + Err(e) => { + warn!("Response indicated upgrade, but upgrade failed."); + warn_!("Upgrade error: {}", e); + } + } + } + /// Preprocess the request for Rocket things. Currently, this means: /// /// * Rewriting the method in the request if _method form field exists. diff --git a/core/lib/src/upgrade/mod.rs b/core/lib/src/upgrade/mod.rs deleted file mode 100644 index 08627f15..00000000 --- a/core/lib/src/upgrade/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -//! Upgrade wrapper to deal with hyper::upgarde::Upgraded - -use crate::http::hyper; - -/// Trait to determine if any given response in rocket is upgradeable. -/// -/// When a response has the http code 101 SwitchingProtocols, and the response implements the Upgrade trait, -/// then rocket aquires the hyper::upgarde::Upgraded struct and calls the start() method of the trait with the hyper upgrade -/// and awaits the result. -#[crate::async_trait] -pub trait Upgrade<'a> { - /// Called with the hyper::upgarde::Upgraded struct when a rocket response should be upgraded - async fn start(&self, upgraded: hyper::upgrade::Upgraded); -}