From 19e7e82fd64ae98c3c53a87f941a8ba98e97e0d8 Mon Sep 17 00:00:00 2001 From: Mai-Lapyst Date: Fri, 10 Mar 2023 22:15:10 +0100 Subject: [PATCH] Initial connection upgrade API implementation. --- core/http/src/hyper.rs | 2 +- core/lib/src/lib.rs | 2 ++ core/lib/src/response/response.rs | 30 ++++++++++++++++++++++ core/lib/src/server.rs | 41 ++++++++++++++++++++++++++++--- core/lib/src/upgrade/mod.rs | 14 +++++++++++ 5 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 core/lib/src/upgrade/mod.rs diff --git a/core/http/src/hyper.rs b/core/http/src/hyper.rs index f117c757..2e98e1f0 100644 --- a/core/http/src/hyper.rs +++ b/core/http/src/hyper.rs @@ -5,7 +5,7 @@ //! while necessary. pub use hyper::{Method, Error, Body, Uri, Version, Request, Response}; -pub use hyper::{body, server, service}; +pub use hyper::{body, server, service, upgrade}; pub use http::{HeaderValue, request, uri}; /// Reexported Hyper HTTP header types. diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index e1b8a260..823a6e0f 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -123,6 +123,7 @@ pub use time; #[doc(hidden)] pub mod sentinel; pub mod local; pub mod request; +pub mod upgrade; pub mod response; pub mod config; pub mod form; @@ -175,6 +176,7 @@ mod rocket; mod router; mod phase; +#[doc(inline)] pub use crate::upgrade::Upgrade; #[doc(inline)] pub use crate::response::Response; #[doc(inline)] pub use crate::data::Data; #[doc(inline)] pub use crate::config::Config; diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 31b18fb9..5ef2aeea 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -5,6 +5,7 @@ use tokio::io::{AsyncRead, AsyncSeek}; use crate::http::{Header, HeaderMap, Status, ContentType, Cookie}; use crate::response::Body; +use crate::upgrade::Upgrade; /// Builder for the [`Response`] type. /// @@ -261,6 +262,13 @@ impl<'r> Builder<'r> { self } + /// Sets the upgrade of the `Response`. + #[inline(always)] + pub fn upgrade(&mut self, upgrade: Option + Send>>) -> &mut Builder<'r> { + self.response.set_upgrade(upgrade); + self + } + /// Sets the max chunk size of a body, if any, to `size`. /// /// See [`Response::set_max_chunk_size()`] for notes. @@ -413,6 +421,7 @@ pub struct Response<'r> { status: Option, headers: HeaderMap<'r>, body: Body<'r>, + upgrade: Option + Send>>, } impl<'r> Response<'r> { @@ -807,6 +816,27 @@ impl<'r> Response<'r> { self.body = Body::with_unsized(body); } + /// Returns a instance of the `Upgrade`-trait when the `Response` is upgradeable + #[inline(always)] + pub fn upgrade(&self) -> Option<&Box + Send>> { + self.upgrade.as_ref() + } + + /// Takes the upgrade out of the response, leaving a [`None`] in it's place. + /// With this, the caller takes ownership about the `Upgrade`-trait. + #[inline(always)] + pub fn take_upgrade(&mut self) -> Option + Send>> { + self.upgrade.take() + } + + /// Sets the upgrade contained in this `Response` + /// + /// While the response also need to have status 101 SwitchingProtocols in order to be a valid upgrade, + /// this method doesn't set this, and it's expected that the caller sets this. + pub fn set_upgrade(&mut self, upgrade: Option + Send>>) { + self.upgrade = upgrade; + } + /// Sets the body's maximum chunk size to `size` bytes. /// /// The default max chunk size is [`Body::DEFAULT_MAX_CHUNK`]. The max chunk diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index b84a6c80..d0414c0f 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -64,7 +64,7 @@ async fn handle(name: Option<&str>, run: F) -> Option async fn hyper_service_fn( rocket: Arc>, conn: ConnectionMeta, - hyp_req: hyper::Request, + mut hyp_req: hyper::Request, ) -> Result, io::Error> { // This future must return a hyper::Response, but the response body might // borrow from the request. Instead, write the body in another future that @@ -72,6 +72,9 @@ async fn hyper_service_fn( let (tx, rx) = oneshot::channel(); tokio::spawn(async move { + // Upgrade before do any other; we handle errors below + let hyp_upgraded = hyper::upgrade::on(&mut hyp_req); + // Convert a Hyper request into a Rocket request. let (h_parts, mut h_body) = hyp_req.into_parts(); match Request::from_hyp(&rocket, &h_parts, Some(conn)) { @@ -79,8 +82,40 @@ async fn hyper_service_fn( // Convert into Rocket `Data`, dispatch request, write response. let mut data = Data::from(&mut h_body); let token = rocket.preprocess_request(&mut req, &mut data).await; - let response = rocket.dispatch(token, &mut req, data).await; - rocket.send_response(response, tx).await; + let mut response = rocket.dispatch(token, &req, data).await; + + if response.status() == Status::SwitchingProtocols { + let may_upgrade = response.take_upgrade(); + match may_upgrade { + Some(upgrade) => { + + // send the finishing response; needed so that hyper can upgrade the request + rocket.send_response(response, tx).await; + + match hyp_upgraded.await { + Ok(hyp_upgraded) => { + // let the upgrade take the upgraded hyper request + let fu = upgrade.start(hyp_upgraded); + fu.await; + } + Err(e) => { + error_!("Failed to upgrade request: {e}"); + // NOTE: we *should* send a response here but since we send one earlier AND upgraded the request, + // this cannot be done easily at this point... + // let response = rocket.handle_error(Status::InternalServerError, &req).await; + // rocket.send_response(response, tx).await; + } + } + } + None => { + error_!("Status is 101 switching protocols, but response dosn't hold a upgrade"); + let response = rocket.handle_error(Status::InternalServerError, &req).await; + rocket.send_response(response, tx).await; + } + } + } else { + rocket.send_response(response, tx).await; + } }, Err(e) => { warn!("Bad incoming HTTP request."); diff --git a/core/lib/src/upgrade/mod.rs b/core/lib/src/upgrade/mod.rs new file mode 100644 index 00000000..08627f15 --- /dev/null +++ b/core/lib/src/upgrade/mod.rs @@ -0,0 +1,14 @@ +//! Upgrade wrapper to deal with hyper::upgarde::Upgraded + +use crate::http::hyper; + +/// Trait to determine if any given response in rocket is upgradeable. +/// +/// When a response has the http code 101 SwitchingProtocols, and the response implements the Upgrade trait, +/// then rocket aquires the hyper::upgarde::Upgraded struct and calls the start() method of the trait with the hyper upgrade +/// and awaits the result. +#[crate::async_trait] +pub trait Upgrade<'a> { + /// Called with the hyper::upgarde::Upgraded struct when a rocket response should be upgraded + async fn start(&self, upgraded: hyper::upgrade::Upgraded); +}