Finalize support for external connection upgrades.

This commit is contained in:
Sergio Benitez 2023-03-29 16:59:57 -07:00
parent 19e7e82fd6
commit d97c83d7e0
6 changed files with 344 additions and 73 deletions

View File

@ -0,0 +1,150 @@
use std::io;
use std::task::{Context, Poll};
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::http::hyper::upgrade::Upgraded;
/// A bidirectional, raw stream to the client.
///
/// An instance of `IoStream` is passed to an [`IoHandler`] in response to a
/// successful upgrade request initiated by responders via
/// [`Response::add_upgrade()`] or the equivalent builder method
/// [`Builder::upgrade()`]. For details on upgrade connections, see
/// [`Response`#upgrading].
///
/// An `IoStream` is guaranteed to be [`AsyncRead`], [`AsyncWrite`], and
/// `Unpin`. Bytes written to the stream are sent directly to the client. Bytes
/// read from the stream are those sent directly _by_ the client. See
/// [`IoHandler`] for one example of how values of this type are used.
///
/// [`Response::add_upgrade()`]: crate::Response::add_upgrade()
/// [`Builder::upgrade()`]: crate::response::Builder::upgrade()
/// [`Response`#upgrading]: crate::response::Response#upgrading
pub struct IoStream {
kind: IoStreamKind,
}
/// Just in case we want to add stream kinds in the future.
enum IoStreamKind {
Upgraded(Upgraded)
}
/// An upgraded connection I/O handler.
///
/// An I/O handler performs raw I/O via the passed in [`IoStream`], which is
/// [`AsyncRead`], [`AsyncWrite`], and `Unpin`.
///
/// # Example
///
/// The example below implements an `EchoHandler` that echos the raw bytes back
/// to the client.
///
/// ```rust
/// use rocket::tokio::io;
/// use rocket::data::{IoHandler, IoStream};
///
/// struct EchoHandler;
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
/// }
/// }
///
/// # use rocket::Response;
/// # rocket::async_test(async {
/// # let mut response = Response::new();
/// # response.add_upgrade("raw-echo", EchoHandler);
/// # assert!(response.upgrade("raw-echo").is_some());
/// # })
/// ```
#[crate::async_trait]
pub trait IoHandler: Send {
/// Performs the raw I/O.
async fn io(&mut self, io: IoStream) -> io::Result<()>;
}
#[doc(hidden)]
impl From<Upgraded> for IoStream {
fn from(io: Upgraded) -> Self {
IoStream { kind: IoStreamKind::Upgraded(io) }
}
}
/// A "trait alias" of sorts so we can use `AsyncRead + AsyncWrite + Unpin` in `dyn`.
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin { }
/// Implemented for all `AsyncRead + AsyncWrite + Unpin`, of course.
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncReadWrite for T { }
impl IoStream {
/// Returns the internal I/O stream.
fn inner_mut(&mut self) -> Pin<&mut dyn AsyncReadWrite> {
match self.kind {
IoStreamKind::Upgraded(ref mut io) => Pin::new(io),
}
}
/// Returns `true` if the inner I/O stream is write vectored.
fn inner_is_write_vectored(&self) -> bool {
match self.kind {
IoStreamKind::Upgraded(ref io) => io.is_write_vectored(),
}
}
}
impl AsyncRead for IoStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().inner_mut().poll_read(cx, buf)
}
}
impl AsyncWrite for IoStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut().inner_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().inner_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().inner_mut().poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_mut().inner_mut().poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner_is_write_vectored()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_unpin() {
fn check_traits<T: AsyncRead + AsyncWrite + Unpin + Send>() {}
check_traits::<IoStream>();
}
}

View File

@ -6,12 +6,14 @@ mod data;
mod data_stream; mod data_stream;
mod from_data; mod from_data;
mod limits; mod limits;
mod io_stream;
pub use self::data::Data; pub use self::data::Data;
pub use self::data_stream::DataStream; pub use self::data_stream::DataStream;
pub use self::from_data::{FromData, Outcome}; pub use self::from_data::{FromData, Outcome};
pub use self::limits::Limits; pub use self::limits::Limits;
pub use self::capped::{N, Capped}; pub use self::capped::{N, Capped};
pub use self::io_stream::{IoHandler, IoStream};
pub use ubyte::{ByteUnit, ToByteUnit}; pub use ubyte::{ByteUnit, ToByteUnit};
pub(crate) use self::data_stream::StreamReader; pub(crate) use self::data_stream::StreamReader;

View File

