mirror of https://github.com/rwf2/Rocket.git
Fix header lookups for connection upgrades.
This commit is contained in:
parent
2abddd923e
commit
64a7bfb37c
|
@ -794,11 +794,29 @@ impl<'r> Response<'r> {
|
|||
&self.body
|
||||
}
|
||||
|
||||
/// Returns `Ok(Some(_))` if `self` contains a suitable handler for any of
|
||||
/// the comma-separated protocols any of the strings in `I`. Returns
|
||||
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
|
||||
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
|
||||
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
|
||||
&mut self,
|
||||
mut protocols: I
|
||||
) -> Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)> {
|
||||
protocols.find_map(|p| self.upgrade.remove_entry(p.as_uncased()))
|
||||
protocols: I
|
||||
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
|
||||
if self.upgrade.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut protocols = protocols.peekable();
|
||||
let have_protocols = protocols.peek().is_some();
|
||||
let found = protocols
|
||||
.flat_map(|v| v.split(',').map(str::trim))
|
||||
.find_map(|p| self.upgrade.remove_entry(p.as_uncased()));
|
||||
|
||||
match found {
|
||||
Some(handler) => Ok(Some(handler)),
|
||||
None if have_protocols => Err(()),
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`IoHandler`] for the protocol `proto`.
|
||||
|
|
|
@ -86,9 +86,14 @@ async fn hyper_service_fn(
|
|||
let token = rocket.preprocess_request(&mut req, &mut data).await;
|
||||
let mut response = rocket.dispatch(token, &req, data).await;
|
||||
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
|
||||
if let Some((proto, handler)) = upgrade {
|
||||
if let Ok(Some((proto, handler))) = upgrade {
|
||||
rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
|
||||
} else {
|
||||
if upgrade.is_err() {
|
||||
warn_!("Request wants upgrade but no I/O handler matched.");
|
||||
info_!("Request is not being upgraded.");
|
||||
}
|
||||
|
||||
rocket.send_response(response, tx).await;
|
||||
}
|
||||
},
|
||||
|
|
|
@ -25,8 +25,12 @@ impl<'r> FromRequest<'r> for WebSocket {
|
|||
use rocket::http::uncased::eq;
|
||||
|
||||
let headers = req.headers();
|
||||
let is_upgrade = headers.get_one("Connection").map_or(false, |c| eq(c, "upgrade"));
|
||||
let is_ws = headers.get("Upgrade").any(|p| eq(p, "websocket"));
|
||||
let is_upgrade = headers.get("Connection")
|
||||
.any(|h| h.split(',').any(|v| eq(v.trim(), "upgrade")));
|
||||
|
||||
let is_ws = headers.get("Upgrade")
|
||||
.any(|h| h.split(',').any(|v| eq(v.trim(), "websocket")));
|
||||
|
||||
let is_ws_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
|
||||
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
|
||||
match key {
|
||||
|
|
Loading…
Reference in New Issue