Implement stream websocket API in upgrade example.

This commit is contained in:
Sergio Benitez 2023-03-30 12:48:20 -07:00
parent 2a63b1a41f
commit 2abddd923e
4 changed files with 172 additions and 88 deletions

View File

@ -1,69 +0,0 @@
<!DOCTYPE html>
<head>
<title>WebSocket client test</title>
</head>
<body>
<h1>WebSocket Client Test</h1>
<div id="log"></div>
</body>
<script language="javascript" type="text/javascript">
var wsUri = "ws://127.0.0.1:8000/echo/";
var log;
function init()
{
log = document.getElementById("log");
testWebSocket();
}
function testWebSocket()
{
websocket = new WebSocket(wsUri);
websocket.onopen = function(evt) { onOpen(evt) };
websocket.onclose = function(evt) { onClose(evt) };
websocket.onmessage = function(evt) { onMessage(evt) };
websocket.onerror = function(evt) { onError(evt) };
}
function onOpen(evt)
{
writeLog("CONNECTED");
sendMessage("Hello world");
}
function onClose(evt)
{
writeLog("Websocket DISCONNECTED");
}
function onMessage(evt)
{
writeLog('<span style="color: blue;">RESPONSE: ' + evt.data+'</span>');
websocket.close();
}
function onError(evt)
{
writeLog('<span style="color: red;">ERROR:</span> ' + evt.data);
}
function sendMessage(message)
{
writeLog("SENT: " + message);
websocket.send(message);
}
function writeLog(message)
{
var pre = document.createElement("p");
pre.innerHTML = message;
log.appendChild(pre);
}
window.addEventListener("load", init, false);
</script>

View File

@ -1,18 +1,13 @@
#[macro_use] extern crate rocket; #[macro_use] extern crate rocket;
use rocket::fs::{self, FileServer};
use rocket::futures::{SinkExt, StreamExt}; use rocket::futures::{SinkExt, StreamExt};
use rocket::response::content::RawHtml;
mod ws; mod ws;
#[get("/")] #[get("/echo/manual")]
fn index() -> RawHtml<&'static str> { fn echo_manual<'r>(ws: ws::WebSocket) -> ws::Channel<'r> {
RawHtml(include_str!("../index.html")) ws.channel(move |mut stream| Box::pin(async move {
}
#[get("/echo")]
fn echo(ws: ws::WebSocket) -> ws::Channel {
ws.channel(|mut stream| Box::pin(async move {
while let Some(message) = stream.next().await { while let Some(message) = stream.next().await {
let _ = stream.send(message?).await; let _ = stream.send(message?).await;
} }
@ -21,8 +16,18 @@ fn echo(ws: ws::WebSocket) -> ws::Channel {
})) }))
} }
#[get("/echo")]
fn echo_stream<'r>(ws: ws::WebSocket) -> ws::Stream!['r] {
ws::stream! { ws =>
for await message in ws {
yield message?;
}
}
}
#[launch] #[launch]
fn rocket() -> _ { fn rocket() -> _ {
rocket::build() rocket::build()
.mount("/", routes![index, echo]) .mount("/", routes![echo_manual, echo_stream])
.mount("/", FileServer::from(fs::relative!("static")))
} }

View File

