Fix header lookups for connection upgrades.

This commit is contained in:
Sergio Benitez 2023-03-30 15:09:25 -07:00
parent 2abddd923e
commit 64a7bfb37c
3 changed files with 33 additions and 6 deletions

View File

@ -794,11 +794,29 @@ impl<'r> Response<'r> {
&self.body
}
/// Returns `Ok(Some(_))` if `self` contains a suitable handler for any of
/// the comma-separated protocols any of the strings in `I`. Returns
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
&mut self,
mut protocols: I
) -> Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)> {
protocols.find_map(|p| self.upgrade.remove_entry(p.as_uncased()))
protocols: I
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
if self.upgrade.is_empty() {
return Ok(None);
}
let mut protocols = protocols.peekable();
let have_protocols = protocols.peek().is_some();
let found = protocols
.flat_map(|v| v.split(',').map(str::trim))
.find_map(|p| self.upgrade.remove_entry(p.as_uncased()));
match found {
Some(handler) => Ok(Some(handler)),
None if have_protocols => Err(()),
None => Ok(None)
}
}
/// Returns the [`IoHandler`] for the protocol `proto`.

View File

@ -86,9 +86,14 @@ async fn hyper_service_fn(
let token = rocket.preprocess_request(&mut req, &mut data).await;
let mut response = rocket.dispatch(token, &req, data).await;
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
if let Some((proto, handler)) = upgrade {
if let Ok(Some((proto, handler))) = upgrade {
rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
} else {
if upgrade.is_err() {
warn_!("Request wants upgrade but no I/O handler matched.");
info_!("Request is not being upgraded.");
}
rocket.send_response(response, tx).await;
}
},

View File

@ -25,8 +25,12 @@ impl<'r> FromRequest<'r> for WebSocket {
use rocket::http::uncased::eq;
let headers = req.headers();
let is_upgrade = headers.get_one("Connection").map_or(false, |c| eq(c, "upgrade"));
let is_ws = headers.get("Upgrade").any(|p| eq(p, "websocket"));
let is_upgrade = headers.get("Connection")
.any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
let is_ws = headers.get("Upgrade")
.any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
let is_ws_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
match key {