From bd482081ad58b68fdcf2398d87ddda4d8e850964 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Wed, 29 Mar 2023 17:01:33 -0700 Subject: [PATCH] Add 'upgrade' example with WebSocket support. This is an initial example that showcases using the new connection upgrade API to implement WebSocket support outside of Rocket's core. --- examples/Cargo.toml | 1 + examples/README.md | 3 ++ examples/upgrade/Cargo.toml | 10 ++++++ examples/upgrade/index.html | 69 ++++++++++++++++++++++++++++++++++++ examples/upgrade/src/main.rs | 28 +++++++++++++++ examples/upgrade/src/ws.rs | 67 ++++++++++++++++++++++++++++++++++ 6 files changed, 178 insertions(+) create mode 100644 examples/upgrade/Cargo.toml create mode 100644 examples/upgrade/index.html create mode 100644 examples/upgrade/src/main.rs create mode 100644 examples/upgrade/src/ws.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 34d66552..c4c7d4ce 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -15,6 +15,7 @@ members = [ "templating", "testing", "tls", + "upgrade", "pastebin", "todo", diff --git a/examples/README.md b/examples/README.md index 9a3a0997..f265c780 100644 --- a/examples/README.md +++ b/examples/README.md @@ -87,3 +87,6 @@ This directory contains projects showcasing Rocket's features. * **[`tls`](./tls)** - Illustrates configuring TLS with a variety of key pair kinds. + + * **[`upgrade`](./upgrade)** - Uses the connection upgrade API to implement + WebSocket support using tungstenite. diff --git a/examples/upgrade/Cargo.toml b/examples/upgrade/Cargo.toml new file mode 100644 index 00000000..0b70adf0 --- /dev/null +++ b/examples/upgrade/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "upgrade" +version = "0.0.0" +workspace = "../" +edition = "2021" +publish = false + +[dependencies] +rocket = { path = "../../core/lib" } +tokio-tungstenite = "0.18" diff --git a/examples/upgrade/index.html b/examples/upgrade/index.html new file mode 100644 index 00000000..0293461c --- /dev/null +++ b/examples/upgrade/index.html @@ -0,0 +1,69 @@ + + + + WebSocket client test + + + +

WebSocket Client Test

+
+ + + diff --git a/examples/upgrade/src/main.rs b/examples/upgrade/src/main.rs new file mode 100644 index 00000000..fd358945 --- /dev/null +++ b/examples/upgrade/src/main.rs @@ -0,0 +1,28 @@ +#[macro_use] extern crate rocket; + +use rocket::futures::{SinkExt, StreamExt}; +use rocket::response::content::RawHtml; + +mod ws; + +#[get("/")] +fn index() -> RawHtml<&'static str> { + RawHtml(include_str!("../index.html")) +} + +#[get("/echo")] +fn echo(ws: ws::WebSocket) -> ws::Channel { + ws.channel(|mut stream| Box::pin(async move { + while let Some(message) = stream.next().await { + let _ = stream.send(message?).await; + } + + Ok(()) + })) +} + +#[launch] +fn rocket() -> _ { + rocket::build() + .mount("/", routes![index, echo]) +} diff --git a/examples/upgrade/src/ws.rs b/examples/upgrade/src/ws.rs new file mode 100644 index 00000000..c5342089 --- /dev/null +++ b/examples/upgrade/src/ws.rs @@ -0,0 +1,67 @@ +use std::io; + +use rocket::{Request, response}; +use rocket::data::{IoHandler, IoStream}; +use rocket::request::{FromRequest, Outcome}; +use rocket::response::{Responder, Response}; +use rocket::futures::future::BoxFuture; + +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::handshake::derive_accept_key; +use tokio_tungstenite::tungstenite::protocol::Role; +use tokio_tungstenite::tungstenite::error::{Result, Error}; + +pub struct WebSocket(String); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for WebSocket { + type Error = std::convert::Infallible; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + use rocket::http::uncased::eq; + + let headers = req.headers(); + let is_upgrade = headers.get_one("Connection").map_or(false, |c| eq(c, "upgrade")); + let is_ws = headers.get("Upgrade").any(|p| eq(p, "websocket")); + let is_ws_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13"); + let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes())); + match key { + Some(key) if is_upgrade && is_ws && is_ws_13 => Outcome::Success(WebSocket(key)), + Some(_) | None => Outcome::Forward(()) + } + } +} + +pub struct Channel { + ws: WebSocket, + handler: Box) -> BoxFuture<'static, Result<()>> + Send>, +} + +impl WebSocket { + pub fn channel(self, handler: F) -> Channel + where F: FnMut(WebSocketStream) -> BoxFuture<'static, Result<()>> + { + Channel { ws: self, handler: Box::new(handler), } + } +} + +impl<'r> Responder<'r, 'static> for Channel { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + Response::build() + .raw_header("Sec-Websocket-Version", "13") + .raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) + .upgrade("websocket", self) + .ok() + } +} + +#[rocket::async_trait] +impl IoHandler for Channel { + async fn io(&mut self, io: IoStream) -> io::Result<()> { + let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; + (self.handler)(stream).await.map_err(|e| match e { + Error::Io(e) => e, + other => io::Error::new(io::ErrorKind::Other, other) + }) + } +}