diff --git a/examples/Cargo.toml b/examples/Cargo.toml
index 34d66552..c4c7d4ce 100644
--- a/examples/Cargo.toml
+++ b/examples/Cargo.toml
@@ -15,6 +15,7 @@ members = [
"templating",
"testing",
"tls",
+ "upgrade",
"pastebin",
"todo",
diff --git a/examples/README.md b/examples/README.md
index 9a3a0997..f265c780 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -87,3 +87,6 @@ This directory contains projects showcasing Rocket's features.
* **[`tls`](./tls)** - Illustrates configuring TLS with a variety of key pair
kinds.
+
+ * **[`upgrade`](./upgrade)** - Uses the connection upgrade API to implement
+ WebSocket support using tungstenite.
diff --git a/examples/upgrade/Cargo.toml b/examples/upgrade/Cargo.toml
new file mode 100644
index 00000000..0b70adf0
--- /dev/null
+++ b/examples/upgrade/Cargo.toml
@@ -0,0 +1,10 @@
+[package]
+name = "upgrade"
+version = "0.0.0"
+workspace = "../"
+edition = "2021"
+publish = false
+
+[dependencies]
+rocket = { path = "../../core/lib" }
+tokio-tungstenite = "0.18"
diff --git a/examples/upgrade/index.html b/examples/upgrade/index.html
new file mode 100644
index 00000000..0293461c
--- /dev/null
+++ b/examples/upgrade/index.html
@@ -0,0 +1,69 @@
+
+
+
+ WebSocket client test
+
+
+
+ WebSocket Client Test
+
+
+
+
diff --git a/examples/upgrade/src/main.rs b/examples/upgrade/src/main.rs
new file mode 100644
index 00000000..fd358945
--- /dev/null
+++ b/examples/upgrade/src/main.rs
@@ -0,0 +1,28 @@
+#[macro_use] extern crate rocket;
+
+use rocket::futures::{SinkExt, StreamExt};
+use rocket::response::content::RawHtml;
+
+mod ws;
+
+#[get("/")]
+fn index() -> RawHtml<&'static str> {
+ RawHtml(include_str!("../index.html"))
+}
+
+#[get("/echo")]
+fn echo(ws: ws::WebSocket) -> ws::Channel {
+ ws.channel(|mut stream| Box::pin(async move {
+ while let Some(message) = stream.next().await {
+ let _ = stream.send(message?).await;
+ }
+
+ Ok(())
+ }))
+}
+
+#[launch]
+fn rocket() -> _ {
+ rocket::build()
+ .mount("/", routes![index, echo])
+}
diff --git a/examples/upgrade/src/ws.rs b/examples/upgrade/src/ws.rs
new file mode 100644
index 00000000..c5342089
--- /dev/null
+++ b/examples/upgrade/src/ws.rs
@@ -0,0 +1,67 @@
+use std::io;
+
+use rocket::{Request, response};
+use rocket::data::{IoHandler, IoStream};
+use rocket::request::{FromRequest, Outcome};
+use rocket::response::{Responder, Response};
+use rocket::futures::future::BoxFuture;
+
+use tokio_tungstenite::WebSocketStream;
+use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
+use tokio_tungstenite::tungstenite::protocol::Role;
+use tokio_tungstenite::tungstenite::error::{Result, Error};
+
+pub struct WebSocket(String);
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for WebSocket {
+ type Error = std::convert::Infallible;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome {
+ use rocket::http::uncased::eq;
+
+ let headers = req.headers();
+ let is_upgrade = headers.get_one("Connection").map_or(false, |c| eq(c, "upgrade"));
+ let is_ws = headers.get("Upgrade").any(|p| eq(p, "websocket"));
+ let is_ws_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
+ let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
+ match key {
+ Some(key) if is_upgrade && is_ws && is_ws_13 => Outcome::Success(WebSocket(key)),
+ Some(_) | None => Outcome::Forward(())
+ }
+ }
+}
+
+pub struct Channel {
+ ws: WebSocket,
+ handler: Box) -> BoxFuture<'static, Result<()>> + Send>,
+}
+
+impl WebSocket {
+ pub fn channel(self, handler: F) -> Channel
+ where F: FnMut(WebSocketStream) -> BoxFuture<'static, Result<()>>
+ {
+ Channel { ws: self, handler: Box::new(handler), }
+ }
+}
+
+impl<'r> Responder<'r, 'static> for Channel {
+ fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
+ Response::build()
+ .raw_header("Sec-Websocket-Version", "13")
+ .raw_header("Sec-WebSocket-Accept", self.ws.0.clone())
+ .upgrade("websocket", self)
+ .ok()
+ }
+}
+
+#[rocket::async_trait]
+impl IoHandler for Channel {
+ async fn io(&mut self, io: IoStream) -> io::Result<()> {
+ let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
+ (self.handler)(stream).await.map_err(|e| match e {
+ Error::Io(e) => e,
+ other => io::Error::new(io::ErrorKind::Other, other)
+ })
+ }
+}