From 2abddd923e36b49b01821f09f4dc82ee19d69dcc Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Thu, 30 Mar 2023 12:48:20 -0700 Subject: [PATCH] Implement stream websocket API in upgrade example. --- examples/upgrade/index.html | 69 ---------------------- examples/upgrade/src/main.rs | 25 ++++---- examples/upgrade/src/ws.rs | 92 +++++++++++++++++++++++++++--- examples/upgrade/static/index.html | 74 ++++++++++++++++++++++++ 4 files changed, 172 insertions(+), 88 deletions(-) delete mode 100644 examples/upgrade/index.html create mode 100644 examples/upgrade/static/index.html diff --git a/examples/upgrade/index.html b/examples/upgrade/index.html deleted file mode 100644 index 0293461c..00000000 --- a/examples/upgrade/index.html +++ /dev/null @@ -1,69 +0,0 @@ - - - - WebSocket client test - - - -

WebSocket Client Test

-
- - - diff --git a/examples/upgrade/src/main.rs b/examples/upgrade/src/main.rs index fd358945..b23c2ca8 100644 --- a/examples/upgrade/src/main.rs +++ b/examples/upgrade/src/main.rs @@ -1,18 +1,13 @@ #[macro_use] extern crate rocket; +use rocket::fs::{self, FileServer}; use rocket::futures::{SinkExt, StreamExt}; -use rocket::response::content::RawHtml; mod ws; -#[get("/")] -fn index() -> RawHtml<&'static str> { - RawHtml(include_str!("../index.html")) -} - -#[get("/echo")] -fn echo(ws: ws::WebSocket) -> ws::Channel { - ws.channel(|mut stream| Box::pin(async move { +#[get("/echo/manual")] +fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> { + ws.channel(move |mut stream| Box::pin(async move { while let Some(message) = stream.next().await { let _ = stream.send(message?).await; } @@ -21,8 +16,18 @@ fn echo(ws: ws::WebSocket) -> ws::Channel { })) } +#[get("/echo")] +fn echo_stream<'r>(ws: ws::WebSocket) -> ws::Stream!['r] { + ws::stream! { ws => + for await message in ws { + yield message?; + } + } +} + #[launch] fn rocket() -> _ { rocket::build() - .mount("/", routes![index, echo]) + .mount("/", routes![echo_manual, echo_stream]) + .mount("/", FileServer::from(fs::relative!("static"))) } diff --git a/examples/upgrade/src/ws.rs b/examples/upgrade/src/ws.rs index c5342089..7122a801 100644 --- a/examples/upgrade/src/ws.rs +++ b/examples/upgrade/src/ws.rs @@ -1,15 +1,19 @@ use std::io; +use rocket::futures::{StreamExt, SinkExt}; +use rocket::futures::stream::SplitStream; use rocket::{Request, response}; use rocket::data::{IoHandler, IoStream}; use rocket::request::{FromRequest, Outcome}; use rocket::response::{Responder, Response}; -use rocket::futures::future::BoxFuture; +use rocket::futures::{self, future::BoxFuture}; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::protocol::Role; -use tokio_tungstenite::tungstenite::error::{Result, Error}; + +pub use tokio_tungstenite::tungstenite::error::{Result, Error}; +pub use tokio_tungstenite::tungstenite::Message; pub struct WebSocket(String); @@ -32,21 +36,45 @@ impl<'r> FromRequest<'r> for WebSocket { } } -pub struct Channel { +pub struct Channel<'r> { ws: WebSocket, - handler: Box) -> BoxFuture<'static, Result<()>> + Send>, + handler: Box) -> BoxFuture<'r, Result<()>> + Send + 'r>, +} + +pub struct MessageStream<'r, S> { + ws: WebSocket, + handler: Box>) -> S + Send + 'r> } impl WebSocket { - pub fn channel(self, handler: F) -> Channel - where F: FnMut(WebSocketStream) -> BoxFuture<'static, Result<()>> + pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r> + where F: FnMut(WebSocketStream) -> BoxFuture<'r, Result<()>> + 'r { Channel { ws: self, handler: Box::new(handler), } } + + pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S> + where F: FnMut(SplitStream>) -> S + Send + 'r, + S: futures::Stream> + Send + 'r + { + MessageStream { ws: self, handler: Box::new(stream), } + } } -impl<'r> Responder<'r, 'static> for Channel { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { +impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + Response::build() + .raw_header("Sec-Websocket-Version", "13") + .raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) + .upgrade("websocket", self) + .ok() + } +} + +impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> + where S: futures::Stream> + Send + 'o +{ + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { Response::build() .raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) @@ -56,7 +84,7 @@ impl<'r> Responder<'r, 'static> for Channel { } #[rocket::async_trait] -impl IoHandler for Channel { +impl IoHandler for Channel<'_> { async fn io(&mut self, io: IoStream) -> io::Result<()> { let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; (self.handler)(stream).await.map_err(|e| match e { @@ -65,3 +93,49 @@ impl IoHandler for Channel { }) } } + +#[rocket::async_trait] +impl<'r, S> IoHandler for MessageStream<'r, S> + where S: futures::Stream> + Send + 'r +{ + async fn io(&mut self, io: IoStream) -> io::Result<()> { + let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; + let (mut sink, stream) = stream.split(); + let mut stream = std::pin::pin!((self.handler)(stream)); + while let Some(msg) = stream.next().await { + let result = match msg { + Ok(msg) => sink.send(msg).await, + Err(e) => Err(e) + }; + + result.map_err(|e| match e { + Error::Io(e) => e, + other => io::Error::new(io::ErrorKind::Other, other) + })?; + } + + Ok(()) + } +} + +#[macro_export] +macro_rules! Stream { + ($l:lifetime) => ( + $crate::ws::MessageStream<$l, impl rocket::futures::Stream< + Item = $crate::ws::Result<$crate::ws::Message> + > + $l> + ) +} + +#[macro_export] +macro_rules! stream { + ($channel:ident => $($token:tt)*) => ( + let ws: $crate::ws::WebSocket = $channel; + ws.stream(move |$channel| rocket::async_stream::try_stream! { + $($token)* + }) + ) +} + +pub use Stream as Stream; +pub use stream as stream; diff --git a/examples/upgrade/static/index.html b/examples/upgrade/static/index.html new file mode 100644 index 00000000..f4bf4694 --- /dev/null +++ b/examples/upgrade/static/index.html @@ -0,0 +1,74 @@ + + + + WebSocket Client Test + + + + +

WebSocket Client Test

+
+ + +
+
+ + +