@ -123,7 +123,6 @@ pub use time;
#[doc(hidden)] pub mod sentinel; #[doc(hidden)] pub mod sentinel;
pub mod local; pub mod local;
pub mod request; pub mod request;
pub mod upgrade;
pub mod response; pub mod response;
pub mod config; pub mod config;
pub mod form; pub mod form;
@ -176,7 +175,6 @@ mod rocket;
mod router; mod router;
mod phase; mod phase;
#[doc(inline)] pub use crate::upgrade::Upgrade;
#[doc(inline)] pub use crate::response::Response; #[doc(inline)] pub use crate::response::Response;
#[doc(inline)] pub use crate::data::Data; #[doc(inline)] pub use crate::data::Data;
#[doc(inline)] pub use crate::config::Config; #[doc(inline)] pub use crate::config::Config;

View File

@ -1,11 +1,13 @@
use std::{fmt, str}; use std::{fmt, str};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncSeek}; use tokio::io::{AsyncRead, AsyncSeek};
use crate::http::{Header, HeaderMap, Status, ContentType, Cookie}; use crate::http::{Header, HeaderMap, Status, ContentType, Cookie};
use crate::http::uncased::{Uncased, AsUncased};
use crate::data::IoHandler;
use crate::response::Body; use crate::response::Body;
use crate::upgrade::Upgrade;
/// Builder for the [`Response`] type. /// Builder for the [`Response`] type.
/// ///
@ -262,10 +264,43 @@ impl<'r> Builder<'r> {
self self
} }
/// Sets the upgrade of the `Response`. /// Registers `handler` as the I/O handler for upgrade protocol `protocol`.
///
/// This is equivalent to [`Response::add_upgrade()`].
///
/// **NOTE**: Responses registering I/O handlers for upgraded protocols
/// **should not** set the response status to `101 Switching Protocols`, nor set the
/// `Connection` or `Upgrade` headers. Rocket automatically sets these
/// headers as needed. See [`Response`#upgrading] for details.
///
/// # Example
///
/// ```rust
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
///
/// struct EchoHandler;
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
/// }
/// }
///
/// let response = Response::build()
/// .upgrade("raw-echo", EchoHandler)
/// .streamed_body(std::io::Cursor::new("We didn't upgrade!"))
/// .finalize();
/// ```
#[inline(always)] #[inline(always)]
pub fn upgrade(&mut self, upgrade: Option<Box<dyn Upgrade<'static> + Send>>) -> &mut Builder<'r> { pub fn upgrade<P, H>(&mut self, protocol: P, handler: H) -> &mut Builder<'r>
self.response.set_upgrade(upgrade); where P: Into<Uncased<'r>>, H: IoHandler + 'r
{
self.response.add_upgrade(protocol.into(), handler);
self self
} }
@ -415,13 +450,42 @@ impl<'r> Builder<'r> {
/// A response, as returned by types implementing /// A response, as returned by types implementing
/// [`Responder`](crate::response::Responder). /// [`Responder`](crate::response::Responder).
/// ///
/// See [`Builder`] for docs on how a `Response` is typically created. /// See [`Builder`] for docs on how a `Response` is typically created and the
/// [module docs](crate::response) for notes on composing responses
///
/// ## Upgrading
///
/// A response may optionally register [`IoHandler`]s for upgraded requests via
/// [`Response::add_upgrade()`] or the corresponding builder method
/// [`Builder::upgrade()`]. If the incoming request 1) requests an upgrade via a
/// `Connection: Upgrade` header _and_ 2) includes a protocol in its `Upgrade`
/// header that is registered by the returned `Response`, the connection will be
/// upgraded. An upgrade response is sent to the client, and the registered
/// `IoHandler` for the client's preferred protocol is invoked with an
/// [`IoStream`](crate::data::IoStream) representing a raw byte stream to the
/// client. Note that protocol names are treated case-insensitively during
/// matching.
///
/// If a connection is upgraded, Rocket automatically set the following in the
/// upgrade response:
/// * The response status to `101 Switching Protocols`.
/// * The `Connection: Upgrade` header.
/// * The `Upgrade` header's value to the selected protocol.
///
/// As such, a response **should never** set a `101` status nor the `Connection`
/// or `Upgrade` headers: Rocket handles this automatically. Instead, it should
/// set a status and headers to use in case the connection is not upgraded,
/// either due to an error or because the client did not request an upgrade.
///
/// If a connection _is not_ upgraded due to an error, even though there was a
/// matching, registered protocol, the `IoHandler` is not invoked, and the
/// original response is sent to the client without alteration.
#[derive(Default)] #[derive(Default)]
pub struct Response<'r> { pub struct Response<'r> {
status: Option<Status>, status: Option<Status>,
headers: HeaderMap<'r>, headers: HeaderMap<'r>,
body: Body<'r>, body: Body<'r>,
upgrade: Option<Box<dyn Upgrade<'static> + Send>>, upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>,
} }
impl<'r> Response<'r> { impl<'r> Response<'r> {
@ -730,6 +794,47 @@ impl<'r> Response<'r> {
&self.body &self.body
} }
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
&mut self,
mut protocols: I
) -> Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)> {
protocols.find_map(|p| self.upgrade.remove_entry(p.as_uncased()))
}
/// Returns the [`IoHandler`] for the protocol `proto`.
///
/// Returns `Some` if such a handler was registered via
/// [`Response::add_upgrade()`] or the corresponding builder method
/// [`upgrade()`](Builder::upgrade()). Otherwise returns `None`.
///
/// ```rust
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
///
/// struct EchoHandler;
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
/// }
/// }
///
/// # rocket::async_test(async {
/// let mut response = Response::new();
/// assert!(response.upgrade("raw-echo").is_none());
///
/// response.add_upgrade("raw-echo", EchoHandler);
/// assert!(response.upgrade("raw-echo").is_some());
/// # })
/// ```
pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> {
self.upgrade.get_mut(proto.as_uncased()).map(|h| &mut **h)
}
/// Returns a mutable borrow of the body of `self`, if there is one. A /// Returns a mutable borrow of the body of `self`, if there is one. A
/// mutable borrow allows for reading the body. /// mutable borrow allows for reading the body.
/// ///
@ -816,25 +921,51 @@ impl<'r> Response<'r> {
self.body = Body::with_unsized(body); self.body = Body::with_unsized(body);
} }
/// Returns a instance of the `Upgrade`-trait when the `Response` is upgradeable /// Registers `handler` as the I/O handler for upgrade protocol `protocol`.
#[inline(always)]
pub fn upgrade(&self) -> Option<&Box<dyn Upgrade<'static> + Send>> {
self.upgrade.as_ref()
}
/// Takes the upgrade out of the response, leaving a [`None`] in it's place.
/// With this, the caller takes ownership about the `Upgrade`-trait.
#[inline(always)]
pub fn take_upgrade(&mut self) -> Option<Box<dyn Upgrade<'static> + Send>> {
self.upgrade.take()
}
/// Sets the upgrade contained in this `Response`
/// ///
/// While the response also need to have status 101 SwitchingProtocols in order to be a valid upgrade, /// Responses registering I/O handlers for upgraded protocols **should not**
/// this method doesn't set this, and it's expected that the caller sets this. /// set the response status to `101`, nor set the `Connection` or `Upgrade`
pub fn set_upgrade(&mut self, upgrade: Option<Box<dyn Upgrade<'static> + Send>>) { /// headers. Rocket automatically sets these headers as needed. See
self.upgrade = upgrade; /// [`Response`#upgrading] for details.
///
/// If a handler was previously registered for `protocol`, this `handler`
/// replaces it. If the connection is upgraded to `protocol`, the last
/// `handler` registered for the protocol is used to handle the connection.
/// See [`IoHandler`] for details on implementing an I/O handler. For
/// details on connection upgrading, see [`Response`#upgrading].
///
/// [`Response`#upgrading]: Response#upgrading
///
/// # Example
///
/// ```rust
/// use rocket::Response;
/// use rocket::data::{IoHandler, IoStream};
/// use rocket::tokio::io;
///
/// struct EchoHandler;
///
/// #[rocket::async_trait]
/// impl IoHandler for EchoHandler {
/// async fn io(&mut self, io: IoStream) -> io::Result<()> {
/// let (mut reader, mut writer) = io::split(io);
/// io::copy(&mut reader, &mut writer).await?;
/// Ok(())
/// }
/// }
///
/// # rocket::async_test(async {
/// let mut response = Response::new();
/// assert!(response.upgrade("raw-echo").is_none());
///
/// response.add_upgrade("raw-echo", EchoHandler);
/// assert!(response.upgrade("raw-echo").is_some());
/// # })
/// ```
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
where N: Into<Uncased<'r>>, H: IoHandler + 'r
{
self.upgrade.insert(protocol.into(), Box::new(handler));
} }
/// Sets the body's maximum chunk size to `size` bytes. /// Sets the body's maximum chunk size to `size` bytes.

