diff --git a/examples/upgrade/index.html b/examples/upgrade/index.html
deleted file mode 100644
index 0293461c..00000000
--- a/examples/upgrade/index.html
+++ /dev/null
@@ -1,69 +0,0 @@
-
-
-
- WebSocket client test
-
-
-
- WebSocket Client Test
-
-
-
-
diff --git a/examples/upgrade/src/main.rs b/examples/upgrade/src/main.rs
index fd358945..b23c2ca8 100644
--- a/examples/upgrade/src/main.rs
+++ b/examples/upgrade/src/main.rs
@@ -1,18 +1,13 @@
#[macro_use] extern crate rocket;
+use rocket::fs::{self, FileServer};
use rocket::futures::{SinkExt, StreamExt};
-use rocket::response::content::RawHtml;
mod ws;
-#[get("/")]
-fn index() -> RawHtml<&'static str> {
- RawHtml(include_str!("../index.html"))
-}
-
-#[get("/echo")]
-fn echo(ws: ws::WebSocket) -> ws::Channel {
- ws.channel(|mut stream| Box::pin(async move {
+#[get("/echo/manual")]
+fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> {
+ ws.channel(move |mut stream| Box::pin(async move {
while let Some(message) = stream.next().await {
let _ = stream.send(message?).await;
}
@@ -21,8 +16,18 @@ fn echo(ws: ws::WebSocket) -> ws::Channel {
}))
}
+#[get("/echo")]
+fn echo_stream<'r>(ws: ws::WebSocket) -> ws::Stream!['r] {
+ ws::stream! { ws =>
+ for await message in ws {
+ yield message?;
+ }
+ }
+}
+
#[launch]
fn rocket() -> _ {
rocket::build()
- .mount("/", routes![index, echo])
+ .mount("/", routes![echo_manual, echo_stream])
+ .mount("/", FileServer::from(fs::relative!("static")))
}
diff --git a/examples/upgrade/src/ws.rs b/examples/upgrade/src/ws.rs
index c5342089..7122a801 100644
--- a/examples/upgrade/src/ws.rs
+++ b/examples/upgrade/src/ws.rs
@@ -1,15 +1,19 @@
use std::io;
+use rocket::futures::{StreamExt, SinkExt};
+use rocket::futures::stream::SplitStream;
use rocket::{Request, response};
use rocket::data::{IoHandler, IoStream};
use rocket::request::{FromRequest, Outcome};
use rocket::response::{Responder, Response};
-use rocket::futures::future::BoxFuture;
+use rocket::futures::{self, future::BoxFuture};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Role;
-use tokio_tungstenite::tungstenite::error::{Result, Error};
+
+pub use tokio_tungstenite::tungstenite::error::{Result, Error};
+pub use tokio_tungstenite::tungstenite::Message;
pub struct WebSocket(String);
@@ -32,21 +36,45 @@ impl<'r> FromRequest<'r> for WebSocket {
}
}
-pub struct Channel {
+pub struct Channel<'r> {
ws: WebSocket,
- handler: Box) -> BoxFuture<'static, Result<()>> + Send>,
+ handler: Box) -> BoxFuture<'r, Result<()>> + Send + 'r>,
+}
+
+pub struct MessageStream<'r, S> {
+ ws: WebSocket,
+ handler: Box>) -> S + Send + 'r>
}
impl WebSocket {
- pub fn channel(self, handler: F) -> Channel
- where F: FnMut(WebSocketStream) -> BoxFuture<'static, Result<()>>
+ pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
+ where F: FnMut(WebSocketStream) -> BoxFuture<'r, Result<()>> + 'r
{
Channel { ws: self, handler: Box::new(handler), }
}
+
+ pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
+ where F: FnMut(SplitStream>) -> S + Send + 'r,
+ S: futures::Stream- > + Send + 'r
+ {
+ MessageStream { ws: self, handler: Box::new(stream), }
+ }
}
-impl<'r> Responder<'r, 'static> for Channel {
- fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
+impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
+ fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
+ Response::build()
+ .raw_header("Sec-Websocket-Version", "13")
+ .raw_header("Sec-WebSocket-Accept", self.ws.0.clone())
+ .upgrade("websocket", self)
+ .ok()
+ }
+}
+
+impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
+ where S: futures::Stream
- > + Send + 'o
+{
+ fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.0.clone())
@@ -56,7 +84,7 @@ impl<'r> Responder<'r, 'static> for Channel {
}
#[rocket::async_trait]
-impl IoHandler for Channel {
+impl IoHandler for Channel<'_> {
async fn io(&mut self, io: IoStream) -> io::Result<()> {
let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
(self.handler)(stream).await.map_err(|e| match e {
@@ -65,3 +93,49 @@ impl IoHandler for Channel {
})
}
}
+
+#[rocket::async_trait]
+impl<'r, S> IoHandler for MessageStream<'r, S>
+ where S: futures::Stream
- > + Send + 'r
+{
+ async fn io(&mut self, io: IoStream) -> io::Result<()> {
+ let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
+ let (mut sink, stream) = stream.split();
+ let mut stream = std::pin::pin!((self.handler)(stream));
+ while let Some(msg) = stream.next().await {
+ let result = match msg {
+ Ok(msg) => sink.send(msg).await,
+ Err(e) => Err(e)
+ };
+
+ result.map_err(|e| match e {
+ Error::Io(e) => e,
+ other => io::Error::new(io::ErrorKind::Other, other)
+ })?;
+ }
+
+ Ok(())
+ }
+}
+
+#[macro_export]
+macro_rules! Stream {
+ ($l:lifetime) => (
+ $crate::ws::MessageStream<$l, impl rocket::futures::Stream<
+ Item = $crate::ws::Result<$crate::ws::Message>
+ > + $l>
+ )
+}
+
+#[macro_export]
+macro_rules! stream {
+ ($channel:ident => $($token:tt)*) => (
+ let ws: $crate::ws::WebSocket = $channel;
+ ws.stream(move |$channel| rocket::async_stream::try_stream! {
+ $($token)*
+ })
+ )
+}
+
+pub use Stream as Stream;
+pub use stream as stream;
diff --git a/examples/upgrade/static/index.html b/examples/upgrade/static/index.html
new file mode 100644
index 00000000..f4bf4694
--- /dev/null
+++ b/examples/upgrade/static/index.html
@@ -0,0 +1,74 @@
+
+
+
+ WebSocket Client Test
+
+
+
+
+
WebSocket Client Test
+
+
+
+
+