diff --git a/examples/upgrade/src/ws.rs b/examples/upgrade/src/ws.rs index ace8c5fa..48353186 100644 --- a/examples/upgrade/src/ws.rs +++ b/examples/upgrade/src/ws.rs @@ -87,14 +87,23 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> } } +/// Returns `Ok(true)` if processing should continue, `Ok(false)` if processing +/// has terminated without error, and `Err(e)` if an error has occurred. +fn handle_result(result: Result<()>) -> io::Result { + match result { + Ok(_) => Ok(true), + Err(Error::ConnectionClosed) => Ok(false), + Err(Error::Io(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)) + } +} + #[rocket::async_trait] impl IoHandler for Channel<'_> { async fn io(&mut self, io: IoStream) -> io::Result<()> { let stream = WebSocketStream::from_raw_socket(io, Role::Server, None).await; - (self.handler)(stream).await.map_err(|e| match e { - Error::Io(e) => e, - other => io::Error::new(io::ErrorKind::Other, other) - }) + let result = (self.handler)(stream).await; + handle_result(result).map(|_| ()) } } @@ -112,10 +121,9 @@ impl<'r, S> IoHandler for MessageStream<'r, S> Err(e) => Err(e) }; - result.map_err(|e| match e { - Error::Io(e) => e, - other => io::Error::new(io::ErrorKind::Other, other) - })?; + if !handle_result(result)? { + return Ok(()); + } } Ok(())