From 7d895eb9f674ac493942cc2a56dea36556aa87ac Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Sat, 1 Apr 2023 15:02:24 -0700 Subject: [PATCH] Add initial implementation of 'rocket_ws'. This provides WebSocket support in Rocket's official 'contrib'. --- Cargo.toml | 1 + contrib/ws/Cargo.toml | 28 +++ contrib/ws/README.md | 35 ++++ contrib/ws/src/duplex.rs | 91 +++++++++ contrib/ws/src/lib.rs | 182 ++++++++++++++++++ .../src/ws.rs => contrib/ws/src/websocket.rs | 145 +++++++------- examples/upgrade/Cargo.toml | 2 +- examples/upgrade/src/main.rs | 5 +- scripts/config.sh | 1 + scripts/mk-docs.sh | 2 +- scripts/test.sh | 9 + 11 files changed, 420 insertions(+), 81 deletions(-) create mode 100644 contrib/ws/Cargo.toml create mode 100644 contrib/ws/README.md create mode 100644 contrib/ws/src/duplex.rs create mode 100644 contrib/ws/src/lib.rs rename examples/upgrade/src/ws.rs => contrib/ws/src/websocket.rs (60%) diff --git a/Cargo.toml b/Cargo.toml index 8ec081ef..7260ca1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,5 +8,6 @@ members = [ "contrib/sync_db_pools/codegen/", "contrib/sync_db_pools/lib/", "contrib/dyn_templates/", + "contrib/ws/", "site/tests", ] diff --git a/contrib/ws/Cargo.toml b/contrib/ws/Cargo.toml new file mode 100644 index 00000000..b4a9481f --- /dev/null +++ b/contrib/ws/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "rocket_ws" +version = "0.1.0-rc.3" +authors = ["Sergio Benitez "] +description = "WebSocket support for Rocket." +documentation = "https://api.rocket.rs/v0.5-rc/rocket_ws/" +homepage = "https://rocket.rs" +repository = "https://github.com/SergioBenitez/Rocket/tree/master/contrib/ws" +readme = "README.md" +keywords = ["rocket", "web", "framework", "websocket"] +license = "MIT OR Apache-2.0" +edition = "2021" +rust-version = "1.56" + +[features] +default = ["tungstenite"] +tungstenite = ["tokio-tungstenite"] + +[dependencies] +tokio-tungstenite = { version = "0.18", optional = true } + +[dependencies.rocket] +version = "=0.5.0-rc.3" +path = "../../core/lib" +default-features = false + +[package.metadata.docs.rs] +all-features = true diff --git a/contrib/ws/README.md b/contrib/ws/README.md new file mode 100644 index 00000000..37759707 --- /dev/null +++ b/contrib/ws/README.md @@ -0,0 +1,35 @@ +# `ws` [![ci.svg]][ci] [![crates.io]][crate] [![docs.svg]][crate docs] + +[crates.io]: https://img.shields.io/crates/v/rocket_ws.svg +[crate]: https://crates.io/crates/rocket_ws +[docs.svg]: https://img.shields.io/badge/web-master-red.svg?style=flat&label=docs&colorB=d33847 +[crate docs]: https://api.rocket.rs/v0.5-rc/rocket_ws +[ci.svg]: https://github.com/SergioBenitez/Rocket/workflows/CI/badge.svg +[ci]: https://github.com/SergioBenitez/Rocket/actions + +This crate provides WebSocket support for Rocket via integration with Rocket's +[connection upgrades] API. + +# Usage + + 1. Depend on `rocket_ws`, renamed here to `ws`: + + ```toml + [dependencies] + ws = { package = "rocket_ws", version ="=0.1.0-rc.3" } + ``` + + 2. Use it! + + ```rust + #[get("/echo")] + fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] { + ws::stream! { ws => + for await message in ws { + yield message?; + } + } + } + ``` + +See the [crate docs] for full details. diff --git a/contrib/ws/src/duplex.rs b/contrib/ws/src/duplex.rs new file mode 100644 index 00000000..2a57ae90 --- /dev/null +++ b/contrib/ws/src/duplex.rs @@ -0,0 +1,91 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use rocket::data::IoStream; +use rocket::futures::{StreamExt, SinkExt, Sink}; +use rocket::futures::stream::{Stream, FusedStream}; + +use crate::frame::{Message, CloseFrame}; +use crate::result::{Result, Error}; + +/// A readable and writeable WebSocket [`Message`] `async` stream. +/// +/// This struct implements [`Stream`] and [`Sink`], allowing for `async` reading +/// and writing of [`Message`]s. The [`StreamExt`] and [`SinkExt`] traits can be +/// imported to provide additional functionality for streams and sinks: +/// +/// ```rust +/// # use rocket::get; +/// use rocket_ws as ws; +/// +/// use rocket::futures::{SinkExt, StreamExt}; +/// +/// #[get("/echo/manual")] +/// fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> { +/// ws.channel(move |mut stream| Box::pin(async move { +/// while let Some(message) = stream.next().await { +/// let _ = stream.send(message?).await; +/// } +/// +/// Ok(()) +/// })) +/// } +/// ```` +/// +/// [`StreamExt`]: rocket::futures::StreamExt +/// [`SinkExt`]: rocket::futures::SinkExt + +pub struct DuplexStream(tokio_tungstenite::WebSocketStream); + +impl DuplexStream { + pub(crate) async fn new(stream: IoStream, config: crate::Config) -> Self { + use tokio_tungstenite::WebSocketStream; + use crate::tungstenite::protocol::Role; + + let inner = WebSocketStream::from_raw_socket(stream, Role::Server, Some(config)); + DuplexStream(inner.await) + } + + /// Close the stream now. This does not typically need to be called. + pub async fn close(&mut self, msg: Option>) -> Result<()> { + self.0.close(msg).await + } +} + +impl Stream for DuplexStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().0.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +impl FusedStream for DuplexStream { + fn is_terminated(&self) -> bool { + self.0.is_terminated() + } +} + +impl Sink for DuplexStream { + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().0.poll_ready_unpin(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.get_mut().0.start_send_unpin(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().0.poll_flush_unpin(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().0.poll_close_unpin(cx) + } +} diff --git a/contrib/ws/src/lib.rs b/contrib/ws/src/lib.rs new file mode 100644 index 00000000..7c180fb6 --- /dev/null +++ b/contrib/ws/src/lib.rs @@ -0,0 +1,182 @@ +//! WebSocket support for Rocket. + +#![doc(html_root_url = "https://api.rocket.rs/v0.5-rc/rocket_ws")] +#![doc(html_favicon_url = "https://rocket.rs/images/favicon.ico")] +#![doc(html_logo_url = "https://rocket.rs/images/logo-boxed.png")] + +mod tungstenite { + #[doc(inline)] pub use tokio_tungstenite::tungstenite::*; +} + +mod duplex; +mod websocket; + +pub use self::tungstenite::Message; +pub use self::tungstenite::protocol::WebSocketConfig as Config; +pub use self::websocket::{WebSocket, Channel}; + +/// Structures for constructing raw WebSocket frames. +pub mod frame { + #[doc(hidden)] pub use crate::Message; + pub use crate::tungstenite::protocol::frame::{CloseFrame, Frame}; + pub use crate::tungstenite::protocol::frame::coding::CloseCode; +} + +/// Types representing incoming and/or outgoing `async` [`Message`] streams. +pub mod stream { + pub use crate::duplex::DuplexStream; + pub use crate::websocket::MessageStream; +} + +/// Library [`Error`](crate::result::Error) and +/// [`Result`](crate::result::Result) types. +pub mod result { + pub use crate::tungstenite::error::{Result, Error}; +} + +/// Type and expression macro for `async` WebSocket [`Message`] streams. +/// +/// This macro can be used both where types are expected or +/// where expressions are expected. +/// +/// # Type Position +/// +/// When used in a type position, the macro invoked as `Stream['r]` expands to: +/// +/// - [`MessageStream`]`<'r, impl `[`Stream`]`>> + 'r>` +/// +/// The lifetime need not be specified as `'r`. For instance, `Stream['request]` +/// is valid and expands as expected: +/// +/// - [`MessageStream`]`<'request, impl `[`Stream`]`>> + 'request>` +/// +/// As a convenience, when the macro is invoked as `Stream![]`, the lifetime +/// defaults to `'static`. That is, `Stream![]` is equivalent to +/// `Stream!['static]`. +/// +/// [`MessageStream`]: crate::stream::MessageStream +/// [`Stream`]: rocket::futures::stream::Stream +/// [`Result`]: crate::result::Result +/// [`Message`]: crate::Message +/// +/// # Expression Position +/// +/// When invoked as an expression, the macro behaves similarly to Rocket's +/// [`stream!`](rocket::response::stream::stream) macro. Specifically, it +/// supports `yield` and `for await` syntax. It is invoked as follows: +/// +/// ```rust +/// # use rocket::get; +/// use rocket_ws as ws; +/// +/// #[get("/")] +/// fn echo(ws: ws::WebSocket) -> ws::Stream![] { +/// ws::Stream! { ws => +/// for await message in ws { +/// yield message?; +/// yield "foo".into(); +/// yield vec![1, 2, 3, 4].into(); +/// } +/// } +/// } +/// ``` +/// +/// It enjoins the following type requirements: +/// +/// * The type of `ws` _must_ be [`WebSocket`]. `ws` can be any ident. +/// * The type of yielded expressions (`expr` in `yield expr`) _must_ be [`Message`]. +/// * The `Err` type of expressions short-circuited with `?` _must_ be [`Error`]. +/// +/// [`Error`]: crate::result::Error +/// +/// The macro takes any series of statements and expands them into an expression +/// of type `impl Stream>`, a stream that `yield`s elements of +/// type [`Result`]``. It automatically converts yielded items of type `T` into +/// `Ok(T)`. It supports any Rust statement syntax with the following +/// extensions: +/// +/// * `?` short-circuits stream termination on `Err` +/// +/// The type of the error value must be [`Error`]. +///