@ -1,15 +1,19 @@
use std::io; use std::io;
use rocket::futures::{StreamExt, SinkExt};
use rocket::futures::stream::SplitStream;
use rocket::{Request, response}; use rocket::{Request, response};
use rocket::data::{IoHandler, IoStream}; use rocket::data::{IoHandler, IoStream};
use rocket::request::{FromRequest, Outcome}; use rocket::request::{FromRequest, Outcome};
use rocket::response::{Responder, Response}; use rocket::response::{Responder, Response};
use rocket::futures::future::BoxFuture; use rocket::futures::{self, future::BoxFuture};
use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Role; use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::tungstenite::error::{Result, Error};
pub use tokio_tungstenite::tungstenite::error::{Result, Error};
pub use tokio_tungstenite::tungstenite::Message;
pub struct WebSocket(String); pub struct WebSocket(String);
@ -32,21 +36,45 @@ impl<'r> FromRequest<'r> for WebSocket {
} }
} }
pub struct Channel { pub struct Channel<'r> {
ws: WebSocket, ws: WebSocket,
handler: Box<dyn FnMut(WebSocketStream<IoStream>) -> BoxFuture<'static, Result<()>> + Send>, handler: Box<dyn FnMut(WebSocketStream<IoStream>) -> BoxFuture<'r, Result<()>> + Send + 'r>,
}
pub struct MessageStream<'r, S> {
ws: WebSocket,
handler: Box<dyn FnMut(SplitStream<WebSocketStream<IoStream>>) -> S + Send + 'r>
} }
impl WebSocket { impl WebSocket {
pub fn channel<F: Send + 'static>(self, handler: F) -> Channel pub fn channel<'r, F: Send + 'r>(self, handler: F) -> Channel<'r>
where F: FnMut(WebSocketStream<IoStream>) -> BoxFuture<'static, Result<()>> where F: FnMut(WebSocketStream<IoStream>) -> BoxFuture<'r, Result<()>> + 'r
{ {
Channel { ws: self, handler: Box::new(handler), } Channel { ws: self, handler: Box::new(handler), }
} }
pub fn stream<'r, F, S>(self, stream: F) -> MessageStream<'r, S>
where F: FnMut(SplitStream<WebSocketStream<IoStream>>) -> S + Send + 'r,
S: futures::Stream<Item = Result<Message>> + Send + 'r
{
MessageStream { ws: self, handler: Box::new(stream), }
}
} }
impl<'r> Responder<'r, 'static> for Channel { impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.0.clone())
.upgrade("websocket", self)
.ok()
}
}
impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'o
{
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
Response::build() Response::build()
.raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.0.clone()) .raw_header("Sec-WebSocket-Accept", self.ws.0.clone())
@ -56,7 +84,7 @@ impl<'r> Responder<'r, 'static> for Channel {
} }
#[rocket::async_trait] #[rocket::async_trait]
impl IoHandler for Channel { impl IoHandler for Channel<'_> {
async fn io(&mut self, io: IoStream) -> io::Result<()> { async fn io(&mut self, io: IoStream) -> io::Result<()> {
let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
(self.handler)(stream).await.map_err(|e| match e { (self.handler)(stream).await.map_err(|e| match e {
@ -65,3 +93,49 @@ impl IoHandler for Channel {
}) })
} }
} }
#[rocket::async_trait]
impl<'r, S> IoHandler for MessageStream<'r, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'r
{
async fn io(&mut self, io: IoStream) -> io::Result<()> {
let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
let (mut sink, stream) = stream.split();
let mut stream = std::pin::pin!((self.handler)(stream));
while let Some(msg) = stream.next().await {
let result = match msg {
Ok(msg) => sink.send(msg).await,
Err(e) => Err(e)
};
result.map_err(|e| match e {
Error::Io(e) => e,
other => io::Error::new(io::ErrorKind::Other, other)
})?;
}
Ok(())
}
}
#[macro_export]
macro_rules! Stream {
($l:lifetime) => (
$crate::ws::MessageStream<$l, impl rocket::futures::Stream<
Item = $crate::ws::Result<$crate::ws::Message>
> + $l>
)
}
#[macro_export]
macro_rules! stream {
($channel:ident => $($token:tt)*) => (
let ws: $crate::ws::WebSocket = $channel;
ws.stream(move |$channel| rocket::async_stream::try_stream! {
$($token)*
})
)
}
pub use Stream as Stream;
pub use stream as stream;

View File

@ -0,0 +1,74 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>WebSocket Client Test</title>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body>
<h1>WebSocket Client Test</h1>
<form action="#">
<input type="text" id="message" name="message" value="" placeholder="Send a message...">
<input type="submit" value="Submit">
</form>
<div id="log"></div>
</body>
<script language="javascript" type="text/javascript">
var wsUri = "ws://127.0.0.1:8000/echo";
var log;
function init() {
log = document.getElementById("log");
form = document.getElementsByTagName("form")[0];
message = document.getElementById("message");
testWebSocket();
form.addEventListener("submit", (e) => {
e.preventDefault();
if (message.value !== "") {
sendMessage(message.value);
message.value = "";
}
});
}
function testWebSocket() {
websocket = new WebSocket(wsUri);
websocket.onopen = onOpen;
websocket.onclose = onClose;
websocket.onmessage = onMessage;
websocket.onerror = onError;
}
function onOpen(evt) {
writeLog("CONNECTED");
sendMessage("Hello, Rocket!");
}
function onClose(evt) {
writeLog("Websocket DISCONNECTED");
}
function onMessage(evt) {
writeLog('<span style="color: blue;">RESPONSE: ' + evt.data+'</span>');
}
function onError(evt) {
writeLog('<span style="color: red;">ERROR:</span> ' + evt.data);
}
function sendMessage(message) {
writeLog("SENT: " + message);
websocket.send(message);
}
function writeLog(message) {
var pre = document.createElement("p");
pre.innerHTML = message;
log.prepend(pre);
}
window.addEventListener("load", init, false);
</script>
</html>