View File

@ -14,8 +14,9 @@ use crate::outcome::Outcome;
use crate::error::{Error, ErrorKind}; use crate::error::{Error, ErrorKind};
use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo}; use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
use crate::request::ConnectionMeta; use crate::request::ConnectionMeta;
use crate::data::IoHandler;
use crate::http::{hyper, Method, Status, Header}; use crate::http::{hyper, uncased, Method, Status, Header};
use crate::http::private::{TcpListener, Listener, Connection, Incoming}; use crate::http::private::{TcpListener, Listener, Connection, Incoming};
// A token returned to force the execution of one method before another. // A token returned to force the execution of one method before another.
@ -71,9 +72,10 @@ async fn hyper_service_fn(
// sends the response metadata (and a body channel) prior. // sends the response metadata (and a body channel) prior.
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
debug!("Received request: {:#?}", hyp_req);
tokio::spawn(async move { tokio::spawn(async move {
// Upgrade before do any other; we handle errors below // We move the request next, so get the upgrade future now.
let hyp_upgraded = hyper::upgrade::on(&mut hyp_req); let pending_upgrade = hyper::upgrade::on(&mut hyp_req);
// Convert a Hyper request into a Rocket request. // Convert a Hyper request into a Rocket request.
let (h_parts, mut h_body) = hyp_req.into_parts(); let (h_parts, mut h_body) = hyp_req.into_parts();
@ -83,36 +85,9 @@ async fn hyper_service_fn(
let mut data = Data::from(&mut h_body); let mut data = Data::from(&mut h_body);
let token = rocket.preprocess_request(&mut req, &mut data).await; let token = rocket.preprocess_request(&mut req, &mut data).await;
let mut response = rocket.dispatch(token, &req, data).await; let mut response = rocket.dispatch(token, &req, data).await;
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
if response.status() == Status::SwitchingProtocols { if let Some((proto, handler)) = upgrade {
let may_upgrade = response.take_upgrade(); rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
match may_upgrade {
Some(upgrade) => {
// send the finishing response; needed so that hyper can upgrade the request
rocket.send_response(response, tx).await;
match hyp_upgraded.await {
Ok(hyp_upgraded) => {
// let the upgrade take the upgraded hyper request
let fu = upgrade.start(hyp_upgraded);
fu.await;
}
Err(e) => {
error_!("Failed to upgrade request: {e}");
// NOTE: we *should* send a response here but since we send one earlier AND upgraded the request,
// this cannot be done easily at this point...
// let response = rocket.handle_error(Status::InternalServerError, &req).await;
// rocket.send_response(response, tx).await;
}
}
}
None => {
error_!("Status is 101 switching protocols, but response dosn't hold a upgrade");
let response = rocket.handle_error(Status::InternalServerError, &req).await;
rocket.send_response(response, tx).await;
}
}
} else { } else {
rocket.send_response(response, tx).await; rocket.send_response(response, tx).await;
} }
@ -179,6 +154,7 @@ impl Rocket<Orbit> {
let hyp_response = hyp_res.body(hyp_body) let hyp_response = hyp_res.body(hyp_body)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
debug!("sending response: {:#?}", hyp_response);
tx.send(hyp_response).map_err(|_| { tx.send(hyp_response).map_err(|_| {
let msg = "client disconnect before response started"; let msg = "client disconnect before response started";
io::Error::new(io::ErrorKind::BrokenPipe, msg) io::Error::new(io::ErrorKind::BrokenPipe, msg)
@ -194,6 +170,34 @@ impl Rocket<Orbit> {
Ok(()) Ok(())
} }
async fn handle_upgrade<'r>(
&self,
mut response: Response<'r>,
protocol: uncased::Uncased<'r>,
mut io_handler: Box<dyn IoHandler + 'r>,
pending_upgrade: hyper::upgrade::OnUpgrade,
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
) {
info_!("Upgrading connection to {}.", Paint::white(&protocol));
response.set_status(Status::SwitchingProtocols);
response.set_raw_header("Connection", "Upgrade");
response.set_raw_header("Upgrade", protocol.into_cow());
self.send_response(response, tx).await;
match pending_upgrade.await {
Ok(io_stream) => {
info_!("Upgrade successful.");
if let Err(e) = io_handler.io(io_stream.into()).await {
error!("Upgraded I/O handler failed: {}", e);
}
},
Err(e) => {
warn!("Response indicated upgrade, but upgrade failed.");
warn_!("Upgrade error: {}", e);
}
}
}
/// Preprocess the request for Rocket things. Currently, this means: /// Preprocess the request for Rocket things. Currently, this means:
/// ///
/// * Rewriting the method in the request if _method form field exists. /// * Rewriting the method in the request if _method form field exists.

View File

@ -1,14 +0,0 @@
//! Upgrade wrapper to deal with hyper::upgarde::Upgraded
use crate::http::hyper;
/// Trait to determine if any given response in rocket is upgradeable.
///
/// When a response has the http code 101 SwitchingProtocols, and the response implements the Upgrade trait,
/// then rocket aquires the hyper::upgarde::Upgraded struct and calls the start() method of the trait with the hyper upgrade
/// and awaits the result.
#[crate::async_trait]
pub trait Upgrade<'a> {
/// Called with the hyper::upgarde::Upgraded struct when a rocket response should be upgraded
async fn start(&self, upgraded: hyper::upgrade::Upgraded);
}