+/// +/// * `yield expr` +/// +/// Yields the result of evaluating `expr` to the caller (the stream +/// consumer) wrapped in `Ok`. +/// +/// `expr` must be of type `T`. +///

+/// +/// * `for await x in stream { .. }` +/// +/// `await`s the next element in `stream`, binds it to `x`, and executes the +/// block with the binding. +/// +/// `stream` must implement `Stream`; the type of `x` is `T`. +/// +/// ### Examples +/// +/// Borrow from the request. Send a single message and close: +/// +/// ```rust +/// # use rocket::get; +/// use rocket_ws as ws; +/// +/// #[get("/hello/")] +/// fn ws_hello(ws: ws::WebSocket, user: &str) -> ws::Stream!['_] { +/// ws::Stream! { ws => +/// yield user.into(); +/// } +/// } +/// ``` +/// +/// Borrow from the request with explicit lifetime: +/// +/// ```rust +/// # use rocket::get; +/// use rocket_ws as ws; +/// +/// #[get("/hello/")] +/// fn ws_hello<'r>(ws: ws::WebSocket, user: &'r str) -> ws::Stream!['r] { +/// ws::Stream! { ws => +/// yield user.into(); +/// } +/// } +/// ``` +/// +/// Emit several messages and short-circuit if the client sends a bad message: +/// +/// ```rust +/// # use rocket::get; +/// use rocket_ws as ws; +/// +/// #[get("/")] +/// fn echo(ws: ws::WebSocket) -> ws::Stream![] { +/// ws::Stream! { ws => +/// for await message in ws { +/// for i in 0..5u8 { +/// yield i.to_string().into(); +/// } +/// +/// yield message?; +/// } +/// } +/// } +/// ``` +/// +#[macro_export] +macro_rules! Stream { + () => ($crate::Stream!['static]); + ($l:lifetime) => ( + $crate::stream::MessageStream<$l, impl rocket::futures::Stream< + Item = $crate::result::Result<$crate::Message> + > + $l> + ); + ($channel:ident => $($token:tt)*) => ( + let ws: $crate::WebSocket = $channel; + ws.stream(move |$channel| rocket::async_stream::try_stream! { + $($token)* + }) + ); +} diff --git a/examples/upgrade/src/ws.rs b/contrib/ws/src/websocket.rs similarity index 60% rename from examples/upgrade/src/ws.rs rename to contrib/ws/src/websocket.rs index a2758977..ce03cae1 100644 --- a/examples/upgrade/src/ws.rs +++ b/contrib/ws/src/websocket.rs @@ -1,27 +1,70 @@ use std::io; -use rocket::futures::{StreamExt, SinkExt}; -use rocket::futures::stream::SplitStream; -use rocket::{Request, response}; use rocket::data::{IoHandler, IoStream}; +use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream}; +use rocket::response::{self, Responder, Response}; use rocket::request::{FromRequest, Outcome}; -use rocket::response::{Responder, Response}; -use rocket::futures::{self, future::BoxFuture}; +use rocket::request::Request; -use tokio_tungstenite::WebSocketStream; -use tokio_tungstenite::tungstenite::handshake::derive_accept_key; -use tokio_tungstenite::tungstenite::protocol::Role; +use crate::{Config, Message}; +use crate::stream::DuplexStream; +use crate::result::{Result, Error}; -pub use tokio_tungstenite::tungstenite::error::{Result, Error}; -pub use tokio_tungstenite::tungstenite::Message; +/// A request guard that identifies WebSocket requests. Converts into a +/// [`Channel`] or [`MessageStream`]. +pub struct WebSocket { + config: Config, + key: String, +} -pub struct WebSocket(String); +impl WebSocket { + fn new(key: String) -> WebSocket { + WebSocket { config: Config::default(), key } + } + + pub fn config(mut self, config: Config) -> Self { + self.config = config; + self + } + + pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r> + where F: FnMut(DuplexStream) -> BoxFuture<'r, Result<()>> + 'r + { + Channel { ws: self, handler: Box::new(handler), } + } + + pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S> + where F: FnMut(SplitStream) -> S + Send + 'r, + S: futures::Stream> + Send + 'r + { + MessageStream { ws: self, handler: Box::new(stream), } + } +} + +/// A streaming channel, returned by [`WebSocket::channel()`]. +pub struct Channel<'r> { + ws: WebSocket, + handler: Box BoxFuture<'r, Result<()>> + Send + 'r>, +} + +/// A [`Stream`](futures::Stream) of [`Message`]s, returned by +/// [`WebSocket::stream()`], used via [`Stream!`]. +/// +/// This type is not typically used directly. Instead, it is used via the +/// [`Stream!`] macro, which expands to both the type itself and an expression +/// which evaluates to this type. +// TODO: Get rid of this or `Channel` via a single `enum`. +pub struct MessageStream<'r, S> { + ws: WebSocket, + handler: Box) -> S + Send + 'r> +} #[rocket::async_trait] impl<'r> FromRequest<'r> for WebSocket { type Error = std::convert::Infallible; async fn from_request(req: &'r Request<'_>) -> Outcome { + use crate::tungstenite::handshake::derive_accept_key; use rocket::http::uncased::eq; let headers = req.headers(); @@ -31,45 +74,20 @@ impl<'r> FromRequest<'r> for WebSocket { let is_ws = headers.get("Upgrade") .any(|h| h.split(',').any(|v| eq(v.trim(), "websocket"))); - let is_ws_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13"); + let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13"); let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes())); match key { - Some(key) if is_upgrade && is_ws && is_ws_13 => Outcome::Success(WebSocket(key)), + Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)), Some(_) | None => Outcome::Forward(()) } } } -pub struct Channel<'r> { - ws: WebSocket, - handler: Box) -> BoxFuture<'r, Result<()>> + Send + 'r>, -} - -pub struct MessageStream<'r, S> { - ws: WebSocket, - handler: Box>) -> S + Send + 'r> -} - -impl WebSocket { - pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r> - where F: FnMut(WebSocketStream) -> BoxFuture<'r, Result<()>> + 'r - { - Channel { ws: self, handler: Box::new(handler), } - } - - pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S> - where F: FnMut(SplitStream>) -> S + Send + 'r, - S: futures::Stream> + Send + 'r - { - MessageStream { ws: self, handler: Box::new(stream), } - } -} - impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { Response::build() .raw_header("Sec-Websocket-Version", "13") - .raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) + .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) .upgrade("websocket", self) .ok() } @@ -81,28 +99,16 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { Response::build() .raw_header("Sec-Websocket-Version", "13") - .raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) + .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) .upgrade("websocket", self) .ok() } } -/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing -/// has terminated without error, and `Err(e)` if an error has occurred. -fn handle_result(result: Result<()>) -> io::Result { - match result { - Ok(_) => Ok(true), - Err(Error::ConnectionClosed) => Ok(false), - Err(Error::Io(e)) => Err(e), - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)) - } -} - #[rocket::async_trait] impl IoHandler for Channel<'_> { async fn io(&mut self, io: IoStream) -> io::Result<()> { - let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; - let result = (self.handler)(stream).await; + let result = (self.handler)(DuplexStream::new(io, self.ws.config).await).await; handle_result(result).map(|_| ()) } } @@ -112,8 +118,7 @@ impl<'r, S> IoHandler for MessageStream<'r, S> where S: futures::Stream> + Send + 'r { async fn io(&mut self, io: IoStream) -> io::Result<()> { - let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; - let (mut sink, stream) = stream.split(); + let (mut sink, stream) = DuplexStream::new(io, self.ws.config).await.split(); let mut stream = std::pin::pin!((self.handler)(stream)); while let Some(msg) = stream.next().await { let result = match msg { @@ -130,25 +135,13 @@ impl<'r, S> IoHandler for MessageStream<'r, S> } } -#[macro_export] -macro_rules! Stream { - () => (Stream!['static]); - ($l:lifetime) => ( - $crate::ws::MessageStream<$l, impl rocket::futures::Stream< - Item = $crate::ws::Result<$crate::ws::Message> - > + $l> - ); +/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing +/// has terminated without error, and `Err(e)` if an error has occurred. +fn handle_result(result: Result<()>) -> io::Result { + match result { + Ok(_) => Ok(true), + Err(Error::ConnectionClosed) => Ok(false), + Err(Error::Io(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)) + } } - -#[macro_export] -macro_rules! stream { - ($channel:ident => $($token:tt)*) => ( - let ws: $crate::ws::WebSocket = $channel; - ws.stream(move |$channel| rocket::async_stream::try_stream! { - $($token)* - }) - ) -} - -pub use Stream as Stream; -pub use stream as stream; diff --git a/examples/upgrade/Cargo.toml b/examples/upgrade/Cargo.toml index 0b70adf0..e34a504f 100644 --- a/examples/upgrade/Cargo.toml +++ b/examples/upgrade/Cargo.toml @@ -7,4 +7,4 @@ publish = false [dependencies] rocket = { path = "../../core/lib" } -tokio-tungstenite = "0.18" +ws = { package = "rocket_ws", path = "../../contrib/ws" } diff --git a/examples/upgrade/src/main.rs b/examples/upgrade/src/main.rs index bb6b6617..a572e022 100644 --- a/examples/upgrade/src/main.rs +++ b/examples/upgrade/src/main.rs @@ -3,8 +3,6 @@ use rocket::fs::{self, FileServer}; use rocket::futures::{SinkExt, StreamExt}; -mod ws; - #[get("/echo/manual")] fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> { ws.channel(move |mut stream| Box::pin(async move { @@ -18,7 +16,8 @@ fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> { #[get("/echo")] fn echo_stream(ws: ws::WebSocket) -> ws::Stream!['static] { - ws::stream! { ws => + let ws = ws.config(ws::Config { max_send_queue: Some(5), ..Default::default() }); + ws::Stream! { ws => for await message in ws { yield message?; } diff --git a/scripts/config.sh b/scripts/config.sh index 07c3ab65..bfae2413 100755 --- a/scripts/config.sh +++ b/scripts/config.sh @@ -98,6 +98,7 @@ ALL_CRATE_ROOTS=( "${CONTRIB_ROOT}/db_pools/codegen" "${CONTRIB_ROOT}/db_pools/lib" "${CONTRIB_ROOT}/dyn_templates" + "${CONTRIB_ROOT}/ws" ) function print_environment() { diff --git a/scripts/mk-docs.sh b/scripts/mk-docs.sh index bb0c7375..93659d09 100755 --- a/scripts/mk-docs.sh +++ b/scripts/mk-docs.sh @@ -22,7 +22,7 @@ pushd "${PROJECT_ROOT}" > /dev/null 2>&1 # Set the crate version and fill in missing doc URLs with docs.rs links. RUSTDOCFLAGS="-Zunstable-options --crate-version ${DOC_VERSION}" \ cargo doc -p rocket \ - -p rocket_sync_db_pools -p rocket_dyn_templates -p rocket_db_pools \ + -p rocket_sync_db_pools -p rocket_dyn_templates -p rocket_db_pools -p rocket_ws \ -Zrustdoc-map --no-deps --all-features popd > /dev/null 2>&1 diff --git a/scripts/test.sh b/scripts/test.sh index bb21d184..b691dab9 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -98,6 +98,10 @@ function test_contrib() { handlebars ) + WS_FEATURES=( + tungstenite + ) + for feature in "${DB_POOLS_FEATURES[@]}"; do echo ":: Building and testing db_pools [$feature]..." $CARGO test -p rocket_db_pools --no-default-features --features $feature $@ @@ -112,6 +116,11 @@ function test_contrib() { echo ":: Building and testing dyn_templates [$feature]..." $CARGO test -p rocket_dyn_templates --no-default-features --features $feature $@ done + + for feature in "${WS_FEATURES[@]}"; do + echo ":: Building and testing ws [$feature]..." + $CARGO test -p rocket_ws --no-default-features --features $feature $@ + done } function test_core() {