mirror of https://github.com/rwf2/Rocket.git
Update to hyper 1. Enable custom + unix listeners.
This commit completely rewrites Rocket's HTTP serving. In addition to significant internal cleanup, this commit introduces the following major features: * Support for custom, external listeners in the `listener` module. The new `listener` module contains new `Bindable`, `Listener`, and `Connection` traits which enable composable, external implementations of connection listeners. Rocket can launch on any `Listener`, or anything that can be used to create a listener (`Bindable`), via a new `launch_on()` method. * Support for Unix domain socket listeners out of the box. The default listener backwards compatibly supports listening on Unix domain sockets. To do so, configure an `address` of `unix:path/to/socket` and optional set `reuse` to `true` (the default) or `false` which controls whether Rocket will handle creating and deleting the unix domain socket. In addition to these new features, this commit makes the following major improvements: * Rocket now depends on hyper 1. * Rocket no longer depends on hyper to handle connections. This allows us to handle more connection failure conditions which results in an overall more robust server with fewer dependencies. * Logic to work around hyper's inability to reference incoming request data in the response results in a 15% performance improvement. * `Client`s can be marked secure with `Client::{un}tracked_secure()`, allowing Rocket to treat local connections as running under TLS. * The `macros` feature of `tokio` is no longer used by Rocket itself. Dependencies can take advantage of this reduction in compile-time cost by disabling the new default feature `tokio-macros`. * A new `TlsConfig::validate()` method allows checking a TLS config. * New `TlsConfig::{certs,key}_reader()`, `MtlsConfig::ca_certs_reader()` methods return `BufReader`s, which allow reading the configured certs and key directly. * A new `NamedFile::open_with()` constructor allows specifying `OpenOptions`. These improvements resulted in the following breaking changes: * The MSRV is now 1.74. * `hyper` is no longer exported from `rocket::http`. * `IoHandler::io` takes `Box<Self>` instead of `Pin<Box<Self>>`. - Use `Box::into_pin(self)` to recover the previous type. * `Response::upgrade()` now returns an `&mut dyn IoHandler`, not `Pin<& mut _>`. * `Config::{address,port,tls,mtls}` methods have been removed. - Use methods on `Rocket::endpoint()` instead. * `TlsConfig` was moved to `tls::TlsConfig`. * `MutualTls` was renamed and moved to `mtls::MtlsConfig`. * `ErrorKind::TlsBind` was removed. * The second field of `ErrorKind::Shutdown` was removed. * `{Local}Request::{set_}remote()` methods take/return an `Endpoint`. * `Client::new()` was removed; it was previously deprecated. Internally, the following major changes were made: * A new `async_bound` attribute macro was introduced to allow setting bounds on futures returned by `async fn`s in traits while maintaining good docs. * All utility functionality was moved to a new `util` module. Resolves #2671. Resolves #1070.
This commit is contained in:
parent
e9b568d9b2
commit
fd294049c7
|
@ -33,7 +33,6 @@ use crate::result::{Result, Error};
|
||||||
///
|
///
|
||||||
/// [`StreamExt`]: rocket::futures::StreamExt
|
/// [`StreamExt`]: rocket::futures::StreamExt
|
||||||
/// [`SinkExt`]: rocket::futures::SinkExt
|
/// [`SinkExt`]: rocket::futures::SinkExt
|
||||||
|
|
||||||
pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>);
|
pub struct DuplexStream(tokio_tungstenite::WebSocketStream<IoStream>);
|
||||||
|
|
||||||
impl DuplexStream {
|
impl DuplexStream {
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
use rocket::data::{IoHandler, IoStream};
|
use rocket::data::{IoHandler, IoStream};
|
||||||
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
|
use rocket::futures::{self, StreamExt, SinkExt, future::BoxFuture, stream::SplitStream};
|
||||||
|
@ -37,10 +36,6 @@ pub struct WebSocket {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebSocket {
|
impl WebSocket {
|
||||||
fn new(key: String) -> WebSocket {
|
|
||||||
WebSocket { config: Config::default(), key }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Change the default connection configuration to `config`.
|
/// Change the default connection configuration to `config`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
@ -202,7 +197,9 @@ impl<'r> FromRequest<'r> for WebSocket {
|
||||||
let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
|
let is_13 = headers.get_one("Sec-WebSocket-Version").map_or(false, |v| v == "13");
|
||||||
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
|
let key = headers.get_one("Sec-WebSocket-Key").map(|k| derive_accept_key(k.as_bytes()));
|
||||||
match key {
|
match key {
|
||||||
Some(key) if is_upgrade && is_ws && is_13 => Outcome::Success(WebSocket::new(key)),
|
Some(key) if is_upgrade && is_ws && is_13 => {
|
||||||
|
Outcome::Success(WebSocket { key, config: Config::default() })
|
||||||
|
},
|
||||||
Some(_) | None => Outcome::Forward(Status::BadRequest)
|
Some(_) | None => Outcome::Forward(Status::BadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -232,9 +229,9 @@ impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
|
||||||
|
|
||||||
#[rocket::async_trait]
|
#[rocket::async_trait]
|
||||||
impl IoHandler for Channel<'_> {
|
impl IoHandler for Channel<'_> {
|
||||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
let channel = Pin::into_inner(self);
|
let stream = DuplexStream::new(io, self.ws.config).await;
|
||||||
let result = (channel.handler)(DuplexStream::new(io, channel.ws.config).await).await;
|
let result = (self.handler)(stream).await;
|
||||||
handle_result(result).map(|_| ())
|
handle_result(result).map(|_| ())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -243,9 +240,9 @@ impl IoHandler for Channel<'_> {
|
||||||
impl<'r, S> IoHandler for MessageStream<'r, S>
|
impl<'r, S> IoHandler for MessageStream<'r, S>
|
||||||
where S: futures::Stream<Item = Result<Message>> + Send + 'r
|
where S: futures::Stream<Item = Result<Message>> + Send + 'r
|
||||||
{
|
{
|
||||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
|
let (mut sink, source) = DuplexStream::new(io, self.ws.config).await.split();
|
||||||
let stream = (Pin::into_inner(self).handler)(source);
|
let stream = (self.handler)(source);
|
||||||
rocket::tokio::pin!(stream);
|
rocket::tokio::pin!(stream);
|
||||||
while let Some(msg) = stream.next().await {
|
while let Some(msg) = stream.next().await {
|
||||||
let result = match msg {
|
let result = match msg {
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
use proc_macro2::{TokenStream, Span};
|
||||||
|
use devise::{Spanned, Result, ext::SpanDiagnosticExt};
|
||||||
|
use syn::{Token, parse_quote, parse_quote_spanned};
|
||||||
|
use syn::{TraitItemFn, TypeParamBound, ReturnType, Attribute};
|
||||||
|
use syn::punctuated::Punctuated;
|
||||||
|
use syn::parse::Parser;
|
||||||
|
|
||||||
|
fn _async_bound(
|
||||||
|
args: proc_macro::TokenStream,
|
||||||
|
input: proc_macro::TokenStream
|
||||||
|
) -> Result<TokenStream> {
|
||||||
|
let bounds = <Punctuated<TypeParamBound, Token![+]>>::parse_terminated.parse(args)?;
|
||||||
|
if bounds.is_empty() {
|
||||||
|
return Ok(input.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut func: TraitItemFn = syn::parse(input)?;
|
||||||
|
let original: TraitItemFn = func.clone();
|
||||||
|
if !func.sig.asyncness.is_some() {
|
||||||
|
let diag = Span::call_site()
|
||||||
|
.error("attribute can only be applied to async fns")
|
||||||
|
.span_help(func.sig.span(), "this fn declaration must be `async`");
|
||||||
|
|
||||||
|
return Err(diag);
|
||||||
|
}
|
||||||
|
|
||||||
|
let doc: Attribute = parse_quote! {
|
||||||
|
#[doc = concat!(
|
||||||
|
"# Future Bounds",
|
||||||
|
"\n",
|
||||||
|
"**The `Future` generated by this `async fn` must be `", stringify!(#bounds), "`**."
|
||||||
|
)]
|
||||||
|
};
|
||||||
|
|
||||||
|
func.sig.asyncness = None;
|
||||||
|
func.sig.output = match func.sig.output {
|
||||||
|
ReturnType::Type(arrow, ty) => parse_quote_spanned!(ty.span() =>
|
||||||
|
#arrow impl ::core::future::Future<Output = #ty> + #bounds
|
||||||
|
),
|
||||||
|
default@ReturnType::Default => default
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#[cfg(all(not(doc), rust_analyzer))]
|
||||||
|
#original
|
||||||
|
|
||||||
|
#[cfg(all(doc, not(rust_analyzer)))]
|
||||||
|
#doc
|
||||||
|
#original
|
||||||
|
|
||||||
|
#[cfg(not(any(doc, rust_analyzer)))]
|
||||||
|
#func
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn async_bound(
|
||||||
|
args: proc_macro::TokenStream,
|
||||||
|
input: proc_macro::TokenStream
|
||||||
|
) -> TokenStream {
|
||||||
|
_async_bound(args, input).unwrap_or_else(|d| d.emit_as_item_tokens())
|
||||||
|
}
|
|
@ -2,3 +2,4 @@ pub mod entry;
|
||||||
pub mod catch;
|
pub mod catch;
|
||||||
pub mod route;
|
pub mod route;
|
||||||
pub mod param;
|
pub mod param;
|
||||||
|
pub mod async_bound;
|
||||||
|
|
|
@ -331,7 +331,7 @@ fn codegen_route(route: Route) -> Result<TokenStream> {
|
||||||
let internal_uri_macro = internal_uri_macro_decl(&route);
|
let internal_uri_macro = internal_uri_macro_decl(&route);
|
||||||
let responder_outcome = responder_outcome_expr(&route);
|
let responder_outcome = responder_outcome_expr(&route);
|
||||||
|
|
||||||
let method = route.attr.method;
|
let method = &route.attr.method;
|
||||||
let uri = route.attr.uri.to_string();
|
let uri = route.attr.uri.to_string();
|
||||||
let rank = Optional(route.attr.rank);
|
let rank = Optional(route.attr.rank);
|
||||||
let format = Optional(route.attr.format.as_ref());
|
let format = Optional(route.attr.format.as_ref());
|
||||||
|
|
|
@ -13,7 +13,7 @@ pub struct Status(pub http::Status);
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MediaType(pub http::MediaType);
|
pub struct MediaType(pub http::MediaType);
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Method(pub http::Method);
|
pub struct Method(pub http::Method);
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -108,7 +108,7 @@ const VALID_METHODS: &[http::Method] = &[
|
||||||
impl FromMeta for Method {
|
impl FromMeta for Method {
|
||||||
fn from_meta(meta: &MetaItem) -> Result<Self> {
|
fn from_meta(meta: &MetaItem) -> Result<Self> {
|
||||||
let span = meta.value_span();
|
let span = meta.value_span();
|
||||||
let help_text = format!("method must be one of: {}", VALID_METHODS_STR);
|
let help_text = format!("method must be one of: {VALID_METHODS_STR}");
|
||||||
|
|
||||||
if let MetaItem::Path(path) = meta {
|
if let MetaItem::Path(path) = meta {
|
||||||
if let Some(ident) = path.last_ident() {
|
if let Some(ident) = path.last_ident() {
|
||||||
|
@ -131,19 +131,13 @@ impl FromMeta for Method {
|
||||||
|
|
||||||
impl ToTokens for Method {
|
impl ToTokens for Method {
|
||||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||||
let method_tokens = match self.0 {
|
let mut chars = self.0.as_str().chars();
|
||||||
http::Method::Get => quote!(::rocket::http::Method::Get),
|
let variant_str = chars.next()
|
||||||
http::Method::Put => quote!(::rocket::http::Method::Put),
|
.map(|c| c.to_ascii_uppercase().to_string() + &chars.as_str().to_lowercase())
|
||||||
http::Method::Post => quote!(::rocket::http::Method::Post),
|
.unwrap_or_default();
|
||||||
http::Method::Delete => quote!(::rocket::http::Method::Delete),
|
|
||||||
http::Method::Options => quote!(::rocket::http::Method::Options),
|
|
||||||
http::Method::Head => quote!(::rocket::http::Method::Head),
|
|
||||||
http::Method::Trace => quote!(::rocket::http::Method::Trace),
|
|
||||||
http::Method::Connect => quote!(::rocket::http::Method::Connect),
|
|
||||||
http::Method::Patch => quote!(::rocket::http::Method::Patch),
|
|
||||||
};
|
|
||||||
|
|
||||||
tokens.extend(method_tokens);
|
let variant = syn::Ident::new(&variant_str, Span::call_site());
|
||||||
|
tokens.extend(quote!(::rocket::http::Method::#variant));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1497,3 +1497,10 @@ pub fn internal_guide_tests(input: TokenStream) -> TokenStream {
|
||||||
pub fn export(input: TokenStream) -> TokenStream {
|
pub fn export(input: TokenStream) -> TokenStream {
|
||||||
emit!(bang::export_internal(input))
|
emit!(bang::export_internal(input))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Private Rocket attribute: `async_bound(Bounds + On + Returned + Future)`.
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
pub fn async_bound(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
|
emit!(attribute::async_bound::async_bound(args, input))
|
||||||
|
}
|
||||||
|
|
|
@ -17,43 +17,22 @@ rust-version = "1.64"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
|
|
||||||
mtls = ["tls", "x509-parser"]
|
|
||||||
http2 = ["hyper/http2"]
|
|
||||||
private-cookies = ["cookie/private", "cookie/key-expansion"]
|
|
||||||
serde = ["uncased/with-serde-alloc", "serde_"]
|
serde = ["uncased/with-serde-alloc", "serde_"]
|
||||||
uuid = ["uuid_"]
|
uuid = ["uuid_"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
smallvec = { version = "1.11", features = ["const_generics", "const_new"] }
|
smallvec = { version = "1.11", features = ["const_generics", "const_new"] }
|
||||||
percent-encoding = "2"
|
percent-encoding = "2"
|
||||||
http = "0.2"
|
|
||||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||||
indexmap = "2"
|
indexmap = "2"
|
||||||
rustls = { version = "0.22", optional = true }
|
|
||||||
tokio-rustls = { version = "0.25", optional = true }
|
|
||||||
rustls-pemfile = { version = "2.0.0", optional = true }
|
|
||||||
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
|
|
||||||
log = "0.4"
|
|
||||||
ref-cast = "1.0"
|
ref-cast = "1.0"
|
||||||
uncased = "0.9.6"
|
uncased = "0.9.10"
|
||||||
either = "1"
|
either = "1"
|
||||||
pear = "0.2.8"
|
pear = "0.2.8"
|
||||||
pin-project-lite = "0.2"
|
|
||||||
memchr = "2"
|
memchr = "2"
|
||||||
stable-pattern = "0.1"
|
stable-pattern = "0.1"
|
||||||
cookie = { version = "0.18", features = ["percent-encode"] }
|
cookie = { version = "0.18", features = ["percent-encode"] }
|
||||||
state = "0.6"
|
state = "0.6"
|
||||||
futures = { version = "0.3", default-features = false }
|
|
||||||
|
|
||||||
[dependencies.x509-parser]
|
|
||||||
version = "0.13"
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[dependencies.hyper]
|
|
||||||
version = "0.14.9"
|
|
||||||
default-features = false
|
|
||||||
features = ["http1", "runtime", "server", "stream"]
|
|
||||||
|
|
||||||
[dependencies.serde_]
|
[dependencies.serde_]
|
||||||
package = "serde"
|
package = "serde"
|
||||||
|
|
|
@ -745,8 +745,7 @@ impl<'h> HeaderMap<'h> {
|
||||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn into_iter_raw(self)
|
pub fn into_iter_raw(self) -> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
|
||||||
-> impl Iterator<Item=(Uncased<'h>, Vec<Cow<'h, str>>)> {
|
|
||||||
self.headers.into_iter()
|
self.headers.into_iter()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
//! Re-exported hyper HTTP library types.
|
|
||||||
//!
|
|
||||||
//! All types that are re-exported from Hyper reside inside of this module.
|
|
||||||
//! These types will, with certainty, be removed with time, but they reside here
|
|
||||||
//! while necessary.
|
|
||||||
|
|
||||||
pub use hyper::{Method, Error, Body, Uri, Version, Request, Response};
|
|
||||||
pub use hyper::{body, server, service, upgrade};
|
|
||||||
pub use http::{HeaderValue, request, uri};
|
|
||||||
|
|
||||||
/// Reexported Hyper HTTP header types.
|
|
||||||
pub mod header {
|
|
||||||
macro_rules! import_http_headers {
|
|
||||||
($($name:ident),*) => ($(
|
|
||||||
pub use hyper::header::$name as $name;
|
|
||||||
)*)
|
|
||||||
}
|
|
||||||
|
|
||||||
import_http_headers! {
|
|
||||||
ACCEPT, ACCEPT_CHARSET, ACCEPT_ENCODING, ACCEPT_LANGUAGE, ACCEPT_RANGES,
|
|
||||||
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS,
|
|
||||||
ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
|
||||||
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE,
|
|
||||||
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ALLOW,
|
|
||||||
AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_DISPOSITION,
|
|
||||||
CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_LOCATION,
|
|
||||||
CONTENT_RANGE, CONTENT_SECURITY_POLICY,
|
|
||||||
CONTENT_SECURITY_POLICY_REPORT_ONLY, CONTENT_TYPE, DATE, ETAG, EXPECT,
|
|
||||||
EXPIRES, FORWARDED, FROM, HOST, IF_MATCH, IF_MODIFIED_SINCE,
|
|
||||||
IF_NONE_MATCH, IF_RANGE, IF_UNMODIFIED_SINCE, LAST_MODIFIED, LINK,
|
|
||||||
LOCATION, ORIGIN, PRAGMA, RANGE, REFERER, REFERRER_POLICY, REFRESH,
|
|
||||||
STRICT_TRANSPORT_SECURITY, TE, TRANSFER_ENCODING, UPGRADE, USER_AGENT,
|
|
||||||
VARY
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -4,15 +4,11 @@
|
||||||
//! Types that map to concepts in HTTP.
|
//! Types that map to concepts in HTTP.
|
||||||
//!
|
//!
|
||||||
//! This module exports types that map to HTTP concepts or to the underlying
|
//! This module exports types that map to HTTP concepts or to the underlying
|
||||||
//! HTTP library when needed. Because the underlying HTTP library is likely to
|
//! HTTP library when needed.
|
||||||
//! change (see [#17]), types in [`hyper`] should be considered unstable.
|
|
||||||
//!
|
|
||||||
//! [#17]: https://github.com/rwf2/Rocket/issues/17
|
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate pear;
|
extern crate pear;
|
||||||
|
|
||||||
pub mod hyper;
|
|
||||||
pub mod uri;
|
pub mod uri;
|
||||||
pub mod ext;
|
pub mod ext;
|
||||||
|
|
||||||
|
@ -22,7 +18,6 @@ mod method;
|
||||||
mod status;
|
mod status;
|
||||||
mod raw_str;
|
mod raw_str;
|
||||||
mod parse;
|
mod parse;
|
||||||
mod listener;
|
|
||||||
|
|
||||||
/// Case-preserving, ASCII case-insensitive string types.
|
/// Case-preserving, ASCII case-insensitive string types.
|
||||||
///
|
///
|
||||||
|
@ -39,14 +34,8 @@ pub mod uncased {
|
||||||
pub mod private {
|
pub mod private {
|
||||||
pub use crate::parse::Indexed;
|
pub use crate::parse::Indexed;
|
||||||
pub use smallvec::{SmallVec, Array};
|
pub use smallvec::{SmallVec, Array};
|
||||||
pub use crate::listener::{TcpListener, Incoming, Listener, Connection, Certificates};
|
|
||||||
pub use cookie;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
pub mod tls;
|
|
||||||
|
|
||||||
pub use crate::method::Method;
|
pub use crate::method::Method;
|
||||||
pub use crate::status::{Status, StatusClass};
|
pub use crate::status::{Status, StatusClass};
|
||||||
pub use crate::raw_str::{RawStr, RawStrBuf};
|
pub use crate::raw_str::{RawStr, RawStrBuf};
|
||||||
|
|
|
@ -1,257 +0,0 @@
|
||||||
use std::fmt;
|
|
||||||
use std::future::Future;
|
|
||||||
use std::io;
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
use std::time::Duration;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use log::warn;
|
|
||||||
use tokio::time::Sleep;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use hyper::server::accept::Accept;
|
|
||||||
use state::InitCell;
|
|
||||||
|
|
||||||
pub use tokio::net::TcpListener;
|
|
||||||
|
|
||||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
|
||||||
#[cfg(not(feature = "tls"))]
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
|
||||||
pub struct CertificateDer(pub(crate) Vec<u8>);
|
|
||||||
|
|
||||||
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);
|
|
||||||
|
|
||||||
/// A collection of raw certificate data.
|
|
||||||
#[derive(Clone, Default)]
|
|
||||||
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);
|
|
||||||
|
|
||||||
impl From<Vec<CertificateDer>> for Certificates {
|
|
||||||
fn from(value: Vec<CertificateDer>) -> Self {
|
|
||||||
Certificates(Arc::new(value.into()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
|
|
||||||
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
|
|
||||||
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
|
|
||||||
Certificates(Arc::new(value.into()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
impl Certificates {
|
|
||||||
/// Set the the raw certificate chain data. Only the first call actually
|
|
||||||
/// sets the data; the remaining do nothing.
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
|
|
||||||
self.0.set(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the raw certificate chain data, if any is available.
|
|
||||||
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
|
|
||||||
self.0.try_get().map(|v| v.as_slice())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
|
|
||||||
// that they could be introduced in upstream libraries.
|
|
||||||
/// A 'Listener' yields incoming connections
|
|
||||||
pub trait Listener {
|
|
||||||
/// The connection type returned by this listener.
|
|
||||||
type Connection: Connection;
|
|
||||||
|
|
||||||
/// Return the actual address this listener bound to.
|
|
||||||
fn local_addr(&self) -> Option<SocketAddr>;
|
|
||||||
|
|
||||||
/// Try to accept an incoming Connection if ready. This should only return
|
|
||||||
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<Self::Connection>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A 'Connection' represents an open connection to a client
|
|
||||||
pub trait Connection: AsyncRead + AsyncWrite {
|
|
||||||
/// The remote address, i.e. the client's socket address, if it is known.
|
|
||||||
fn peer_address(&self) -> Option<SocketAddr>;
|
|
||||||
|
|
||||||
/// Requests that the connection not delay reading or writing data as much
|
|
||||||
/// as possible. For connections backed by TCP, this corresponds to setting
|
|
||||||
/// `TCP_NODELAY`.
|
|
||||||
fn enable_nodelay(&self) -> io::Result<()>;
|
|
||||||
|
|
||||||
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
|
||||||
///
|
|
||||||
/// The certificate order must be as it appears in the TLS protocol: the
|
|
||||||
/// first certificate relates to the peer, the second certifies the first,
|
|
||||||
/// the third certifies the second, and so on.
|
|
||||||
///
|
|
||||||
/// Defaults to an empty vector to indicate that no certificates were
|
|
||||||
/// presented.
|
|
||||||
fn peer_certificates(&self) -> Option<Certificates> { None }
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project_lite::pin_project! {
|
|
||||||
/// This is a generic version of hyper's AddrIncoming that is intended to be
|
|
||||||
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
|
|
||||||
/// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
|
|
||||||
/// Accept). This type is internal to Rocket.
|
|
||||||
#[must_use = "streams do nothing unless polled"]
|
|
||||||
pub struct Incoming<L> {
|
|
||||||
sleep_on_errors: Option<Duration>,
|
|
||||||
nodelay: bool,
|
|
||||||
#[pin]
|
|
||||||
pending_error_delay: Option<Sleep>,
|
|
||||||
#[pin]
|
|
||||||
listener: L,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<L: Listener> Incoming<L> {
|
|
||||||
/// Construct an `Incoming` from an existing `Listener`.
|
|
||||||
pub fn new(listener: L) -> Self {
|
|
||||||
Self {
|
|
||||||
listener,
|
|
||||||
sleep_on_errors: Some(Duration::from_millis(250)),
|
|
||||||
pending_error_delay: None,
|
|
||||||
nodelay: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set whether and how long to sleep on accept errors.
|
|
||||||
///
|
|
||||||
/// A possible scenario is that the process has hit the max open files
|
|
||||||
/// allowed, and so trying to accept a new connection will fail with
|
|
||||||
/// `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
|
||||||
/// the application will likely close some files (or connections), and try
|
|
||||||
/// to accept the connection again. If this option is `true`, the error
|
|
||||||
/// will be logged at the `error` level, since it is still a big deal,
|
|
||||||
/// and then the listener will sleep for 1 second.
|
|
||||||
///
|
|
||||||
/// In other cases, hitting the max open files should be treat similarly
|
|
||||||
/// to being out-of-memory, and simply error (and shutdown). Setting
|
|
||||||
/// this option to `None` will allow that.
|
|
||||||
///
|
|
||||||
/// Default is 1 second.
|
|
||||||
pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
|
|
||||||
self.sleep_on_errors = val;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set whether to request no delay on all incoming connections. The default
|
|
||||||
/// is `false`. See [`Connection::enable_nodelay()`] for details.
|
|
||||||
pub fn nodelay(mut self, nodelay: bool) -> Self {
|
|
||||||
self.nodelay = nodelay;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_accept_next(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<L::Connection>> {
|
|
||||||
/// This function defines per-connection errors: errors that affect only
|
|
||||||
/// a single connection's accept() and don't imply anything about the
|
|
||||||
/// success probability of the next accept(). Thus, we can attempt to
|
|
||||||
/// `accept()` another connection immediately. All other errors will
|
|
||||||
/// incur a delay before the next `accept()` is performed. The delay is
|
|
||||||
/// useful to handle resource exhaustion errors like ENFILE and EMFILE.
|
|
||||||
/// Otherwise, could enter into tight loop.
|
|
||||||
fn is_connection_error(e: &io::Error) -> bool {
|
|
||||||
matches!(e.kind(),
|
|
||||||
| io::ErrorKind::ConnectionRefused
|
|
||||||
| io::ErrorKind::ConnectionAborted
|
|
||||||
| io::ErrorKind::ConnectionReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut this = self.project();
|
|
||||||
loop {
|
|
||||||
// Check if a previous sleep timer is active, set on I/O errors.
|
|
||||||
if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
|
|
||||||
futures::ready!(delay.poll(cx));
|
|
||||||
}
|
|
||||||
|
|
||||||
this.pending_error_delay.set(None);
|
|
||||||
|
|
||||||
match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
|
|
||||||
Ok(stream) => {
|
|
||||||
if *this.nodelay {
|
|
||||||
if let Err(e) = stream.enable_nodelay() {
|
|
||||||
warn!("failed to enable NODELAY: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Poll::Ready(Ok(stream));
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
if is_connection_error(&e) {
|
|
||||||
warn!("single connection accept error {}; accepting next now", e);
|
|
||||||
} else if let Some(duration) = this.sleep_on_errors {
|
|
||||||
// We might be able to recover. Try again in a bit.
|
|
||||||
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
|
|
||||||
this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
|
|
||||||
} else {
|
|
||||||
return Poll::Ready(Err(e));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<L: Listener> Accept for Incoming<L> {
|
|
||||||
type Conn = L::Connection;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<Option<io::Result<Self::Conn>>> {
|
|
||||||
self.poll_accept_next(cx).map(Some)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
f.debug_struct("Incoming")
|
|
||||||
.field("listener", &self.listener)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Listener for TcpListener {
|
|
||||||
type Connection = TcpStream;
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn local_addr(&self) -> Option<SocketAddr> {
|
|
||||||
self.local_addr().ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<Self::Connection>> {
|
|
||||||
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Connection for TcpStream {
|
|
||||||
#[inline]
|
|
||||||
fn peer_address(&self) -> Option<SocketAddr> {
|
|
||||||
self.peer_addr().ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn enable_nodelay(&self) -> io::Result<()> {
|
|
||||||
self.set_nodelay(true)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,8 +3,6 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use self::Method::*;
|
use self::Method::*;
|
||||||
|
|
||||||
use crate::hyper;
|
|
||||||
|
|
||||||
// TODO: Support non-standard methods, here and in codegen?
|
// TODO: Support non-standard methods, here and in codegen?
|
||||||
|
|
||||||
/// Representation of HTTP methods.
|
/// Representation of HTTP methods.
|
||||||
|
@ -29,6 +27,7 @@ use crate::hyper;
|
||||||
/// }
|
/// }
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
|
#[repr(u8)]
|
||||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||||
pub enum Method {
|
pub enum Method {
|
||||||
/// The `GET` variant.
|
/// The `GET` variant.
|
||||||
|
@ -52,23 +51,6 @@ pub enum Method {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Method {
|
impl Method {
|
||||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub fn from_hyp(method: &hyper::Method) -> Option<Method> {
|
|
||||||
match *method {
|
|
||||||
hyper::Method::GET => Some(Get),
|
|
||||||
hyper::Method::PUT => Some(Put),
|
|
||||||
hyper::Method::POST => Some(Post),
|
|
||||||
hyper::Method::DELETE => Some(Delete),
|
|
||||||
hyper::Method::OPTIONS => Some(Options),
|
|
||||||
hyper::Method::HEAD => Some(Head),
|
|
||||||
hyper::Method::TRACE => Some(Trace),
|
|
||||||
hyper::Method::CONNECT => Some(Connect),
|
|
||||||
hyper::Method::PATCH => Some(Patch),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns `true` if an HTTP request with the method represented by `self`
|
/// Returns `true` if an HTTP request with the method represented by `self`
|
||||||
/// always supports a payload.
|
/// always supports a payload.
|
||||||
///
|
///
|
||||||
|
|
|
@ -1,235 +0,0 @@
|
||||||
use std::io;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
use std::future::Future;
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
|
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
|
|
||||||
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
|
||||||
|
|
||||||
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
|
|
||||||
use crate::listener::{Connection, Listener, Certificates, CertificateDer};
|
|
||||||
|
|
||||||
/// A TLS listener over TCP.
|
|
||||||
pub struct TlsListener {
|
|
||||||
listener: TcpListener,
|
|
||||||
acceptor: TlsAcceptor,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This implementation exists so that ROCKET_WORKERS=1 can make progress while
|
|
||||||
/// a TLS handshake is being completed. It does this by returning `Ready` from
|
|
||||||
/// `poll_accept()` as soon as we have a TCP connection and performing the
|
|
||||||
/// handshake in the `AsyncRead` and `AsyncWrite` implementations.
|
|
||||||
///
|
|
||||||
/// A straight-forward implementation of this strategy results in none of the
|
|
||||||
/// TLS information being available at the time the connection is "established",
|
|
||||||
/// that is, when `poll_accept()` returns, since the handshake has yet to occur.
|
|
||||||
/// Importantly, certificate information isn't available at the time that we
|
|
||||||
/// request it.
|
|
||||||
///
|
|
||||||
/// The underlying problem is hyper's "Accept" trait. Were we to manage
|
|
||||||
/// connections ourselves, we'd likely want to:
|
|
||||||
///
|
|
||||||
/// 1. Stop blocking the worker as soon as we have a TCP connection.
|
|
||||||
/// 2. Perform the handshake in the background.
|
|
||||||
/// 3. Give the connection to Rocket when/if the handshake is done.
|
|
||||||
///
|
|
||||||
/// See hyperium/hyper/issues/2321 for more details.
|
|
||||||
///
|
|
||||||
/// To work around this, we "lie" when `peer_certificates()` are requested and
|
|
||||||
/// always return `Some(Certificates)`. Internally, `Certificates` is an
|
|
||||||
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
|
|
||||||
/// `OnceCell`. The cell is initially empty and is filled as soon as the
|
|
||||||
/// handshake is complete. If the certificate data were to be requested prior to
|
|
||||||
/// this point, it would be empty. However, in Rocket, we only request
|
|
||||||
/// certificate data when we have a `Request` object, which implies we're
|
|
||||||
/// receiving payload data, which implies the TLS handshake has finished, so the
|
|
||||||
/// certificate data as seen by a Rocket application will always be "fresh".
|
|
||||||
pub struct TlsStream {
|
|
||||||
remote: SocketAddr,
|
|
||||||
state: TlsState,
|
|
||||||
certs: Certificates,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// State of `TlsStream`.
|
|
||||||
pub enum TlsState {
|
|
||||||
/// The TLS handshake is taking place. We don't have a full connection yet.
|
|
||||||
Handshaking(Accept<TcpStream>),
|
|
||||||
/// TLS handshake completed successfully; we're getting payload data.
|
|
||||||
Streaming(BareTlsStream<TcpStream>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// TLS as ~configured by `TlsConfig` in `rocket` core.
|
|
||||||
pub struct Config<R> {
|
|
||||||
pub cert_chain: R,
|
|
||||||
pub private_key: R,
|
|
||||||
pub ciphersuites: Vec<rustls::SupportedCipherSuite>,
|
|
||||||
pub prefer_server_order: bool,
|
|
||||||
pub ca_certs: Option<R>,
|
|
||||||
pub mandatory_mtls: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TlsListener {
|
|
||||||
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
|
|
||||||
where R: io::BufRead
|
|
||||||
{
|
|
||||||
let provider = rustls::crypto::CryptoProvider {
|
|
||||||
cipher_suites: c.ciphersuites,
|
|
||||||
..rustls::crypto::ring::default_provider()
|
|
||||||
};
|
|
||||||
|
|
||||||
let verifier = match c.ca_certs {
|
|
||||||
Some(ref mut ca_certs) => {
|
|
||||||
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
|
|
||||||
let verifier = WebPkiClientVerifier::builder(ca_roots);
|
|
||||||
match c.mandatory_mtls {
|
|
||||||
true => verifier.build()?,
|
|
||||||
false => verifier.allow_unauthenticated().build()?,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => WebPkiClientVerifier::no_client_auth(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let key = load_key(&mut c.private_key)?;
|
|
||||||
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
|
|
||||||
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
|
|
||||||
.with_safe_default_protocol_versions()?
|
|
||||||
.with_client_cert_verifier(verifier)
|
|
||||||
.with_single_cert(cert_chain, key)?;
|
|
||||||
|
|
||||||
config.ignore_client_order = c.prefer_server_order;
|
|
||||||
config.session_storage = ServerSessionMemoryCache::new(1024);
|
|
||||||
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
|
||||||
config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
|
||||||
if cfg!(feature = "http2") {
|
|
||||||
config.alpn_protocols.insert(0, b"h2".to_vec());
|
|
||||||
}
|
|
||||||
|
|
||||||
let listener = TcpListener::bind(addr).await?;
|
|
||||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
|
||||||
Ok(TlsListener { listener, acceptor })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Listener for TlsListener {
|
|
||||||
type Connection = TlsStream;
|
|
||||||
|
|
||||||
fn local_addr(&self) -> Option<SocketAddr> {
|
|
||||||
self.listener.local_addr().ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_accept(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<Self::Connection>> {
|
|
||||||
match futures::ready!(self.listener.poll_accept(cx)) {
|
|
||||||
Ok((io, addr)) => Poll::Ready(Ok(TlsStream {
|
|
||||||
remote: addr,
|
|
||||||
state: TlsState::Handshaking(self.acceptor.accept(io)),
|
|
||||||
// These are empty and filled in after handshake is complete.
|
|
||||||
certs: Certificates::default(),
|
|
||||||
})),
|
|
||||||
Err(e) => Poll::Ready(Err(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Connection for TlsStream {
|
|
||||||
fn peer_address(&self) -> Option<SocketAddr> {
|
|
||||||
Some(self.remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn enable_nodelay(&self) -> io::Result<()> {
|
|
||||||
// If `Handshaking` is `None`, it either failed, so we returned an `Err`
|
|
||||||
// from `poll_accept()` and there's no connection to enable `NODELAY`
|
|
||||||
// on, or it succeeded, so we're in the `Streaming` stage and we have
|
|
||||||
// infallible access to the connection.
|
|
||||||
match &self.state {
|
|
||||||
TlsState::Handshaking(accept) => match accept.get_ref() {
|
|
||||||
None => Ok(()),
|
|
||||||
Some(s) => s.enable_nodelay(),
|
|
||||||
},
|
|
||||||
TlsState::Streaming(stream) => stream.get_ref().0.enable_nodelay()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn peer_certificates(&self) -> Option<Certificates> {
|
|
||||||
Some(self.certs.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TlsStream {
|
|
||||||
fn poll_accept_then<F, T>(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
mut f: F
|
|
||||||
) -> Poll<io::Result<T>>
|
|
||||||
where F: FnMut(&mut BareTlsStream<TcpStream>, &mut Context<'_>) -> Poll<io::Result<T>>
|
|
||||||
{
|
|
||||||
loop {
|
|
||||||
match self.state {
|
|
||||||
TlsState::Handshaking(ref mut accept) => {
|
|
||||||
match futures::ready!(Pin::new(accept).poll(cx)) {
|
|
||||||
Ok(stream) => {
|
|
||||||
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
|
|
||||||
self.certs.set(peer_certs.into_iter()
|
|
||||||
.map(|v| CertificateDer(v.clone().into_owned()))
|
|
||||||
.collect());
|
|
||||||
}
|
|
||||||
|
|
||||||
self.state = TlsState::Streaming(stream);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!("tls handshake with {} failed: {}", self.remote, e);
|
|
||||||
return Poll::Ready(Err(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
TlsState::Streaming(ref mut stream) => return f(stream, cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for TlsStream {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut tokio::io::ReadBuf<'_>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_read(cx, buf))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncWrite for TlsStream {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
self.poll_accept_then(cx, |stream, cx| Pin::new(stream).poll_write(cx, buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
match &mut self.state {
|
|
||||||
TlsState::Handshaking(accept) => match accept.get_mut() {
|
|
||||||
Some(io) => Pin::new(io).poll_flush(cx),
|
|
||||||
None => Poll::Ready(Ok(())),
|
|
||||||
}
|
|
||||||
TlsState::Streaming(stream) => Pin::new(stream).poll_flush(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
match &mut self.state {
|
|
||||||
TlsState::Handshaking(accept) => match accept.get_mut() {
|
|
||||||
Some(io) => Pin::new(io).poll_shutdown(cx),
|
|
||||||
None => Poll::Ready(Ok(())),
|
|
||||||
}
|
|
||||||
TlsState::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
mod listener;
|
|
||||||
|
|
||||||
#[cfg(feature = "mtls")]
|
|
||||||
pub mod mtls;
|
|
||||||
|
|
||||||
pub use rustls;
|
|
||||||
pub use listener::{TlsListener, Config};
|
|
||||||
pub mod util;
|
|
||||||
pub mod error;
|
|
||||||
|
|
||||||
pub use error::Result;
|
|
|
@ -20,23 +20,36 @@ rust-version = "1.64"
|
||||||
all-features = true
|
all-features = true
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["http2"]
|
default = ["http2", "tokio-macros"]
|
||||||
tls = ["rocket_http/tls"]
|
http2 = ["hyper/http2", "hyper-util/http2"]
|
||||||
mtls = ["rocket_http/mtls", "tls"]
|
secrets = ["cookie/private", "cookie/key-expansion"]
|
||||||
http2 = ["rocket_http/http2"]
|
json = ["serde_json"]
|
||||||
secrets = ["rocket_http/private-cookies"]
|
msgpack = ["rmp-serde"]
|
||||||
json = ["serde_json", "tokio/io-util"]
|
|
||||||
msgpack = ["rmp-serde", "tokio/io-util"]
|
|
||||||
uuid = ["uuid_", "rocket_http/uuid"]
|
uuid = ["uuid_", "rocket_http/uuid"]
|
||||||
|
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
|
||||||
|
mtls = ["tls", "x509-parser"]
|
||||||
|
tokio-macros = ["tokio/macros"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Serialization dependencies.
|
# Optional serialization dependencies.
|
||||||
serde_json = { version = "1.0.26", optional = true }
|
serde_json = { version = "1.0.26", optional = true }
|
||||||
rmp-serde = { version = "1", optional = true }
|
rmp-serde = { version = "1", optional = true }
|
||||||
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }
|
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }
|
||||||
|
|
||||||
|
# Optional TLS dependencies
|
||||||
|
rustls = { version = "0.22", optional = true }
|
||||||
|
tokio-rustls = { version = "0.25", optional = true }
|
||||||
|
rustls-pemfile = { version = "2.0.0", optional = true }
|
||||||
|
|
||||||
|
# Optional MTLS dependencies
|
||||||
|
x509-parser = { version = "0.13", optional = true }
|
||||||
|
|
||||||
|
# Hyper dependencies
|
||||||
|
http = "1"
|
||||||
|
bytes = "1.4"
|
||||||
|
hyper = { version = "1.1", default-features = false, features = ["http1", "server"] }
|
||||||
|
|
||||||
# Non-optional, core dependencies from here on out.
|
# Non-optional, core dependencies from here on out.
|
||||||
futures = { version = "0.3.0", default-features = false, features = ["std"] }
|
|
||||||
yansi = { version = "1.0.0-rc", features = ["detect-tty"] }
|
yansi = { version = "1.0.0-rc", features = ["detect-tty"] }
|
||||||
log = { version = "0.4", features = ["std"] }
|
log = { version = "0.4", features = ["std"] }
|
||||||
num_cpus = "1.0"
|
num_cpus = "1.0"
|
||||||
|
@ -44,11 +57,11 @@ time = { version = "0.3", features = ["macros", "parsing"] }
|
||||||
memchr = "2" # TODO: Use pear instead.
|
memchr = "2" # TODO: Use pear instead.
|
||||||
binascii = "0.1"
|
binascii = "0.1"
|
||||||
ref-cast = "1.0"
|
ref-cast = "1.0"
|
||||||
atomic = "0.5"
|
ref-swap = "0.1.2"
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
ubyte = {version = "0.10.2", features = ["serde"] }
|
ubyte = {version = "0.10.2", features = ["serde"] }
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
figment = { version = "0.10.6", features = ["toml", "env"] }
|
figment = { version = "0.10.13", features = ["toml", "env"] }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
either = "1"
|
either = "1"
|
||||||
pin-project-lite = "0.2"
|
pin-project-lite = "0.2"
|
||||||
|
@ -58,8 +71,25 @@ async-trait = "0.1.43"
|
||||||
async-stream = "0.3.2"
|
async-stream = "0.3.2"
|
||||||
multer = { version = "3.0.0", features = ["tokio-io"] }
|
multer = { version = "3.0.0", features = ["tokio-io"] }
|
||||||
tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
|
tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
|
||||||
|
cookie = { version = "0.18", features = ["percent-encode"] }
|
||||||
|
futures = { version = "0.3.30", default-features = false, features = ["std"] }
|
||||||
state = "0.6"
|
state = "0.6"
|
||||||
|
|
||||||
|
[dependencies.hyper-util]
|
||||||
|
git = "https://github.com/SergioBenitez/hyper-util.git"
|
||||||
|
branch = "fix-readversion"
|
||||||
|
default-features = false
|
||||||
|
features = ["http1", "server", "tokio"]
|
||||||
|
|
||||||
|
[dependencies.tokio]
|
||||||
|
version = "1.35.1"
|
||||||
|
features = ["rt-multi-thread", "net", "io-util", "fs", "time", "sync", "signal", "parking_lot"]
|
||||||
|
|
||||||
|
[dependencies.tokio-util]
|
||||||
|
version = "0.7"
|
||||||
|
default-features = false
|
||||||
|
features = ["io"]
|
||||||
|
|
||||||
[dependencies.rocket_codegen]
|
[dependencies.rocket_codegen]
|
||||||
version = "0.6.0-dev"
|
version = "0.6.0-dev"
|
||||||
path = "../codegen"
|
path = "../codegen"
|
||||||
|
@ -69,21 +99,13 @@ version = "0.6.0-dev"
|
||||||
path = "../http"
|
path = "../http"
|
||||||
features = ["serde"]
|
features = ["serde"]
|
||||||
|
|
||||||
[dependencies.tokio]
|
[target.'cfg(unix)'.dependencies]
|
||||||
version = "1.6.1"
|
libc = "0.2.149"
|
||||||
features = ["fs", "io-std", "io-util", "rt-multi-thread", "sync", "signal", "macros"]
|
|
||||||
|
|
||||||
[dependencies.tokio-util]
|
|
||||||
version = "0.7"
|
|
||||||
default-features = false
|
|
||||||
features = ["io"]
|
|
||||||
|
|
||||||
[dependencies.bytes]
|
|
||||||
version = "1.0"
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
version_check = "0.9.1"
|
version_check = "0.9.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
tokio = { version = "1", features = ["macros", "io-std"] }
|
||||||
figment = { version = "0.10", features = ["test"] }
|
figment = { version = "0.10", features = ["test"] }
|
||||||
pretty_assertions = "1"
|
pretty_assertions = "1"
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
use std::net::{IpAddr, Ipv4Addr};
|
|
||||||
|
|
||||||
use figment::{Figment, Profile, Provider, Metadata, error::Result};
|
use figment::{Figment, Profile, Provider, Metadata, error::Result};
|
||||||
use figment::providers::{Serialized, Env, Toml, Format};
|
use figment::providers::{Serialized, Env, Toml, Format};
|
||||||
use figment::value::{Map, Dict, magic::RelativePathBuf};
|
use figment::value::{Map, Dict, magic::RelativePathBuf};
|
||||||
|
@ -12,9 +10,6 @@ use crate::request::{self, Request, FromRequest};
|
||||||
use crate::http::uncased::Uncased;
|
use crate::http::uncased::Uncased;
|
||||||
use crate::data::Limits;
|
use crate::data::Limits;
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
use crate::config::TlsConfig;
|
|
||||||
|
|
||||||
#[cfg(feature = "secrets")]
|
#[cfg(feature = "secrets")]
|
||||||
use crate::config::SecretKey;
|
use crate::config::SecretKey;
|
||||||
|
|
||||||
|
@ -66,10 +61,6 @@ pub struct Config {
|
||||||
/// set to the extracting Figment's selected `Profile`._
|
/// set to the extracting Figment's selected `Profile`._
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub profile: Profile,
|
pub profile: Profile,
|
||||||
/// IP address to serve on. **(default: `127.0.0.1`)**
|
|
||||||
pub address: IpAddr,
|
|
||||||
/// Port to serve on. **(default: `8000`)**
|
|
||||||
pub port: u16,
|
|
||||||
/// Number of threads to use for executing futures. **(default: `num_cores`)**
|
/// Number of threads to use for executing futures. **(default: `num_cores`)**
|
||||||
///
|
///
|
||||||
/// _**Note:** Rocket only reads this value from sources in the [default
|
/// _**Note:** Rocket only reads this value from sources in the [default
|
||||||
|
@ -121,10 +112,6 @@ pub struct Config {
|
||||||
pub temp_dir: RelativePathBuf,
|
pub temp_dir: RelativePathBuf,
|
||||||
/// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)**
|
/// Keep-alive timeout in seconds; disabled when `0`. **(default: `5`)**
|
||||||
pub keep_alive: u32,
|
pub keep_alive: u32,
|
||||||
/// The TLS configuration, if any. **(default: `None`)**
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
|
||||||
pub tls: Option<TlsConfig>,
|
|
||||||
/// The secret key for signing and encrypting. **(default: `0`)**
|
/// The secret key for signing and encrypting. **(default: `0`)**
|
||||||
///
|
///
|
||||||
/// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to
|
/// _**Note:** This field _always_ serializes as a 256-bit array of `0`s to
|
||||||
|
@ -148,7 +135,6 @@ pub struct Config {
|
||||||
/// use rocket::Config;
|
/// use rocket::Config;
|
||||||
///
|
///
|
||||||
/// let config = Config {
|
/// let config = Config {
|
||||||
/// port: 1024,
|
|
||||||
/// keep_alive: 10,
|
/// keep_alive: 10,
|
||||||
/// ..Default::default()
|
/// ..Default::default()
|
||||||
/// };
|
/// };
|
||||||
|
@ -204,8 +190,6 @@ impl Config {
|
||||||
pub fn debug_default() -> Config {
|
pub fn debug_default() -> Config {
|
||||||
Config {
|
Config {
|
||||||
profile: Self::DEBUG_PROFILE,
|
profile: Self::DEBUG_PROFILE,
|
||||||
address: Ipv4Addr::new(127, 0, 0, 1).into(),
|
|
||||||
port: 8000,
|
|
||||||
workers: num_cpus::get(),
|
workers: num_cpus::get(),
|
||||||
max_blocking: 512,
|
max_blocking: 512,
|
||||||
ident: Ident::default(),
|
ident: Ident::default(),
|
||||||
|
@ -214,8 +198,6 @@ impl Config {
|
||||||
limits: Limits::default(),
|
limits: Limits::default(),
|
||||||
temp_dir: std::env::temp_dir().into(),
|
temp_dir: std::env::temp_dir().into(),
|
||||||
keep_alive: 5,
|
keep_alive: 5,
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
tls: None,
|
|
||||||
#[cfg(feature = "secrets")]
|
#[cfg(feature = "secrets")]
|
||||||
secret_key: SecretKey::zero(),
|
secret_key: SecretKey::zero(),
|
||||||
shutdown: Shutdown::default(),
|
shutdown: Shutdown::default(),
|
||||||
|
@ -331,59 +313,6 @@ impl Config {
|
||||||
Self::try_from(provider).unwrap_or_else(bail_with_config_error)
|
Self::try_from(provider).unwrap_or_else(bail_with_config_error)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns `true` if TLS is enabled.
|
|
||||||
///
|
|
||||||
/// TLS is enabled when the `tls` feature is enabled and TLS has been
|
|
||||||
/// configured with at least one ciphersuite. Note that without changing
|
|
||||||
/// defaults, all supported ciphersuites are enabled in the recommended
|
|
||||||
/// configuration.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// let config = rocket::Config::default();
|
|
||||||
/// if config.tls_enabled() {
|
|
||||||
/// println!("TLS is enabled!");
|
|
||||||
/// } else {
|
|
||||||
/// println!("TLS is disabled.");
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn tls_enabled(&self) -> bool {
|
|
||||||
#[cfg(feature = "tls")] {
|
|
||||||
self.tls.as_ref().map_or(false, |tls| !tls.ciphers.is_empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "tls"))] { false }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns `true` if mTLS is enabled.
|
|
||||||
///
|
|
||||||
/// mTLS is enabled when TLS is enabled ([`Config::tls_enabled()`]) _and_
|
|
||||||
/// the `mtls` feature is enabled _and_ mTLS has been configured with a CA
|
|
||||||
/// certificate chain.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// let config = rocket::Config::default();
|
|
||||||
/// if config.mtls_enabled() {
|
|
||||||
/// println!("mTLS is enabled!");
|
|
||||||
/// } else {
|
|
||||||
/// println!("mTLS is disabled.");
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn mtls_enabled(&self) -> bool {
|
|
||||||
if !self.tls_enabled() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "mtls")] {
|
|
||||||
self.tls.as_ref().map_or(false, |tls| tls.mutual.is_some())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "mtls"))] { false }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "secrets")]
|
#[cfg(feature = "secrets")]
|
||||||
pub(crate) fn known_secret_key_used(&self) -> bool {
|
pub(crate) fn known_secret_key_used(&self) -> bool {
|
||||||
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
|
const KNOWN_SECRET_KEYS: &'static [&'static str] = &[
|
||||||
|
@ -420,8 +349,6 @@ impl Config {
|
||||||
|
|
||||||
self.trace_print(figment);
|
self.trace_print(figment);
|
||||||
launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline());
|
launch_meta!("{}Configured for {}.", "🔧 ".emoji(), self.profile.underline());
|
||||||
launch_meta_!("address: {}", self.address.paint(VAL));
|
|
||||||
launch_meta_!("port: {}", self.port.paint(VAL));
|
|
||||||
launch_meta_!("workers: {}", self.workers.paint(VAL));
|
launch_meta_!("workers: {}", self.workers.paint(VAL));
|
||||||
launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL));
|
launch_meta_!("max blocking threads: {}", self.max_blocking.paint(VAL));
|
||||||
launch_meta_!("ident: {}", self.ident.paint(VAL));
|
launch_meta_!("ident: {}", self.ident.paint(VAL));
|
||||||
|
@ -445,12 +372,6 @@ impl Config {
|
||||||
ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)),
|
ka => launch_meta_!("keep-alive: {}{}", ka.paint(VAL), "s".paint(VAL)),
|
||||||
}
|
}
|
||||||
|
|
||||||
match (self.tls_enabled(), self.mtls_enabled()) {
|
|
||||||
(true, true) => launch_meta_!("tls: {}", "enabled w/mtls".paint(VAL)),
|
|
||||||
(true, false) => launch_meta_!("tls: {} w/o mtls", "enabled".paint(VAL)),
|
|
||||||
(false, _) => launch_meta_!("tls: {}", "disabled".paint(VAL)),
|
|
||||||
}
|
|
||||||
|
|
||||||
launch_meta_!("shutdown: {}", self.shutdown.paint(VAL));
|
launch_meta_!("shutdown: {}", self.shutdown.paint(VAL));
|
||||||
launch_meta_!("log level: {}", self.log_level.paint(VAL));
|
launch_meta_!("log level: {}", self.log_level.paint(VAL));
|
||||||
launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL));
|
launch_meta_!("cli colors: {}", self.cli_colors.paint(VAL));
|
||||||
|
@ -519,12 +440,6 @@ impl Config {
|
||||||
/// This isn't `pub` because setting it directly does nothing.
|
/// This isn't `pub` because setting it directly does nothing.
|
||||||
const PROFILE: &'static str = "profile";
|
const PROFILE: &'static str = "profile";
|
||||||
|
|
||||||
/// The stringy parameter name for setting/extracting [`Config::address`].
|
|
||||||
pub const ADDRESS: &'static str = "address";
|
|
||||||
|
|
||||||
/// The stringy parameter name for setting/extracting [`Config::port`].
|
|
||||||
pub const PORT: &'static str = "port";
|
|
||||||
|
|
||||||
/// The stringy parameter name for setting/extracting [`Config::workers`].
|
/// The stringy parameter name for setting/extracting [`Config::workers`].
|
||||||
pub const WORKERS: &'static str = "workers";
|
pub const WORKERS: &'static str = "workers";
|
||||||
|
|
||||||
|
@ -546,9 +461,6 @@ impl Config {
|
||||||
/// The stringy parameter name for setting/extracting [`Config::limits`].
|
/// The stringy parameter name for setting/extracting [`Config::limits`].
|
||||||
pub const LIMITS: &'static str = "limits";
|
pub const LIMITS: &'static str = "limits";
|
||||||
|
|
||||||
/// The stringy parameter name for setting/extracting [`Config::tls`].
|
|
||||||
pub const TLS: &'static str = "tls";
|
|
||||||
|
|
||||||
/// The stringy parameter name for setting/extracting [`Config::secret_key`].
|
/// The stringy parameter name for setting/extracting [`Config::secret_key`].
|
||||||
pub const SECRET_KEY: &'static str = "secret_key";
|
pub const SECRET_KEY: &'static str = "secret_key";
|
||||||
|
|
||||||
|
@ -566,9 +478,10 @@ impl Config {
|
||||||
|
|
||||||
/// An array of all of the stringy parameter names.
|
/// An array of all of the stringy parameter names.
|
||||||
pub const PARAMETERS: &'static [&'static str] = &[
|
pub const PARAMETERS: &'static [&'static str] = &[
|
||||||
Self::ADDRESS, Self::PORT, Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE,
|
Self::WORKERS, Self::MAX_BLOCKING, Self::KEEP_ALIVE, Self::IDENT,
|
||||||
Self::IDENT, Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS, Self::TLS,
|
Self::IP_HEADER, Self::PROXY_PROTO_HEADER, Self::LIMITS,
|
||||||
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN, Self::CLI_COLORS,
|
Self::SECRET_KEY, Self::TEMP_DIR, Self::LOG_LEVEL, Self::SHUTDOWN,
|
||||||
|
Self::CLI_COLORS,
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,9 +117,6 @@ mod shutdown;
|
||||||
mod cli_colors;
|
mod cli_colors;
|
||||||
mod http_header;
|
mod http_header;
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
mod tls;
|
|
||||||
|
|
||||||
#[cfg(feature = "secrets")]
|
#[cfg(feature = "secrets")]
|
||||||
mod secret_key;
|
mod secret_key;
|
||||||
|
|
||||||
|
@ -132,12 +129,6 @@ pub use shutdown::Shutdown;
|
||||||
pub use ident::Ident;
|
pub use ident::Ident;
|
||||||
pub use cli_colors::CliColors;
|
pub use cli_colors::CliColors;
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
pub use tls::{TlsConfig, CipherSuite};
|
|
||||||
|
|
||||||
#[cfg(feature = "mtls")]
|
|
||||||
pub use tls::MutualTls;
|
|
||||||
|
|
||||||
#[cfg(feature = "secrets")]
|
#[cfg(feature = "secrets")]
|
||||||
pub use secret_key::SecretKey;
|
pub use secret_key::SecretKey;
|
||||||
|
|
||||||
|
@ -146,7 +137,6 @@ pub use shutdown::Sig;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::net::Ipv4Addr;
|
|
||||||
use figment::{Figment, Profile};
|
use figment::{Figment, Profile};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
@ -202,9 +192,7 @@ mod tests {
|
||||||
figment::Jail::expect_with(|jail| {
|
figment::Jail::expect_with(|jail| {
|
||||||
jail.create_file("Rocket.toml", r#"
|
jail.create_file("Rocket.toml", r#"
|
||||||
[default]
|
[default]
|
||||||
address = "1.2.3.4"
|
|
||||||
ident = "Something Cool"
|
ident = "Something Cool"
|
||||||
port = 1234
|
|
||||||
workers = 20
|
workers = 20
|
||||||
keep_alive = 10
|
keep_alive = 10
|
||||||
log_level = "off"
|
log_level = "off"
|
||||||
|
@ -213,8 +201,6 @@ mod tests {
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
let config = Config::from(Config::figment());
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
|
||||||
port: 1234,
|
|
||||||
workers: 20,
|
workers: 20,
|
||||||
ident: ident!("Something Cool"),
|
ident: ident!("Something Cool"),
|
||||||
keep_alive: 10,
|
keep_alive: 10,
|
||||||
|
@ -225,9 +211,7 @@ mod tests {
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
jail.create_file("Rocket.toml", r#"
|
||||||
[global]
|
[global]
|
||||||
address = "1.2.3.4"
|
|
||||||
ident = "Something Else Cool"
|
ident = "Something Else Cool"
|
||||||
port = 1234
|
|
||||||
workers = 20
|
workers = 20
|
||||||
keep_alive = 10
|
keep_alive = 10
|
||||||
log_level = "off"
|
log_level = "off"
|
||||||
|
@ -236,8 +220,6 @@ mod tests {
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
let config = Config::from(Config::figment());
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
|
||||||
port: 1234,
|
|
||||||
workers: 20,
|
workers: 20,
|
||||||
ident: ident!("Something Else Cool"),
|
ident: ident!("Something Else Cool"),
|
||||||
keep_alive: 10,
|
keep_alive: 10,
|
||||||
|
@ -249,8 +231,6 @@ mod tests {
|
||||||
jail.set_env("ROCKET_CONFIG", "Other.toml");
|
jail.set_env("ROCKET_CONFIG", "Other.toml");
|
||||||
jail.create_file("Other.toml", r#"
|
jail.create_file("Other.toml", r#"
|
||||||
[default]
|
[default]
|
||||||
address = "1.2.3.4"
|
|
||||||
port = 1234
|
|
||||||
workers = 20
|
workers = 20
|
||||||
keep_alive = 10
|
keep_alive = 10
|
||||||
log_level = "off"
|
log_level = "off"
|
||||||
|
@ -259,8 +239,6 @@ mod tests {
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
let config = Config::from(Config::figment());
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
address: Ipv4Addr::new(1, 2, 3, 4).into(),
|
|
||||||
port: 1234,
|
|
||||||
workers: 20,
|
workers: 20,
|
||||||
keep_alive: 10,
|
keep_alive: 10,
|
||||||
log_level: LogLevel::Off,
|
log_level: LogLevel::Off,
|
||||||
|
@ -367,228 +345,6 @@ mod tests {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
fn test_tls_config_from_file() {
|
|
||||||
use crate::config::{TlsConfig, CipherSuite, Ident, Shutdown};
|
|
||||||
|
|
||||||
figment::Jail::expect_with(|jail| {
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global]
|
|
||||||
shutdown.ctrlc = 0
|
|
||||||
ident = false
|
|
||||||
|
|
||||||
[global.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
|
|
||||||
[global.limits]
|
|
||||||
forms = "1mib"
|
|
||||||
json = "10mib"
|
|
||||||
stream = "50kib"
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
shutdown: Shutdown { ctrlc: false, ..Default::default() },
|
|
||||||
ident: Ident::none(),
|
|
||||||
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
|
|
||||||
limits: Limits::default()
|
|
||||||
.limit("forms", 1.mebibytes())
|
|
||||||
.limit("json", 10.mebibytes())
|
|
||||||
.limit("stream", 50.kibibytes()),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global.tls]
|
|
||||||
certs = "cert.pem"
|
|
||||||
key = "key.pem"
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
tls: Some(TlsConfig::from_paths(
|
|
||||||
jail.directory().join("cert.pem"),
|
|
||||||
jail.directory().join("key.pem")
|
|
||||||
)),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global.tls]
|
|
||||||
certs = "cert.pem"
|
|
||||||
key = "key.pem"
|
|
||||||
prefer_server_cipher_order = true
|
|
||||||
ciphers = [
|
|
||||||
"TLS_CHACHA20_POLY1305_SHA256",
|
|
||||||
"TLS_AES_256_GCM_SHA384",
|
|
||||||
"TLS_AES_128_GCM_SHA256",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
|
||||||
]
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
let cert_path = jail.directory().join("cert.pem");
|
|
||||||
let key_path = jail.directory().join("key.pem");
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)
|
|
||||||
.with_preferred_server_cipher_order(true)
|
|
||||||
.with_ciphers([
|
|
||||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
|
|
||||||
CipherSuite::TLS_AES_256_GCM_SHA384,
|
|
||||||
CipherSuite::TLS_AES_128_GCM_SHA256,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
])),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global]
|
|
||||||
shutdown.ctrlc = 0
|
|
||||||
ident = false
|
|
||||||
|
|
||||||
[global.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
|
|
||||||
[global.limits]
|
|
||||||
forms = "1mib"
|
|
||||||
json = "10mib"
|
|
||||||
stream = "50kib"
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
shutdown: Shutdown { ctrlc: false, ..Default::default() },
|
|
||||||
ident: Ident::none(),
|
|
||||||
tls: Some(TlsConfig::from_paths("/ssl/cert.pem", "/ssl/key.pem")),
|
|
||||||
limits: Limits::default()
|
|
||||||
.limit("forms", 1.mebibytes())
|
|
||||||
.limit("json", 10.mebibytes())
|
|
||||||
.limit("stream", 50.kibibytes()),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global.tls]
|
|
||||||
certs = "cert.pem"
|
|
||||||
key = "key.pem"
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
tls: Some(TlsConfig::from_paths(
|
|
||||||
jail.directory().join("cert.pem"),
|
|
||||||
jail.directory().join("key.pem")
|
|
||||||
)),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[global.tls]
|
|
||||||
certs = "cert.pem"
|
|
||||||
key = "key.pem"
|
|
||||||
prefer_server_cipher_order = true
|
|
||||||
ciphers = [
|
|
||||||
"TLS_CHACHA20_POLY1305_SHA256",
|
|
||||||
"TLS_AES_256_GCM_SHA384",
|
|
||||||
"TLS_AES_128_GCM_SHA256",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
|
||||||
]
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
let cert_path = jail.directory().join("cert.pem");
|
|
||||||
let key_path = jail.directory().join("key.pem");
|
|
||||||
assert_eq!(config, Config {
|
|
||||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)
|
|
||||||
.with_preferred_server_cipher_order(true)
|
|
||||||
.with_ciphers([
|
|
||||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
|
|
||||||
CipherSuite::TLS_AES_256_GCM_SHA384,
|
|
||||||
CipherSuite::TLS_AES_128_GCM_SHA256,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
])),
|
|
||||||
..Config::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[cfg(feature = "mtls")]
|
|
||||||
fn test_mtls_config() {
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
figment::Jail::expect_with(|jail| {
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[default.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert!(config.tls.is_some());
|
|
||||||
assert!(config.tls.as_ref().unwrap().mutual.is_none());
|
|
||||||
assert!(config.tls_enabled());
|
|
||||||
assert!(!config.mtls_enabled());
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[default.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
mutual = { ca_certs = "/ssl/ca.pem" }
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
assert!(config.tls_enabled());
|
|
||||||
assert!(config.mtls_enabled());
|
|
||||||
|
|
||||||
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
|
|
||||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
|
||||||
assert!(!mtls.mandatory);
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[default.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
|
|
||||||
[default.tls.mutual]
|
|
||||||
ca_certs = "/ssl/ca.pem"
|
|
||||||
mandatory = true
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
let mtls = config.tls.as_ref().unwrap().mutual.as_ref().unwrap();
|
|
||||||
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
|
||||||
assert!(mtls.mandatory);
|
|
||||||
|
|
||||||
jail.create_file("Rocket.toml", r#"
|
|
||||||
[default.tls]
|
|
||||||
certs = "/ssl/cert.pem"
|
|
||||||
key = "/ssl/key.pem"
|
|
||||||
mutual = { ca_certs = "relative/ca.pem" }
|
|
||||||
"#)?;
|
|
||||||
|
|
||||||
let config = Config::from(Config::figment());
|
|
||||||
let mtls = config.tls.as_ref().unwrap().mutual().unwrap();
|
|
||||||
assert_eq!(mtls.ca_certs().unwrap_left(),
|
|
||||||
jail.directory().join("relative/ca.pem"));
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_profiles_merge() {
|
fn test_profiles_merge() {
|
||||||
figment::Jail::expect_with(|jail| {
|
figment::Jail::expect_with(|jail| {
|
||||||
|
@ -629,42 +385,41 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
fn test_env_vars_merge() {
|
fn test_env_vars_merge() {
|
||||||
use crate::config::{TlsConfig, Ident};
|
use crate::config::{Ident, Shutdown};
|
||||||
|
|
||||||
figment::Jail::expect_with(|jail| {
|
figment::Jail::expect_with(|jail| {
|
||||||
jail.set_env("ROCKET_PORT", 9999);
|
jail.set_env("ROCKET_KEEP_ALIVE", 9999);
|
||||||
let config = Config::from(Config::figment());
|
let config = Config::from(Config::figment());
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
port: 9999,
|
keep_alive: 9999,
|
||||||
..Config::default()
|
..Config::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
jail.set_env("ROCKET_TLS", r#"{certs="certs.pem"}"#);
|
jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#);
|
||||||
let first_figment = Config::figment();
|
let first_figment = Config::figment();
|
||||||
jail.set_env("ROCKET_TLS", r#"{key="key.pem"}"#);
|
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#);
|
||||||
let prev_figment = Config::figment().join(&first_figment);
|
let prev_figment = Config::figment().join(&first_figment);
|
||||||
let config = Config::from(&prev_figment);
|
let config = Config::from(&prev_figment);
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
port: 9999,
|
keep_alive: 9999,
|
||||||
tls: Some(TlsConfig::from_paths("certs.pem", "key.pem")),
|
shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() },
|
||||||
..Config::default()
|
..Config::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
jail.set_env("ROCKET_TLS", r#"{certs="new.pem"}"#);
|
jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#);
|
||||||
let config = Config::from(Config::figment().join(&prev_figment));
|
let config = Config::from(Config::figment().join(&prev_figment));
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
port: 9999,
|
keep_alive: 9999,
|
||||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||||
..Config::default()
|
..Config::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
|
jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#);
|
||||||
let config = Config::from(Config::figment().join(&prev_figment));
|
let config = Config::from(Config::figment().join(&prev_figment));
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
port: 9999,
|
keep_alive: 9999,
|
||||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||||
limits: Limits::default().limit("stream", 100.kibibytes()),
|
limits: Limits::default().limit("stream", 100.kibibytes()),
|
||||||
..Config::default()
|
..Config::default()
|
||||||
});
|
});
|
||||||
|
@ -672,8 +427,8 @@ mod tests {
|
||||||
jail.set_env("ROCKET_IDENT", false);
|
jail.set_env("ROCKET_IDENT", false);
|
||||||
let config = Config::from(Config::figment().join(&prev_figment));
|
let config = Config::from(Config::figment().join(&prev_figment));
|
||||||
assert_eq!(config, Config {
|
assert_eq!(config, Config {
|
||||||
port: 9999,
|
keep_alive: 9999,
|
||||||
tls: Some(TlsConfig::from_paths("new.pem", "key.pem")),
|
shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() },
|
||||||
limits: Limits::default().limit("stream", 100.kibibytes()),
|
limits: Limits::default().limit("stream", 100.kibibytes()),
|
||||||
ident: Ident::none(),
|
ident: Ident::none(),
|
||||||
..Config::default()
|
..Config::default()
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
use cookie::Key;
|
||||||
use serde::{de, ser, Deserialize, Serialize};
|
use serde::{de, ser, Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::http::private::cookie::Key;
|
|
||||||
use crate::request::{Outcome, Request, FromRequest};
|
use crate::request::{Outcome, Request, FromRequest};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::fmt;
|
use std::{fmt, time::Duration};
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
@ -291,6 +291,14 @@ impl Default for Shutdown {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Shutdown {
|
impl Shutdown {
|
||||||
|
pub(crate) fn grace(&self) -> Duration {
|
||||||
|
Duration::from_secs(self.grace as u64)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn mercy(&self) -> Duration {
|
||||||
|
Duration::from_secs(self.mercy as u64)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub(crate) fn signal_stream(&self) -> Option<impl Stream<Item = Sig>> {
|
pub(crate) fn signal_stream(&self) -> Option<impl Stream<Item = Sig>> {
|
||||||
use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream};
|
use tokio_stream::{StreamExt, StreamMap, wrappers::SignalStream};
|
||||||
|
|
|
@ -3,16 +3,16 @@ use std::task::{Context, Poll};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::io::{self, Cursor};
|
use std::io::{self, Cursor};
|
||||||
|
|
||||||
|
use futures::ready;
|
||||||
|
use futures::stream::Stream;
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take};
|
||||||
use tokio_util::io::StreamReader;
|
use tokio_util::io::StreamReader;
|
||||||
use futures::{ready, stream::Stream};
|
use hyper::body::{Body, Bytes, Incoming as HyperBody};
|
||||||
|
|
||||||
use crate::http::hyper;
|
|
||||||
use crate::ext::{PollExt, Chain};
|
|
||||||
use crate::data::{Capped, N};
|
use crate::data::{Capped, N};
|
||||||
use crate::http::hyper::body::Bytes;
|
|
||||||
use crate::data::transform::Transform;
|
use crate::data::transform::Transform;
|
||||||
|
use crate::util::Chain;
|
||||||
|
|
||||||
use super::peekable::Peekable;
|
use super::peekable::Peekable;
|
||||||
use super::transform::TransformBuf;
|
use super::transform::TransformBuf;
|
||||||
|
@ -68,7 +68,7 @@ pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
|
||||||
/// Raw underlying data stream.
|
/// Raw underlying data stream.
|
||||||
pub enum RawStream<'r> {
|
pub enum RawStream<'r> {
|
||||||
Empty,
|
Empty,
|
||||||
Body(&'r mut hyper::Body),
|
Body(&'r mut HyperBody),
|
||||||
Multipart(multer::Field<'r>),
|
Multipart(multer::Field<'r>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,8 +154,14 @@ impl<'r> DataStream<'r> {
|
||||||
/// ```
|
/// ```
|
||||||
pub fn hint(&self) -> usize {
|
pub fn hint(&self) -> usize {
|
||||||
let base = self.base();
|
let base = self.base();
|
||||||
let buf_len = base.get_ref().get_ref().0.get_ref().len();
|
if let (Some(cursor), _) = base.get_ref().get_ref() {
|
||||||
std::cmp::min(buf_len, base.limit() as usize)
|
let len = cursor.get_ref().len() as u64;
|
||||||
|
let position = cursor.position().min(len);
|
||||||
|
let remaining = len - position;
|
||||||
|
remaining.min(base.limit()) as usize
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A helper method to write the body of the request to any `AsyncWrite`
|
/// A helper method to write the body of the request to any `AsyncWrite`
|
||||||
|
@ -331,17 +337,25 @@ impl Stream for RawStream<'_> {
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
RawStream::Body(body) => Pin::new(body).poll_next(cx)
|
// TODO: Expose trailer headers, somehow.
|
||||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
RawStream::Body(body) => {
|
||||||
RawStream::Multipart(mp) => Pin::new(mp).poll_next(cx)
|
Pin::new(body)
|
||||||
.map_err_ext(|e| io::Error::new(io::ErrorKind::Other, e)),
|
.poll_frame(cx)
|
||||||
|
.map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
|
||||||
|
.map_err(io::Error::other)
|
||||||
|
}
|
||||||
|
RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
|
||||||
RawStream::Empty => Poll::Ready(None),
|
RawStream::Empty => Poll::Ready(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
match self {
|
match self {
|
||||||
RawStream::Body(body) => body.size_hint(),
|
RawStream::Body(body) => {
|
||||||
|
let hint = body.size_hint();
|
||||||
|
let (lower, upper) = (hint.lower(), hint.upper());
|
||||||
|
(lower as usize, upper.map(|x| x as usize))
|
||||||
|
},
|
||||||
RawStream::Multipart(mp) => mp.size_hint(),
|
RawStream::Multipart(mp) => mp.size_hint(),
|
||||||
RawStream::Empty => (0, Some(0)),
|
RawStream::Empty => (0, Some(0)),
|
||||||
}
|
}
|
||||||
|
@ -358,8 +372,8 @@ impl std::fmt::Display for RawStream<'_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r> From<&'r mut hyper::Body> for RawStream<'r> {
|
impl<'r> From<&'r mut HyperBody> for RawStream<'r> {
|
||||||
fn from(value: &'r mut hyper::Body) -> Self {
|
fn from(value: &'r mut HyperBody) -> Self {
|
||||||
Self::Body(value)
|
Self::Body(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@ use std::task::{Context, Poll};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use hyper::upgrade::Upgraded;
|
||||||
use crate::http::hyper::upgrade::Upgraded;
|
use hyper_util::rt::TokioIo;
|
||||||
|
|
||||||
/// A bidirectional, raw stream to the client.
|
/// A bidirectional, raw stream to the client.
|
||||||
///
|
///
|
||||||
|
@ -28,7 +28,7 @@ pub struct IoStream {
|
||||||
|
|
||||||
/// Just in case we want to add stream kinds in the future.
|
/// Just in case we want to add stream kinds in the future.
|
||||||
enum IoStreamKind {
|
enum IoStreamKind {
|
||||||
Upgraded(Upgraded)
|
Upgraded(TokioIo<Upgraded>)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An upgraded connection I/O handler.
|
/// An upgraded connection I/O handler.
|
||||||
|
@ -51,7 +51,7 @@ enum IoStreamKind {
|
||||||
///
|
///
|
||||||
/// #[rocket::async_trait]
|
/// #[rocket::async_trait]
|
||||||
/// impl IoHandler for EchoHandler {
|
/// impl IoHandler for EchoHandler {
|
||||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
/// let (mut reader, mut writer) = io::split(io);
|
/// let (mut reader, mut writer) = io::split(io);
|
||||||
/// io::copy(&mut reader, &mut writer).await?;
|
/// io::copy(&mut reader, &mut writer).await?;
|
||||||
/// Ok(())
|
/// Ok(())
|
||||||
|
@ -68,13 +68,20 @@ enum IoStreamKind {
|
||||||
#[crate::async_trait]
|
#[crate::async_trait]
|
||||||
pub trait IoHandler: Send {
|
pub trait IoHandler: Send {
|
||||||
/// Performs the raw I/O.
|
/// Performs the raw I/O.
|
||||||
async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()>;
|
async fn io(self: Box<Self>, io: IoStream) -> io::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[crate::async_trait]
|
||||||
|
impl IoHandler for () {
|
||||||
|
async fn io(self: Box<Self>, _: IoStream) -> io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
impl From<Upgraded> for IoStream {
|
impl From<Upgraded> for IoStream {
|
||||||
fn from(io: Upgraded) -> Self {
|
fn from(io: Upgraded) -> Self {
|
||||||
IoStream { kind: IoStreamKind::Upgraded(io) }
|
IoStream { kind: IoStreamKind::Upgraded(TokioIo::new(io)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::hash::SipHasher;
|
use std::hash::SipHasher;
|
||||||
use std::sync::{Arc, atomic::{AtomicU64, AtomicU8}};
|
use std::sync::{Arc, atomic::{AtomicU8, AtomicU64, Ordering}};
|
||||||
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use ubyte::ToByteUnit;
|
use ubyte::ToByteUnit;
|
||||||
|
@ -264,41 +264,41 @@ mod tests {
|
||||||
assert_eq!(bytes.len(), 8);
|
assert_eq!(bytes.len(), 8);
|
||||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||||
let value = u64::from_be_bytes(bytes);
|
let value = u64::from_be_bytes(bytes);
|
||||||
hash1.store(value, atomic::Ordering::Release);
|
hash1.store(value, Ordering::Release);
|
||||||
})
|
})
|
||||||
.chain_inspect(move |bytes| {
|
.chain_inspect(move |bytes| {
|
||||||
assert_eq!(bytes.len(), 8);
|
assert_eq!(bytes.len(), 8);
|
||||||
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
|
||||||
let value = u64::from_be_bytes(bytes);
|
let value = u64::from_be_bytes(bytes);
|
||||||
let prev = hash2.load(atomic::Ordering::Acquire);
|
let prev = hash2.load(Ordering::Acquire);
|
||||||
assert_eq!(prev, value);
|
assert_eq!(prev, value);
|
||||||
inspect2.fetch_add(1, atomic::Ordering::Release);
|
inspect2.fetch_add(1, Ordering::Release);
|
||||||
});
|
});
|
||||||
})));
|
})));
|
||||||
|
|
||||||
// Make sure nothing has happened yet.
|
// Make sure nothing has happened yet.
|
||||||
assert!(raw_data.lock().is_empty());
|
assert!(raw_data.lock().is_empty());
|
||||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
assert_eq!(hash.load(Ordering::Acquire), 0);
|
||||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
assert_eq!(inspect2.load(Ordering::Acquire), 0);
|
||||||
|
|
||||||
// Check that nothing happens if the data isn't read.
|
// Check that nothing happens if the data isn't read.
|
||||||
let client = Client::debug(rocket).unwrap();
|
let client = Client::debug(rocket).unwrap();
|
||||||
client.get("/").body("Hello, world!").dispatch();
|
client.get("/").body("Hello, world!").dispatch();
|
||||||
assert!(raw_data.lock().is_empty());
|
assert!(raw_data.lock().is_empty());
|
||||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0);
|
assert_eq!(hash.load(Ordering::Acquire), 0);
|
||||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 0);
|
assert_eq!(inspect2.load(Ordering::Acquire), 0);
|
||||||
|
|
||||||
// Check inspect + hash + inspect + inspect.
|
// Check inspect + hash + inspect + inspect.
|
||||||
client.post("/").body("Hello, world!").dispatch();
|
client.post("/").body("Hello, world!").dispatch();
|
||||||
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
|
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
|
||||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0xae5020d7cf49d14f);
|
assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
|
||||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 1);
|
assert_eq!(inspect2.load(Ordering::Acquire), 1);
|
||||||
|
|
||||||
// Check inspect + hash + inspect + inspect, round 2.
|
// Check inspect + hash + inspect + inspect, round 2.
|
||||||
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
|
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
|
||||||
client.post("/").body(string).dispatch();
|
client.post("/").body(string).dispatch();
|
||||||
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
|
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
|
||||||
assert_eq!(hash.load(atomic::Ordering::Acquire), 0x323f9aa98f907faf);
|
assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
|
||||||
assert_eq!(inspect2.load(atomic::Ordering::Acquire), 2);
|
assert_eq!(inspect2.load(Ordering::Acquire), 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,193 @@
|
||||||
|
use std::io;
|
||||||
|
use std::mem::transmute;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Poll, Context};
|
||||||
|
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use http::request::Parts;
|
||||||
|
use hyper::body::Incoming;
|
||||||
|
use tokio::io::{AsyncRead, ReadBuf};
|
||||||
|
|
||||||
|
use crate::data::{Data, IoHandler};
|
||||||
|
use crate::{Request, Response, Rocket, Orbit};
|
||||||
|
|
||||||
|
// TODO: Magic with trait async fn to get rid of the box pin.
|
||||||
|
// TODO: Write safety proofs.
|
||||||
|
|
||||||
|
macro_rules! static_assert_covariance {
|
||||||
|
($T:tt) => (
|
||||||
|
const _: () = {
|
||||||
|
fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x }
|
||||||
|
};
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ErasedRequest {
|
||||||
|
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||||
|
request: Request<'static>,
|
||||||
|
_rocket: Arc<Rocket<Orbit>>,
|
||||||
|
_parts: Box<Parts>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ErasedRequest {
|
||||||
|
fn drop(&mut self) { }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ErasedResponse {
|
||||||
|
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||||
|
response: Response<'static>,
|
||||||
|
_request: Arc<ErasedRequest>,
|
||||||
|
_incoming: Box<Incoming>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ErasedResponse {
|
||||||
|
fn drop(&mut self) { }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ErasedIoHandler {
|
||||||
|
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||||
|
io: Box<dyn IoHandler + 'static>,
|
||||||
|
_request: Arc<ErasedRequest>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ErasedIoHandler {
|
||||||
|
fn drop(&mut self) { }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErasedRequest {
|
||||||
|
pub fn new(
|
||||||
|
rocket: Arc<Rocket<Orbit>>,
|
||||||
|
parts: Parts,
|
||||||
|
constructor: impl for<'r> FnOnce(
|
||||||
|
&'r Rocket<Orbit>,
|
||||||
|
&'r Parts
|
||||||
|
) -> Request<'r>,
|
||||||
|
) -> ErasedRequest {
|
||||||
|
let rocket: Arc<Rocket<Orbit>> = rocket;
|
||||||
|
let parts: Box<Parts> = Box::new(parts);
|
||||||
|
let request: Request<'_> = {
|
||||||
|
let rocket: &Rocket<Orbit> = &*rocket;
|
||||||
|
let rocket: &'static Rocket<Orbit> = unsafe { transmute(rocket) };
|
||||||
|
let parts: &Parts = &*parts;
|
||||||
|
let parts: &'static Parts = unsafe { transmute(parts) };
|
||||||
|
constructor(&rocket, &parts)
|
||||||
|
};
|
||||||
|
|
||||||
|
ErasedRequest { _rocket: rocket, _parts: parts, request, }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn into_response<T: Send + Sync + 'static>(
|
||||||
|
self,
|
||||||
|
incoming: Incoming,
|
||||||
|
data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>,
|
||||||
|
preprocess: impl for<'r, 'x> FnOnce(
|
||||||
|
&'r Rocket<Orbit>,
|
||||||
|
&'r mut Request<'x>,
|
||||||
|
&'r mut Data<'x>
|
||||||
|
) -> BoxFuture<'r, T>,
|
||||||
|
dispatch: impl for<'r> FnOnce(
|
||||||
|
T,
|
||||||
|
&'r Rocket<Orbit>,
|
||||||
|
&'r Request<'r>,
|
||||||
|
Data<'r>
|
||||||
|
) -> BoxFuture<'r, Response<'r>>,
|
||||||
|
) -> ErasedResponse {
|
||||||
|
let mut incoming = Box::new(incoming);
|
||||||
|
let mut data: Data<'_> = {
|
||||||
|
let incoming: &mut Incoming = &mut *incoming;
|
||||||
|
let incoming: &'static mut Incoming = unsafe { transmute(incoming) };
|
||||||
|
data_builder(incoming)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut parent = Arc::new(self);
|
||||||
|
let token: T = {
|
||||||
|
let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap();
|
||||||
|
let rocket: &Rocket<Orbit> = &*parent._rocket;
|
||||||
|
let request: &mut Request<'_> = &mut parent.request;
|
||||||
|
let data: &mut Data<'_> = &mut data;
|
||||||
|
preprocess(rocket, request, data).await
|
||||||
|
};
|
||||||
|
|
||||||
|
let parent = parent;
|
||||||
|
let response: Response<'_> = {
|
||||||
|
let parent: &ErasedRequest = &*parent;
|
||||||
|
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
|
||||||
|
let rocket: &Rocket<Orbit> = &*parent._rocket;
|
||||||
|
let request: &Request<'_> = &parent.request;
|
||||||
|
dispatch(token, rocket, request, data).await
|
||||||
|
};
|
||||||
|
|
||||||
|
ErasedResponse {
|
||||||
|
_request: parent,
|
||||||
|
_incoming: incoming,
|
||||||
|
response: response,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErasedResponse {
|
||||||
|
pub fn inner<'a>(&'a self) -> &'a Response<'a> {
|
||||||
|
static_assert_covariance!(Response);
|
||||||
|
&self.response
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_inner_mut<'a, T>(
|
||||||
|
&'a mut self,
|
||||||
|
f: impl for<'r> FnOnce(&'a mut Response<'r>) -> T
|
||||||
|
) -> T {
|
||||||
|
static_assert_covariance!(Response);
|
||||||
|
f(&mut self.response)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_io_handler<'a>(
|
||||||
|
&'a mut self,
|
||||||
|
constructor: impl for<'r> FnOnce(
|
||||||
|
&'r Request<'r>,
|
||||||
|
&'a mut Response<'r>,
|
||||||
|
) -> Option<Box<dyn IoHandler + 'r>>
|
||||||
|
) -> Option<ErasedIoHandler> {
|
||||||
|
let parent: Arc<ErasedRequest> = self._request.clone();
|
||||||
|
let io: Option<Box<dyn IoHandler + '_>> = {
|
||||||
|
let parent: &ErasedRequest = &*parent;
|
||||||
|
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
|
||||||
|
let request: &Request<'_> = &parent.request;
|
||||||
|
constructor(request, &mut self.response)
|
||||||
|
};
|
||||||
|
|
||||||
|
io.map(|io| ErasedIoHandler { _request: parent, io })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErasedIoHandler {
|
||||||
|
pub fn with_inner_mut<'a, T: 'a>(
|
||||||
|
&'a mut self,
|
||||||
|
f: impl for<'r> FnOnce(&'a mut Box<dyn IoHandler + 'r>) -> T
|
||||||
|
) -> T {
|
||||||
|
fn _assert_covariance<'x: 'y, 'y>(
|
||||||
|
x: &'y Box<dyn IoHandler + 'x>
|
||||||
|
) -> &'y Box<dyn IoHandler + 'y> { x }
|
||||||
|
|
||||||
|
f(&mut self.io)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take<'a>(&'a mut self) -> Box<dyn IoHandler + 'a> {
|
||||||
|
fn _assert_covariance<'x: 'y, 'y>(
|
||||||
|
x: &'y Box<dyn IoHandler + 'x>
|
||||||
|
) -> &'y Box<dyn IoHandler + 'y> { x }
|
||||||
|
|
||||||
|
self.with_inner_mut(|handler| std::mem::replace(handler, Box::new(())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for ErasedResponse {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
self.get_mut().with_inner_mut(|r| Pin::new(r.body_mut()).poll_read(cx, buf))
|
||||||
|
}
|
||||||
|
}
|
|
@ -74,11 +74,8 @@ pub struct Error {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum ErrorKind {
|
pub enum ErrorKind {
|
||||||
/// Binding to the provided address/port failed.
|
/// Binding to the network interface failed.
|
||||||
Bind(io::Error),
|
Bind(Box<dyn StdError + Send>),
|
||||||
/// Binding via TLS to the provided address/port failed.
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
TlsBind(crate::http::tls::error::Error),
|
|
||||||
/// An I/O error occurred during launch.
|
/// An I/O error occurred during launch.
|
||||||
Io(io::Error),
|
Io(io::Error),
|
||||||
/// A valid [`Config`](crate::Config) could not be extracted from the
|
/// A valid [`Config`](crate::Config) could not be extracted from the
|
||||||
|
@ -90,15 +87,10 @@ pub enum ErrorKind {
|
||||||
FailedFairings(Vec<crate::fairing::Info>),
|
FailedFairings(Vec<crate::fairing::Info>),
|
||||||
/// Sentinels requested abort.
|
/// Sentinels requested abort.
|
||||||
SentinelAborts(Vec<crate::sentinel::Sentry>),
|
SentinelAborts(Vec<crate::sentinel::Sentry>),
|
||||||
/// The configuration profile is not debug but not secret key is configured.
|
/// The configuration profile is not debug but no secret key is configured.
|
||||||
InsecureSecretKey(Profile),
|
InsecureSecretKey(Profile),
|
||||||
/// Shutdown failed.
|
/// Shutdown failed. Contains the Rocket instance that failed to shutdown.
|
||||||
Shutdown(
|
Shutdown(Arc<Rocket<Orbit>>),
|
||||||
/// The instance of Rocket that failed to shutdown.
|
|
||||||
Arc<Rocket<Orbit>>,
|
|
||||||
/// The error that occurred during shutdown, if any.
|
|
||||||
Option<Box<dyn StdError + Send + Sync>>
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error that occurs when a value was unexpectedly empty.
|
/// An error that occurs when a value was unexpectedly empty.
|
||||||
|
@ -111,20 +103,24 @@ impl From<ErrorKind> for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<figment::Error> for Error {
|
||||||
|
fn from(e: figment::Error) -> Self {
|
||||||
|
Error::new(ErrorKind::Config(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for Error {
|
||||||
|
fn from(e: io::Error) -> Self {
|
||||||
|
Error::new(ErrorKind::Io(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn new(kind: ErrorKind) -> Error {
|
pub(crate) fn new(kind: ErrorKind) -> Error {
|
||||||
Error { handled: AtomicBool::new(false), kind }
|
Error { handled: AtomicBool::new(false), kind }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
pub(crate) fn shutdown<E>(rocket: Arc<Rocket<Orbit>>, error: E) -> Error
|
|
||||||
where E: Into<Option<crate::http::hyper::Error>>
|
|
||||||
{
|
|
||||||
let error = error.into().map(|e| Box::new(e) as Box<dyn StdError + Sync + Send>);
|
|
||||||
Error::new(ErrorKind::Shutdown(rocket, error))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn was_handled(&self) -> bool {
|
fn was_handled(&self) -> bool {
|
||||||
self.handled.load(Ordering::Acquire)
|
self.handled.load(Ordering::Acquire)
|
||||||
|
@ -176,9 +172,9 @@ impl Error {
|
||||||
self.mark_handled();
|
self.mark_handled();
|
||||||
match self.kind() {
|
match self.kind() {
|
||||||
ErrorKind::Bind(ref e) => {
|
ErrorKind::Bind(ref e) => {
|
||||||
error!("Rocket failed to bind network socket to given address/port.");
|
error!("Binding to the network interface failed.");
|
||||||
info_!("{}", e);
|
info_!("{}", e);
|
||||||
"aborting due to socket bind error"
|
"aborting due to bind error"
|
||||||
}
|
}
|
||||||
ErrorKind::Io(ref e) => {
|
ErrorKind::Io(ref e) => {
|
||||||
error!("Rocket failed to launch due to an I/O error.");
|
error!("Rocket failed to launch due to an I/O error.");
|
||||||
|
@ -229,20 +225,10 @@ impl Error {
|
||||||
|
|
||||||
"aborting due to sentinel-triggered abort(s)"
|
"aborting due to sentinel-triggered abort(s)"
|
||||||
}
|
}
|
||||||
ErrorKind::Shutdown(_, error) => {
|
ErrorKind::Shutdown(_) => {
|
||||||
error!("Rocket failed to shutdown gracefully.");
|
error!("Rocket failed to shutdown gracefully.");
|
||||||
if let Some(e) = error {
|
|
||||||
info_!("{}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
"aborting due to failed shutdown"
|
"aborting due to failed shutdown"
|
||||||
}
|
}
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
ErrorKind::TlsBind(e) => {
|
|
||||||
error!("Rocket failed to bind via TLS to network socket.");
|
|
||||||
info_!("{}", e);
|
|
||||||
"aborting due to TLS bind error"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -260,10 +246,7 @@ impl fmt::Display for ErrorKind {
|
||||||
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
|
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
|
||||||
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
|
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
|
||||||
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
|
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
|
||||||
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"),
|
ErrorKind::Shutdown(_) => "shutdown failed".fmt(f),
|
||||||
ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f),
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -308,3 +291,42 @@ impl fmt::Display for Empty {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StdError for Empty { }
|
impl StdError for Empty { }
|
||||||
|
|
||||||
|
/// Log an error that occurs during request processing
|
||||||
|
pub(crate) fn log_server_error(error: &Box<dyn StdError + Send + Sync>) {
|
||||||
|
struct ServerError<'a>(&'a (dyn StdError + 'static));
|
||||||
|
|
||||||
|
impl fmt::Display for ServerError<'_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let error = &self.0;
|
||||||
|
if let Some(e) = error.downcast_ref::<hyper::Error>() {
|
||||||
|
write!(f, "request processing failed: {e}")?;
|
||||||
|
} else if let Some(e) = error.downcast_ref::<io::Error>() {
|
||||||
|
write!(f, "connection I/O error: ")?;
|
||||||
|
|
||||||
|
match e.kind() {
|
||||||
|
io::ErrorKind::NotConnected => write!(f, "remote disconnected")?,
|
||||||
|
io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?,
|
||||||
|
io::ErrorKind::ConnectionReset
|
||||||
|
| io::ErrorKind::ConnectionAborted
|
||||||
|
| io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?,
|
||||||
|
_ => write!(f, "{e}")?,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
write!(f, "http server error: {error}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(e) = error.source() {
|
||||||
|
write!(f, " ({})", ServerError(e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if error.downcast_ref::<hyper::Error>().is_some() {
|
||||||
|
warn!("{}", ServerError(&**error))
|
||||||
|
} else {
|
||||||
|
error!("{}", ServerError(&**error))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,404 +0,0 @@
|
||||||
use std::{io, time::Duration};
|
|
||||||
use std::task::{Poll, Context};
|
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut};
|
|
||||||
use pin_project_lite::pin_project;
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
||||||
use tokio::time::{sleep, Sleep};
|
|
||||||
|
|
||||||
use futures::stream::Stream;
|
|
||||||
use futures::future::{self, Future, FutureExt};
|
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
pub struct ReaderStream<R> {
|
|
||||||
#[pin]
|
|
||||||
reader: Option<R>,
|
|
||||||
buf: BytesMut,
|
|
||||||
cap: usize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: AsyncRead> Stream for ReaderStream<R> {
|
|
||||||
type Item = std::io::Result<Bytes>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
use tokio_util::io::poll_read_buf;
|
|
||||||
|
|
||||||
let mut this = self.as_mut().project();
|
|
||||||
|
|
||||||
let reader = match this.reader.as_pin_mut() {
|
|
||||||
Some(r) => r,
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
};
|
|
||||||
|
|
||||||
if this.buf.capacity() == 0 {
|
|
||||||
this.buf.reserve(*this.cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
match poll_read_buf(reader, cx, &mut this.buf) {
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
Poll::Ready(Err(err)) => {
|
|
||||||
self.project().reader.set(None);
|
|
||||||
Poll::Ready(Some(Err(err)))
|
|
||||||
}
|
|
||||||
Poll::Ready(Ok(0)) => {
|
|
||||||
self.project().reader.set(None);
|
|
||||||
Poll::Ready(None)
|
|
||||||
}
|
|
||||||
Poll::Ready(Ok(_)) => {
|
|
||||||
let chunk = this.buf.split();
|
|
||||||
Poll::Ready(Some(Ok(chunk.freeze())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait AsyncReadExt: AsyncRead + Sized {
|
|
||||||
fn into_bytes_stream(self, cap: usize) -> ReaderStream<Self> {
|
|
||||||
ReaderStream { reader: Some(self), cap, buf: BytesMut::with_capacity(cap) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead> AsyncReadExt for T { }
|
|
||||||
|
|
||||||
pub trait PollExt<T, E> {
|
|
||||||
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
|
|
||||||
where F: FnOnce(E) -> U;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, E> PollExt<T, E> for Poll<Option<Result<T, E>>> {
|
|
||||||
/// Changes the error value of this `Poll` with the closure provided.
|
|
||||||
fn map_err_ext<U, F>(self, f: F) -> Poll<Option<Result<T, U>>>
|
|
||||||
where F: FnOnce(E) -> U
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))),
|
|
||||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))),
|
|
||||||
Poll::Ready(None) => Poll::Ready(None),
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
|
|
||||||
#[must_use = "streams do nothing unless polled"]
|
|
||||||
pub struct Chain<T, U> {
|
|
||||||
#[pin]
|
|
||||||
first: T,
|
|
||||||
#[pin]
|
|
||||||
second: U,
|
|
||||||
done_first: bool,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
|
||||||
pub(crate) fn new(first: T, second: U) -> Self {
|
|
||||||
Self { first, second, done_first: false }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
|
||||||
/// Gets references to the underlying readers in this `Chain`.
|
|
||||||
pub fn get_ref(&self) -> (&T, &U) {
|
|
||||||
(&self.first, &self.second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut ReadBuf<'_>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
let me = self.project();
|
|
||||||
|
|
||||||
if !*me.done_first {
|
|
||||||
let init_rem = buf.remaining();
|
|
||||||
futures::ready!(me.first.poll_read(cx, buf))?;
|
|
||||||
if buf.remaining() == init_rem {
|
|
||||||
*me.done_first = true;
|
|
||||||
} else {
|
|
||||||
return Poll::Ready(Ok(()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
me.second.poll_read(cx, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
enum State {
|
|
||||||
/// I/O has not been cancelled. Proceed as normal.
|
|
||||||
Active,
|
|
||||||
/// I/O has been cancelled. See if we can finish before the timer expires.
|
|
||||||
Grace(Pin<Box<Sleep>>),
|
|
||||||
/// Grace period elapsed. Shutdown the connection, waiting for the timer
|
|
||||||
/// until we force close.
|
|
||||||
Mercy(Pin<Box<Sleep>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
/// I/O that can be cancelled when a future `F` resolves.
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
pub struct CancellableIo<F, I> {
|
|
||||||
#[pin]
|
|
||||||
io: Option<I>,
|
|
||||||
#[pin]
|
|
||||||
trigger: future::Fuse<F>,
|
|
||||||
state: State,
|
|
||||||
grace: Duration,
|
|
||||||
mercy: Duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
|
|
||||||
pub fn new(trigger: F, io: I, grace: Duration, mercy: Duration) -> Self {
|
|
||||||
CancellableIo {
|
|
||||||
grace, mercy,
|
|
||||||
io: Some(io),
|
|
||||||
trigger: trigger.fuse(),
|
|
||||||
state: State::Active,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn io(&self) -> Option<&I> {
|
|
||||||
self.io.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run `do_io` while connection processing should continue.
|
|
||||||
fn poll_trigger_then<T>(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
|
|
||||||
) -> Poll<io::Result<T>> {
|
|
||||||
let mut me = self.as_mut().project();
|
|
||||||
let io = match me.io.as_pin_mut() {
|
|
||||||
Some(io) => io,
|
|
||||||
None => return Poll::Ready(Err(gone())),
|
|
||||||
};
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match me.state {
|
|
||||||
State::Active => {
|
|
||||||
if me.trigger.as_mut().poll(cx).is_ready() {
|
|
||||||
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
|
|
||||||
} else {
|
|
||||||
return do_io(io, cx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
State::Grace(timer) => {
|
|
||||||
if timer.as_mut().poll(cx).is_ready() {
|
|
||||||
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
|
|
||||||
} else {
|
|
||||||
return do_io(io, cx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
State::Mercy(timer) => {
|
|
||||||
if timer.as_mut().poll(cx).is_ready() {
|
|
||||||
self.project().io.set(None);
|
|
||||||
return Poll::Ready(Err(time_out()));
|
|
||||||
} else {
|
|
||||||
let result = futures::ready!(io.poll_shutdown(cx));
|
|
||||||
self.project().io.set(None);
|
|
||||||
return match result {
|
|
||||||
Err(e) => Poll::Ready(Err(e)),
|
|
||||||
Ok(()) => Poll::Ready(Err(gone()))
|
|
||||||
};
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn time_out() -> io::Error {
|
|
||||||
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn gone() -> io::Error {
|
|
||||||
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
|
|
||||||
fn poll_read(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut ReadBuf<'_>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
|
||||||
fn poll_write(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_write_vectored(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
bufs: &[io::IoSlice<'_>],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_write_vectored(&self) -> bool {
|
|
||||||
self.io().map(|io| io.is_write_vectored()).unwrap_or(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use crate::http::private::{Listener, Connection, Certificates};
|
|
||||||
|
|
||||||
impl<F: Future, C: Connection> Connection for CancellableIo<F, C> {
|
|
||||||
fn peer_address(&self) -> Option<std::net::SocketAddr> {
|
|
||||||
self.io().and_then(|io| io.peer_address())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn peer_certificates(&self) -> Option<Certificates> {
|
|
||||||
self.io().and_then(|io| io.peer_certificates())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn enable_nodelay(&self) -> io::Result<()> {
|
|
||||||
match self.io() {
|
|
||||||
Some(io) => io.enable_nodelay(),
|
|
||||||
None => Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
pub struct CancellableListener<F, L> {
|
|
||||||
pub trigger: F,
|
|
||||||
#[pin]
|
|
||||||
pub listener: L,
|
|
||||||
pub grace: Duration,
|
|
||||||
pub mercy: Duration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, L> CancellableListener<F, L> {
|
|
||||||
pub fn new(trigger: F, listener: L, grace: u64, mercy: u64) -> Self {
|
|
||||||
let (grace, mercy) = (Duration::from_secs(grace), Duration::from_secs(mercy));
|
|
||||||
CancellableListener { trigger, listener, grace, mercy }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<L: Listener, F: Future + Clone> Listener for CancellableListener<F, L> {
|
|
||||||
type Connection = CancellableIo<F, L::Connection>;
|
|
||||||
|
|
||||||
fn local_addr(&self) -> Option<std::net::SocketAddr> {
|
|
||||||
self.listener.local_addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_accept(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>
|
|
||||||
) -> Poll<io::Result<Self::Connection>> {
|
|
||||||
self.as_mut().project().listener
|
|
||||||
.poll_accept(cx)
|
|
||||||
.map(|res| res.map(|conn| {
|
|
||||||
CancellableIo::new(self.trigger.clone(), conn, self.grace, self.mercy)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait StreamExt: Sized + Stream {
|
|
||||||
fn join<U>(self, other: U) -> Join<Self, U>
|
|
||||||
where U: Stream<Item = Self::Item>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: Stream> StreamExt for S {
|
|
||||||
fn join<U>(self, other: U) -> Join<Self, U>
|
|
||||||
where U: Stream<Item = Self::Item>
|
|
||||||
{
|
|
||||||
Join::new(self, other)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
/// Stream returned by the [`join`](super::StreamExt::join) method.
|
|
||||||
pub struct Join<T, U> {
|
|
||||||
#[pin]
|
|
||||||
a: T,
|
|
||||||
#[pin]
|
|
||||||
b: U,
|
|
||||||
// When `true`, poll `a` first, otherwise, `poll` b`.
|
|
||||||
toggle: bool,
|
|
||||||
// Set when either `a` or `b` return `None`.
|
|
||||||
done: bool,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, U> Join<T, U> {
|
|
||||||
pub(super) fn new(a: T, b: U) -> Join<T, U>
|
|
||||||
where T: Stream, U: Stream,
|
|
||||||
{
|
|
||||||
Join { a, b, toggle: false, done: false, }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_next<A: Stream, B: Stream<Item = A::Item>>(
|
|
||||||
first: Pin<&mut A>,
|
|
||||||
second: Pin<&mut B>,
|
|
||||||
done: &mut bool,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Option<A::Item>> {
|
|
||||||
match first.poll_next(cx) {
|
|
||||||
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
|
|
||||||
Poll::Pending => match second.poll_next(cx) {
|
|
||||||
Poll::Ready(opt) => { *done = opt.is_none(); Poll::Ready(opt) }
|
|
||||||
Poll::Pending => Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, U> Stream for Join<T, U>
|
|
||||||
where T: Stream,
|
|
||||||
U: Stream<Item = T::Item>,
|
|
||||||
{
|
|
||||||
type Item = T::Item;
|
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
|
|
||||||
if self.done {
|
|
||||||
return Poll::Ready(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let me = self.project();
|
|
||||||
*me.toggle = !*me.toggle;
|
|
||||||
match *me.toggle {
|
|
||||||
true => Self::poll_next(me.a, me.b, me.done, cx),
|
|
||||||
false => Self::poll_next(me.b, me.a, me.done, cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
||||||
let (left_low, left_high) = self.a.size_hint();
|
|
||||||
let (right_low, right_high) = self.b.size_hint();
|
|
||||||
|
|
||||||
let low = left_low.saturating_add(right_low);
|
|
||||||
let high = match (left_high, right_high) {
|
|
||||||
(Some(h1), Some(h2)) => h1.checked_add(h2),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
(low, high)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -341,6 +341,7 @@
|
||||||
// `key_contexts: Vec<K::Context>`, a vector of `value_contexts:
|
// `key_contexts: Vec<K::Context>`, a vector of `value_contexts:
|
||||||
// Vec<V::Context>`, a `mapping` from a string index to an integer index
|
// Vec<V::Context>`, a `mapping` from a string index to an integer index
|
||||||
// into the `contexts`, and a vector of `errors`.
|
// into the `contexts`, and a vector of `errors`.
|
||||||
|
//
|
||||||
// 2. **Push.** An index is required; an error is emitted and `push` returns
|
// 2. **Push.** An index is required; an error is emitted and `push` returns
|
||||||
// if they field's first key does not contain an index. If the first key
|
// if they field's first key does not contain an index. If the first key
|
||||||
// contains _one_ index, a new `K::Context` and `V::Context` are created.
|
// contains _one_ index, a new `K::Context` and `V::Context` are created.
|
||||||
|
@ -356,9 +357,9 @@
|
||||||
// to `second` in `mapping`. If the first index is `k`, the field,
|
// to `second` in `mapping`. If the first index is `k`, the field,
|
||||||
// stripped of the first key, is pushed to the key's context; the same is
|
// stripped of the first key, is pushed to the key's context; the same is
|
||||||
// done for the value's context is the first index is `v`.
|
// done for the value's context is the first index is `v`.
|
||||||
|
//
|
||||||
// 3. **Finalization.** Every context is finalized; errors and `Ok` values
|
// 3. **Finalization.** Every context is finalized; errors and `Ok` values
|
||||||
// are collected. TODO: FINISH. Split this into two: one for single-index,
|
// are collected.
|
||||||
// another for two-indices.
|
|
||||||
|
|
||||||
mod field;
|
mod field;
|
||||||
mod options;
|
mod options;
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::io;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::ops::{Deref, DerefMut};
|
||||||
|
|
||||||
use tokio::fs::File;
|
use tokio::fs::{File, OpenOptions};
|
||||||
|
|
||||||
use crate::request::Request;
|
use crate::request::Request;
|
||||||
use crate::response::{self, Responder};
|
use crate::response::{self, Responder};
|
||||||
|
@ -60,7 +60,7 @@ impl NamedFile {
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub async fn open<P: AsRef<Path>>(path: P) -> io::Result<NamedFile> {
|
pub async fn open<P: AsRef<Path>>(path: P) -> io::Result<NamedFile> {
|
||||||
// FIXME: Grab the file size here and prohibit `seek`ing later (or else
|
// TODO: Grab the file size here and prohibit `seek`ing later (or else
|
||||||
// the file's effective size may change), to save on the cost of doing
|
// the file's effective size may change), to save on the cost of doing
|
||||||
// all of those `seek`s to determine the file size. But, what happens if
|
// all of those `seek`s to determine the file size. But, what happens if
|
||||||
// the file gets changed between now and then?
|
// the file gets changed between now and then?
|
||||||
|
@ -68,6 +68,11 @@ impl NamedFile {
|
||||||
Ok(NamedFile(path.as_ref().to_path_buf(), file))
|
Ok(NamedFile(path.as_ref().to_path_buf(), file))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn open_with<P: AsRef<Path>>(path: P, opts: &OpenOptions) -> io::Result<NamedFile> {
|
||||||
|
let file = opts.open(path.as_ref()).await?;
|
||||||
|
Ok(NamedFile(path.as_ref().to_path_buf(), file))
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve the underlying `File`.
|
/// Retrieve the underlying `File`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
|
|
@ -2,11 +2,10 @@ use std::fmt;
|
||||||
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
use crate::http::private::cookie;
|
|
||||||
use crate::{Rocket, Orbit};
|
use crate::{Rocket, Orbit};
|
||||||
|
|
||||||
#[doc(inline)]
|
#[doc(inline)]
|
||||||
pub use self::cookie::{Cookie, SameSite, Iter};
|
pub use cookie::{Cookie, SameSite, Iter};
|
||||||
|
|
||||||
/// Collection of one or more HTTP cookies.
|
/// Collection of one or more HTTP cookies.
|
||||||
///
|
///
|
||||||
|
@ -167,7 +166,7 @@ pub(crate) struct CookieState<'a> {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
enum Op {
|
enum Op {
|
||||||
Add(Cookie<'static>, bool),
|
Add(Cookie<'static>, bool),
|
||||||
Remove(Cookie<'static>, bool),
|
Remove(Cookie<'static>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> CookieJar<'a> {
|
impl<'a> CookieJar<'a> {
|
||||||
|
@ -177,7 +176,7 @@ impl<'a> CookieJar<'a> {
|
||||||
ops: Mutex::new(Vec::new()),
|
ops: Mutex::new(Vec::new()),
|
||||||
state: CookieState {
|
state: CookieState {
|
||||||
// This is updated dynamically when headers are received.
|
// This is updated dynamically when headers are received.
|
||||||
secure: rocket.config().tls_enabled(),
|
secure: rocket.endpoint().is_tls(),
|
||||||
config: rocket.config(),
|
config: rocket.config(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -256,7 +255,7 @@ impl<'a> CookieJar<'a> {
|
||||||
for op in ops.iter().rev().filter(|op| op.cookie().name() == name) {
|
for op in ops.iter().rev().filter(|op| op.cookie().name() == name) {
|
||||||
match op {
|
match op {
|
||||||
Op::Add(c, _) => return Some(c.clone()),
|
Op::Add(c, _) => return Some(c.clone()),
|
||||||
Op::Remove(_, _) => return None,
|
Op::Remove(_) => return None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,7 +388,7 @@ impl<'a> CookieJar<'a> {
|
||||||
pub fn remove<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
pub fn remove<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
||||||
let mut cookie = cookie.into();
|
let mut cookie = cookie.into();
|
||||||
Self::set_removal_defaults(&mut cookie);
|
Self::set_removal_defaults(&mut cookie);
|
||||||
self.ops.lock().push(Op::Remove(cookie, false));
|
self.ops.lock().push(Op::Remove(cookie));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Removes the private `cookie` from the collection.
|
/// Removes the private `cookie` from the collection.
|
||||||
|
@ -432,7 +431,7 @@ impl<'a> CookieJar<'a> {
|
||||||
pub fn remove_private<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
pub fn remove_private<C: Into<Cookie<'static>>>(&self, cookie: C) {
|
||||||
let mut cookie = cookie.into();
|
let mut cookie = cookie.into();
|
||||||
Self::set_removal_defaults(&mut cookie);
|
Self::set_removal_defaults(&mut cookie);
|
||||||
self.ops.lock().push(Op::Remove(cookie, true));
|
self.ops.lock().push(Op::Remove(cookie));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all of the _original_ cookies present in this
|
/// Returns an iterator over all of the _original_ cookies present in this
|
||||||
|
@ -477,7 +476,7 @@ impl<'a> CookieJar<'a> {
|
||||||
Op::Add(c, true) => {
|
Op::Add(c, true) => {
|
||||||
jar.private_mut(&self.state.config.secret_key.key).add(c);
|
jar.private_mut(&self.state.config.secret_key.key).add(c);
|
||||||
}
|
}
|
||||||
Op::Remove(mut c, _) => {
|
Op::Remove(mut c) => {
|
||||||
if self.jar.get(c.name()).is_some() {
|
if self.jar.get(c.name()).is_some() {
|
||||||
c.make_removal();
|
c.make_removal();
|
||||||
jar.add(c);
|
jar.add(c);
|
||||||
|
@ -595,7 +594,7 @@ impl<'a> Clone for CookieJar<'a> {
|
||||||
impl Op {
|
impl Op {
|
||||||
fn cookie(&self) -> &Cookie<'static> {
|
fn cookie(&self) -> &Cookie<'static> {
|
||||||
match self {
|
match self {
|
||||||
Op::Add(c, _) | Op::Remove(c, _) => c
|
Op::Add(c, _) | Op::Remove(c) => c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
//! Types that map to concepts in HTTP.
|
||||||
|
//!
|
||||||
|
//! This module exports types that map to HTTP concepts or to the underlying
|
||||||
|
//! HTTP library when needed.
|
||||||
|
|
||||||
|
mod cookies;
|
||||||
|
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use rocket_http::*;
|
||||||
|
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use cookies::*;
|
|
@ -7,7 +7,9 @@
|
||||||
#![cfg_attr(nightly, feature(decl_macro))]
|
#![cfg_attr(nightly, feature(decl_macro))]
|
||||||
|
|
||||||
#![warn(rust_2018_idioms)]
|
#![warn(rust_2018_idioms)]
|
||||||
#![warn(missing_docs)]
|
// #![warn(missing_docs)]
|
||||||
|
#![allow(async_fn_in_trait)]
|
||||||
|
#![allow(refining_impl_trait)]
|
||||||
|
|
||||||
//! # Rocket - Core API Documentation
|
//! # Rocket - Core API Documentation
|
||||||
//!
|
//!
|
||||||
|
@ -109,18 +111,24 @@
|
||||||
|
|
||||||
/// These are public dependencies! Update docs if these are changed, especially
|
/// These are public dependencies! Update docs if these are changed, especially
|
||||||
/// figment's version number in docs.
|
/// figment's version number in docs.
|
||||||
#[doc(hidden)] pub use yansi;
|
#[doc(hidden)]
|
||||||
#[doc(hidden)] pub use async_stream;
|
pub use yansi;
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub use async_stream;
|
||||||
pub use futures;
|
pub use futures;
|
||||||
pub use tokio;
|
pub use tokio;
|
||||||
pub use figment;
|
pub use figment;
|
||||||
pub use time;
|
pub use time;
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
#[macro_use] pub mod log;
|
#[macro_use]
|
||||||
#[macro_use] pub mod outcome;
|
pub mod log;
|
||||||
#[macro_use] pub mod data;
|
#[macro_use]
|
||||||
#[doc(hidden)] pub mod sentinel;
|
pub mod outcome;
|
||||||
|
#[macro_use]
|
||||||
|
pub mod data;
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub mod sentinel;
|
||||||
pub mod local;
|
pub mod local;
|
||||||
pub mod request;
|
pub mod request;
|
||||||
pub mod response;
|
pub mod response;
|
||||||
|
@ -133,74 +141,41 @@ pub mod route;
|
||||||
pub mod serde;
|
pub mod serde;
|
||||||
pub mod shield;
|
pub mod shield;
|
||||||
pub mod fs;
|
pub mod fs;
|
||||||
|
pub mod http;
|
||||||
// Reexport of HTTP everything.
|
pub mod listener;
|
||||||
pub mod http {
|
#[cfg(feature = "tls")]
|
||||||
//! Types that map to concepts in HTTP.
|
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||||
//!
|
pub mod tls;
|
||||||
//! This module exports types that map to HTTP concepts or to the underlying
|
|
||||||
//! HTTP library when needed.
|
|
||||||
|
|
||||||
#[doc(inline)]
|
|
||||||
pub use rocket_http::*;
|
|
||||||
|
|
||||||
/// Re-exported hyper HTTP library types.
|
|
||||||
///
|
|
||||||
/// All types that are re-exported from Hyper reside inside of this module.
|
|
||||||
/// These types will, with certainty, be removed with time, but they reside here
|
|
||||||
/// while necessary.
|
|
||||||
pub mod hyper {
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub use rocket_http::hyper::*;
|
|
||||||
|
|
||||||
pub use rocket_http::hyper::header;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(inline)]
|
|
||||||
pub use crate::cookies::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "mtls")]
|
#[cfg(feature = "mtls")]
|
||||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||||
pub mod mtls;
|
pub mod mtls;
|
||||||
|
|
||||||
/// TODO: We need a futures mod or something.
|
mod util;
|
||||||
mod trip_wire;
|
|
||||||
mod shutdown;
|
mod shutdown;
|
||||||
mod server;
|
mod server;
|
||||||
mod ext;
|
mod lifecycle;
|
||||||
mod state;
|
mod state;
|
||||||
mod cookies;
|
|
||||||
mod rocket;
|
mod rocket;
|
||||||
mod router;
|
mod router;
|
||||||
mod phase;
|
mod phase;
|
||||||
|
mod erased;
|
||||||
|
|
||||||
|
#[doc(hidden)] pub use either::Either;
|
||||||
|
|
||||||
|
#[doc(inline)] pub use rocket_codegen::*;
|
||||||
|
|
||||||
#[doc(inline)] pub use crate::response::Response;
|
#[doc(inline)] pub use crate::response::Response;
|
||||||
#[doc(inline)] pub use crate::data::Data;
|
#[doc(inline)] pub use crate::data::Data;
|
||||||
#[doc(inline)] pub use crate::config::Config;
|
#[doc(inline)] pub use crate::config::Config;
|
||||||
#[doc(inline)] pub use crate::catcher::Catcher;
|
#[doc(inline)] pub use crate::catcher::Catcher;
|
||||||
#[doc(inline)] pub use crate::route::Route;
|
#[doc(inline)] pub use crate::route::Route;
|
||||||
#[doc(hidden)] pub use either::Either;
|
#[doc(inline)] pub use crate::phase::{Phase, Build, Ignite, Orbit};
|
||||||
#[doc(inline)] pub use phase::{Phase, Build, Ignite, Orbit};
|
#[doc(inline)] pub use crate::error::Error;
|
||||||
#[doc(inline)] pub use error::Error;
|
#[doc(inline)] pub use crate::sentinel::Sentinel;
|
||||||
#[doc(inline)] pub use sentinel::Sentinel;
|
|
||||||
#[doc(inline)] pub use crate::request::Request;
|
#[doc(inline)] pub use crate::request::Request;
|
||||||
#[doc(inline)] pub use crate::rocket::Rocket;
|
#[doc(inline)] pub use crate::rocket::Rocket;
|
||||||
#[doc(inline)] pub use crate::shutdown::Shutdown;
|
#[doc(inline)] pub use crate::shutdown::Shutdown;
|
||||||
#[doc(inline)] pub use crate::state::State;
|
#[doc(inline)] pub use crate::state::State;
|
||||||
#[doc(inline)] pub use rocket_codegen::*;
|
|
||||||
|
|
||||||
/// Creates a [`Rocket`] instance with the default config provider: aliases
|
|
||||||
/// [`Rocket::build()`].
|
|
||||||
pub fn build() -> Rocket<Build> {
|
|
||||||
Rocket::build()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a [`Rocket`] instance with a custom config provider: aliases
|
|
||||||
/// [`Rocket::custom()`].
|
|
||||||
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
|
||||||
Rocket::custom(provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Retrofits support for `async fn` in trait impls and declarations.
|
/// Retrofits support for `async fn` in trait impls and declarations.
|
||||||
///
|
///
|
||||||
|
@ -231,6 +206,20 @@ pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
||||||
#[doc(inline)]
|
#[doc(inline)]
|
||||||
pub use async_trait::async_trait;
|
pub use async_trait::async_trait;
|
||||||
|
|
||||||
|
const WORKER_PREFIX: &'static str = "rocket-worker";
|
||||||
|
|
||||||
|
/// Creates a [`Rocket`] instance with the default config provider: aliases
|
||||||
|
/// [`Rocket::build()`].
|
||||||
|
pub fn build() -> Rocket<Build> {
|
||||||
|
Rocket::build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a [`Rocket`] instance with a custom config provider: aliases
|
||||||
|
/// [`Rocket::custom()`].
|
||||||
|
pub fn custom<T: figment::Provider>(provider: T) -> Rocket<Build> {
|
||||||
|
Rocket::custom(provider)
|
||||||
|
}
|
||||||
|
|
||||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R
|
pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, name: &str) -> R
|
||||||
|
@ -255,7 +244,7 @@ pub fn async_run<F, R>(fut: F, workers: usize, sync: usize, force_end: bool, nam
|
||||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
|
pub fn async_test<R>(fut: impl std::future::Future<Output = R>) -> R {
|
||||||
async_run(fut, 1, 32, true, "rocket-worker-test-thread")
|
async_run(fut, 1, 32, true, &format!("{WORKER_PREFIX}-test-thread"))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
/// WARNING: This is unstable! Do not use this method outside of Rocket!
|
||||||
|
@ -276,7 +265,7 @@ pub fn async_main<R>(fut: impl std::future::Future<Output = R> + Send) -> R {
|
||||||
let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail);
|
let workers = fig.extract_inner(Config::WORKERS).unwrap_or_else(bail);
|
||||||
let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail);
|
let max_blocking = fig.extract_inner(Config::MAX_BLOCKING).unwrap_or_else(bail);
|
||||||
let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail);
|
let force = fig.focus(Config::SHUTDOWN).extract_inner("force").unwrap_or_else(bail);
|
||||||
async_run(fut, workers, max_blocking, force, "rocket-worker-thread")
|
async_run(fut, workers, max_blocking, force, &format!("{WORKER_PREFIX}-thread"))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Executes a `future` to completion on a new tokio-based Rocket async runtime.
|
/// Executes a `future` to completion on a new tokio-based Rocket async runtime.
|
||||||
|
@ -359,3 +348,14 @@ pub fn execute<R, F>(future: F) -> R
|
||||||
{
|
{
|
||||||
async_main(future)
|
async_main(future)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a future that evalutes to `true` exactly when there is a presently
|
||||||
|
/// running tokio async runtime that was likely started by Rocket.
|
||||||
|
fn running_within_rocket_async_rt() -> impl std::future::Future<Output = bool> {
|
||||||
|
use futures::FutureExt;
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(|| {
|
||||||
|
let this = std::thread::current();
|
||||||
|
this.name().map_or(false, |s| s.starts_with(WORKER_PREFIX))
|
||||||
|
}).map(|r| r.unwrap_or(false))
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,272 @@
|
||||||
|
use yansi::Paint;
|
||||||
|
use futures::future::{FutureExt, Future};
|
||||||
|
|
||||||
|
use crate::{route, Rocket, Orbit, Request, Response, Data};
|
||||||
|
use crate::data::IoHandler;
|
||||||
|
use crate::http::{Method, Status, Header};
|
||||||
|
use crate::outcome::Outcome;
|
||||||
|
use crate::form::Form;
|
||||||
|
|
||||||
|
// A token returned to force the execution of one method before another.
|
||||||
|
pub(crate) struct RequestToken;
|
||||||
|
|
||||||
|
async fn catch_handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
|
||||||
|
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
|
||||||
|
{
|
||||||
|
macro_rules! panic_info {
|
||||||
|
($name:expr, $e:expr) => {{
|
||||||
|
match $name {
|
||||||
|
Some(name) => error_!("Handler {} panicked.", name.primary()),
|
||||||
|
None => error_!("A handler panicked.")
|
||||||
|
};
|
||||||
|
|
||||||
|
info_!("This is an application bug.");
|
||||||
|
info_!("A panic in Rust must be treated as an exceptional event.");
|
||||||
|
info_!("Panicking is not a suitable error handling mechanism.");
|
||||||
|
info_!("Unwinding, the result of a panic, is an expensive operation.");
|
||||||
|
info_!("Panics will degrade application performance.");
|
||||||
|
info_!("Instead of panicking, return `Option` and/or `Result`.");
|
||||||
|
info_!("Values of either type can be returned directly from handlers.");
|
||||||
|
warn_!("A panic is treated as an internal server error.");
|
||||||
|
$e
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
let run = std::panic::AssertUnwindSafe(run);
|
||||||
|
let fut = std::panic::catch_unwind(move || run())
|
||||||
|
.map_err(|e| panic_info!(name, e))
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
std::panic::AssertUnwindSafe(fut)
|
||||||
|
.catch_unwind()
|
||||||
|
.await
|
||||||
|
.map_err(|e| panic_info!(name, e))
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Rocket<Orbit> {
|
||||||
|
/// Preprocess the request for Rocket things. Currently, this means:
|
||||||
|
///
|
||||||
|
/// * Rewriting the method in the request if _method form field exists.
|
||||||
|
/// * Run the request fairings.
|
||||||
|
///
|
||||||
|
/// This is the only place during lifecycle processing that `Request` is
|
||||||
|
/// mutable. Keep this in-sync with the `FromForm` derive.
|
||||||
|
pub(crate) async fn preprocess(
|
||||||
|
&self,
|
||||||
|
req: &mut Request<'_>,
|
||||||
|
data: &mut Data<'_>
|
||||||
|
) -> RequestToken {
|
||||||
|
// Check if this is a form and if the form contains the special _method
|
||||||
|
// field which we use to reinterpret the request's method.
|
||||||
|
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
|
||||||
|
let peek_buffer = data.peek(max_len).await;
|
||||||
|
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
|
||||||
|
|
||||||
|
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
|
||||||
|
let method = std::str::from_utf8(peek_buffer).ok()
|
||||||
|
.and_then(|raw_form| Form::values(raw_form).next())
|
||||||
|
.filter(|field| field.name == "_method")
|
||||||
|
.and_then(|field| field.value.parse().ok());
|
||||||
|
|
||||||
|
if let Some(method) = method {
|
||||||
|
req.set_method(method);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run request fairings.
|
||||||
|
self.fairings.handle_request(req, data).await;
|
||||||
|
|
||||||
|
RequestToken
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dispatches the request to the router and processes the outcome to
|
||||||
|
/// produce a response. If the initial outcome is a *forward* and the
|
||||||
|
/// request was a HEAD request, the request is rewritten and rerouted as a
|
||||||
|
/// GET. This is automatic HEAD handling.
|
||||||
|
///
|
||||||
|
/// After performing the above, if the outcome is a forward or error, the
|
||||||
|
/// appropriate error catcher is invoked to produce the response. Otherwise,
|
||||||
|
/// the successful response is used directly.
|
||||||
|
///
|
||||||
|
/// Finally, new cookies in the cookie jar are added to the response,
|
||||||
|
/// Rocket-specific headers are written, and response fairings are run. Note
|
||||||
|
/// that error responses have special cookie handling. See `handle_error`.
|
||||||
|
pub(crate) async fn dispatch<'r, 's: 'r>(
|
||||||
|
&'s self,
|
||||||
|
_token: RequestToken,
|
||||||
|
request: &'r Request<'s>,
|
||||||
|
data: Data<'r>,
|
||||||
|
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
|
||||||
|
) -> Response<'r> {
|
||||||
|
info!("{}:", request);
|
||||||
|
|
||||||
|
// Remember if the request is `HEAD` for later body stripping.
|
||||||
|
let was_head_request = request.method() == Method::Head;
|
||||||
|
|
||||||
|
// Route the request and run the user's handlers.
|
||||||
|
let mut response = match self.route(request, data).await {
|
||||||
|
Outcome::Success(response) => response,
|
||||||
|
Outcome::Forward((data, _)) if request.method() == Method::Head => {
|
||||||
|
info_!("Autohandling {} request.", "HEAD".primary().bold());
|
||||||
|
|
||||||
|
// Dispatch the request again with Method `GET`.
|
||||||
|
request._set_method(Method::Get);
|
||||||
|
match self.route(request, data).await {
|
||||||
|
Outcome::Success(response) => response,
|
||||||
|
Outcome::Error(status) => self.dispatch_error(status, request).await,
|
||||||
|
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Outcome::Forward((_, status)) => self.dispatch_error(status, request).await,
|
||||||
|
Outcome::Error(status) => self.dispatch_error(status, request).await,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set the cookies. Note that error responses will only include cookies
|
||||||
|
// set by the error handler. See `handle_error` for more.
|
||||||
|
let delta_jar = request.cookies().take_delta_jar();
|
||||||
|
for cookie in delta_jar.delta() {
|
||||||
|
response.adjoin_header(cookie);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a default 'Server' header if it isn't already there.
|
||||||
|
// TODO: If removing Hyper, write out `Date` header too.
|
||||||
|
if let Some(ident) = request.rocket().config.ident.as_str() {
|
||||||
|
if !response.headers().contains("Server") {
|
||||||
|
response.set_header(Header::new("Server", ident));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the response fairings.
|
||||||
|
self.fairings.handle_response(request, &mut response).await;
|
||||||
|
|
||||||
|
// Strip the body if this is a `HEAD` request.
|
||||||
|
if was_head_request {
|
||||||
|
response.strip_body();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Should upgrades be handled here? We miss them on local clients.
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extract_io_handler<'r>(
|
||||||
|
request: &Request<'_>,
|
||||||
|
response: &mut Response<'r>,
|
||||||
|
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
|
||||||
|
) -> Option<Box<dyn IoHandler + 'r>> {
|
||||||
|
let upgrades = request.headers().get("upgrade");
|
||||||
|
let Ok(upgrade) = response.search_upgrades(upgrades) else {
|
||||||
|
warn_!("Request wants upgrade but no I/O handler matched.");
|
||||||
|
info_!("Request is not being upgraded.");
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some((proto, io_handler)) = upgrade {
|
||||||
|
info_!("Attemping upgrade with {proto} I/O handler.");
|
||||||
|
response.set_status(Status::SwitchingProtocols);
|
||||||
|
response.set_raw_header("Connection", "Upgrade");
|
||||||
|
response.set_raw_header("Upgrade", proto.to_string());
|
||||||
|
return Some(io_handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calls the handler for each matching route until one of the handlers
|
||||||
|
/// returns success or error, or there are no additional routes to try, in
|
||||||
|
/// which case a `Forward` with the last forwarding state is returned.
|
||||||
|
#[inline]
|
||||||
|
async fn route<'s, 'r: 's>(
|
||||||
|
&'s self,
|
||||||
|
request: &'r Request<'s>,
|
||||||
|
mut data: Data<'r>,
|
||||||
|
) -> route::Outcome<'r> {
|
||||||
|
// Go through all matching routes until we fail or succeed or run out of
|
||||||
|
// routes to try, in which case we forward with the last status.
|
||||||
|
let mut status = Status::NotFound;
|
||||||
|
for route in self.router.route(request) {
|
||||||
|
// Retrieve and set the requests parameters.
|
||||||
|
info_!("Matched: {}", route);
|
||||||
|
request.set_route(route);
|
||||||
|
|
||||||
|
let name = route.name.as_deref();
|
||||||
|
let outcome = catch_handle(name, || route.handler.handle(request, data)).await
|
||||||
|
.unwrap_or(Outcome::Error(Status::InternalServerError));
|
||||||
|
|
||||||
|
// Check if the request processing completed (Some) or if the
|
||||||
|
// request needs to be forwarded. If it does, continue the loop
|
||||||
|
// (None) to try again.
|
||||||
|
info_!("{}", outcome.log_display());
|
||||||
|
match outcome {
|
||||||
|
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
|
||||||
|
Outcome::Forward(forwarded) => (data, status) = forwarded,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
error_!("No matching routes for {}.", request);
|
||||||
|
Outcome::Forward((data, status))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invokes the catcher for `status`. Returns the response on success.
|
||||||
|
//
|
||||||
|
// Resets the cookie jar delta state to prevent any modifications from
|
||||||
|
// earlier unsuccessful paths from being reflected in the error response.
|
||||||
|
//
|
||||||
|
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
|
||||||
|
// the (infallible) default 500 error cather is used.
|
||||||
|
pub(crate) async fn dispatch_error<'r, 's: 'r>(
|
||||||
|
&'s self,
|
||||||
|
mut status: Status,
|
||||||
|
req: &'r Request<'s>
|
||||||
|
) -> Response<'r> {
|
||||||
|
// We may wish to relax this in the future.
|
||||||
|
req.cookies().reset_delta();
|
||||||
|
|
||||||
|
// Dispatch to the `status` catcher.
|
||||||
|
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it fails and it's not a 500, try the 500 catcher.
|
||||||
|
if status != Status::InternalServerError {
|
||||||
|
error_!("Catcher failed. Attempting 500 error catcher.");
|
||||||
|
status = Status::InternalServerError;
|
||||||
|
if let Ok(r) = self.invoke_catcher(status, req).await {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it failed again or if it was already a 500, use Rocket's default.
|
||||||
|
error_!("{} catcher failed. Using Rocket default 500.", status.code);
|
||||||
|
crate::catcher::default_handler(Status::InternalServerError, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the handler with `req` for catcher with status `status`.
|
||||||
|
///
|
||||||
|
/// In order of preference, invoked handler is:
|
||||||
|
/// * the user's registered handler for `status`
|
||||||
|
/// * the user's registered `default` handler
|
||||||
|
/// * Rocket's default handler for `status`
|
||||||
|
///
|
||||||
|
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
|
||||||
|
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
|
||||||
|
/// handler panicked while executing.
|
||||||
|
async fn invoke_catcher<'s, 'r: 's>(
|
||||||
|
&'s self,
|
||||||
|
status: Status,
|
||||||
|
req: &'r Request<'s>
|
||||||
|
) -> Result<Response<'r>, Option<Status>> {
|
||||||
|
if let Some(catcher) = self.router.catch(status, req) {
|
||||||
|
warn_!("Responding with registered {} catcher.", catcher);
|
||||||
|
let name = catcher.name.as_deref();
|
||||||
|
catch_handle(name, || catcher.handler.handle(status, req)).await
|
||||||
|
.map(|result| result.map_err(Some))
|
||||||
|
.unwrap_or_else(|| Err(None))
|
||||||
|
} else {
|
||||||
|
let code = status.code.blue().bold();
|
||||||
|
warn_!("No {} catcher registered. Using Rocket default.", code);
|
||||||
|
Ok(crate::catcher::default_handler(status, req))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
use futures::TryFutureExt;
|
||||||
|
|
||||||
|
use crate::listener::Listener;
|
||||||
|
|
||||||
|
pub trait Bindable: Sized {
|
||||||
|
type Listener: Listener + 'static;
|
||||||
|
|
||||||
|
type Error: std::error::Error + Send + 'static;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener + 'static> Bindable for L {
|
||||||
|
type Listener = L;
|
||||||
|
|
||||||
|
type Error = std::convert::Infallible;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
|
||||||
|
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;
|
||||||
|
|
||||||
|
type Error = either::Either<A::Error, B::Error>;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||||
|
match self {
|
||||||
|
either::Either::Left(a) => a.bind()
|
||||||
|
.map_ok(tokio_util::either::Either::Left)
|
||||||
|
.map_err(either::Either::Left)
|
||||||
|
.await,
|
||||||
|
either::Either::Right(b) => b.bind()
|
||||||
|
.map_ok(tokio_util::either::Either::Right)
|
||||||
|
.map_err(either::Either::Right)
|
||||||
|
.await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
use std::{io, time::Duration};
|
||||||
|
|
||||||
|
use crate::listener::{Listener, Endpoint};
|
||||||
|
|
||||||
|
static DURATION: Duration = Duration::from_millis(250);
|
||||||
|
|
||||||
|
pub struct Bounced<L> {
|
||||||
|
listener: L,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait BouncedExt: Sized {
|
||||||
|
fn bounced(self) -> Bounced<Self> {
|
||||||
|
Bounced { listener: self }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L> BouncedExt for L { }
|
||||||
|
|
||||||
|
fn is_recoverable(e: &io::Error) -> bool {
|
||||||
|
matches!(e.kind(),
|
||||||
|
| io::ErrorKind::ConnectionRefused
|
||||||
|
| io::ErrorKind::ConnectionAborted
|
||||||
|
| io::ErrorKind::ConnectionReset)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener + Sync> Bounced<L> {
|
||||||
|
#[inline]
|
||||||
|
pub async fn accept_next(&self) -> <Self as Listener>::Accept {
|
||||||
|
loop {
|
||||||
|
match self.listener.accept().await {
|
||||||
|
Ok(accept) => return accept,
|
||||||
|
Err(e) if is_recoverable(&e) => warn!("recoverable connection error: {e}"),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("accept error: {e} [retrying in {}ms]", DURATION.as_millis());
|
||||||
|
tokio::time::sleep(DURATION).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener + Sync> Listener for Bounced<L> {
|
||||||
|
type Accept = L::Accept;
|
||||||
|
|
||||||
|
type Connection = L::Connection;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
Ok(self.accept_next().await)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||||
|
self.listener.connect(accept).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
self.listener.socket_addr()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,273 @@
|
||||||
|
use std::io;
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::task::{Poll, Context};
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use tokio::time::{sleep, Sleep};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}};
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
|
use crate::{config, Shutdown};
|
||||||
|
use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint};
|
||||||
|
|
||||||
|
// Rocket wraps all connections in a `CancellableIo` struct, an internal
|
||||||
|
// structure that gracefully closes I/O when it receives a signal. That signal
|
||||||
|
// is the `shutdown` future. When the future resolves, `CancellableIo` begins to
|
||||||
|
// terminate in grace, mercy, and finally force close phases. Since all
|
||||||
|
// connections are wrapped in `CancellableIo`, this eventually ends all I/O.
|
||||||
|
//
|
||||||
|
// At that point, unless a user spawned an infinite, stand-alone task that isn't
|
||||||
|
// monitoring `Shutdown`, all tasks should resolve. This means that all
|
||||||
|
// instances of the shared `Arc<Rocket>` are dropped and we can return the owned
|
||||||
|
// instance of `Rocket`.
|
||||||
|
//
|
||||||
|
// Unfortunately, the Hyper `server` future resolves as soon as it has finished
|
||||||
|
// processing requests without respect for ongoing responses. That is, `server`
|
||||||
|
// resolves even when there are running tasks that are generating a response.
|
||||||
|
// So, `server` resolving implies little to nothing about the state of
|
||||||
|
// connections. As a result, we depend on the timing of grace + mercy + some
|
||||||
|
// buffer to determine when all connections should be closed, thus all tasks
|
||||||
|
// should be complete, thus all references to `Arc<Rocket>` should be dropped
|
||||||
|
// and we can get a unique reference.
|
||||||
|
pin_project! {
|
||||||
|
pub struct CancellableListener<F, L> {
|
||||||
|
pub trigger: F,
|
||||||
|
#[pin]
|
||||||
|
pub listener: L,
|
||||||
|
pub grace: Duration,
|
||||||
|
pub mercy: Duration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// I/O that can be cancelled when a future `F` resolves.
|
||||||
|
#[must_use = "futures do nothing unless polled"]
|
||||||
|
pub struct CancellableIo<F, I> {
|
||||||
|
#[pin]
|
||||||
|
io: Option<I>,
|
||||||
|
#[pin]
|
||||||
|
trigger: Fuse<F>,
|
||||||
|
state: State,
|
||||||
|
grace: Duration,
|
||||||
|
mercy: Duration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum State {
|
||||||
|
/// I/O has not been cancelled. Proceed as normal.
|
||||||
|
Active,
|
||||||
|
/// I/O has been cancelled. See if we can finish before the timer expires.
|
||||||
|
Grace(Pin<Box<Sleep>>),
|
||||||
|
/// Grace period elapsed. Shutdown the connection, waiting for the timer
|
||||||
|
/// until we force close.
|
||||||
|
Mercy(Pin<Box<Sleep>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CancellableExt: Sized {
|
||||||
|
fn cancellable(
|
||||||
|
self,
|
||||||
|
trigger: Shutdown,
|
||||||
|
config: &config::Shutdown
|
||||||
|
) -> CancellableListener<Shutdown, Self> {
|
||||||
|
if let Some(mut stream) = config.signal_stream() {
|
||||||
|
let trigger = trigger.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(sig) = stream.next().await {
|
||||||
|
if trigger.0.tripped() {
|
||||||
|
warn!("Received {}. Shutdown already in progress.", sig);
|
||||||
|
} else {
|
||||||
|
warn!("Received {}. Requesting shutdown.", sig);
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger.0.trip();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
CancellableListener {
|
||||||
|
trigger,
|
||||||
|
listener: self,
|
||||||
|
grace: config.grace(),
|
||||||
|
mercy: config.mercy(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener> CancellableExt for L { }
|
||||||
|
|
||||||
|
fn time_out() -> io::Error {
|
||||||
|
io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gone() -> io::Error {
|
||||||
|
io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated")
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L, F> CancellableListener<F, Bounced<L>>
|
||||||
|
where L: Listener + Sync,
|
||||||
|
F: Future + Unpin + Clone + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
pub async fn accept_next(&self) -> Option<<Self as Listener>::Accept> {
|
||||||
|
let next = std::pin::pin!(self.listener.accept_next());
|
||||||
|
match select(next, self.trigger.clone()).await {
|
||||||
|
Either::Left((next, _)) => Some(next),
|
||||||
|
Either::Right(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L, F> CancellableListener<F, L>
|
||||||
|
where L: Listener + Sync,
|
||||||
|
F: Future + Clone + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
fn io<C>(&self, conn: C) -> CancellableIo<F, C> {
|
||||||
|
CancellableIo {
|
||||||
|
io: Some(conn),
|
||||||
|
trigger: self.trigger.clone().fuse(),
|
||||||
|
state: State::Active,
|
||||||
|
grace: self.grace,
|
||||||
|
mercy: self.mercy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L, F> Listener for CancellableListener<F, L>
|
||||||
|
where L: Listener + Sync,
|
||||||
|
F: Future + Clone + Send + Sync + Unpin + 'static
|
||||||
|
{
|
||||||
|
type Accept = L::Accept;
|
||||||
|
|
||||||
|
type Connection = CancellableIo<F, L::Connection>;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
let accept = std::pin::pin!(self.listener.accept());
|
||||||
|
match select(accept, self.trigger.clone()).await {
|
||||||
|
Either::Left((result, _)) => result,
|
||||||
|
Either::Right(_) => Err(gone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||||
|
let conn = std::pin::pin!(self.listener.connect(accept));
|
||||||
|
match select(conn, self.trigger.clone()).await {
|
||||||
|
Either::Left((conn, _)) => Ok(self.io(conn?)),
|
||||||
|
Either::Right(_) => Err(gone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
self.listener.socket_addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future, I: AsyncWrite> CancellableIo<F, I> {
|
||||||
|
fn inner(&self) -> Option<&I> {
|
||||||
|
self.io.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run `do_io` while connection processing should continue.
|
||||||
|
fn poll_trigger_then<T>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll<io::Result<T>>,
|
||||||
|
) -> Poll<io::Result<T>> {
|
||||||
|
let mut me = self.as_mut().project();
|
||||||
|
let io = match me.io.as_pin_mut() {
|
||||||
|
Some(io) => io,
|
||||||
|
None => return Poll::Ready(Err(gone())),
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match me.state {
|
||||||
|
State::Active => {
|
||||||
|
if me.trigger.as_mut().poll(cx).is_ready() {
|
||||||
|
*me.state = State::Grace(Box::pin(sleep(*me.grace)));
|
||||||
|
} else {
|
||||||
|
return do_io(io, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::Grace(timer) => {
|
||||||
|
if timer.as_mut().poll(cx).is_ready() {
|
||||||
|
*me.state = State::Mercy(Box::pin(sleep(*me.mercy)));
|
||||||
|
} else {
|
||||||
|
return do_io(io, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::Mercy(timer) => {
|
||||||
|
if timer.as_mut().poll(cx).is_ready() {
|
||||||
|
self.project().io.set(None);
|
||||||
|
return Poll::Ready(Err(time_out()));
|
||||||
|
} else {
|
||||||
|
let result = futures::ready!(io.poll_shutdown(cx));
|
||||||
|
self.project().io.set(None);
|
||||||
|
return match result {
|
||||||
|
Err(e) => Poll::Ready(Err(e)),
|
||||||
|
Ok(()) => Poll::Ready(Err(gone()))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future, I: AsyncRead + AsyncWrite> AsyncRead for CancellableIo<F, I> {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future, I: AsyncWrite> AsyncWrite for CancellableIo<F, I> {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
bufs: &[io::IoSlice<'_>],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_write_vectored(&self) -> bool {
|
||||||
|
self.inner().map(|io| io.is_write_vectored()).unwrap_or(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future, C: Connection> Connection for CancellableIo<F, C>
|
||||||
|
where F: Unpin + Send + 'static
|
||||||
|
{
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||||
|
self.inner()
|
||||||
|
.ok_or_else(|| gone())
|
||||||
|
.and_then(|io| io.peer_address())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||||
|
self.inner().and_then(|io| io.peer_certificates())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
use std::io;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
use tokio_util::either::Either;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
|
use super::Endpoint;
|
||||||
|
|
||||||
|
/// A collection of raw certificate data.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);
|
||||||
|
|
||||||
|
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint>;
|
||||||
|
|
||||||
|
/// DER-encoded X.509 certificate chain presented by the client, if any.
|
||||||
|
///
|
||||||
|
/// The certificate order must be as it appears in the TLS protocol: the
|
||||||
|
/// first certificate relates to the peer, the second certifies the first,
|
||||||
|
/// the third certifies the second, and so on.
|
||||||
|
///
|
||||||
|
/// Defaults to an empty vector to indicate that no certificates were
|
||||||
|
/// presented.
|
||||||
|
fn peer_certificates(&self) -> Option<Certificates<'_>> { None }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A: Connection, B: Connection> Connection for Either<A, B> {
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||||
|
match self {
|
||||||
|
Either::Left(c) => c.peer_address(),
|
||||||
|
Either::Right(c) => c.peer_address(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||||
|
match self {
|
||||||
|
Either::Left(c) => c.peer_certificates(),
|
||||||
|
Either::Right(c) => c.peer_certificates(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Certificates<'_> {
|
||||||
|
pub fn into_owned(self) -> Certificates<'static> {
|
||||||
|
let cow = self.0.into_iter()
|
||||||
|
.map(|der| der.clone().into_owned())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into();
|
||||||
|
|
||||||
|
Certificates(cow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||||
|
mod der {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub use crate::mtls::CertificateDer;
|
||||||
|
|
||||||
|
impl<'r> Certificates<'r> {
|
||||||
|
pub(crate) fn inner(&self) -> &[CertificateDer<'r>] {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r> From<&'r [CertificateDer<'r>]> for Certificates<'r> {
|
||||||
|
fn from(value: &'r [CertificateDer<'r>]) -> Self {
|
||||||
|
Certificates(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<CertificateDer<'static>>> for Certificates<'static> {
|
||||||
|
fn from(value: Vec<CertificateDer<'static>>) -> Self {
|
||||||
|
Certificates(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "mtls"))]
|
||||||
|
mod der {
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CertificateDer<'r>(PhantomData<&'r [u8]>);
|
||||||
|
|
||||||
|
impl CertificateDer<'_> {
|
||||||
|
pub fn into_owned(self) -> CertificateDer<'static> {
|
||||||
|
CertificateDer(PhantomData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
use either::Either;
|
||||||
|
|
||||||
|
use crate::listener::{Bindable, Endpoint};
|
||||||
|
use crate::error::{Error, ErrorKind};
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
pub struct DefaultListener {
|
||||||
|
#[serde(default)]
|
||||||
|
pub address: Endpoint,
|
||||||
|
pub port: Option<u16>,
|
||||||
|
pub reuse: Option<bool>,
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
pub tls: Option<crate::tls::TlsConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
|
||||||
|
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
|
||||||
|
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
|
||||||
|
|
||||||
|
impl DefaultListener {
|
||||||
|
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
|
||||||
|
match &self.address {
|
||||||
|
Endpoint::Tcp(mut address) => {
|
||||||
|
self.port.map(|port| address.set_port(port));
|
||||||
|
Ok(BaseBindable::Left(address))
|
||||||
|
},
|
||||||
|
#[cfg(unix)]
|
||||||
|
Endpoint::Unix(path) => {
|
||||||
|
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
|
||||||
|
Ok(BaseBindable::Right(uds))
|
||||||
|
},
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
Endpoint::Unix(_) => {
|
||||||
|
let msg = "Unix domain sockets unavailable on non-unix platforms.";
|
||||||
|
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||||
|
Err(Error::new(ErrorKind::Bind(boxed)))
|
||||||
|
},
|
||||||
|
other => {
|
||||||
|
let msg = format!("unsupported default listener address: {other}");
|
||||||
|
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
|
||||||
|
Err(Error::new(ErrorKind::Bind(boxed)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
if let Some(tls) = self.tls.clone() {
|
||||||
|
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
|
||||||
|
}
|
||||||
|
|
||||||
|
TlsBindable::Right(inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
|
||||||
|
self.base_bindable()
|
||||||
|
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,281 @@
|
||||||
|
use std::fmt;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::any::Any;
|
||||||
|
use std::net::{SocketAddr as TcpAddr, Ipv4Addr, AddrParseError};
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::de;
|
||||||
|
|
||||||
|
use crate::http::uncased::AsUncased;
|
||||||
|
|
||||||
|
pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
|
||||||
|
|
||||||
|
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;
|
||||||
|
#[cfg(feature = "tls")] type TlsInfo = Option<crate::tls::TlsConfig>;
|
||||||
|
|
||||||
|
/// # Conversions
|
||||||
|
///
|
||||||
|
/// * [`&str`] - parse with [`FromStr`]
|
||||||
|
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`ListenerAddr::Unix`]
|
||||||
|
/// * [`std::net::SocketAddr`] - infallibly as [ListenerAddr::Tcp]
|
||||||
|
/// * [`PathBuf`] - infallibly as [`ListenerAddr::Unix`]
|
||||||
|
// TODO: Rename to something better. `Endpoint`?
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Endpoint {
|
||||||
|
Tcp(TcpAddr),
|
||||||
|
Unix(PathBuf),
|
||||||
|
Tls(Arc<Endpoint>, TlsInfo),
|
||||||
|
Custom(Arc<dyn EndpointAddr>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Endpoint {
|
||||||
|
pub fn new<T: EndpointAddr>(value: T) -> Endpoint {
|
||||||
|
Endpoint::Custom(Arc::new(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tcp(&self) -> Option<TcpAddr> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Tcp(addr) => Some(*addr),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unix(&self) -> Option<&Path> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Unix(addr) => Some(addr),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls(&self) -> Option<&Endpoint> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Tls(addr, _) => Some(addr),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
pub fn tls_config(&self) -> Option<&crate::tls::TlsConfig> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Tls(_, Some(ref config)) => Some(config),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
pub fn mtls_config(&self) -> Option<&crate::mtls::MtlsConfig> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Tls(_, Some(config)) => config.mutual(),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn downcast<T: 'static>(&self) -> Option<&T> {
|
||||||
|
match self {
|
||||||
|
Endpoint::Tcp(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||||
|
Endpoint::Unix(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||||
|
Endpoint::Custom(addr) => (&*addr as &dyn Any).downcast_ref(),
|
||||||
|
Endpoint::Tls(inner, ..) => inner.downcast(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_tcp(&self) -> bool {
|
||||||
|
self.tcp().is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_unix(&self) -> bool {
|
||||||
|
self.unix().is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_tls(&self) -> bool {
|
||||||
|
self.tls().is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
pub fn with_tls(self, config: crate::tls::TlsConfig) -> Endpoint {
|
||||||
|
if self.is_tls() {
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::Tls(Arc::new(self), Some(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assume_tls(self) -> Endpoint {
|
||||||
|
if self.is_tls() {
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::Tls(Arc::new(self), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Endpoint {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
use Endpoint::*;
|
||||||
|
|
||||||
|
match self {
|
||||||
|
Tcp(addr) => write!(f, "http://{addr}"),
|
||||||
|
Unix(addr) => write!(f, "unix:{}", addr.display()),
|
||||||
|
Custom(inner) => inner.fmt(f),
|
||||||
|
Tls(inner, c) => match (&**inner, c.as_ref()) {
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
(Tcp(i), Some(c)) if c.mutual().is_some() => write!(f, "https://{i} (TLS + MTLS)"),
|
||||||
|
(Tcp(i), _) => write!(f, "https://{i} (TLS)"),
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
(i, Some(c)) if c.mutual().is_some() => write!(f, "{i} (TLS + MTLS)"),
|
||||||
|
(inner, _) => write!(f, "{inner} (TLS)"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::net::SocketAddr> for Endpoint {
|
||||||
|
fn from(value: std::net::SocketAddr) -> Self {
|
||||||
|
Self::Tcp(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::net::SocketAddrV4> for Endpoint {
|
||||||
|
fn from(value: std::net::SocketAddrV4) -> Self {
|
||||||
|
Self::Tcp(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::net::SocketAddrV6> for Endpoint {
|
||||||
|
fn from(value: std::net::SocketAddrV6) -> Self {
|
||||||
|
Self::Tcp(value.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PathBuf> for Endpoint {
|
||||||
|
fn from(value: PathBuf) -> Self {
|
||||||
|
Self::Unix(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn try_from(v: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
|
||||||
|
v.as_pathname()
|
||||||
|
.ok_or_else(|| std::io::Error::other("unix socket is not path"))
|
||||||
|
.map(|path| Endpoint::Unix(path.to_path_buf()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&str> for Endpoint {
|
||||||
|
type Error = AddrParseError;
|
||||||
|
|
||||||
|
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||||
|
value.parse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Endpoint {
|
||||||
|
fn default() -> Self {
|
||||||
|
Endpoint::Tcp(TcpAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parses an address into a `ListenerAddr`.
|
||||||
|
///
|
||||||
|
/// The syntax is:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// listener_addr = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr
|
||||||
|
/// tcp_addr := IP_ADDR | SOCKET_ADDR
|
||||||
|
/// unix_addr := PATH
|
||||||
|
///
|
||||||
|
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
|
||||||
|
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
|
||||||
|
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// If `IP_ADDR` is specified, the port defaults to `8000`.
|
||||||
|
impl FromStr for Endpoint {
|
||||||
|
type Err = AddrParseError;
|
||||||
|
|
||||||
|
fn from_str(string: &str) -> Result<Self, Self::Err> {
|
||||||
|
fn parse_tcp(string: &str, def_port: u16) -> Result<TcpAddr, AddrParseError> {
|
||||||
|
string.parse().or_else(|_| string.parse().map(|ip| TcpAddr::new(ip, def_port)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((proto, string)) = string.split_once(':') {
|
||||||
|
if proto.trim().as_uncased() == "tcp" {
|
||||||
|
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
|
||||||
|
} else if proto.trim().as_uncased() == "unix" {
|
||||||
|
return Ok(Self::Unix(PathBuf::from(string.trim())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parse_tcp(string.trim(), 8000).map(Self::Tcp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> de::Deserialize<'de> for Endpoint {
|
||||||
|
fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
|
||||||
|
struct Visitor;
|
||||||
|
|
||||||
|
impl<'de> de::Visitor<'de> for Visitor {
|
||||||
|
type Value = Endpoint;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
formatter.write_str("TCP or Unix address")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
|
||||||
|
v.parse::<Endpoint>().map_err(|e| E::custom(e.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
de.deserialize_any(Visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eq for Endpoint { }
|
||||||
|
|
||||||
|
impl PartialEq for Endpoint {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
|
||||||
|
(Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
|
||||||
|
(Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
|
||||||
|
(Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<std::net::SocketAddr> for Endpoint {
|
||||||
|
fn eq(&self, other: &std::net::SocketAddr) -> bool {
|
||||||
|
self.tcp() == Some(*other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<std::net::SocketAddrV4> for Endpoint {
|
||||||
|
fn eq(&self, other: &std::net::SocketAddrV4) -> bool {
|
||||||
|
self.tcp() == Some((*other).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<std::net::SocketAddrV6> for Endpoint {
|
||||||
|
fn eq(&self, other: &std::net::SocketAddrV6) -> bool {
|
||||||
|
self.tcp() == Some((*other).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<PathBuf> for Endpoint {
|
||||||
|
fn eq(&self, other: &PathBuf) -> bool {
|
||||||
|
self.unix() == Some(other.as_path())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<Path> for Endpoint {
|
||||||
|
fn eq(&self, other: &Path) -> bool {
|
||||||
|
self.unix() == Some(other)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
use futures::TryFutureExt;
|
||||||
|
use tokio_util::either::Either;
|
||||||
|
|
||||||
|
use crate::listener::{Connection, Endpoint};
|
||||||
|
|
||||||
|
pub trait Listener: Send + Sync {
|
||||||
|
type Accept: Send;
|
||||||
|
|
||||||
|
type Connection: Connection;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept>;
|
||||||
|
|
||||||
|
#[crate::async_bound(Send)]
|
||||||
|
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection>;
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener> Listener for &L {
|
||||||
|
type Accept = L::Accept;
|
||||||
|
|
||||||
|
type Connection = L::Connection;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
<L as Listener>::accept(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||||
|
<L as Listener>::connect(self, accept).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
<L as Listener>::socket_addr(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A: Listener, B: Listener> Listener for Either<A, B> {
|
||||||
|
type Accept = Either<A::Accept, B::Accept>;
|
||||||
|
|
||||||
|
type Connection = Either<A::Connection, B::Connection>;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
match self {
|
||||||
|
Either::Left(l) => l.accept().map_ok(Either::Left).await,
|
||||||
|
Either::Right(l) => l.accept().map_ok(Either::Right).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
|
||||||
|
match (self, accept) {
|
||||||
|
(Either::Left(l), Either::Left(a)) => l.connect(a).map_ok(Either::Left).await,
|
||||||
|
(Either::Right(l), Either::Right(a)) => l.connect(a).map_ok(Either::Right).await,
|
||||||
|
_ => unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
match self {
|
||||||
|
Either::Left(l) => l.socket_addr(),
|
||||||
|
Either::Right(l) => l.socket_addr(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
mod cancellable;
|
||||||
|
mod bounced;
|
||||||
|
mod listener;
|
||||||
|
mod endpoint;
|
||||||
|
mod connection;
|
||||||
|
mod bindable;
|
||||||
|
mod default;
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[cfg_attr(nightly, doc(cfg(unix)))]
|
||||||
|
pub mod unix;
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
|
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
|
||||||
|
pub mod tls;
|
||||||
|
pub mod tcp;
|
||||||
|
|
||||||
|
pub use endpoint::*;
|
||||||
|
pub use listener::*;
|
||||||
|
pub use connection::*;
|
||||||
|
pub use bindable::*;
|
||||||
|
pub use default::*;
|
||||||
|
|
||||||
|
pub(crate) use cancellable::*;
|
||||||
|
pub(crate) use bounced::*;
|
|
@ -0,0 +1,43 @@
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
|
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||||
|
|
||||||
|
impl Bindable for std::net::SocketAddr {
|
||||||
|
type Listener = TcpListener;
|
||||||
|
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||||
|
TcpListener::bind(self).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Listener for TcpListener {
|
||||||
|
type Accept = Self::Connection;
|
||||||
|
|
||||||
|
type Connection = TcpStream;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
let conn = self.accept().await?.0;
|
||||||
|
let _ = conn.set_nodelay(true);
|
||||||
|
let _ = conn.set_linger(None);
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, conn: Self::Connection) -> io::Result<Self::Connection> {
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
self.local_addr().map(Endpoint::Tcp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection for TcpStream {
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||||
|
self.peer_addr().map(Endpoint::Tcp)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
use std::io;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
|
||||||
|
use tokio_rustls::TlsAcceptor;
|
||||||
|
|
||||||
|
use crate::tls::{TlsConfig, Error};
|
||||||
|
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
|
||||||
|
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
|
||||||
|
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use tokio_rustls::server::TlsStream;
|
||||||
|
|
||||||
|
/// A TLS listener over some listener interface L.
|
||||||
|
pub struct TlsListener<L> {
|
||||||
|
listener: L,
|
||||||
|
acceptor: TlsAcceptor,
|
||||||
|
config: TlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize)]
|
||||||
|
pub struct TlsBindable<I> {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub inner: I,
|
||||||
|
pub tls: TlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TlsConfig {
|
||||||
|
pub(crate) fn acceptor(&self) -> Result<tokio_rustls::TlsAcceptor, Error> {
|
||||||
|
let provider = rustls::crypto::CryptoProvider {
|
||||||
|
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
|
||||||
|
..rustls::crypto::ring::default_provider()
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
let verifier = match self.mutual {
|
||||||
|
Some(ref mtls) => {
|
||||||
|
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
|
||||||
|
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
|
||||||
|
match mtls.mandatory {
|
||||||
|
true => verifier.build()?,
|
||||||
|
false => verifier.allow_unauthenticated().build()?,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => WebPkiClientVerifier::no_client_auth(),
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(feature = "mtls"))]
|
||||||
|
let verifier = WebPkiClientVerifier::no_client_auth();
|
||||||
|
|
||||||
|
let key = load_key(&mut self.key_reader()?)?;
|
||||||
|
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
|
||||||
|
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||||
|
.with_safe_default_protocol_versions()?
|
||||||
|
.with_client_cert_verifier(verifier)
|
||||||
|
.with_single_cert(cert_chain, key)?;
|
||||||
|
|
||||||
|
tls_config.ignore_client_order = self.prefer_server_cipher_order;
|
||||||
|
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
|
||||||
|
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
|
||||||
|
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
|
||||||
|
if cfg!(feature = "http2") {
|
||||||
|
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TlsAcceptor::from(Arc::new(tls_config)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Bindable> Bindable for TlsBindable<I> {
|
||||||
|
type Listener = TlsListener<I::Listener>;
|
||||||
|
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||||
|
Ok(TlsListener {
|
||||||
|
acceptor: self.tls.acceptor()?,
|
||||||
|
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
|
||||||
|
config: self.tls,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener + Sync> Listener for TlsListener<L>
|
||||||
|
where L::Connection: Unpin
|
||||||
|
{
|
||||||
|
type Accept = L::Accept;
|
||||||
|
|
||||||
|
type Connection = TlsStream<L::Connection>;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
self.listener.accept().await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept: L::Accept) -> io::Result<Self::Connection> {
|
||||||
|
let conn = self.listener.connect(accept).await?;
|
||||||
|
self.acceptor.accept(conn).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
Ok(self.listener.socket_addr()?.with_tls(self.config.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: Connection + Unpin> Connection for TlsStream<C> {
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||||
|
Ok(self.get_ref().0.peer_address()?.assume_tls())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "mtls")]
|
||||||
|
fn peer_certificates(&self) -> Option<Certificates<'_>> {
|
||||||
|
let cert_chain = self.get_ref().1.peer_certificates()?;
|
||||||
|
Some(Certificates::from(cert_chain))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
use std::io;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
|
use crate::fs::NamedFile;
|
||||||
|
use crate::listener::{Listener, Bindable, Connection, Endpoint};
|
||||||
|
use crate::util::unix;
|
||||||
|
|
||||||
|
pub use tokio::net::UnixStream;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct UdsConfig {
|
||||||
|
/// Socket address.
|
||||||
|
pub path: PathBuf,
|
||||||
|
/// Recreate a socket that already exists.
|
||||||
|
pub reuse: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UdsListener {
|
||||||
|
path: PathBuf,
|
||||||
|
lock: Option<NamedFile>,
|
||||||
|
listener: tokio::net::UnixListener,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bindable for UdsConfig {
|
||||||
|
type Listener = UdsListener;
|
||||||
|
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
async fn bind(self) -> Result<Self::Listener, Self::Error> {
|
||||||
|
let lock = if self.reuse.unwrap_or(true) {
|
||||||
|
let lock_ext = match self.path.extension().and_then(|s| s.to_str()) {
|
||||||
|
Some(ext) if !ext.is_empty() => format!("{}.lock", ext),
|
||||||
|
_ => "lock".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut opts = tokio::fs::File::options();
|
||||||
|
opts.create(true).write(true);
|
||||||
|
let lock_path = self.path.with_extension(lock_ext);
|
||||||
|
let lock_file = NamedFile::open_with(lock_path, &opts).await?;
|
||||||
|
|
||||||
|
unix::lock_exlusive_nonblocking(lock_file.file())?;
|
||||||
|
if self.path.exists() {
|
||||||
|
tokio::fs::remove_file(&self.path).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(lock_file)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sometimes, we get `AddrInUse`, even though we've tried deleting the
|
||||||
|
// socket. If all is well, eventually the socket will _really_ be gone,
|
||||||
|
// and this will succeed. So let's try a few times.
|
||||||
|
let mut retries = 5;
|
||||||
|
let listener = loop {
|
||||||
|
match tokio::net::UnixListener::bind(&self.path) {
|
||||||
|
Ok(listener) => break listener,
|
||||||
|
Err(e) if self.path.exists() && lock.is_none() => return Err(e),
|
||||||
|
Err(_) if retries > 0 => {
|
||||||
|
retries -= 1;
|
||||||
|
sleep(Duration::from_millis(100)).await;
|
||||||
|
},
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(UdsListener { lock, listener, path: self.path, })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Listener for UdsListener {
|
||||||
|
type Accept = UnixStream;
|
||||||
|
|
||||||
|
type Connection = Self::Accept;
|
||||||
|
|
||||||
|
async fn accept(&self) -> io::Result<Self::Accept> {
|
||||||
|
Ok(self.listener.accept().await?.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, accept:Self::Accept) -> io::Result<Self::Connection> {
|
||||||
|
Ok(accept)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_addr(&self) -> io::Result<Endpoint> {
|
||||||
|
self.listener.local_addr()?.try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection for UnixStream {
|
||||||
|
fn peer_address(&self) -> io::Result<Endpoint> {
|
||||||
|
self.local_addr()?.try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UdsListener {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(lock) = &self.lock {
|
||||||
|
let _ = std::fs::remove_file(&self.path);
|
||||||
|
let _ = std::fs::remove_file(lock.path());
|
||||||
|
let _ = unix::unlock_nonblocking(lock.file());
|
||||||
|
} else {
|
||||||
|
let _ = std::fs::remove_file(&self.path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,8 @@ use parking_lot::RwLock;
|
||||||
|
|
||||||
use crate::{Rocket, Phase, Orbit, Ignite, Error};
|
use crate::{Rocket, Phase, Orbit, Ignite, Error};
|
||||||
use crate::local::asynchronous::{LocalRequest, LocalResponse};
|
use crate::local::asynchronous::{LocalRequest, LocalResponse};
|
||||||
use crate::http::{Method, uri::Origin, private::cookie};
|
use crate::http::{Method, uri::Origin};
|
||||||
|
use crate::listener::Endpoint;
|
||||||
|
|
||||||
/// An `async` client to construct and dispatch local requests.
|
/// An `async` client to construct and dispatch local requests.
|
||||||
///
|
///
|
||||||
|
@ -55,9 +56,15 @@ pub struct Client {
|
||||||
impl Client {
|
impl Client {
|
||||||
pub(crate) async fn _new<P: Phase>(
|
pub(crate) async fn _new<P: Phase>(
|
||||||
rocket: Rocket<P>,
|
rocket: Rocket<P>,
|
||||||
tracked: bool
|
tracked: bool,
|
||||||
|
secure: bool,
|
||||||
) -> Result<Client, Error> {
|
) -> Result<Client, Error> {
|
||||||
let rocket = rocket.local_launch().await?;
|
let mut listener = Endpoint::new("local client");
|
||||||
|
if secure {
|
||||||
|
listener = listener.assume_tls();
|
||||||
|
}
|
||||||
|
|
||||||
|
let rocket = rocket.local_launch(listener).await?;
|
||||||
let cookies = RwLock::new(cookie::CookieJar::new());
|
let cookies = RwLock::new(cookie::CookieJar::new());
|
||||||
Ok(Client { rocket, cookies, tracked })
|
Ok(Client { rocket, cookies, tracked })
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ use super::{Client, LocalResponse};
|
||||||
/// let client = Client::tracked(rocket::build()).await.expect("valid rocket");
|
/// let client = Client::tracked(rocket::build()).await.expect("valid rocket");
|
||||||
/// let req = client.post("/")
|
/// let req = client.post("/")
|
||||||
/// .header(ContentType::JSON)
|
/// .header(ContentType::JSON)
|
||||||
/// .remote("127.0.0.1:8000".parse().unwrap())
|
/// .remote("127.0.0.1:8000")
|
||||||
/// .cookie(("name", "value"))
|
/// .cookie(("name", "value"))
|
||||||
/// .body(r#"{ "value": 42 }"#);
|
/// .body(r#"{ "value": 42 }"#);
|
||||||
///
|
///
|
||||||
|
@ -86,14 +86,14 @@ impl<'c> LocalRequest<'c> {
|
||||||
if self.inner().uri() == invalid {
|
if self.inner().uri() == invalid {
|
||||||
error!("invalid request URI: {:?}", invalid.path());
|
error!("invalid request URI: {:?}", invalid.path());
|
||||||
return LocalResponse::new(self.request, move |req| {
|
return LocalResponse::new(self.request, move |req| {
|
||||||
rocket.handle_error(Status::BadRequest, req)
|
rocket.dispatch_error(Status::BadRequest, req)
|
||||||
}).await
|
}).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually dispatch the request.
|
// Actually dispatch the request.
|
||||||
let mut data = Data::local(self.data);
|
let mut data = Data::local(self.data);
|
||||||
let token = rocket.preprocess_request(&mut self.request, &mut data).await;
|
let token = rocket.preprocess(&mut self.request, &mut data).await;
|
||||||
let response = LocalResponse::new(self.request, move |req| {
|
let response = LocalResponse::new(self.request, move |req| {
|
||||||
rocket.dispatch(token, req, data)
|
rocket.dispatch(token, req, data)
|
||||||
}).await;
|
}).await;
|
||||||
|
|
|
@ -53,9 +53,14 @@ use crate::{Request, Response};
|
||||||
///
|
///
|
||||||
/// For more, see [the top-level documentation](../index.html#localresponse).
|
/// For more, see [the top-level documentation](../index.html#localresponse).
|
||||||
pub struct LocalResponse<'c> {
|
pub struct LocalResponse<'c> {
|
||||||
_request: Box<Request<'c>>,
|
// XXX: SAFETY: This (dependent) field must come first due to drop order!
|
||||||
response: Response<'c>,
|
response: Response<'c>,
|
||||||
cookies: CookieJar<'c>,
|
cookies: CookieJar<'c>,
|
||||||
|
_request: Box<Request<'c>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for LocalResponse<'_> {
|
||||||
|
fn drop(&mut self) { }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'c> LocalResponse<'c> {
|
impl<'c> LocalResponse<'c> {
|
||||||
|
@ -64,7 +69,8 @@ impl<'c> LocalResponse<'c> {
|
||||||
O: Future<Output = Response<'c>> + Send
|
O: Future<Output = Response<'c>> + Send
|
||||||
{
|
{
|
||||||
// `LocalResponse` is a self-referential structure. In particular,
|
// `LocalResponse` is a self-referential structure. In particular,
|
||||||
// `inner` can refer to `_request` and its contents. As such, we must
|
// `response` and `cookies` can refer to `_request` and its contents. As
|
||||||
|
// such, we must
|
||||||
// 1) Ensure `Request` has a stable address.
|
// 1) Ensure `Request` has a stable address.
|
||||||
//
|
//
|
||||||
// This is done by `Box`ing the `Request`, using only the stable
|
// This is done by `Box`ing the `Request`, using only the stable
|
||||||
|
@ -97,7 +103,7 @@ impl<'c> LocalResponse<'c> {
|
||||||
cookies.add_original(cookie.into_owned());
|
cookies.add_original(cookie.into_owned());
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalResponse { cookies, _request: boxed_req, response, }
|
LocalResponse { _request: boxed_req, cookies, response, }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ pub struct Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool) -> Result<Client, Error> {
|
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool, secure: bool) -> Result<Client, Error> {
|
||||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||||
.thread_name("rocket-local-client-worker-thread")
|
.thread_name("rocket-local-client-worker-thread")
|
||||||
.worker_threads(1)
|
.worker_threads(1)
|
||||||
|
@ -39,7 +39,7 @@ impl Client {
|
||||||
.expect("create tokio runtime");
|
.expect("create tokio runtime");
|
||||||
|
|
||||||
// Initialize the Rocket instance
|
// Initialize the Rocket instance
|
||||||
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?);
|
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?);
|
||||||
Ok(Self { inner, runtime: RefCell::new(runtime) })
|
Ok(Self { inner, runtime: RefCell::new(runtime) })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ impl Client {
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn _with_raw_cookies<F, T>(&self, f: F) -> T
|
pub(crate) fn _with_raw_cookies<F, T>(&self, f: F) -> T
|
||||||
where F: FnOnce(&crate::http::private::cookie::CookieJar) -> T
|
where F: FnOnce(&cookie::CookieJar) -> T
|
||||||
{
|
{
|
||||||
self.inner()._with_raw_cookies(f)
|
self.inner()._with_raw_cookies(f)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ use super::{Client, LocalResponse};
|
||||||
/// let client = Client::tracked(rocket::build()).expect("valid rocket");
|
/// let client = Client::tracked(rocket::build()).expect("valid rocket");
|
||||||
/// let req = client.post("/")
|
/// let req = client.post("/")
|
||||||
/// .header(ContentType::JSON)
|
/// .header(ContentType::JSON)
|
||||||
/// .remote("127.0.0.1:8000".parse().unwrap())
|
/// .remote("127.0.0.1:8000")
|
||||||
/// .cookie(("name", "value"))
|
/// .cookie(("name", "value"))
|
||||||
/// .body(r#"{ "value": 42 }"#);
|
/// .body(r#"{ "value": 42 }"#);
|
||||||
///
|
///
|
||||||
|
|
|
@ -68,7 +68,12 @@ macro_rules! pub_client_impl {
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||||
Self::_new(rocket, true) $(.$suffix)?
|
Self::_new(rocket, true, false) $(.$suffix)?
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub $($prefix)? fn tracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||||
|
Self::_new(rocket, true, true) $(.$suffix)?
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct a new `Client` from an instance of `Rocket` _without_
|
/// Construct a new `Client` from an instance of `Rocket` _without_
|
||||||
|
@ -92,7 +97,11 @@ macro_rules! pub_client_impl {
|
||||||
/// let client = Client::untracked(rocket);
|
/// let client = Client::untracked(rocket);
|
||||||
/// ```
|
/// ```
|
||||||
pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||||
Self::_new(rocket, false) $(.$suffix)?
|
Self::_new(rocket, false, false) $(.$suffix)?
|
||||||
|
}
|
||||||
|
|
||||||
|
pub $($prefix)? fn untracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
||||||
|
Self::_new(rocket, false, true) $(.$suffix)?
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Terminates `Client` by initiating a graceful shutdown via
|
/// Terminates `Client` by initiating a graceful shutdown via
|
||||||
|
@ -135,15 +144,6 @@ macro_rules! pub_client_impl {
|
||||||
Self::tracked(rocket.configure(figment)) $(.$suffix)?
|
Self::tracked(rocket.configure(figment)) $(.$suffix)?
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deprecated alias to [`Client::tracked()`].
|
|
||||||
#[deprecated(
|
|
||||||
since = "0.6.0-dev",
|
|
||||||
note = "choose between `Client::untracked()` and `Client::tracked()`"
|
|
||||||
)]
|
|
||||||
pub $($prefix)? fn new<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
|
|
||||||
Self::tracked(rocket) $(.$suffix)?
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a reference to the `Rocket` this client is creating requests
|
/// Returns a reference to the `Rocket` this client is creating requests
|
||||||
/// for.
|
/// for.
|
||||||
///
|
///
|
||||||
|
|
|
@ -97,24 +97,40 @@ macro_rules! pub_request_impl {
|
||||||
self._request_mut().add_header(header.into());
|
self._request_mut().add_header(header.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the remote address of this request.
|
/// Set the remote address of this request to `address`.
|
||||||
|
///
|
||||||
|
/// `address` may be any type that [can be converted into a `ListenerAddr`].
|
||||||
|
/// If `address` fails to convert, the remote is left unchanged.
|
||||||
|
///
|
||||||
|
/// [can be converted into a `ListenerAddr`]: crate::listener::ListenerAddr#conversions
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// Set the remote address to "8.8.8.8:80":
|
/// Set the remote address to "8.8.8.8:80":
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
|
/// use std::net::{SocketAddrV4, Ipv4Addr};
|
||||||
|
///
|
||||||
#[doc = $import]
|
#[doc = $import]
|
||||||
///
|
///
|
||||||
/// # Client::_test(|_, request, _| {
|
/// # Client::_test(|_, request, _| {
|
||||||
/// let request: LocalRequest = request;
|
/// let request: LocalRequest = request;
|
||||||
/// let address = "8.8.8.8:80".parse().unwrap();
|
/// let req = request.remote("8.8.8.8:80");
|
||||||
/// let req = request.remote(address);
|
///
|
||||||
|
/// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80);
|
||||||
|
/// assert_eq!(req.inner().remote().unwrap(), &addr);
|
||||||
/// # });
|
/// # });
|
||||||
/// ```
|
/// ```
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn remote(mut self, address: std::net::SocketAddr) -> Self {
|
pub fn remote<T>(mut self, endpoint: T) -> Self
|
||||||
self.set_remote(address);
|
where T: TryInto<crate::listener::Endpoint>
|
||||||
|
{
|
||||||
|
if let Ok(endpoint) = endpoint.try_into() {
|
||||||
|
self.set_remote(endpoint);
|
||||||
|
} else {
|
||||||
|
warn!("remote failed to convert");
|
||||||
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,11 +244,13 @@ macro_rules! pub_request_impl {
|
||||||
#[cfg(feature = "mtls")]
|
#[cfg(feature = "mtls")]
|
||||||
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
|
||||||
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
|
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
|
||||||
use crate::http::{tls::util::load_cert_chain, private::Certificates};
|
use std::sync::Arc;
|
||||||
|
use crate::tls::util::load_cert_chain;
|
||||||
|
use crate::listener::Certificates;
|
||||||
|
|
||||||
let mut reader = std::io::BufReader::new(reader);
|
let mut reader = std::io::BufReader::new(reader);
|
||||||
let certs = load_cert_chain(&mut reader).map(Certificates::from);
|
let certs = load_cert_chain(&mut reader).map(Certificates::from);
|
||||||
self._request_mut().connection.client_certificates = certs.ok();
|
self._request_mut().connection.peer_certs = certs.ok().map(Arc::new);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
//! Support for mutual TLS client certificates.
|
|
||||||
//!
|
|
||||||
//! For details on how to configure mutual TLS, see
|
|
||||||
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
|
|
||||||
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
|
|
||||||
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
|
|
||||||
//! client certificates.
|
|
||||||
|
|
||||||
#[doc(inline)]
|
|
||||||
pub use crate::http::tls::mtls::*;
|
|
||||||
|
|
||||||
use crate::request::{Request, FromRequest, Outcome};
|
|
||||||
use crate::outcome::{try_outcome, IntoOutcome};
|
|
||||||
use crate::http::Status;
|
|
||||||
|
|
||||||
#[crate::async_trait]
|
|
||||||
impl<'r> FromRequest<'r> for Certificate<'r> {
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
|
||||||
let certs = req.connection.client_certificates.as_ref().or_forward(Status::Unauthorized);
|
|
||||||
let data = try_outcome!(try_outcome!(certs).chain_data().or_forward(Status::Unauthorized));
|
|
||||||
Certificate::parse(data).or_error(Status::Unauthorized)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,51 +1,8 @@
|
||||||
pub mod oid {
|
|
||||||
//! Lower-level OID types re-exported from
|
|
||||||
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
|
|
||||||
//! [`der-parser`](https://docs.rs/der-parser/7).
|
|
||||||
|
|
||||||
pub use x509_parser::oid_registry::*;
|
|
||||||
pub use x509_parser::objects::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod bigint {
|
|
||||||
//! Signed and unsigned big integer types re-exported from
|
|
||||||
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
|
|
||||||
pub use x509_parser::der_parser::num_bigint::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod x509 {
|
|
||||||
//! Lower-level X.509 types re-exported from
|
|
||||||
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
|
|
||||||
//!
|
|
||||||
//! Lack of documentation is directly inherited from the source crate.
|
|
||||||
//! Prefer to use Rocket's wrappers when possible.
|
|
||||||
|
|
||||||
pub use x509_parser::certificate::*;
|
|
||||||
pub use x509_parser::cri_attributes::*;
|
|
||||||
pub use x509_parser::error::*;
|
|
||||||
pub use x509_parser::extensions::*;
|
|
||||||
pub use x509_parser::revocation_list::*;
|
|
||||||
pub use x509_parser::time::*;
|
|
||||||
pub use x509_parser::x509::*;
|
|
||||||
pub use x509_parser::der_parser::der;
|
|
||||||
pub use x509_parser::der_parser::ber;
|
|
||||||
pub use x509_parser::traits::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
use std::fmt;
|
|
||||||
use std::ops::Deref;
|
|
||||||
use std::num::NonZeroUsize;
|
|
||||||
|
|
||||||
use ref_cast::RefCast;
|
use ref_cast::RefCast;
|
||||||
use x509_parser::nom;
|
|
||||||
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
|
|
||||||
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
|
||||||
|
|
||||||
use crate::listener::CertificateDer;
|
use crate::mtls::{x509, oid, bigint, Name, Result, Error};
|
||||||
|
use crate::request::{Request, FromRequest, Outcome};
|
||||||
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
use crate::http::Status;
|
||||||
/// [`Error`].
|
|
||||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|
||||||
|
|
||||||
/// A request guard for validated, verified client certificates.
|
/// A request guard for validated, verified client certificates.
|
||||||
///
|
///
|
||||||
|
@ -143,60 +100,42 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct Certificate<'a> {
|
pub struct Certificate<'a> {
|
||||||
x509: X509Certificate<'a>,
|
x509: x509::X509Certificate<'a>,
|
||||||
data: &'a CertificateDer,
|
data: &'a CertificateDer<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
pub use rustls::pki_types::CertificateDer;
|
||||||
///
|
|
||||||
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
|
|
||||||
/// complete documentation. Should the data exposed by the inherent methods not
|
|
||||||
/// suffice, this type derefs to [`x509::X509Name`].
|
|
||||||
#[repr(transparent)]
|
|
||||||
#[derive(Debug, PartialEq, RefCast)]
|
|
||||||
pub struct Name<'a>(X509Name<'a>);
|
|
||||||
|
|
||||||
/// An error returned by the [`Certificate`] request guard.
|
#[crate::async_trait]
|
||||||
///
|
impl<'r> FromRequest<'r> for Certificate<'r> {
|
||||||
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
|
type Error = Error;
|
||||||
/// guard type:
|
|
||||||
///
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
/// ```rust
|
use crate::outcome::{try_outcome, IntoOutcome};
|
||||||
/// # extern crate rocket;
|
|
||||||
/// # use rocket::get;
|
let certs = req.connection
|
||||||
/// use rocket::mtls::{self, Certificate};
|
.peer_certs
|
||||||
///
|
.as_ref()
|
||||||
/// #[get("/auth")]
|
.or_forward(Status::Unauthorized);
|
||||||
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
|
|
||||||
/// match cert {
|
let chain = try_outcome!(certs);
|
||||||
/// Ok(cert) => { /* do something with the client cert */ },
|
Certificate::parse(chain.inner()).or_error(Status::Unauthorized)
|
||||||
/// Err(e) => { /* do something with the error */ },
|
}
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum Error {
|
|
||||||
/// The certificate chain presented by the client had no certificates.
|
|
||||||
Empty,
|
|
||||||
/// The certificate contained neither a subject nor a subjectAlt extension.
|
|
||||||
NoSubject,
|
|
||||||
/// There is no subject and the subjectAlt is not marked as critical.
|
|
||||||
NonCriticalSubjectAlt,
|
|
||||||
/// An error occurred while parsing the certificate.
|
|
||||||
Parse(X509Error),
|
|
||||||
/// The certificate parsed partially but is incomplete.
|
|
||||||
///
|
|
||||||
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
|
|
||||||
/// of expected bytes is unknown.
|
|
||||||
Incomplete(Option<NonZeroUsize>),
|
|
||||||
/// The certificate contained `.0` bytes of trailing data.
|
|
||||||
Trailing(usize),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Certificate<'a> {
|
impl<'a> Certificate<'a> {
|
||||||
fn parse_one(raw: &[u8]) -> Result<X509Certificate<'_>> {
|
/// PRIVATE: For internal Rocket use only!
|
||||||
let (left, x509) = X509Certificate::from_der(raw)?;
|
fn parse<'r>(chain: &'r [CertificateDer<'r>]) -> Result<Certificate<'r>> {
|
||||||
|
let data = chain.first().ok_or_else(|| Error::Empty)?;
|
||||||
|
let x509 = Certificate::parse_one(&*data)?;
|
||||||
|
Ok(Certificate { x509, data })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_one(raw: &[u8]) -> Result<x509::X509Certificate<'_>> {
|
||||||
|
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;
|
||||||
|
use x509_parser::traits::FromDer;
|
||||||
|
|
||||||
|
let (left, x509) = x509::X509Certificate::from_der(raw)?;
|
||||||
if !left.is_empty() {
|
if !left.is_empty() {
|
||||||
return Err(Error::Trailing(left.len()));
|
return Err(Error::Trailing(left.len()));
|
||||||
}
|
}
|
||||||
|
@ -204,7 +143,7 @@ impl<'a> Certificate<'a> {
|
||||||
// Ensure we have a subject or a subjectAlt.
|
// Ensure we have a subject or a subjectAlt.
|
||||||
if x509.subject().as_raw().is_empty() {
|
if x509.subject().as_raw().is_empty() {
|
||||||
if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) {
|
if let Some(ext) = x509.extensions().iter().find(|e| e.oid == SUBJECT_ALT_NAME) {
|
||||||
if !matches!(ext.parsed_extension(), ParsedExtension::SubjectAlternativeName(..)) {
|
if let x509::ParsedExtension::SubjectAlternativeName(..) = ext.parsed_extension() {
|
||||||
return Err(Error::NoSubject);
|
return Err(Error::NoSubject);
|
||||||
} else if !ext.critical {
|
} else if !ext.critical {
|
||||||
return Err(Error::NonCriticalSubjectAlt);
|
return Err(Error::NonCriticalSubjectAlt);
|
||||||
|
@ -218,18 +157,10 @@ impl<'a> Certificate<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn inner(&self) -> &TbsCertificate<'a> {
|
fn inner(&self) -> &x509::TbsCertificate<'a> {
|
||||||
&self.x509.tbs_certificate
|
&self.x509.tbs_certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PRIVATE: For internal Rocket use only!
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
|
|
||||||
let data = chain.first().ok_or_else(|| Error::Empty)?;
|
|
||||||
let x509 = Certificate::parse_one(&data.0)?;
|
|
||||||
Ok(Certificate { x509, data })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the serial number of the X.509 certificate.
|
/// Returns the serial number of the X.509 certificate.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
@ -387,176 +318,14 @@ impl<'a> Certificate<'a> {
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn as_bytes(&self) -> &'a [u8] {
|
pub fn as_bytes(&self) -> &'a [u8] {
|
||||||
&self.data.0
|
&*self.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Deref for Certificate<'a> {
|
impl<'a> std::ops::Deref for Certificate<'a> {
|
||||||
type Target = TbsCertificate<'a>;
|
type Target = x509::TbsCertificate<'a>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
self.inner()
|
self.inner()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Name<'a> {
|
|
||||||
/// Returns the _first_ UTF-8 _string_ common name, if any.
|
|
||||||
///
|
|
||||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
|
||||||
/// This method returns the first common name attribute that is.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # #[macro_use] extern crate rocket;
|
|
||||||
/// use rocket::mtls::Certificate;
|
|
||||||
///
|
|
||||||
/// #[get("/auth")]
|
|
||||||
/// fn auth(cert: Certificate<'_>) {
|
|
||||||
/// if let Some(name) = cert.subject().common_name() {
|
|
||||||
/// println!("Hello, {}!", name);
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn common_name(&self) -> Option<&'a str> {
|
|
||||||
self.common_names().next()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an iterator over all of the UTF-8 _string_ common names in
|
|
||||||
/// `self`.
|
|
||||||
///
|
|
||||||
/// Note that common names need not be UTF-8 strings, or strings at all.
|
|
||||||
/// This method filters the common names in `self` to those that are. Use
|
|
||||||
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
|
|
||||||
/// all value types.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # #[macro_use] extern crate rocket;
|
|
||||||
/// use rocket::mtls::Certificate;
|
|
||||||
///
|
|
||||||
/// #[get("/auth")]
|
|
||||||
/// fn auth(cert: Certificate<'_>) {
|
|
||||||
/// for name in cert.issuer().common_names() {
|
|
||||||
/// println!("Issued by {}.", name);
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
|
|
||||||
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the _first_ UTF-8 _string_ email address, if any.
|
|
||||||
///
|
|
||||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
|
||||||
/// This method returns the first email address attribute that is.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # #[macro_use] extern crate rocket;
|
|
||||||
/// use rocket::mtls::Certificate;
|
|
||||||
///
|
|
||||||
/// #[get("/auth")]
|
|
||||||
/// fn auth(cert: Certificate<'_>) {
|
|
||||||
/// if let Some(email) = cert.subject().email() {
|
|
||||||
/// println!("Hello, {}!", email);
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn email(&self) -> Option<&'a str> {
|
|
||||||
self.emails().next()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
|
|
||||||
/// `self`.
|
|
||||||
///
|
|
||||||
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
|
||||||
/// This method filters the email address in `self` to those that are. Use
|
|
||||||
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
|
|
||||||
/// types.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # #[macro_use] extern crate rocket;
|
|
||||||
/// use rocket::mtls::Certificate;
|
|
||||||
///
|
|
||||||
/// #[get("/auth")]
|
|
||||||
/// fn auth(cert: Certificate<'_>) {
|
|
||||||
/// for email in cert.subject().emails() {
|
|
||||||
/// println!("Reach me at: {}", email);
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
|
|
||||||
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns `true` if `self` has no data.
|
|
||||||
///
|
|
||||||
/// When this is the case for a `subject()`, the subject data can be found
|
|
||||||
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # #[macro_use] extern crate rocket;
|
|
||||||
/// use rocket::mtls::Certificate;
|
|
||||||
///
|
|
||||||
/// #[get("/auth")]
|
|
||||||
/// fn auth(cert: Certificate<'_>) {
|
|
||||||
/// let no_data = cert.subject().is_empty();
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.0.as_raw().is_empty()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Deref for Name<'a> {
|
|
||||||
type Target = X509Name<'a>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for Name<'_> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
self.0.fmt(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
Error::Parse(e) => write!(f, "parse error: {}", e),
|
|
||||||
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
|
|
||||||
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
|
|
||||||
Error::Empty => write!(f, "empty certificate chain"),
|
|
||||||
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
|
|
||||||
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<nom::Err<X509Error>> for Error {
|
|
||||||
fn from(e: nom::Err<X509Error>) -> Self {
|
|
||||||
match e {
|
|
||||||
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
|
|
||||||
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
|
|
||||||
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for Error {
|
|
||||||
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
||||||
// match self {
|
|
||||||
// Error::Parse(e) => Some(e),
|
|
||||||
// _ => None
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
|
|
@ -0,0 +1,212 @@
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
use figment::value::magic::{RelativePathBuf, Either};
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
/// Mutual TLS configuration.
|
||||||
|
///
|
||||||
|
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which
|
||||||
|
/// provides a request guard to validate, verify, and retrieve client
|
||||||
|
/// certificates in routes.
|
||||||
|
///
|
||||||
|
/// By default, mutual TLS is disabled and client certificates are not required,
|
||||||
|
/// validated or verified. To enable mutual TLS, the `mtls` feature must be
|
||||||
|
/// enabled and support configured via two `tls.mutual` parameters:
|
||||||
|
///
|
||||||
|
/// * `ca_certs`
|
||||||
|
///
|
||||||
|
/// A required path to a PEM file or raw bytes to a DER-encoded X.509 TLS
|
||||||
|
/// certificate chain for the certificate authority to verify client
|
||||||
|
/// certificates against. When a path is configured in a file, such as
|
||||||
|
/// `Rocket.toml`, relative paths are interpreted as relative to the source
|
||||||
|
/// file's directory.
|
||||||
|
///
|
||||||
|
/// * `mandatory`
|
||||||
|
///
|
||||||
|
/// An optional boolean that control whether client authentication is
|
||||||
|
/// required.
|
||||||
|
///
|
||||||
|
/// When `true`, client authentication is required. TLS connections where
|
||||||
|
/// the client does not present a certificate are immediately terminated.
|
||||||
|
/// When `false`, the client is not required to present a certificate. In
|
||||||
|
/// either case, if a certificate _is_ presented, it must be valid or the
|
||||||
|
/// connection is terminated.
|
||||||
|
///
|
||||||
|
/// In a `Rocket.toml`, configuration might look like:
|
||||||
|
///
|
||||||
|
/// ```toml
|
||||||
|
/// [default.tls.mutual]
|
||||||
|
/// ca_certs = "/ssl/ca_cert.pem"
|
||||||
|
/// mandatory = true # when absent, defaults to false
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Programmatically, configuration might look like:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::MtlsConfig;
|
||||||
|
/// use rocket::figment::providers::Serialized;
|
||||||
|
///
|
||||||
|
/// #[launch]
|
||||||
|
/// fn rocket() -> _ {
|
||||||
|
/// let mtls = MtlsConfig::from_path("/ssl/ca_cert.pem");
|
||||||
|
/// rocket::custom(rocket::Config::figment().merge(("tls.mutual", mtls)))
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Once mTLS is configured, the [`mtls::Certificate`](crate::mtls::Certificate)
|
||||||
|
/// request guard can be used to retrieve client certificates in routes.
|
||||||
|
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MtlsConfig {
|
||||||
|
/// Path to a PEM file with, or raw bytes for, DER-encoded Certificate
|
||||||
|
/// Authority certificates which will be used to verify client-presented
|
||||||
|
/// certificates.
|
||||||
|
// TODO: Support more than one CA root.
|
||||||
|
pub(crate) ca_certs: Either<RelativePathBuf, Vec<u8>>,
|
||||||
|
/// Whether the client is required to present a certificate.
|
||||||
|
///
|
||||||
|
/// When `true`, the client is required to present a valid certificate to
|
||||||
|
/// proceed with TLS. When `false`, the client is not required to present a
|
||||||
|
/// certificate. In either case, if a certificate _is_ presented, it must be
|
||||||
|
/// valid or the connection is terminated.
|
||||||
|
#[serde(default)]
|
||||||
|
#[serde(deserialize_with = "figment::util::bool_from_str_or_int")]
|
||||||
|
pub mandatory: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MtlsConfig {
|
||||||
|
/// Constructs a `MtlsConfig` from a path to a PEM file with a certificate
|
||||||
|
/// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This
|
||||||
|
/// method does no validation; it simply creates a structure suitable for
|
||||||
|
/// passing into a [`TlsConfig`].
|
||||||
|
///
|
||||||
|
/// These certificates will be used to verify client-presented certificates
|
||||||
|
/// in TLS connections.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use rocket::mtls::MtlsConfig;
|
||||||
|
///
|
||||||
|
/// let tls_config = MtlsConfig::from_path("/ssl/ca_certs.pem");
|
||||||
|
/// ```
|
||||||
|
pub fn from_path<C: AsRef<std::path::Path>>(ca_certs: C) -> Self {
|
||||||
|
MtlsConfig {
|
||||||
|
ca_certs: Either::Left(ca_certs.as_ref().to_path_buf().into()),
|
||||||
|
mandatory: Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs a `MtlsConfig` from a byte buffer to a certificate authority
|
||||||
|
/// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no
|
||||||
|
/// validation; it simply creates a structure suitable for passing into a
|
||||||
|
/// [`TlsConfig`].
|
||||||
|
///
|
||||||
|
/// These certificates will be used to verify client-presented certificates
|
||||||
|
/// in TLS connections.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use rocket::mtls::MtlsConfig;
|
||||||
|
///
|
||||||
|
/// # let ca_certs_buf = &[];
|
||||||
|
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf);
|
||||||
|
/// ```
|
||||||
|
pub fn from_bytes(ca_certs: &[u8]) -> Self {
|
||||||
|
MtlsConfig {
|
||||||
|
ca_certs: Either::Right(ca_certs.to_vec()),
|
||||||
|
mandatory: Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets whether client authentication is required. Disabled by default.
|
||||||
|
///
|
||||||
|
/// When `true`, client authentication will be required. TLS connections
|
||||||
|
/// where the client does not present a certificate will be immediately
|
||||||
|
/// terminated. When `false`, the client is not required to present a
|
||||||
|
/// certificate. In either case, if a certificate _is_ presented, it must be
|
||||||
|
/// valid or the connection is terminated.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use rocket::mtls::MtlsConfig;
|
||||||
|
///
|
||||||
|
/// # let ca_certs_buf = &[];
|
||||||
|
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
|
||||||
|
/// ```
|
||||||
|
pub fn mandatory(mut self, mandatory: bool) -> Self {
|
||||||
|
self.mandatory = mandatory;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the value of the `ca_certs` parameter.
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use rocket::mtls::MtlsConfig;
|
||||||
|
///
|
||||||
|
/// # let ca_certs_buf = &[];
|
||||||
|
/// let mtls_config = MtlsConfig::from_bytes(ca_certs_buf).mandatory(true);
|
||||||
|
/// assert_eq!(mtls_config.ca_certs().unwrap_right(), ca_certs_buf);
|
||||||
|
/// ```
|
||||||
|
pub fn ca_certs(&self) -> either::Either<std::path::PathBuf, &[u8]> {
|
||||||
|
match &self.ca_certs {
|
||||||
|
Either::Left(path) => either::Either::Left(path.relative()),
|
||||||
|
Either::Right(bytes) => either::Either::Right(&bytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
|
||||||
|
crate::tls::config::to_reader(&self.ca_certs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::path::Path;
|
||||||
|
use figment::{Figment, providers::{Toml, Format}};
|
||||||
|
|
||||||
|
use crate::mtls::MtlsConfig;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mtls_config() {
|
||||||
|
figment::Jail::expect_with(|jail| {
|
||||||
|
jail.create_file("MTLS.toml", r#"
|
||||||
|
certs = "/ssl/cert.pem"
|
||||||
|
key = "/ssl/key.pem"
|
||||||
|
"#)?;
|
||||||
|
|
||||||
|
let figment = || Figment::from(Toml::file("MTLS.toml"));
|
||||||
|
figment().extract::<MtlsConfig>().expect_err("no ca");
|
||||||
|
|
||||||
|
jail.create_file("MTLS.toml", r#"
|
||||||
|
ca_certs = "/ssl/ca.pem"
|
||||||
|
"#)?;
|
||||||
|
|
||||||
|
let mtls: MtlsConfig = figment().extract()?;
|
||||||
|
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||||
|
assert!(!mtls.mandatory);
|
||||||
|
|
||||||
|
jail.create_file("MTLS.toml", r#"
|
||||||
|
ca_certs = "/ssl/ca.pem"
|
||||||
|
mandatory = true
|
||||||
|
"#)?;
|
||||||
|
|
||||||
|
let mtls: MtlsConfig = figment().extract()?;
|
||||||
|
assert_eq!(mtls.ca_certs().unwrap_left(), Path::new("/ssl/ca.pem"));
|
||||||
|
assert!(mtls.mandatory);
|
||||||
|
|
||||||
|
jail.create_file("MTLS.toml", r#"
|
||||||
|
ca_certs = "relative/ca.pem"
|
||||||
|
"#)?;
|
||||||
|
|
||||||
|
let mtls: MtlsConfig = figment().extract()?;
|
||||||
|
assert_eq!(mtls.ca_certs().unwrap_left(), jail.directory().join("relative/ca.pem"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
use std::fmt;
|
||||||
|
use std::num::NonZeroUsize;
|
||||||
|
|
||||||
|
use crate::mtls::x509::{self, nom};
|
||||||
|
|
||||||
|
/// An error returned by the [`Certificate`] request guard.
|
||||||
|
///
|
||||||
|
/// To retrieve this error in a handler, use an `mtls::Result<Certificate>`
|
||||||
|
/// guard type:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # extern crate rocket;
|
||||||
|
/// # use rocket::get;
|
||||||
|
/// use rocket::mtls::{self, Certificate};
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: mtls::Result<Certificate<'_>>) {
|
||||||
|
/// match cert {
|
||||||
|
/// Ok(cert) => { /* do something with the client cert */ },
|
||||||
|
/// Err(e) => { /* do something with the error */ },
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum Error {
|
||||||
|
/// The certificate chain presented by the client had no certificates.
|
||||||
|
Empty,
|
||||||
|
/// The certificate contained neither a subject nor a subjectAlt extension.
|
||||||
|
NoSubject,
|
||||||
|
/// There is no subject and the subjectAlt is not marked as critical.
|
||||||
|
NonCriticalSubjectAlt,
|
||||||
|
/// An error occurred while parsing the certificate.
|
||||||
|
Parse(x509::X509Error),
|
||||||
|
/// The certificate parsed partially but is incomplete.
|
||||||
|
///
|
||||||
|
/// If `Some(n)`, then `n` more bytes were expected. Otherwise, the number
|
||||||
|
/// of expected bytes is unknown.
|
||||||
|
Incomplete(Option<NonZeroUsize>),
|
||||||
|
/// The certificate contained `.0` bytes of trailing data.
|
||||||
|
Trailing(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Error {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Error::Parse(e) => write!(f, "parse error: {}", e),
|
||||||
|
Error::Incomplete(_) => write!(f, "incomplete certificate data"),
|
||||||
|
Error::Trailing(n) => write!(f, "found {} trailing bytes", n),
|
||||||
|
Error::Empty => write!(f, "empty certificate chain"),
|
||||||
|
Error::NoSubject => write!(f, "empty subject without subjectAlt"),
|
||||||
|
Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<nom::Err<x509::X509Error>> for Error {
|
||||||
|
fn from(e: nom::Err<x509::X509Error>) -> Self {
|
||||||
|
match e {
|
||||||
|
nom::Err::Incomplete(nom::Needed::Unknown) => Error::Incomplete(None),
|
||||||
|
nom::Err::Incomplete(nom::Needed::Size(n)) => Error::Incomplete(Some(n)),
|
||||||
|
nom::Err::Error(e) | nom::Err::Failure(e) => Error::Parse(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for Error {
|
||||||
|
// fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
// match self {
|
||||||
|
// Error::Parse(e) => Some(e),
|
||||||
|
// _ => None
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
//! Support for mutual TLS client certificates.
|
||||||
|
//!
|
||||||
|
//! For details on how to configure mutual TLS, see
|
||||||
|
//! [`MutualTls`](crate::config::MutualTls) and the [TLS
|
||||||
|
//! guide](https://rocket.rs/master/guide/configuration/#tls). See
|
||||||
|
//! [`Certificate`] for a request guard that validated, verifies, and retrieves
|
||||||
|
//! client certificates.
|
||||||
|
|
||||||
|
pub mod oid {
|
||||||
|
//! Lower-level OID types re-exported from
|
||||||
|
//! [`oid_registry`](https://docs.rs/oid-registry/0.4) and
|
||||||
|
//! [`der-parser`](https://docs.rs/der-parser/7).
|
||||||
|
|
||||||
|
pub use x509_parser::oid_registry::*;
|
||||||
|
pub use x509_parser::objects::*;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod bigint {
|
||||||
|
//! Signed and unsigned big integer types re-exported from
|
||||||
|
//! [`num_bigint`](https://docs.rs/num-bigint/0.4).
|
||||||
|
pub use x509_parser::der_parser::num_bigint::*;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod x509 {
|
||||||
|
//! Lower-level X.509 types re-exported from
|
||||||
|
//! [`x509_parser`](https://docs.rs/x509-parser/0.13).
|
||||||
|
//!
|
||||||
|
//! Lack of documentation is directly inherited from the source crate.
|
||||||
|
//! Prefer to use Rocket's wrappers when possible.
|
||||||
|
|
||||||
|
pub(crate) use x509_parser::nom;
|
||||||
|
pub use x509_parser::certificate::*;
|
||||||
|
pub use x509_parser::cri_attributes::*;
|
||||||
|
pub use x509_parser::error::*;
|
||||||
|
pub use x509_parser::extensions::*;
|
||||||
|
pub use x509_parser::revocation_list::*;
|
||||||
|
pub use x509_parser::time::*;
|
||||||
|
pub use x509_parser::x509::*;
|
||||||
|
pub use x509_parser::der_parser::der;
|
||||||
|
pub use x509_parser::der_parser::ber;
|
||||||
|
pub use x509_parser::traits::*;
|
||||||
|
}
|
||||||
|
|
||||||
|
mod certificate;
|
||||||
|
mod error;
|
||||||
|
mod name;
|
||||||
|
mod config;
|
||||||
|
|
||||||
|
pub use error::Error;
|
||||||
|
pub use name::Name;
|
||||||
|
pub use config::MtlsConfig;
|
||||||
|
pub use certificate::{Certificate, CertificateDer};
|
||||||
|
|
||||||
|
/// A type alias for [`Result`](std::result::Result) with the error type set to
|
||||||
|
/// [`Error`].
|
||||||
|
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|
@ -0,0 +1,146 @@
|
||||||
|
use std::fmt;
|
||||||
|
use std::ops::Deref;
|
||||||
|
|
||||||
|
use ref_cast::RefCast;
|
||||||
|
|
||||||
|
use crate::mtls::x509::X509Name;
|
||||||
|
use crate::mtls::oid;
|
||||||
|
|
||||||
|
/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
|
||||||
|
///
|
||||||
|
/// This type is a wrapper over [`x509::X509Name`] with convenient methods and
|
||||||
|
/// complete documentation. Should the data exposed by the inherent methods not
|
||||||
|
/// suffice, this type derefs to [`x509::X509Name`].
|
||||||
|
#[repr(transparent)]
|
||||||
|
#[derive(Debug, PartialEq, RefCast)]
|
||||||
|
pub struct Name<'a>(X509Name<'a>);
|
||||||
|
|
||||||
|
impl<'a> Name<'a> {
|
||||||
|
/// Returns the _first_ UTF-8 _string_ common name, if any.
|
||||||
|
///
|
||||||
|
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||||
|
/// This method returns the first common name attribute that is.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::Certificate;
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: Certificate<'_>) {
|
||||||
|
/// if let Some(name) = cert.subject().common_name() {
|
||||||
|
/// println!("Hello, {}!", name);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn common_name(&self) -> Option<&'a str> {
|
||||||
|
self.common_names().next()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an iterator over all of the UTF-8 _string_ common names in
|
||||||
|
/// `self`.
|
||||||
|
///
|
||||||
|
/// Note that common names need not be UTF-8 strings, or strings at all.
|
||||||
|
/// This method filters the common names in `self` to those that are. Use
|
||||||
|
/// the raw [`iter_common_name()`](#method.iter_common_name) to iterate over
|
||||||
|
/// all value types.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::Certificate;
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: Certificate<'_>) {
|
||||||
|
/// for name in cert.issuer().common_names() {
|
||||||
|
/// println!("Issued by {}.", name);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn common_names(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||||
|
self.iter_by_oid(&oid::OID_X509_COMMON_NAME).filter_map(|n| n.as_str().ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the _first_ UTF-8 _string_ email address, if any.
|
||||||
|
///
|
||||||
|
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||||
|
/// This method returns the first email address attribute that is.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::Certificate;
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: Certificate<'_>) {
|
||||||
|
/// if let Some(email) = cert.subject().email() {
|
||||||
|
/// println!("Hello, {}!", email);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn email(&self) -> Option<&'a str> {
|
||||||
|
self.emails().next()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an iterator over all of the UTF-8 _string_ email addresses in
|
||||||
|
/// `self`.
|
||||||
|
///
|
||||||
|
/// Note that email addresses need not be UTF-8 strings, or strings at all.
|
||||||
|
/// This method filters the email address in `self` to those that are. Use
|
||||||
|
/// the raw [`iter_email()`](#method.iter_email) to iterate over all value
|
||||||
|
/// types.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::Certificate;
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: Certificate<'_>) {
|
||||||
|
/// for email in cert.subject().emails() {
|
||||||
|
/// println!("Reach me at: {}", email);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn emails(&self) -> impl Iterator<Item = &'a str> + '_ {
|
||||||
|
self.iter_by_oid(&oid::OID_PKCS9_EMAIL_ADDRESS).filter_map(|n| n.as_str().ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if `self` has no data.
|
||||||
|
///
|
||||||
|
/// When this is the case for a `subject()`, the subject data can be found
|
||||||
|
/// in the `subjectAlt` [`extension()`](Certificate::extensions()).
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # #[macro_use] extern crate rocket;
|
||||||
|
/// use rocket::mtls::Certificate;
|
||||||
|
///
|
||||||
|
/// #[get("/auth")]
|
||||||
|
/// fn auth(cert: Certificate<'_>) {
|
||||||
|
/// let no_data = cert.subject().is_empty();
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.0.as_raw().is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Deref for Name<'a> {
|
||||||
|
type Target = X509Name<'a>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Name<'_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.0.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
use state::TypeMap;
|
use state::TypeMap;
|
||||||
use figment::Figment;
|
use figment::Figment;
|
||||||
|
|
||||||
|
use crate::listener::Endpoint;
|
||||||
use crate::{Catcher, Config, Rocket, Route, Shutdown};
|
use crate::{Catcher, Config, Rocket, Route, Shutdown};
|
||||||
use crate::router::Router;
|
use crate::router::Router;
|
||||||
use crate::fairing::Fairings;
|
use crate::fairing::Fairings;
|
||||||
|
@ -113,5 +114,6 @@ phases! {
|
||||||
pub(crate) config: Config,
|
pub(crate) config: Config,
|
||||||
pub(crate) state: TypeMap![Send + Sync],
|
pub(crate) state: TypeMap![Send + Sync],
|
||||||
pub(crate) shutdown: Shutdown,
|
pub(crate) shutdown: Shutdown,
|
||||||
|
pub(crate) endpoint: Endpoint,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
use crate::http::Method;
|
||||||
|
|
||||||
|
pub struct AtomicMethod(ref_swap::RefSwap<'static, Method>);
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
const fn makeref(method: Method) -> &'static Method {
|
||||||
|
match method {
|
||||||
|
Method::Get => &Method::Get,
|
||||||
|
Method::Put => &Method::Put,
|
||||||
|
Method::Post => &Method::Post,
|
||||||
|
Method::Delete => &Method::Delete,
|
||||||
|
Method::Options => &Method::Options,
|
||||||
|
Method::Head => &Method::Head,
|
||||||
|
Method::Trace => &Method::Trace,
|
||||||
|
Method::Connect => &Method::Connect,
|
||||||
|
Method::Patch => &Method::Patch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AtomicMethod {
|
||||||
|
pub fn new(value: Method) -> Self {
|
||||||
|
Self(ref_swap::RefSwap::new(makeref(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self) -> Method {
|
||||||
|
*self.0.load(std::sync::atomic::Ordering::Acquire)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set(&mut self, new: Method) {
|
||||||
|
*self = Self::new(new);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store(&self, new: Method) {
|
||||||
|
self.0.store(makeref(new), std::sync::atomic::Ordering::Release)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for AtomicMethod {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
let inner = self.0.load(std::sync::atomic::Ordering::Acquire);
|
||||||
|
Self(ref_swap::RefSwap::new(inner))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,13 @@
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::IpAddr;
|
||||||
|
|
||||||
use crate::{Request, Route};
|
use crate::{Request, Route};
|
||||||
use crate::outcome::{self, IntoOutcome, Outcome::*};
|
use crate::outcome::{self, IntoOutcome, Outcome::*};
|
||||||
|
|
||||||
use crate::http::uri::{Host, Origin};
|
use crate::http::uri::{Host, Origin};
|
||||||
use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar};
|
use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar};
|
||||||
|
use crate::listener::Endpoint;
|
||||||
|
|
||||||
/// Type alias for the `Outcome` of a `FromRequest` conversion.
|
/// Type alias for the `Outcome` of a `FromRequest` conversion.
|
||||||
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
|
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
|
||||||
|
@ -486,14 +487,22 @@ impl<'r> FromRequest<'r> for ProxyProto<'r> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[crate::async_trait]
|
#[crate::async_trait]
|
||||||
impl<'r> FromRequest<'r> for SocketAddr {
|
impl<'r> FromRequest<'r> for &'r Endpoint {
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
|
|
||||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
||||||
match request.remote() {
|
request.remote().or_forward(Status::InternalServerError)
|
||||||
Some(addr) => Success(addr),
|
}
|
||||||
None => Forward(Status::InternalServerError)
|
}
|
||||||
}
|
|
||||||
|
#[crate::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for std::net::SocketAddr {
|
||||||
|
type Error = Infallible;
|
||||||
|
|
||||||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
|
||||||
|
request.remote()
|
||||||
|
.and_then(|r| r.tcp())
|
||||||
|
.or_forward(Status::InternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
mod request;
|
mod request;
|
||||||
mod from_param;
|
mod from_param;
|
||||||
mod from_request;
|
mod from_request;
|
||||||
|
mod atomic_method;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
@ -15,6 +16,7 @@ pub use self::from_param::{FromParam, FromSegments};
|
||||||
pub use crate::response::flash::FlashMessage;
|
pub use crate::response::flash::FlashMessage;
|
||||||
|
|
||||||
pub(crate) use self::request::ConnectionMeta;
|
pub(crate) use self::request::ConnectionMeta;
|
||||||
|
pub(crate) use self::atomic_method::AtomicMethod;
|
||||||
|
|
||||||
crate::export! {
|
crate::export! {
|
||||||
/// Store and immediately retrieve a vector-like value `$v` (`String` or
|
/// Store and immediately retrieve a vector-like value `$v` (`String` or
|
||||||
|
|
|
@ -1,22 +1,24 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::RangeFrom;
|
use std::ops::RangeFrom;
|
||||||
use std::{future::Future, borrow::Cow, sync::Arc};
|
use std::sync::{Arc, atomic::Ordering};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::borrow::Cow;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
use yansi::Paint;
|
use yansi::Paint;
|
||||||
use state::{TypeMap, InitCell};
|
use state::{TypeMap, InitCell};
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use atomic::{Atomic, Ordering};
|
use ref_swap::OptionRefSwap;
|
||||||
|
|
||||||
use crate::{Rocket, Route, Orbit};
|
use crate::{Rocket, Route, Orbit};
|
||||||
use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
|
use crate::request::{FromParam, FromSegments, FromRequest, Outcome, AtomicMethod};
|
||||||
use crate::form::{self, ValueField, FromForm};
|
use crate::form::{self, ValueField, FromForm};
|
||||||
use crate::data::Limits;
|
use crate::data::Limits;
|
||||||
|
|
||||||
use crate::http::{hyper, Method, Header, HeaderMap, ProxyProto};
|
use crate::http::ProxyProto;
|
||||||
use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie};
|
use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie};
|
||||||
use crate::http::private::Certificates;
|
|
||||||
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
||||||
|
use crate::listener::{Certificates, Endpoint, Connection};
|
||||||
|
|
||||||
/// The type of an incoming web request.
|
/// The type of an incoming web request.
|
||||||
///
|
///
|
||||||
|
@ -24,26 +26,37 @@ use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority};
|
||||||
/// should likely only be used when writing [`FromRequest`] implementations. It
|
/// should likely only be used when writing [`FromRequest`] implementations. It
|
||||||
/// contains all of the information for a given web request except for the body
|
/// contains all of the information for a given web request except for the body
|
||||||
/// data. This includes the HTTP method, URI, cookies, headers, and more.
|
/// data. This includes the HTTP method, URI, cookies, headers, and more.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Request<'r> {
|
pub struct Request<'r> {
|
||||||
method: Atomic<Method>,
|
method: AtomicMethod,
|
||||||
uri: Origin<'r>,
|
uri: Origin<'r>,
|
||||||
headers: HeaderMap<'r>,
|
headers: HeaderMap<'r>,
|
||||||
|
pub(crate) errors: Vec<RequestError>,
|
||||||
pub(crate) connection: ConnectionMeta,
|
pub(crate) connection: ConnectionMeta,
|
||||||
pub(crate) state: RequestState<'r>,
|
pub(crate) state: RequestState<'r>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Information derived from an incoming connection, if any.
|
/// Information derived from an incoming connection, if any.
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Default)]
|
||||||
pub(crate) struct ConnectionMeta {
|
pub(crate) struct ConnectionMeta {
|
||||||
pub remote: Option<SocketAddr>,
|
pub peer_address: Option<Arc<Endpoint>>,
|
||||||
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
|
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
|
||||||
pub client_certificates: Option<Certificates>,
|
pub peer_certs: Option<Arc<Certificates<'static>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: Connection> From<&C> for ConnectionMeta {
|
||||||
|
fn from(conn: &C) -> Self {
|
||||||
|
ConnectionMeta {
|
||||||
|
peer_address: conn.peer_address().ok().map(Arc::new),
|
||||||
|
peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Information derived from the request.
|
/// Information derived from the request.
|
||||||
pub(crate) struct RequestState<'r> {
|
pub(crate) struct RequestState<'r> {
|
||||||
pub rocket: &'r Rocket<Orbit>,
|
pub rocket: &'r Rocket<Orbit>,
|
||||||
pub route: Atomic<Option<&'r Route>>,
|
pub route: OptionRefSwap<'r, Route>,
|
||||||
pub cookies: CookieJar<'r>,
|
pub cookies: CookieJar<'r>,
|
||||||
pub accept: InitCell<Option<Accept>>,
|
pub accept: InitCell<Option<Accept>>,
|
||||||
pub content_type: InitCell<Option<ContentType>>,
|
pub content_type: InitCell<Option<ContentType>>,
|
||||||
|
@ -51,23 +64,11 @@ pub(crate) struct RequestState<'r> {
|
||||||
pub host: Option<Host<'r>>,
|
pub host: Option<Host<'r>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Request<'_> {
|
impl Clone for RequestState<'_> {
|
||||||
pub(crate) fn clone(&self) -> Self {
|
|
||||||
Request {
|
|
||||||
method: Atomic::new(self.method()),
|
|
||||||
uri: self.uri.clone(),
|
|
||||||
headers: self.headers.clone(),
|
|
||||||
connection: self.connection.clone(),
|
|
||||||
state: self.state.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RequestState<'_> {
|
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
RequestState {
|
RequestState {
|
||||||
rocket: self.rocket,
|
rocket: self.rocket,
|
||||||
route: Atomic::new(self.route.load(Ordering::Acquire)),
|
route: OptionRefSwap::new(self.route.load(Ordering::Acquire)),
|
||||||
cookies: self.cookies.clone(),
|
cookies: self.cookies.clone(),
|
||||||
accept: self.accept.clone(),
|
accept: self.accept.clone(),
|
||||||
content_type: self.content_type.clone(),
|
content_type: self.content_type.clone(),
|
||||||
|
@ -87,15 +88,13 @@ impl<'r> Request<'r> {
|
||||||
) -> Request<'r> {
|
) -> Request<'r> {
|
||||||
Request {
|
Request {
|
||||||
uri,
|
uri,
|
||||||
method: Atomic::new(method),
|
method: AtomicMethod::new(method),
|
||||||
headers: HeaderMap::new(),
|
headers: HeaderMap::new(),
|
||||||
connection: ConnectionMeta {
|
errors: Vec::new(),
|
||||||
remote: None,
|
connection: ConnectionMeta::default(),
|
||||||
client_certificates: None,
|
|
||||||
},
|
|
||||||
state: RequestState {
|
state: RequestState {
|
||||||
rocket,
|
rocket,
|
||||||
route: Atomic::new(None),
|
route: OptionRefSwap::new(None),
|
||||||
cookies: CookieJar::new(None, rocket),
|
cookies: CookieJar::new(None, rocket),
|
||||||
accept: InitCell::new(),
|
accept: InitCell::new(),
|
||||||
content_type: InitCell::new(),
|
content_type: InitCell::new(),
|
||||||
|
@ -120,7 +119,7 @@ impl<'r> Request<'r> {
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn method(&self) -> Method {
|
pub fn method(&self) -> Method {
|
||||||
self.method.load(Ordering::Acquire)
|
self.method.load()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the method of `self` to `method`.
|
/// Set the method of `self` to `method`.
|
||||||
|
@ -140,7 +139,7 @@ impl<'r> Request<'r> {
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn set_method(&mut self, method: Method) {
|
pub fn set_method(&mut self, method: Method) {
|
||||||
self._set_method(method);
|
self.method.set(method);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Borrow the [`Origin`] URI from `self`.
|
/// Borrow the [`Origin`] URI from `self`.
|
||||||
|
@ -324,20 +323,20 @@ impl<'r> Request<'r> {
|
||||||
///
|
///
|
||||||
/// assert_eq!(request.remote(), None);
|
/// assert_eq!(request.remote(), None);
|
||||||
///
|
///
|
||||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into();
|
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000);
|
||||||
/// request.set_remote(localhost);
|
/// request.set_remote(localhost);
|
||||||
/// assert_eq!(request.remote(), Some(localhost));
|
/// assert_eq!(request.remote().unwrap(), &localhost);
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn remote(&self) -> Option<SocketAddr> {
|
pub fn remote(&self) -> Option<&Endpoint> {
|
||||||
self.connection.remote
|
self.connection.peer_address.as_deref()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the remote address of `self` to `address`.
|
/// Sets the remote address of `self` to `address`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// Set the remote address to be 127.0.0.1:8000:
|
/// Set the remote address to be 127.0.0.1:8111:
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use std::net::{SocketAddrV4, Ipv4Addr};
|
/// use std::net::{SocketAddrV4, Ipv4Addr};
|
||||||
|
@ -347,13 +346,13 @@ impl<'r> Request<'r> {
|
||||||
///
|
///
|
||||||
/// assert_eq!(request.remote(), None);
|
/// assert_eq!(request.remote(), None);
|
||||||
///
|
///
|
||||||
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000).into();
|
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111);
|
||||||
/// request.set_remote(localhost);
|
/// request.set_remote(localhost);
|
||||||
/// assert_eq!(request.remote(), Some(localhost));
|
/// assert_eq!(request.remote().unwrap(), &localhost);
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn set_remote(&mut self, address: SocketAddr) {
|
pub fn set_remote<A: Into<Endpoint>>(&mut self, address: A) {
|
||||||
self.connection.remote = Some(address);
|
self.connection.peer_address = Some(Arc::new(address.into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the IP address of the configured
|
/// Returns the IP address of the configured
|
||||||
|
@ -489,25 +488,26 @@ impl<'r> Request<'r> {
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use rocket::http::Header;
|
/// # use rocket::http::Header;
|
||||||
/// # use std::net::{SocketAddr, IpAddr, Ipv4Addr};
|
|
||||||
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
|
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
|
||||||
/// # let mut req = c.get("/");
|
/// # let mut req = c.get("/");
|
||||||
/// # let request = req.inner_mut();
|
/// # let request = req.inner_mut();
|
||||||
|
/// # use std::net::{SocketAddrV4, Ipv4Addr};
|
||||||
///
|
///
|
||||||
/// // starting without an "X-Real-IP" header or remote address
|
/// // starting without an "X-Real-IP" header or remote address
|
||||||
/// assert!(request.client_ip().is_none());
|
/// assert!(request.client_ip().is_none());
|
||||||
///
|
///
|
||||||
/// // add a remote address; this is done by Rocket automatically
|
/// // add a remote address; this is done by Rocket automatically
|
||||||
/// request.set_remote("127.0.0.1:8000".parse().unwrap());
|
/// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190);
|
||||||
/// assert_eq!(request.client_ip(), Some("127.0.0.1".parse().unwrap()));
|
/// request.set_remote(localhost_9190);
|
||||||
|
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST);
|
||||||
///
|
///
|
||||||
/// // now with an X-Real-IP header, the default value for `ip_header`.
|
/// // now with an X-Real-IP header, the default value for `ip_header`.
|
||||||
/// request.add_header(Header::new("X-Real-IP", "8.8.8.8"));
|
/// request.add_header(Header::new("X-Real-IP", "8.8.8.8"));
|
||||||
/// assert_eq!(request.client_ip(), Some("8.8.8.8".parse().unwrap()));
|
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::new(8, 8, 8, 8));
|
||||||
/// ```
|
/// ```
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn client_ip(&self) -> Option<IpAddr> {
|
pub fn client_ip(&self) -> Option<IpAddr> {
|
||||||
self.real_ip().or_else(|| self.remote().map(|r| r.ip()))
|
self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a wrapped borrow to the cookies in `self`.
|
/// Returns a wrapped borrow to the cookies in `self`.
|
||||||
|
@ -691,7 +691,7 @@ impl<'r> Request<'r> {
|
||||||
if self.method().supports_payload() {
|
if self.method().supports_payload() {
|
||||||
self.content_type().map(|ct| ct.media_type())
|
self.content_type().map(|ct| ct.media_type())
|
||||||
} else {
|
} else {
|
||||||
// FIXME: Should we be using `accept_first` or `preferred`? Or
|
// TODO: Should we be using `accept_first` or `preferred`? Or
|
||||||
// should we be checking neither and instead pass things through
|
// should we be checking neither and instead pass things through
|
||||||
// where the client accepts the thing at all?
|
// where the client accepts the thing at all?
|
||||||
self.accept()
|
self.accept()
|
||||||
|
@ -1056,11 +1056,9 @@ impl<'r> Request<'r> {
|
||||||
self.state.route.store(Some(route), Ordering::Release)
|
self.state.route.store(Some(route), Ordering::Release)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the method of `self`, even when `self` is a shared reference. Used
|
|
||||||
/// during routing to override methods for re-routing.
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn _set_method(&self, method: Method) {
|
pub(crate) fn _set_method(&self, method: Method) {
|
||||||
self.method.store(method, Ordering::Release)
|
self.method.store(method)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> {
|
pub(crate) fn cookies_mut(&mut self) -> &mut CookieJar<'r> {
|
||||||
|
@ -1070,18 +1068,28 @@ impl<'r> Request<'r> {
|
||||||
/// Convert from Hyper types into a Rocket Request.
|
/// Convert from Hyper types into a Rocket Request.
|
||||||
pub(crate) fn from_hyp(
|
pub(crate) fn from_hyp(
|
||||||
rocket: &'r Rocket<Orbit>,
|
rocket: &'r Rocket<Orbit>,
|
||||||
hyper: &'r hyper::request::Parts,
|
hyper: &'r hyper::http::request::Parts,
|
||||||
connection: Option<ConnectionMeta>,
|
connection: ConnectionMeta,
|
||||||
) -> Result<Request<'r>, BadRequest<'r>> {
|
) -> Result<Request<'r>, Request<'r>> {
|
||||||
// Keep track of parsing errors; emit a `BadRequest` if any exist.
|
// Keep track of parsing errors; emit a `BadRequest` if any exist.
|
||||||
let mut errors = vec![];
|
let mut errors = vec![];
|
||||||
|
|
||||||
// Ensure that the method is known. TODO: Allow made-up methods?
|
// Ensure that the method is known. TODO: Allow made-up methods?
|
||||||
let method = Method::from_hyp(&hyper.method)
|
let method = match hyper.method {
|
||||||
.unwrap_or_else(|| {
|
hyper::Method::GET => Method::Get,
|
||||||
errors.push(Kind::BadMethod(&hyper.method));
|
hyper::Method::PUT => Method::Put,
|
||||||
|
hyper::Method::POST => Method::Post,
|
||||||
|
hyper::Method::DELETE => Method::Delete,
|
||||||
|
hyper::Method::OPTIONS => Method::Options,
|
||||||
|
hyper::Method::HEAD => Method::Head,
|
||||||
|
hyper::Method::TRACE => Method::Trace,
|
||||||
|
hyper::Method::CONNECT => Method::Connect,
|
||||||
|
hyper::Method::PATCH => Method::Patch,
|
||||||
|
_ => {
|
||||||
|
errors.push(RequestError::BadMethod(hyper.method.clone()));
|
||||||
Method::Get
|
Method::Get
|
||||||
});
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: Keep around not just the path/query, but the rest, if there?
|
// TODO: Keep around not just the path/query, but the rest, if there?
|
||||||
let uri = hyper.uri.path_and_query()
|
let uri = hyper.uri.path_and_query()
|
||||||
|
@ -1100,20 +1108,20 @@ impl<'r> Request<'r> {
|
||||||
Origin::new(uri.path(), uri.query().map(Cow::Borrowed))
|
Origin::new(uri.path(), uri.query().map(Cow::Borrowed))
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
errors.push(Kind::InvalidUri(&hyper.uri));
|
errors.push(RequestError::InvalidUri(hyper.uri.clone()));
|
||||||
Origin::ROOT
|
Origin::ROOT
|
||||||
});
|
});
|
||||||
|
|
||||||
// Construct the request object; fill in metadata and headers next.
|
// Construct the request object; fill in metadata and headers next.
|
||||||
let mut request = Request::new(rocket, method, uri);
|
let mut request = Request::new(rocket, method, uri);
|
||||||
|
request.errors = errors;
|
||||||
|
|
||||||
// Set the passed in connection metadata.
|
// Set the passed in connection metadata.
|
||||||
if let Some(connection) = connection {
|
request.connection = connection;
|
||||||
request.connection = connection;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise,
|
// Determine + set host. On HTTP < 2, use the `HOST` header. Otherwise,
|
||||||
// use the `:authority` pseudo-header which hyper makes part of the URI.
|
// use the `:authority` pseudo-header which hyper makes part of the URI.
|
||||||
|
// TODO: Use an `InitCell` to compute this later.
|
||||||
request.state.host = if hyper.version < hyper::Version::HTTP_2 {
|
request.state.host = if hyper.version < hyper::Version::HTTP_2 {
|
||||||
hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok())
|
hyper.headers.get("host").and_then(|h| Host::parse_bytes(h.as_bytes()).ok())
|
||||||
} else {
|
} else {
|
||||||
|
@ -1122,9 +1130,8 @@ impl<'r> Request<'r> {
|
||||||
|
|
||||||
// Set the request cookies, if they exist.
|
// Set the request cookies, if they exist.
|
||||||
for header in hyper.headers.get_all("Cookie") {
|
for header in hyper.headers.get_all("Cookie") {
|
||||||
let raw_str = match std::str::from_utf8(header.as_bytes()) {
|
let Ok(raw_str) = std::str::from_utf8(header.as_bytes()) else {
|
||||||
Ok(string) => string,
|
continue
|
||||||
Err(_) => continue
|
|
||||||
};
|
};
|
||||||
|
|
||||||
for cookie_str in raw_str.split(';').map(|s| s.trim()) {
|
for cookie_str in raw_str.split(';').map(|s| s.trim()) {
|
||||||
|
@ -1137,43 +1144,33 @@ impl<'r> Request<'r> {
|
||||||
// Set the rest of the headers. This is rather unfortunate and slow.
|
// Set the rest of the headers. This is rather unfortunate and slow.
|
||||||
for (name, value) in hyper.headers.iter() {
|
for (name, value) in hyper.headers.iter() {
|
||||||
// FIXME: This is rather unfortunate. Header values needn't be UTF8.
|
// FIXME: This is rather unfortunate. Header values needn't be UTF8.
|
||||||
let value = match std::str::from_utf8(value.as_bytes()) {
|
let Ok(value) = std::str::from_utf8(value.as_bytes()) else {
|
||||||
Ok(value) => value,
|
warn!("Header '{}' contains invalid UTF-8", name);
|
||||||
Err(_) => {
|
warn_!("Rocket only supports UTF-8 header values. Dropping header.");
|
||||||
warn!("Header '{}' contains invalid UTF-8", name);
|
continue;
|
||||||
warn_!("Rocket only supports UTF-8 header values. Dropping header.");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
request.add_header(Header::new(name.as_str(), value));
|
request.add_header(Header::new(name.as_str(), value));
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.is_empty() {
|
match request.errors.is_empty() {
|
||||||
Ok(request)
|
true => Ok(request),
|
||||||
} else {
|
false => Err(request),
|
||||||
Err(BadRequest { request, errors })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BadRequest<'r> {
|
pub(crate) enum RequestError {
|
||||||
pub request: Request<'r>,
|
InvalidUri(hyper::Uri),
|
||||||
pub errors: Vec<Kind<'r>>,
|
BadMethod(hyper::Method),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
impl fmt::Display for RequestError {
|
||||||
pub(crate) enum Kind<'r> {
|
|
||||||
InvalidUri(&'r hyper::Uri),
|
|
||||||
BadMethod(&'r hyper::Method),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for Kind<'_> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Kind::InvalidUri(u) => write!(f, "invalid origin URI: {}", u),
|
RequestError::InvalidUri(u) => write!(f, "invalid origin URI: {}", u),
|
||||||
Kind::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m),
|
RequestError::BadMethod(m) => write!(f, "invalid or unrecognized method: {}", m),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1181,8 +1178,8 @@ impl fmt::Display for Kind<'_> {
|
||||||
impl fmt::Debug for Request<'_> {
|
impl fmt::Debug for Request<'_> {
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
fmt.debug_struct("Request")
|
fmt.debug_struct("Request")
|
||||||
.field("method", &self.method)
|
.field("method", &self.method())
|
||||||
.field("uri", &self.uri)
|
.field("uri", &self.uri())
|
||||||
.field("headers", &self.headers())
|
.field("headers", &self.headers())
|
||||||
.field("remote", &self.remote())
|
.field("remote", &self.remote())
|
||||||
.field("cookies", &self.cookies())
|
.field("cookies", &self.cookies())
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::Request;
|
use crate::request::{Request, ConnectionMeta};
|
||||||
use crate::local::blocking::Client;
|
use crate::local::blocking::Client;
|
||||||
use crate::http::hyper;
|
|
||||||
|
|
||||||
macro_rules! assert_headers {
|
macro_rules! assert_headers {
|
||||||
($($key:expr => [$($value:expr),+]),+) => ({
|
($($key:expr => [$($value:expr),+]),+) => ({
|
||||||
// Create a new Hyper request. Add all of the passed in headers.
|
// Create a new Hyper request. Add all of the passed in headers.
|
||||||
let mut req = hyper::Request::get("/test").body(()).unwrap();
|
let mut req = hyper::Request::get("/test").body(()).unwrap();
|
||||||
$($(req.headers_mut().append($key, hyper::HeaderValue::from_str($value).unwrap());)+)+
|
$($(
|
||||||
|
req.headers_mut()
|
||||||
|
.append($key, hyper::header::HeaderValue::from_str($value).unwrap());
|
||||||
|
)+)+
|
||||||
|
|
||||||
// Build up what we expect the headers to actually be.
|
// Build up what we expect the headers to actually be.
|
||||||
let mut expected = HashMap::new();
|
let mut expected = HashMap::new();
|
||||||
|
@ -17,7 +19,8 @@ macro_rules! assert_headers {
|
||||||
// Create a valid `Rocket` and convert the hyper req to a Rocket one.
|
// Create a valid `Rocket` and convert the hyper req to a Rocket one.
|
||||||
let client = Client::debug_with(vec![]).unwrap();
|
let client = Client::debug_with(vec![]).unwrap();
|
||||||
let hyper = req.into_parts().0;
|
let hyper = req.into_parts().0;
|
||||||
let req = Request::from_hyp(client.rocket(), &hyper, None).unwrap();
|
let meta = ConnectionMeta::default();
|
||||||
|
let req = Request::from_hyp(client.rocket(), &hyper, meta).unwrap();
|
||||||
|
|
||||||
// Dispatch the request and check that the headers match.
|
// Dispatch the request and check that the headers match.
|
||||||
let actual_headers = req.headers();
|
let actual_headers = req.headers();
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use std::{fmt, str};
|
use std::{fmt, str};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncSeek};
|
use tokio::io::{AsyncRead, AsyncSeek};
|
||||||
|
|
||||||
|
@ -146,19 +145,18 @@ impl<'r> Builder<'r> {
|
||||||
/// potentially different values to be present in the `Response`.
|
/// potentially different values to be present in the `Response`.
|
||||||
///
|
///
|
||||||
/// The type of `header` can be any type that implements `Into<Header>`.
|
/// The type of `header` can be any type that implements `Into<Header>`.
|
||||||
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and
|
/// This includes `Header` itself, [`ContentType`](crate::http::ContentType)
|
||||||
/// [hyper::header types](crate::http::hyper::header).
|
/// and [`Accept`](crate::http::Accept).
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use rocket::Response;
|
/// use rocket::Response;
|
||||||
/// use rocket::http::Header;
|
/// use rocket::http::{Header, Accept};
|
||||||
/// use rocket::http::hyper::header::ACCEPT;
|
|
||||||
///
|
///
|
||||||
/// let response = Response::build()
|
/// let response = Response::build()
|
||||||
/// .header_adjoin(Header::new(ACCEPT.as_str(), "application/json"))
|
/// .header_adjoin(Header::new("Accept", "application/json"))
|
||||||
/// .header_adjoin(Header::new(ACCEPT.as_str(), "text/plain"))
|
/// .header_adjoin(Accept::XML)
|
||||||
/// .finalize();
|
/// .finalize();
|
||||||
///
|
///
|
||||||
/// assert_eq!(response.headers().get("Accept").count(), 2);
|
/// assert_eq!(response.headers().get("Accept").count(), 2);
|
||||||
|
@ -287,7 +285,7 @@ impl<'r> Builder<'r> {
|
||||||
///
|
///
|
||||||
/// #[rocket::async_trait]
|
/// #[rocket::async_trait]
|
||||||
/// impl IoHandler for EchoHandler {
|
/// impl IoHandler for EchoHandler {
|
||||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
/// let (mut reader, mut writer) = io::split(io);
|
/// let (mut reader, mut writer) = io::split(io);
|
||||||
/// io::copy(&mut reader, &mut writer).await?;
|
/// io::copy(&mut reader, &mut writer).await?;
|
||||||
/// Ok(())
|
/// Ok(())
|
||||||
|
@ -488,7 +486,7 @@ pub struct Response<'r> {
|
||||||
status: Option<Status>,
|
status: Option<Status>,
|
||||||
headers: HeaderMap<'r>,
|
headers: HeaderMap<'r>,
|
||||||
body: Body<'r>,
|
body: Body<'r>,
|
||||||
upgrade: HashMap<Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>>,
|
upgrade: HashMap<Uncased<'r>, Box<dyn IoHandler + 'r>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r> Response<'r> {
|
impl<'r> Response<'r> {
|
||||||
|
@ -700,23 +698,22 @@ impl<'r> Response<'r> {
|
||||||
/// name `header.name`, another header with the same name and value
|
/// name `header.name`, another header with the same name and value
|
||||||
/// `header.value` is added. The type of `header` can be any type that
|
/// `header.value` is added. The type of `header` can be any type that
|
||||||
/// implements `Into<Header>`. This includes `Header` itself,
|
/// implements `Into<Header>`. This includes `Header` itself,
|
||||||
/// [`ContentType`](crate::http::ContentType) and [`hyper::header`
|
/// [`ContentType`](crate::http::ContentType),
|
||||||
/// types](crate::http::hyper::header).
|
/// [`Accept`](crate::http::Accept).
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use rocket::Response;
|
/// use rocket::Response;
|
||||||
/// use rocket::http::Header;
|
/// use rocket::http::{Header, Accept};
|
||||||
/// use rocket::http::hyper::header::ACCEPT;
|
|
||||||
///
|
///
|
||||||
/// let mut response = Response::new();
|
/// let mut response = Response::new();
|
||||||
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "application/json"));
|
/// response.adjoin_header(Accept::JSON);
|
||||||
/// response.adjoin_header(Header::new(ACCEPT.as_str(), "text/plain"));
|
/// response.adjoin_header(Header::new("Accept", "text/plain"));
|
||||||
///
|
///
|
||||||
/// let mut accept_headers = response.headers().iter();
|
/// let mut accept_headers = response.headers().iter();
|
||||||
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "application/json")));
|
/// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "application/json")));
|
||||||
/// assert_eq!(accept_headers.next(), Some(Header::new(ACCEPT.as_str(), "text/plain")));
|
/// assert_eq!(accept_headers.next(), Some(Header::new("Accept", "text/plain")));
|
||||||
/// assert_eq!(accept_headers.next(), None);
|
/// assert_eq!(accept_headers.next(), None);
|
||||||
/// ```
|
/// ```
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
|
@ -801,10 +798,10 @@ impl<'r> Response<'r> {
|
||||||
/// the comma-separated protocols any of the strings in `I`. Returns
|
/// the comma-separated protocols any of the strings in `I`. Returns
|
||||||
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
|
/// `Ok(None)` if `self` doesn't support any kind of upgrade. Returns
|
||||||
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
|
/// `Err(_)` if `protocols` is non-empty but no match was found in `self`.
|
||||||
pub(crate) fn take_upgrade<I: Iterator<Item = &'r str>>(
|
pub(crate) fn search_upgrades<'a, I: Iterator<Item = &'a str>>(
|
||||||
&mut self,
|
&mut self,
|
||||||
protocols: I
|
protocols: I
|
||||||
) -> Result<Option<(Uncased<'r>, Pin<Box<dyn IoHandler + 'r>>)>, ()> {
|
) -> Result<Option<(Uncased<'r>, Box<dyn IoHandler + 'r>)>, ()> {
|
||||||
if self.upgrade.is_empty() {
|
if self.upgrade.is_empty() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
@ -839,7 +836,7 @@ impl<'r> Response<'r> {
|
||||||
///
|
///
|
||||||
/// #[rocket::async_trait]
|
/// #[rocket::async_trait]
|
||||||
/// impl IoHandler for EchoHandler {
|
/// impl IoHandler for EchoHandler {
|
||||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
/// let (mut reader, mut writer) = io::split(io);
|
/// let (mut reader, mut writer) = io::split(io);
|
||||||
/// io::copy(&mut reader, &mut writer).await?;
|
/// io::copy(&mut reader, &mut writer).await?;
|
||||||
/// Ok(())
|
/// Ok(())
|
||||||
|
@ -854,7 +851,7 @@ impl<'r> Response<'r> {
|
||||||
/// assert!(response.upgrade("raw-echo").is_some());
|
/// assert!(response.upgrade("raw-echo").is_some());
|
||||||
/// # })
|
/// # })
|
||||||
/// ```
|
/// ```
|
||||||
pub fn upgrade(&mut self, proto: &str) -> Option<Pin<&mut (dyn IoHandler + 'r)>> {
|
pub fn upgrade(&mut self, proto: &str) -> Option<&mut (dyn IoHandler + 'r)> {
|
||||||
self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
|
self.upgrade.get_mut(proto.as_uncased()).map(|h| h.as_mut())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -972,7 +969,7 @@ impl<'r> Response<'r> {
|
||||||
///
|
///
|
||||||
/// #[rocket::async_trait]
|
/// #[rocket::async_trait]
|
||||||
/// impl IoHandler for EchoHandler {
|
/// impl IoHandler for EchoHandler {
|
||||||
/// async fn io(self: Pin<Box<Self>>, io: IoStream) -> io::Result<()> {
|
/// async fn io(self: Box<Self>, io: IoStream) -> io::Result<()> {
|
||||||
/// let (mut reader, mut writer) = io::split(io);
|
/// let (mut reader, mut writer) = io::split(io);
|
||||||
/// io::copy(&mut reader, &mut writer).await?;
|
/// io::copy(&mut reader, &mut writer).await?;
|
||||||
/// Ok(())
|
/// Ok(())
|
||||||
|
@ -990,7 +987,7 @@ impl<'r> Response<'r> {
|
||||||
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
|
pub fn add_upgrade<N, H>(&mut self, protocol: N, handler: H)
|
||||||
where N: Into<Uncased<'r>>, H: IoHandler + 'r
|
where N: Into<Uncased<'r>>, H: IoHandler + 'r
|
||||||
{
|
{
|
||||||
self.upgrade.insert(protocol.into(), Box::pin(handler));
|
self.upgrade.insert(protocol.into(), Box::new(handler));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the body's maximum chunk size to `size` bytes.
|
/// Sets the body's maximum chunk size to `size` bytes.
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use tokio::io::AsyncRead;
|
use tokio::io::AsyncRead;
|
||||||
use tokio::time::Duration;
|
use tokio::time::{interval, Duration};
|
||||||
use futures::stream::{self, Stream, StreamExt};
|
use futures::{stream::{self, Stream}, future::Either};
|
||||||
use futures::future::ready;
|
use tokio_stream::{StreamExt, wrappers::IntervalStream};
|
||||||
|
|
||||||
use crate::request::Request;
|
use crate::request::Request;
|
||||||
use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}};
|
use crate::response::{self, Response, Responder, stream::{ReaderStream, RawLinedEvent}};
|
||||||
|
@ -336,7 +336,7 @@ impl Event {
|
||||||
Some(RawLinedEvent::raw("")),
|
Some(RawLinedEvent::raw("")),
|
||||||
];
|
];
|
||||||
|
|
||||||
stream::iter(events).filter_map(ready)
|
stream::iter(events).filter_map(|x| x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -528,25 +528,19 @@ impl<S: Stream<Item = Event>> EventStream<S> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn heartbeat_stream(&self) -> Option<impl Stream<Item = RawLinedEvent>> {
|
fn heartbeat_stream(&self) -> impl Stream<Item = RawLinedEvent> {
|
||||||
use tokio::time::interval;
|
|
||||||
use tokio_stream::wrappers::IntervalStream;
|
|
||||||
|
|
||||||
self.heartbeat
|
self.heartbeat
|
||||||
.map(|beat| IntervalStream::new(interval(beat)))
|
.map(|beat| IntervalStream::new(interval(beat)))
|
||||||
.map(|stream| stream.map(|_| RawLinedEvent::raw(":")))
|
.map(|stream| stream.map(|_| RawLinedEvent::raw(":")))
|
||||||
|
.map_or_else(|| Either::Right(stream::empty()), Either::Left)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_stream(self) -> impl Stream<Item = RawLinedEvent> {
|
fn into_stream(self) -> impl Stream<Item = RawLinedEvent> {
|
||||||
use futures::future::Either;
|
use futures::StreamExt;
|
||||||
use crate::ext::StreamExt;
|
|
||||||
|
|
||||||
let heartbeat_stream = self.heartbeat_stream();
|
let heartbeats = self.heartbeat_stream();
|
||||||
let raw_events = self.stream.map(|e| e.into_stream()).flatten();
|
let events = StreamExt::map(self.stream, |e| e.into_stream()).flatten();
|
||||||
match heartbeat_stream {
|
crate::util::join(events, heartbeats)
|
||||||
Some(heartbeat) => Either::Left(raw_events.join(heartbeat)),
|
|
||||||
None => Either::Right(raw_events)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_reader(self) -> impl AsyncRead {
|
fn into_reader(self) -> impl AsyncRead {
|
||||||
|
@ -621,10 +615,11 @@ mod sse_tests {
|
||||||
|
|
||||||
impl<S: Stream<Item = Event>> EventStream<S> {
|
impl<S: Stream<Item = Event>> EventStream<S> {
|
||||||
fn into_string(self) -> String {
|
fn into_string(self) -> String {
|
||||||
|
use std::pin::pin;
|
||||||
|
|
||||||
crate::async_test(async move {
|
crate::async_test(async move {
|
||||||
let mut string = String::new();
|
let mut string = String::new();
|
||||||
let reader = self.into_reader();
|
let mut reader = pin!(self.into_reader());
|
||||||
tokio::pin!(reader);
|
|
||||||
reader.read_to_string(&mut string).await.expect("event stream -> string");
|
reader.read_to_string(&mut string).await.expect("event stream -> string");
|
||||||
string
|
string
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::net::SocketAddr;
|
|
||||||
|
|
||||||
use yansi::Paint;
|
use yansi::Paint;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use figment::{Figment, Provider};
|
use figment::{Figment, Provider};
|
||||||
|
|
||||||
use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield};
|
use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield};
|
||||||
|
use crate::listener::{Endpoint, Bindable, DefaultListener};
|
||||||
use crate::router::Router;
|
use crate::router::Router;
|
||||||
use crate::trip_wire::TripWire;
|
use crate::util::TripWire;
|
||||||
use crate::fairing::{Fairing, Fairings};
|
use crate::fairing::{Fairing, Fairings};
|
||||||
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
|
use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting};
|
||||||
use crate::phase::{Stateful, StateRef, State};
|
use crate::phase::{Stateful, StateRef, State};
|
||||||
|
@ -203,35 +203,31 @@ impl Rocket<Build> {
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use rocket::Config;
|
/// use rocket::config::{Config, Ident};
|
||||||
/// # use std::net::Ipv4Addr;
|
/// # use std::net::Ipv4Addr;
|
||||||
/// # use std::path::{Path, PathBuf};
|
/// # use std::path::{Path, PathBuf};
|
||||||
/// # type Result = std::result::Result<(), rocket::Error>;
|
/// # type Result = std::result::Result<(), rocket::Error>;
|
||||||
///
|
///
|
||||||
/// let config = Config {
|
/// let config = Config {
|
||||||
/// port: 7777,
|
/// ident: Ident::try_new("MyServer").expect("valid ident"),
|
||||||
/// address: Ipv4Addr::new(18, 127, 0, 1).into(),
|
|
||||||
/// temp_dir: "/tmp/config-example".into(),
|
/// temp_dir: "/tmp/config-example".into(),
|
||||||
/// ..Config::debug_default()
|
/// ..Config::debug_default()
|
||||||
/// };
|
/// };
|
||||||
///
|
///
|
||||||
/// # let _: Result = rocket::async_test(async move {
|
/// # let _: Result = rocket::async_test(async move {
|
||||||
/// let rocket = rocket::custom(&config).ignite().await?;
|
/// let rocket = rocket::custom(&config).ignite().await?;
|
||||||
/// assert_eq!(rocket.config().port, 7777);
|
/// assert_eq!(rocket.config().ident.as_str(), Some("MyServer"));
|
||||||
/// assert_eq!(rocket.config().address, Ipv4Addr::new(18, 127, 0, 1));
|
|
||||||
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
||||||
///
|
///
|
||||||
/// // Create a new figment which modifies _some_ keys the existing figment:
|
/// // Create a new figment which modifies _some_ keys the existing figment:
|
||||||
/// let figment = rocket.figment().clone()
|
/// let figment = rocket.figment().clone()
|
||||||
/// .merge((Config::PORT, 8888))
|
/// .merge((Config::IDENT, "Example"));
|
||||||
/// .merge((Config::ADDRESS, "171.64.200.10"));
|
|
||||||
///
|
///
|
||||||
/// let rocket = rocket::custom(&config)
|
/// let rocket = rocket::custom(&config)
|
||||||
/// .configure(figment)
|
/// .configure(figment)
|
||||||
/// .ignite().await?;
|
/// .ignite().await?;
|
||||||
///
|
///
|
||||||
/// assert_eq!(rocket.config().port, 8888);
|
/// assert_eq!(rocket.config().ident.as_str(), Some("Example"));
|
||||||
/// assert_eq!(rocket.config().address, Ipv4Addr::new(171, 64, 200, 10));
|
|
||||||
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
/// assert_eq!(rocket.config().temp_dir.relative(), Path::new("/tmp/config-example"));
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # });
|
/// # });
|
||||||
|
@ -664,8 +660,9 @@ impl Rocket<Ignite> {
|
||||||
self.shutdown.clone()
|
self.shutdown.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_orbit(self) -> Rocket<Orbit> {
|
pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket<Orbit> {
|
||||||
Rocket(Orbiting {
|
Rocket(Orbiting {
|
||||||
|
endpoint: address,
|
||||||
router: self.0.router,
|
router: self.0.router,
|
||||||
fairings: self.0.fairings,
|
fairings: self.0.fairings,
|
||||||
figment: self.0.figment,
|
figment: self.0.figment,
|
||||||
|
@ -675,28 +672,24 @@ impl Rocket<Ignite> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn _local_launch(self) -> Rocket<Orbit> {
|
async fn _local_launch(self, addr: Endpoint) -> Rocket<Orbit> {
|
||||||
let rocket = self.into_orbit();
|
let rocket = self.into_orbit(addr);
|
||||||
rocket.fairings.handle_liftoff(&rocket).await;
|
Rocket::liftoff(&rocket).await;
|
||||||
launch_info!("{}{}", "🚀 ".emoji(), "Rocket has launched locally".primary().bold());
|
|
||||||
rocket
|
rocket
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn _launch(self) -> Result<Rocket<Ignite>, Error> {
|
async fn _launch(self) -> Result<Rocket<Ignite>, Error> {
|
||||||
self.into_orbit()
|
let config = self.figment().extract::<DefaultListener>()?;
|
||||||
.default_tcp_http_server(|rkt| Box::pin(async move {
|
either::for_both!(config.base_bindable()?, base => {
|
||||||
rkt.fairings.handle_liftoff(&rkt).await;
|
either::for_both!(config.tls_bindable(base), bindable => {
|
||||||
|
self._launch_on(bindable).await
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
let proto = rkt.config.tls_enabled().then(|| "https").unwrap_or("http");
|
async fn _launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
|
||||||
let socket_addr = SocketAddr::new(rkt.config.address, rkt.config.port);
|
let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?;
|
||||||
let addr = format!("{}://{}", proto, socket_addr);
|
self.serve(listener).await
|
||||||
launch_info!("{}{} {}",
|
|
||||||
"🚀 ".emoji(),
|
|
||||||
"Rocket has launched from".bold().primary().linger(),
|
|
||||||
addr.underline());
|
|
||||||
}))
|
|
||||||
.await
|
|
||||||
.map(|rocket| rocket.into_ignite())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -712,6 +705,21 @@ impl Rocket<Orbit> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn liftoff<R: Deref<Target = Self>>(rocket: R) {
|
||||||
|
let rocket = rocket.deref();
|
||||||
|
rocket.fairings.handle_liftoff(rocket).await;
|
||||||
|
|
||||||
|
if !crate::running_within_rocket_async_rt().await {
|
||||||
|
warn!("Rocket is executing inside of a custom runtime.");
|
||||||
|
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
|
||||||
|
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
|
||||||
|
}
|
||||||
|
|
||||||
|
launch_info!("{}{} {}", "🚀 ".emoji(),
|
||||||
|
"Rocket has launched on".bold().primary().linger(),
|
||||||
|
rocket.endpoint().underline());
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the finalized, active configuration. This is guaranteed to
|
/// Returns the finalized, active configuration. This is guaranteed to
|
||||||
/// remain stable after [`Rocket::ignite()`], through ignition and into
|
/// remain stable after [`Rocket::ignite()`], through ignition and into
|
||||||
/// orbit.
|
/// orbit.
|
||||||
|
@ -734,6 +742,10 @@ impl Rocket<Orbit> {
|
||||||
&self.config
|
&self.config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn endpoint(&self) -> &Endpoint {
|
||||||
|
&self.endpoint
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a handle which can be used to trigger a shutdown and detect a
|
/// Returns a handle which can be used to trigger a shutdown and detect a
|
||||||
/// triggered shutdown.
|
/// triggered shutdown.
|
||||||
///
|
///
|
||||||
|
@ -867,10 +879,10 @@ impl<P: Phase> Rocket<P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn local_launch(self) -> Result<Rocket<Orbit>, Error> {
|
pub(crate) async fn local_launch(self, l: Endpoint) -> Result<Rocket<Orbit>, Error> {
|
||||||
let rocket = match self.0.into_state() {
|
let rocket = match self.0.into_state() {
|
||||||
State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await,
|
State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await,
|
||||||
State::Ignite(s) => Rocket::from(s)._local_launch().await,
|
State::Ignite(s) => Rocket::from(s)._local_launch(l).await,
|
||||||
State::Orbit(s) => Rocket::from(s)
|
State::Orbit(s) => Rocket::from(s)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -928,6 +940,14 @@ impl<P: Phase> Rocket<P> {
|
||||||
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn launch_on<B: Bindable>(self, bindable: B) -> Result<Rocket<Ignite>, Error> {
|
||||||
|
match self.0.into_state() {
|
||||||
|
State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await,
|
||||||
|
State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await,
|
||||||
|
State::Orbit(s) => Ok(Rocket::from(s).into_ignite())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
|
|
|
@ -167,7 +167,6 @@ impl<F: Clone + Sync + Send + 'static> Handler for F
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME!
|
|
||||||
impl<'r, 'o: 'r> Outcome<'o> {
|
impl<'r, 'o: 'r> Outcome<'o> {
|
||||||
/// Return the `Outcome` of response to `req` from `responder`.
|
/// Return the `Outcome` of response to `req` from `responder`.
|
||||||
///
|
///
|
||||||
|
|
|
@ -1,540 +1,142 @@
|
||||||
use std::io;
|
use std::io;
|
||||||
|
use std::pin::pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
use yansi::Paint;
|
use hyper::service::service_fn;
|
||||||
use tokio::sync::oneshot;
|
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||||
|
use hyper_util::server::conn::auto::Builder;
|
||||||
|
use futures::{Future, TryFutureExt, future::{select, Either::*}};
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use futures::stream::StreamExt;
|
|
||||||
use futures::future::{FutureExt, Future, BoxFuture};
|
|
||||||
|
|
||||||
use crate::{route, Rocket, Orbit, Request, Response, Data, Config};
|
use crate::{Request, Rocket, Orbit, Data, Ignite};
|
||||||
use crate::form::Form;
|
|
||||||
use crate::outcome::Outcome;
|
|
||||||
use crate::error::{Error, ErrorKind};
|
|
||||||
use crate::ext::{AsyncReadExt, CancellableListener, CancellableIo};
|
|
||||||
use crate::request::ConnectionMeta;
|
use crate::request::ConnectionMeta;
|
||||||
use crate::data::IoHandler;
|
use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler};
|
||||||
|
use crate::listener::{Listener, CancellableExt, BouncedExt};
|
||||||
use crate::http::{hyper, uncased, Method, Status, Header};
|
use crate::error::{Error, ErrorKind};
|
||||||
use crate::http::private::{TcpListener, Listener, Connection, Incoming};
|
use crate::data::IoStream;
|
||||||
|
use crate::util::ReaderStream;
|
||||||
// A token returned to force the execution of one method before another.
|
use crate::http::Status;
|
||||||
pub(crate) struct RequestToken;
|
|
||||||
|
|
||||||
async fn handle<Fut, T, F>(name: Option<&str>, run: F) -> Option<T>
|
|
||||||
where F: FnOnce() -> Fut, Fut: Future<Output = T>,
|
|
||||||
{
|
|
||||||
use std::panic::AssertUnwindSafe;
|
|
||||||
|
|
||||||
macro_rules! panic_info {
|
|
||||||
($name:expr, $e:expr) => {{
|
|
||||||
match $name {
|
|
||||||
Some(name) => error_!("Handler {} panicked.", name.primary()),
|
|
||||||
None => error_!("A handler panicked.")
|
|
||||||
};
|
|
||||||
|
|
||||||
info_!("This is an application bug.");
|
|
||||||
info_!("A panic in Rust must be treated as an exceptional event.");
|
|
||||||
info_!("Panicking is not a suitable error handling mechanism.");
|
|
||||||
info_!("Unwinding, the result of a panic, is an expensive operation.");
|
|
||||||
info_!("Panics will degrade application performance.");
|
|
||||||
info_!("Instead of panicking, return `Option` and/or `Result`.");
|
|
||||||
info_!("Values of either type can be returned directly from handlers.");
|
|
||||||
warn_!("A panic is treated as an internal server error.");
|
|
||||||
$e
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
let run = AssertUnwindSafe(run);
|
|
||||||
let fut = std::panic::catch_unwind(move || run())
|
|
||||||
.map_err(|e| panic_info!(name, e))
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
AssertUnwindSafe(fut)
|
|
||||||
.catch_unwind()
|
|
||||||
.await
|
|
||||||
.map_err(|e| panic_info!(name, e))
|
|
||||||
.ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function tries to hide all of the Hyper-ness from Rocket. It essentially
|
|
||||||
// converts Hyper types into Rocket types, then calls the `dispatch` function,
|
|
||||||
// which knows nothing about Hyper. Because responding depends on the
|
|
||||||
// `HyperResponse` type, this function does the actual response processing.
|
|
||||||
async fn hyper_service_fn(
|
|
||||||
rocket: Arc<Rocket<Orbit>>,
|
|
||||||
conn: ConnectionMeta,
|
|
||||||
mut hyp_req: hyper::Request<hyper::Body>,
|
|
||||||
) -> Result<hyper::Response<hyper::Body>, io::Error> {
|
|
||||||
// This future must return a hyper::Response, but the response body might
|
|
||||||
// borrow from the request. Instead, write the body in another future that
|
|
||||||
// sends the response metadata (and a body channel) prior.
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
|
|
||||||
#[cfg(not(broken_fmt))]
|
|
||||||
debug!("received request: {:#?}", hyp_req);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// We move the request next, so get the upgrade future now.
|
|
||||||
let pending_upgrade = hyper::upgrade::on(&mut hyp_req);
|
|
||||||
|
|
||||||
// Convert a Hyper request into a Rocket request.
|
|
||||||
let (h_parts, mut h_body) = hyp_req.into_parts();
|
|
||||||
match Request::from_hyp(&rocket, &h_parts, Some(conn)) {
|
|
||||||
Ok(mut req) => {
|
|
||||||
// Convert into Rocket `Data`, dispatch request, write response.
|
|
||||||
let mut data = Data::from(&mut h_body);
|
|
||||||
let token = rocket.preprocess_request(&mut req, &mut data).await;
|
|
||||||
let mut response = rocket.dispatch(token, &req, data).await;
|
|
||||||
let upgrade = response.take_upgrade(req.headers().get("upgrade"));
|
|
||||||
if let Ok(Some((proto, handler))) = upgrade {
|
|
||||||
rocket.handle_upgrade(response, proto, handler, pending_upgrade, tx).await;
|
|
||||||
} else {
|
|
||||||
if upgrade.is_err() {
|
|
||||||
warn_!("Request wants upgrade but no I/O handler matched.");
|
|
||||||
info_!("Request is not being upgraded.");
|
|
||||||
}
|
|
||||||
|
|
||||||
rocket.send_response(response, tx).await;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Bad incoming HTTP request.");
|
|
||||||
e.errors.iter().for_each(|e| warn_!("Error: {}.", e));
|
|
||||||
warn_!("Dispatching salvaged request to catcher: {}.", e.request);
|
|
||||||
|
|
||||||
let response = rocket.handle_error(Status::BadRequest, &e.request).await;
|
|
||||||
rocket.send_response(response, tx).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Receive the response written to `tx` by the task above.
|
|
||||||
rx.await.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Rocket<Orbit> {
|
impl Rocket<Orbit> {
|
||||||
/// Wrapper around `_send_response` to log a success or error.
|
async fn service(
|
||||||
#[inline]
|
self: Arc<Self>,
|
||||||
async fn send_response(
|
mut req: hyper::Request<hyper::body::Incoming>,
|
||||||
&self,
|
connection: ConnectionMeta,
|
||||||
response: Response<'_>,
|
) -> Result<hyper::Response<ReaderStream<ErasedResponse>>, http::Error> {
|
||||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
let upgrade = hyper::upgrade::on(&mut req);
|
||||||
) {
|
let (parts, incoming) = req.into_parts();
|
||||||
let remote_hungup = |e: &io::Error| match e.kind() {
|
let request = ErasedRequest::new(self, parts, |rocket, parts| {
|
||||||
| io::ErrorKind::BrokenPipe
|
Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e)
|
||||||
| io::ErrorKind::ConnectionReset
|
|
||||||
| io::ErrorKind::ConnectionAborted => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
match self._send_response(response, tx).await {
|
|
||||||
Ok(()) => info_!("{}", "Response succeeded.".green()),
|
|
||||||
Err(e) if remote_hungup(&e) => warn_!("Remote left: {}.", e),
|
|
||||||
Err(e) => warn_!("Failed to write response: {}.", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempts to create a hyper response from `response` and send it to `tx`.
|
|
||||||
#[inline]
|
|
||||||
async fn _send_response(
|
|
||||||
&self,
|
|
||||||
mut response: Response<'_>,
|
|
||||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
|
||||||
) -> io::Result<()> {
|
|
||||||
let mut hyp_res = hyper::Response::builder();
|
|
||||||
|
|
||||||
hyp_res = hyp_res.status(response.status().code);
|
|
||||||
for header in response.headers().iter() {
|
|
||||||
let name = header.name.as_str();
|
|
||||||
let value = header.value.as_bytes();
|
|
||||||
hyp_res = hyp_res.header(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = response.body_mut();
|
|
||||||
if let Some(n) = body.size().await {
|
|
||||||
hyp_res = hyp_res.header(hyper::header::CONTENT_LENGTH, n);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (mut sender, hyp_body) = hyper::Body::channel();
|
|
||||||
let hyp_response = hyp_res.body(hyp_body)
|
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
|
||||||
|
|
||||||
#[cfg(not(broken_fmt))]
|
|
||||||
debug!("sending response: {:#?}", hyp_response);
|
|
||||||
|
|
||||||
tx.send(hyp_response).map_err(|_| {
|
|
||||||
let msg = "client disconnect before response started";
|
|
||||||
io::Error::new(io::ErrorKind::BrokenPipe, msg)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let max_chunk_size = body.max_chunk_size();
|
|
||||||
let mut stream = body.into_bytes_stream(max_chunk_size);
|
|
||||||
while let Some(next) = stream.next().await {
|
|
||||||
sender.send_data(next?).await
|
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_upgrade<'r>(
|
|
||||||
&self,
|
|
||||||
mut response: Response<'r>,
|
|
||||||
proto: uncased::Uncased<'r>,
|
|
||||||
io_handler: Pin<Box<dyn IoHandler + 'r>>,
|
|
||||||
pending_upgrade: hyper::upgrade::OnUpgrade,
|
|
||||||
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
|
|
||||||
) {
|
|
||||||
info_!("Upgrading connection to {}.", Paint::white(&proto).bold());
|
|
||||||
response.set_status(Status::SwitchingProtocols);
|
|
||||||
response.set_raw_header("Connection", "Upgrade");
|
|
||||||
response.set_raw_header("Upgrade", proto.clone().into_cow());
|
|
||||||
self.send_response(response, tx).await;
|
|
||||||
|
|
||||||
match pending_upgrade.await {
|
|
||||||
Ok(io_stream) => {
|
|
||||||
info_!("Upgrade successful.");
|
|
||||||
if let Err(e) = io_handler.io(io_stream.into()).await {
|
|
||||||
if e.kind() == io::ErrorKind::BrokenPipe {
|
|
||||||
warn!("Upgraded {} I/O handler was closed.", proto);
|
|
||||||
} else {
|
|
||||||
error!("Upgraded {} I/O handler failed: {}", proto, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Response indicated upgrade, but upgrade failed.");
|
|
||||||
warn_!("Upgrade error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Preprocess the request for Rocket things. Currently, this means:
|
|
||||||
///
|
|
||||||
/// * Rewriting the method in the request if _method form field exists.
|
|
||||||
/// * Run the request fairings.
|
|
||||||
///
|
|
||||||
/// Keep this in-sync with derive_form when preprocessing form fields.
|
|
||||||
pub(crate) async fn preprocess_request(
|
|
||||||
&self,
|
|
||||||
req: &mut Request<'_>,
|
|
||||||
data: &mut Data<'_>
|
|
||||||
) -> RequestToken {
|
|
||||||
// Check if this is a form and if the form contains the special _method
|
|
||||||
// field which we use to reinterpret the request's method.
|
|
||||||
let (min_len, max_len) = ("_method=get".len(), "_method=delete".len());
|
|
||||||
let peek_buffer = data.peek(max_len).await;
|
|
||||||
let is_form = req.content_type().map_or(false, |ct| ct.is_form());
|
|
||||||
|
|
||||||
if is_form && req.method() == Method::Post && peek_buffer.len() >= min_len {
|
|
||||||
let method = std::str::from_utf8(peek_buffer).ok()
|
|
||||||
.and_then(|raw_form| Form::values(raw_form).next())
|
|
||||||
.filter(|field| field.name == "_method")
|
|
||||||
.and_then(|field| field.value.parse().ok());
|
|
||||||
|
|
||||||
if let Some(method) = method {
|
|
||||||
req._set_method(method);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run request fairings.
|
|
||||||
self.fairings.handle_request(req, data).await;
|
|
||||||
|
|
||||||
RequestToken
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub(crate) async fn dispatch<'s, 'r: 's>(
|
|
||||||
&'s self,
|
|
||||||
_token: RequestToken,
|
|
||||||
request: &'r Request<'s>,
|
|
||||||
data: Data<'r>
|
|
||||||
) -> Response<'r> {
|
|
||||||
info!("{}:", request);
|
|
||||||
|
|
||||||
// Remember if the request is `HEAD` for later body stripping.
|
|
||||||
let was_head_request = request.method() == Method::Head;
|
|
||||||
|
|
||||||
// Route the request and run the user's handlers.
|
|
||||||
let mut response = self.route_and_process(request, data).await;
|
|
||||||
|
|
||||||
// Add a default 'Server' header if it isn't already there.
|
|
||||||
// TODO: If removing Hyper, write out `Date` header too.
|
|
||||||
if let Some(ident) = request.rocket().config.ident.as_str() {
|
|
||||||
if !response.headers().contains("Server") {
|
|
||||||
response.set_header(Header::new("Server", ident));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the response fairings.
|
|
||||||
self.fairings.handle_response(request, &mut response).await;
|
|
||||||
|
|
||||||
// Strip the body if this is a `HEAD` request.
|
|
||||||
if was_head_request {
|
|
||||||
response.strip_body();
|
|
||||||
}
|
|
||||||
|
|
||||||
response
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn route_and_process<'s, 'r: 's>(
|
|
||||||
&'s self,
|
|
||||||
request: &'r Request<'s>,
|
|
||||||
data: Data<'r>
|
|
||||||
) -> Response<'r> {
|
|
||||||
let mut response = match self.route(request, data).await {
|
|
||||||
Outcome::Success(response) => response,
|
|
||||||
Outcome::Forward((data, _)) if request.method() == Method::Head => {
|
|
||||||
info_!("Autohandling {} request.", "HEAD".primary().bold());
|
|
||||||
|
|
||||||
// Dispatch the request again with Method `GET`.
|
|
||||||
request._set_method(Method::Get);
|
|
||||||
match self.route(request, data).await {
|
|
||||||
Outcome::Success(response) => response,
|
|
||||||
Outcome::Error(status) => self.handle_error(status, request).await,
|
|
||||||
Outcome::Forward((_, status)) => self.handle_error(status, request).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Outcome::Forward((_, status)) => self.handle_error(status, request).await,
|
|
||||||
Outcome::Error(status) => self.handle_error(status, request).await,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set the cookies. Note that error responses will only include cookies
|
|
||||||
// set by the error handler. See `handle_error` for more.
|
|
||||||
let delta_jar = request.cookies().take_delta_jar();
|
|
||||||
for cookie in delta_jar.delta() {
|
|
||||||
response.adjoin_header(cookie);
|
|
||||||
}
|
|
||||||
|
|
||||||
response
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tries to find a `Responder` for a given `request`. It does this by
|
|
||||||
/// routing the request and calling the handler for each matching route
|
|
||||||
/// until one of the handlers returns success or error, or there are no
|
|
||||||
/// additional routes to try (forward). The corresponding outcome for each
|
|
||||||
/// condition is returned.
|
|
||||||
#[inline]
|
|
||||||
async fn route<'s, 'r: 's>(
|
|
||||||
&'s self,
|
|
||||||
request: &'r Request<'s>,
|
|
||||||
mut data: Data<'r>,
|
|
||||||
) -> route::Outcome<'r> {
|
|
||||||
// Go through all matching routes until we fail or succeed or run out of
|
|
||||||
// routes to try, in which case we forward with the last status.
|
|
||||||
let mut status = Status::NotFound;
|
|
||||||
for route in self.router.route(request) {
|
|
||||||
// Retrieve and set the requests parameters.
|
|
||||||
info_!("Matched: {}", route);
|
|
||||||
request.set_route(route);
|
|
||||||
|
|
||||||
let name = route.name.as_deref();
|
|
||||||
let outcome = handle(name, || route.handler.handle(request, data)).await
|
|
||||||
.unwrap_or(Outcome::Error(Status::InternalServerError));
|
|
||||||
|
|
||||||
// Check if the request processing completed (Some) or if the
|
|
||||||
// request needs to be forwarded. If it does, continue the loop
|
|
||||||
// (None) to try again.
|
|
||||||
info_!("{}", outcome.log_display());
|
|
||||||
match outcome {
|
|
||||||
o@Outcome::Success(_) | o@Outcome::Error(_) => return o,
|
|
||||||
Outcome::Forward(forwarded) => (data, status) = forwarded,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
error_!("No matching routes for {}.", request);
|
|
||||||
Outcome::Forward((data, status))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the handler with `req` for catcher with status `status`.
|
|
||||||
///
|
|
||||||
/// In order of preference, invoked handler is:
|
|
||||||
/// * the user's registered handler for `status`
|
|
||||||
/// * the user's registered `default` handler
|
|
||||||
/// * Rocket's default handler for `status`
|
|
||||||
///
|
|
||||||
/// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))`
|
|
||||||
/// if the handler ran to completion but failed. Returns `Ok(None)` if the
|
|
||||||
/// handler panicked while executing.
|
|
||||||
async fn invoke_catcher<'s, 'r: 's>(
|
|
||||||
&'s self,
|
|
||||||
status: Status,
|
|
||||||
req: &'r Request<'s>
|
|
||||||
) -> Result<Response<'r>, Option<Status>> {
|
|
||||||
// For now, we reset the delta state to prevent any modifications
|
|
||||||
// from earlier, unsuccessful paths from being reflected in error
|
|
||||||
// response. We may wish to relax this in the future.
|
|
||||||
req.cookies().reset_delta();
|
|
||||||
|
|
||||||
if let Some(catcher) = self.router.catch(status, req) {
|
|
||||||
warn_!("Responding with registered {} catcher.", catcher);
|
|
||||||
let name = catcher.name.as_deref();
|
|
||||||
handle(name, || catcher.handler.handle(status, req)).await
|
|
||||||
.map(|result| result.map_err(Some))
|
|
||||||
.unwrap_or_else(|| Err(None))
|
|
||||||
} else {
|
|
||||||
let code = status.code.blue().bold();
|
|
||||||
warn_!("No {} catcher registered. Using Rocket default.", code);
|
|
||||||
Ok(crate::catcher::default_handler(status, req))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invokes the catcher for `status`. Returns the response on success.
|
|
||||||
//
|
|
||||||
// On catcher error, the 500 error catcher is attempted. If _that_ errors,
|
|
||||||
// the (infallible) default 500 error cather is used.
|
|
||||||
pub(crate) async fn handle_error<'s, 'r: 's>(
|
|
||||||
&'s self,
|
|
||||||
mut status: Status,
|
|
||||||
req: &'r Request<'s>
|
|
||||||
) -> Response<'r> {
|
|
||||||
// Dispatch to the `status` catcher.
|
|
||||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it fails and it's not a 500, try the 500 catcher.
|
|
||||||
if status != Status::InternalServerError {
|
|
||||||
error_!("Catcher failed. Attempting 500 error catcher.");
|
|
||||||
status = Status::InternalServerError;
|
|
||||||
if let Ok(r) = self.invoke_catcher(status, req).await {
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it failed again or if it was already a 500, use Rocket's default.
|
|
||||||
error_!("{} catcher failed. Using Rocket default 500.", status.code);
|
|
||||||
crate::catcher::default_handler(Status::InternalServerError, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn default_tcp_http_server<C>(mut self, ready: C) -> Result<Self, Error>
|
|
||||||
where C: for<'a> Fn(&'a Self) -> BoxFuture<'a, ()>
|
|
||||||
{
|
|
||||||
use std::net::ToSocketAddrs;
|
|
||||||
|
|
||||||
// Determine the address we're going to serve on.
|
|
||||||
let addr = format!("{}:{}", self.config.address, self.config.port);
|
|
||||||
let mut addr = addr.to_socket_addrs()
|
|
||||||
.map(|mut addrs| addrs.next().expect(">= 1 socket addr"))
|
|
||||||
.map_err(|e| Error::new(ErrorKind::Io(e)))?;
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
if self.config.tls_enabled() {
|
|
||||||
if let Some(ref config) = self.config.tls {
|
|
||||||
use crate::http::tls::TlsListener;
|
|
||||||
|
|
||||||
let conf = config.to_native_config().map_err(ErrorKind::Io)?;
|
|
||||||
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::TlsBind)?;
|
|
||||||
addr = l.local_addr().unwrap_or(addr);
|
|
||||||
self.config.address = addr.ip();
|
|
||||||
self.config.port = addr.port();
|
|
||||||
ready(&mut self).await;
|
|
||||||
return self.http_server(l).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let l = TcpListener::bind(addr).await.map_err(ErrorKind::Bind)?;
|
|
||||||
addr = l.local_addr().unwrap_or(addr);
|
|
||||||
self.config.address = addr.ip();
|
|
||||||
self.config.port = addr.port();
|
|
||||||
ready(&mut self).await;
|
|
||||||
self.http_server(l).await
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO.async: Solidify the Listener APIs and make this function public
|
|
||||||
pub(crate) async fn http_server<L>(self, listener: L) -> Result<Self, Error>
|
|
||||||
where L: Listener + Send, <L as Listener>::Connection: Send + Unpin + 'static
|
|
||||||
{
|
|
||||||
// Emit a warning if we're not running inside of Rocket's async runtime.
|
|
||||||
if self.config.profile == Config::DEBUG_PROFILE {
|
|
||||||
tokio::task::spawn_blocking(|| {
|
|
||||||
let this = std::thread::current();
|
|
||||||
if !this.name().map_or(false, |s| s.starts_with("rocket-worker")) {
|
|
||||||
warn!("Rocket is executing inside of a custom runtime.");
|
|
||||||
info_!("Rocket's runtime is enabled via `#[rocket::main]` or `#[launch]`.");
|
|
||||||
info_!("Forced shutdown is disabled. Runtime settings may be suboptimal.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up cancellable I/O from the given listener. Shutdown occurs when
|
|
||||||
// `Shutdown` (`TripWire`) resolves. This can occur directly through a
|
|
||||||
// notification or indirectly through an external signal which, when
|
|
||||||
// received, results in triggering the notify.
|
|
||||||
let shutdown = self.shutdown();
|
|
||||||
let sig_stream = self.config.shutdown.signal_stream();
|
|
||||||
let grace = self.config.shutdown.grace as u64;
|
|
||||||
let mercy = self.config.shutdown.mercy as u64;
|
|
||||||
|
|
||||||
// Start a task that listens for external signals and notifies shutdown.
|
|
||||||
if let Some(mut stream) = sig_stream {
|
|
||||||
let shutdown = shutdown.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
while let Some(sig) = stream.next().await {
|
|
||||||
if shutdown.0.tripped() {
|
|
||||||
warn!("Received {}. Shutdown already in progress.", sig);
|
|
||||||
} else {
|
|
||||||
warn!("Received {}. Requesting shutdown.", sig);
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdown.0.trip();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the keep-alive value for later use; we're about to move `self`.
|
|
||||||
let keep_alive = self.config.keep_alive;
|
|
||||||
|
|
||||||
// Create the Hyper `Service`.
|
|
||||||
let rocket = Arc::new(self);
|
|
||||||
let service_fn = |conn: &CancellableIo<_, L::Connection>| {
|
|
||||||
let rocket = rocket.clone();
|
|
||||||
let connection = ConnectionMeta {
|
|
||||||
remote: conn.peer_address(),
|
|
||||||
client_certificates: conn.peer_certificates(),
|
|
||||||
};
|
|
||||||
|
|
||||||
async move {
|
|
||||||
Ok::<_, std::convert::Infallible>(hyper::service::service_fn(move |req| {
|
|
||||||
hyper_service_fn(rocket.clone(), connection.clone(), req)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// NOTE: `hyper` uses `tokio::spawn()` as the default executor.
|
|
||||||
let listener = CancellableListener::new(shutdown.clone(), listener, grace, mercy);
|
|
||||||
let builder = hyper::server::Server::builder(Incoming::new(listener).nodelay(true));
|
|
||||||
|
|
||||||
#[cfg(feature = "http2")]
|
|
||||||
let builder = builder.http2_keep_alive_interval(match keep_alive {
|
|
||||||
0 => None,
|
|
||||||
n => Some(Duration::from_secs(n as u64))
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let server = builder
|
let mut response = request.into_response(
|
||||||
.http1_keepalive(keep_alive != 0)
|
incoming,
|
||||||
.http1_preserve_header_case(true)
|
|incoming| Data::from(incoming),
|
||||||
.serve(hyper::service::make_service_fn(service_fn))
|
|rocket, request, data| Box::pin(rocket.preprocess(request, data)),
|
||||||
.with_graceful_shutdown(shutdown.clone());
|
|token, rocket, request, data| Box::pin(async move {
|
||||||
|
if !request.errors.is_empty() {
|
||||||
|
return rocket.dispatch_error(Status::BadRequest, request).await;
|
||||||
|
}
|
||||||
|
|
||||||
// This deserves some explanation.
|
let mut response = rocket.dispatch(token, request, data).await;
|
||||||
//
|
response.body_mut().size().await;
|
||||||
// This is largely to deal with Hyper's dreadful and largely nonexistent
|
response
|
||||||
// handling of shutdown, in general, nevermind graceful.
|
})
|
||||||
//
|
).await;
|
||||||
// When Hyper receives a "graceful shutdown" request, it stops accepting
|
|
||||||
// new requests. That's it. It continues to process existing requests
|
let io_handler = response.to_io_handler(Rocket::extract_io_handler);
|
||||||
// and outgoing responses forever and never cancels them. As a result,
|
if let Some(handler) = io_handler {
|
||||||
// Rocket must take it upon itself to cancel any existing I/O.
|
let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other);
|
||||||
//
|
tokio::task::spawn(io_handler_task(upgrade, handler));
|
||||||
// To do so, Rocket wraps all connections in a `CancellableIo` struct,
|
}
|
||||||
// an internal structure that gracefully closes I/O when it receives a
|
|
||||||
// signal. That signal is the `shutdown` future. When the future
|
let mut builder = hyper::Response::builder();
|
||||||
// resolves, `CancellableIo` begins to terminate in grace, mercy, and
|
builder = builder.status(response.inner().status().code);
|
||||||
// finally force close phases. Since all connections are wrapped in
|
for header in response.inner().headers().iter() {
|
||||||
|
builder = builder.header(header.name().as_str(), header.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(size) = response.inner().body().preset_size() {
|
||||||
|
builder = builder.header("Content-Length", size);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk_size = response.inner().body().max_chunk_size();
|
||||||
|
builder.body(ReaderStream::with_capacity(response, chunk_size))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn io_handler_task<S>(stream: S, mut handler: ErasedIoHandler)
|
||||||
|
where S: Future<Output = io::Result<IoStream>>
|
||||||
|
{
|
||||||
|
let stream = match stream.await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(e) => return warn_!("Upgrade failed: {e}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
info_!("Upgrade succeeded.");
|
||||||
|
if let Err(e) = handler.take().io(stream).await {
|
||||||
|
match e.kind() {
|
||||||
|
io::ErrorKind::BrokenPipe => warn!("Upgrade I/O handler was closed."),
|
||||||
|
e => error!("Upgrade I/O handler failed: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Rocket<Ignite> {
|
||||||
|
pub(crate) async fn serve<L>(self, listener: L) -> Result<Self, crate::Error>
|
||||||
|
where L: Listener + 'static
|
||||||
|
{
|
||||||
|
let mut builder = Builder::new(TokioExecutor::new());
|
||||||
|
let keep_alive = Duration::from_secs(self.config.keep_alive.into());
|
||||||
|
builder.http1()
|
||||||
|
.half_close(true)
|
||||||
|
.timer(TokioTimer::new())
|
||||||
|
.keep_alive(keep_alive > Duration::ZERO)
|
||||||
|
.preserve_header_case(true)
|
||||||
|
.header_read_timeout(Duration::from_secs(15));
|
||||||
|
|
||||||
|
#[cfg(feature = "http2")] {
|
||||||
|
builder.http2().timer(TokioTimer::new());
|
||||||
|
if keep_alive > Duration::ZERO {
|
||||||
|
builder.http2()
|
||||||
|
.timer(TokioTimer::new())
|
||||||
|
.keep_alive_interval(keep_alive / 4)
|
||||||
|
.keep_alive_timeout(keep_alive);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown);
|
||||||
|
let rocket = Arc::new(self.into_orbit(listener.socket_addr()?));
|
||||||
|
let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await;
|
||||||
|
|
||||||
|
let (server, listener) = (Arc::new(builder), Arc::new(listener));
|
||||||
|
while let Some(accept) = listener.accept_next().await {
|
||||||
|
let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone());
|
||||||
|
tokio::spawn({
|
||||||
|
let result = async move {
|
||||||
|
let conn = TokioIo::new(listener.connect(accept).await?);
|
||||||
|
let meta = ConnectionMeta::from(conn.inner());
|
||||||
|
let service = service_fn(|req| rocket.clone().service(req, meta.clone()));
|
||||||
|
let serve = pin!(server.serve_connection_with_upgrades(conn, service));
|
||||||
|
match select(serve, rocket.shutdown()).await {
|
||||||
|
Left((result, _)) => result,
|
||||||
|
Right((_, mut conn)) => {
|
||||||
|
conn.as_mut().graceful_shutdown();
|
||||||
|
conn.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
result.inspect_err(crate::error::log_server_error)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rocket wraps all connections in a `CancellableIo` struct, an internal
|
||||||
|
// structure that gracefully closes I/O when it receives a signal. That
|
||||||
|
// signal is the `shutdown` future. When the future resolves,
|
||||||
|
// `CancellableIo` begins to terminate in grace, mercy, and finally
|
||||||
|
// force close phases. Since all connections are wrapped in
|
||||||
// `CancellableIo`, this eventually ends all I/O.
|
// `CancellableIo`, this eventually ends all I/O.
|
||||||
//
|
//
|
||||||
// At that point, unless a user spawned an infinite, stand-alone task
|
// At that point, unless a user spawned an infinite, stand-alone task
|
||||||
|
@ -543,69 +145,35 @@ impl Rocket<Orbit> {
|
||||||
// we can return the owned instance of `Rocket`.
|
// we can return the owned instance of `Rocket`.
|
||||||
//
|
//
|
||||||
// Unfortunately, the Hyper `server` future resolves as soon as it has
|
// Unfortunately, the Hyper `server` future resolves as soon as it has
|
||||||
// finishes processing requests without respect for ongoing responses.
|
// finished processing requests without respect for ongoing responses.
|
||||||
// That is, `server` resolves even when there are running tasks that are
|
// That is, `server` resolves even when there are running tasks that are
|
||||||
// generating a response. So, `server` resolving implies little to
|
// generating a response. So, `server` resolving implies little to
|
||||||
// nothing about the state of connections. As a result, we depend on the
|
// nothing about the state of connections. As a result, we depend on the
|
||||||
// timing of grace + mercy + some buffer to determine when all
|
// timing of grace + mercy + some buffer to determine when all
|
||||||
// connections should be closed, thus all tasks should be complete, thus
|
// connections should be closed, thus all tasks should be complete, thus
|
||||||
// all references to `Arc<Rocket>` should be dropped and we can get a
|
// all references to `Arc<Rocket>` should be dropped and we can get back
|
||||||
// unique reference.
|
// a unique reference.
|
||||||
tokio::pin!(server);
|
info!("Shutting down. Waiting for shutdown fairings and pending I/O...");
|
||||||
tokio::select! {
|
tokio::spawn({
|
||||||
biased;
|
let rocket = rocket.clone();
|
||||||
|
async move { rocket.fairings.handle_shutdown(&*rocket).await }
|
||||||
|
});
|
||||||
|
|
||||||
_ = shutdown => {
|
let config = &rocket.config.shutdown;
|
||||||
// Run shutdown fairings. We compute `sleep()` for grace periods
|
let wait = Duration::from_micros(250);
|
||||||
// beforehand to ensure we don't add shutdown fairing completion
|
for period in [wait, config.grace(), wait, config.mercy(), wait * 4] {
|
||||||
// time, which is arbitrary, to these periods.
|
if Arc::strong_count(&rocket) == 1 { break }
|
||||||
info!("Shutdown requested. Waiting for pending I/O...");
|
sleep(period).await;
|
||||||
let grace_timer = sleep(Duration::from_secs(grace));
|
}
|
||||||
let mercy_timer = sleep(Duration::from_secs(grace + mercy));
|
|
||||||
let shutdown_timer = sleep(Duration::from_secs(grace + mercy + 1));
|
|
||||||
rocket.fairings.handle_shutdown(&*rocket).await;
|
|
||||||
|
|
||||||
tokio::pin!(grace_timer, mercy_timer, shutdown_timer);
|
match Arc::try_unwrap(rocket) {
|
||||||
tokio::select! {
|
Ok(rocket) => {
|
||||||
biased;
|
info!("Graceful shutdown completed successfully.");
|
||||||
|
Ok(rocket.into_ignite())
|
||||||
result = &mut server => {
|
|
||||||
if let Err(e) = result {
|
|
||||||
warn!("Server failed while shutting down: {}", e);
|
|
||||||
return Err(Error::shutdown(rocket.clone(), e));
|
|
||||||
}
|
|
||||||
|
|
||||||
if Arc::strong_count(&rocket) != 1 { grace_timer.await; }
|
|
||||||
if Arc::strong_count(&rocket) != 1 { mercy_timer.await; }
|
|
||||||
if Arc::strong_count(&rocket) != 1 { shutdown_timer.await; }
|
|
||||||
match Arc::try_unwrap(rocket) {
|
|
||||||
Ok(rocket) => {
|
|
||||||
info!("Graceful shutdown completed successfully.");
|
|
||||||
Ok(rocket)
|
|
||||||
}
|
|
||||||
Err(rocket) => {
|
|
||||||
warn!("Shutdown failed: outstanding background I/O.");
|
|
||||||
Err(Error::shutdown(rocket, None))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = &mut shutdown_timer => {
|
|
||||||
warn!("Shutdown failed: server executing after timeouts.");
|
|
||||||
return Err(Error::shutdown(rocket.clone(), None));
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
result = &mut server => {
|
Err(rocket) => {
|
||||||
match result {
|
warn!("Shutdown failed: outstanding background I/O.");
|
||||||
Ok(()) => {
|
Err(Error::new(ErrorKind::Shutdown(rocket)))
|
||||||
info!("Server shutdown nominally.");
|
|
||||||
Ok(Arc::try_unwrap(rocket).map_err(|r| Error::shutdown(r, None))?)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
info!("Server failed prior to shutdown: {}:", e);
|
|
||||||
Err(Error::shutdown(rocket.clone(), e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,7 +198,7 @@ impl Fairing for Shield {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||||
let force_hsts = rocket.config().tls_enabled()
|
let force_hsts = rocket.endpoint().is_tls()
|
||||||
&& rocket.figment().profile() != Config::DEBUG_PROFILE
|
&& rocket.figment().profile() != Config::DEBUG_PROFILE
|
||||||
&& !self.is_enabled::<Hsts>();
|
&& !self.is_enabled::<Hsts>();
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::pin::Pin;
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
|
|
||||||
use crate::request::{FromRequest, Outcome, Request};
|
use crate::request::{FromRequest, Outcome, Request};
|
||||||
use crate::trip_wire::TripWire;
|
use crate::util::TripWire;
|
||||||
|
|
||||||
/// A request guard and future for graceful shutdown.
|
/// A request guard and future for graceful shutdown.
|
||||||
///
|
///
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,6 +11,7 @@ pub enum KeyError {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
Io(std::io::Error),
|
Io(std::io::Error),
|
||||||
|
Bind(Box<dyn std::error::Error + Send + 'static>),
|
||||||
Tls(rustls::Error),
|
Tls(rustls::Error),
|
||||||
Mtls(rustls::server::VerifierBuilderError),
|
Mtls(rustls::server::VerifierBuilderError),
|
||||||
CertChain(std::io::Error),
|
CertChain(std::io::Error),
|
||||||
|
@ -29,6 +30,7 @@ impl std::fmt::Display for Error {
|
||||||
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
|
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
|
||||||
PrivKey(e) => write!(f, "failed to process private key: {e}"),
|
PrivKey(e) => write!(f, "failed to process private key: {e}"),
|
||||||
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
|
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
|
||||||
|
Bind(e) => write!(f, "failed to bind to network interface: {e}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,6 +68,7 @@ impl std::error::Error for Error {
|
||||||
Error::CertChain(e) => Some(e),
|
Error::CertChain(e) => Some(e),
|
||||||
Error::PrivKey(e) => Some(e),
|
Error::PrivKey(e) => Some(e),
|
||||||
Error::CertAuth(e) => Some(e),
|
Error::CertAuth(e) => Some(e),
|
||||||
|
Error::Bind(e) => Some(&**e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
mod error;
|
||||||
|
pub(crate) mod config;
|
||||||
|
pub(crate) mod util;
|
||||||
|
|
||||||
|
pub use error::Result;
|
||||||
|
pub use config::{TlsConfig, CipherSuite};
|
||||||
|
pub use error::Error;
|
|
@ -0,0 +1,52 @@
|
||||||
|
use std::io;
|
||||||
|
use std::task::{Poll, Context};
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
use tokio::io::{AsyncRead, ReadBuf};
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
|
||||||
|
#[must_use = "streams do nothing unless polled"]
|
||||||
|
pub struct Chain<T, U> {
|
||||||
|
#[pin]
|
||||||
|
first: Option<T>,
|
||||||
|
#[pin]
|
||||||
|
second: U,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, U> Chain<T, U> {
|
||||||
|
pub(crate) fn new(first: T, second: U) -> Self {
|
||||||
|
Self { first: Some(first), second }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead, U: AsyncRead> Chain<T, U> {
|
||||||
|
/// Gets references to the underlying readers in this `Chain`.
|
||||||
|
pub fn get_ref(&self) -> (Option<&T>, &U) {
|
||||||
|
(self.first.as_ref(), &self.second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead, U: AsyncRead> AsyncRead for Chain<T, U> {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
let me = self.as_mut().project();
|
||||||
|
if let Some(first) = me.first.as_pin_mut() {
|
||||||
|
let init_rem = buf.remaining();
|
||||||
|
futures::ready!(first.poll_read(cx, buf))?;
|
||||||
|
if buf.remaining() == init_rem {
|
||||||
|
self.as_mut().project().first.set(None);
|
||||||
|
} else {
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let me = self.as_mut().project();
|
||||||
|
me.second.poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Poll, Context};
|
||||||
|
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
|
||||||
|
use futures::stream::Stream;
|
||||||
|
use futures::ready;
|
||||||
|
|
||||||
|
/// Join two streams, `a` and `b`, into a new `Join` stream that returns items
|
||||||
|
/// from both streams, biased to `a`, until `a` finishes. The joined stream
|
||||||
|
/// completes when `a` completes, irrespective of `b`. If `b` stops producing
|
||||||
|
/// values, then the joined stream acts exactly like a fused `a`.
|
||||||
|
///
|
||||||
|
/// Values are biased to those of `a`: if `a` provides a value, it is always
|
||||||
|
/// emitted before a value provided by `b`. In other words, values from `b` are
|
||||||
|
/// emitted when and only when `a` is not producing a value.
|
||||||
|
pub fn join<A: Stream, B: Stream>(a: A, b: B) -> Join<A, B> {
|
||||||
|
Join { a, b: Some(b), done: false, }
|
||||||
|
}
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// Stream returned by [`join`].
|
||||||
|
pub struct Join<T, U> {
|
||||||
|
#[pin]
|
||||||
|
a: T,
|
||||||
|
#[pin]
|
||||||
|
b: Option<U>,
|
||||||
|
// Set when `a` returns `None`.
|
||||||
|
done: bool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, U> Stream for Join<T, U>
|
||||||
|
where T: Stream,
|
||||||
|
U: Stream<Item = T::Item>,
|
||||||
|
{
|
||||||
|
type Item = T::Item;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
|
||||||
|
if self.done {
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let me = self.as_mut().project();
|
||||||
|
match me.a.poll_next(cx) {
|
||||||
|
Poll::Ready(opt) => {
|
||||||
|
*me.done = opt.is_none();
|
||||||
|
Poll::Ready(opt)
|
||||||
|
},
|
||||||
|
Poll::Pending => match me.b.as_pin_mut() {
|
||||||
|
None => Poll::Pending,
|
||||||
|
Some(b) => match ready!(b.poll_next(cx)) {
|
||||||
|
Some(value) => Poll::Ready(Some(value)),
|
||||||
|
None => {
|
||||||
|
self.as_mut().project().b.set(None);
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
let (left_low, left_high) = self.a.size_hint();
|
||||||
|
let (right_low, right_high) = self.b.as_ref()
|
||||||
|
.map(|b| b.size_hint())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let low = left_low.saturating_add(right_low);
|
||||||
|
let high = match (left_high, right_high) {
|
||||||
|
(Some(h1), Some(h2)) => h1.checked_add(h2),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
(low, high)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
mod chain;
|
||||||
|
mod tripwire;
|
||||||
|
mod reader_stream;
|
||||||
|
mod join;
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
pub mod unix;
|
||||||
|
|
||||||
|
pub use chain::Chain;
|
||||||
|
pub use tripwire::TripWire;
|
||||||
|
pub use reader_stream::ReaderStream;
|
||||||
|
pub use join::join;
|
|
@ -0,0 +1,124 @@
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use futures::stream::Stream;
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
use tokio::io::AsyncRead;
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
/// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
|
||||||
|
///
|
||||||
|
/// This stream is fused. It performs the inverse operation of
|
||||||
|
/// [`StreamReader`].
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> std::io::Result<()> {
|
||||||
|
/// use tokio_stream::StreamExt;
|
||||||
|
/// use tokio_util::io::ReaderStream;
|
||||||
|
///
|
||||||
|
/// // Create a stream of data.
|
||||||
|
/// let data = b"hello, world!";
|
||||||
|
/// let mut stream = ReaderStream::new(&data[..]);
|
||||||
|
///
|
||||||
|
/// // Read all of the chunks into a vector.
|
||||||
|
/// let mut stream_contents = Vec::new();
|
||||||
|
/// while let Some(chunk) = stream.next().await {
|
||||||
|
/// stream_contents.extend_from_slice(&chunk?);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // Once the chunks are concatenated, we should have the
|
||||||
|
/// // original data.
|
||||||
|
/// assert_eq!(stream_contents, data);
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||||
|
/// [`StreamReader`]: crate::io::StreamReader
|
||||||
|
/// [`Stream`]: futures_core::Stream
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ReaderStream<R> {
|
||||||
|
// Reader itself.
|
||||||
|
//
|
||||||
|
// This value is `None` if the stream has terminated.
|
||||||
|
#[pin]
|
||||||
|
reader: R,
|
||||||
|
// Working buffer, used to optimize allocations.
|
||||||
|
buf: BytesMut,
|
||||||
|
capacity: usize,
|
||||||
|
done: bool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRead> ReaderStream<R> {
|
||||||
|
/// Convert an [`AsyncRead`] into a [`Stream`] with item type
|
||||||
|
/// `Result<Bytes, std::io::Error>`,
|
||||||
|
/// with a specific read buffer initial capacity.
|
||||||
|
///
|
||||||
|
/// [`AsyncRead`]: tokio::io::AsyncRead
|
||||||
|
/// [`Stream`]: futures_core::Stream
|
||||||
|
pub fn with_capacity(reader: R, capacity: usize) -> Self {
|
||||||
|
ReaderStream {
|
||||||
|
reader: reader,
|
||||||
|
buf: BytesMut::with_capacity(capacity),
|
||||||
|
capacity,
|
||||||
|
done: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRead> Stream for ReaderStream<R> {
|
||||||
|
type Item = std::io::Result<Bytes>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
use tokio_util::io::poll_read_buf;
|
||||||
|
|
||||||
|
let mut this = self.as_mut().project();
|
||||||
|
|
||||||
|
if *this.done {
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
if this.buf.capacity() == 0 {
|
||||||
|
this.buf.reserve(*this.capacity);
|
||||||
|
}
|
||||||
|
|
||||||
|
let reader = this.reader;
|
||||||
|
match poll_read_buf(reader, cx, &mut this.buf) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
Poll::Ready(Err(err)) => {
|
||||||
|
*this.done = true;
|
||||||
|
Poll::Ready(Some(Err(err)))
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(0)) => {
|
||||||
|
*this.done = true;
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(_)) => {
|
||||||
|
let chunk = this.buf.split();
|
||||||
|
Poll::Ready(Some(Ok(chunk.freeze())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
// self.reader.size_hint()
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRead> hyper::body::Body for ReaderStream<R> {
|
||||||
|
type Data = bytes::Bytes;
|
||||||
|
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn poll_frame(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
|
||||||
|
self.poll_next(cx).map_ok(hyper::body::Frame::data)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
use std::io;
|
||||||
|
use std::os::fd::AsRawFd;
|
||||||
|
|
||||||
|
pub fn lock_exlusive_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
|
||||||
|
let raw_fd = file.as_raw_fd();
|
||||||
|
let res = unsafe {
|
||||||
|
libc::flock(raw_fd, libc::LOCK_EX | libc::LOCK_NB)
|
||||||
|
};
|
||||||
|
|
||||||
|
match res {
|
||||||
|
0 => Ok(()),
|
||||||
|
_ => Err(io::Error::last_os_error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unlock_nonblocking<T: AsRawFd>(file: &T) -> io::Result<()> {
|
||||||
|
let res = unsafe {
|
||||||
|
libc::flock(file.as_raw_fd(), libc::LOCK_UN | libc::LOCK_NB)
|
||||||
|
};
|
||||||
|
|
||||||
|
match res {
|
||||||
|
0 => Ok(()),
|
||||||
|
_ => Err(io::Error::last_os_error()),
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,9 @@
|
||||||
#![cfg(feature = "tls")]
|
#![cfg(feature = "tls")]
|
||||||
|
|
||||||
use rocket::fs::relative;
|
use rocket::fs::relative;
|
||||||
use rocket::config::{Config, TlsConfig, CipherSuite};
|
|
||||||
use rocket::local::asynchronous::Client;
|
use rocket::local::asynchronous::Client;
|
||||||
|
use rocket::tls::{TlsConfig, CipherSuite};
|
||||||
|
use rocket::figment::providers::Serialized;
|
||||||
|
|
||||||
#[rocket::async_test]
|
#[rocket::async_test]
|
||||||
async fn can_launch_tls() {
|
async fn can_launch_tls() {
|
||||||
|
@ -15,9 +16,8 @@ async fn can_launch_tls() {
|
||||||
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let rocket = rocket::custom(Config { tls: Some(tls), ..Config::debug_default() });
|
let config = rocket::Config::figment().merge(Serialized::defaults(tls));
|
||||||
let client = Client::debug(rocket).await.unwrap();
|
let client = Client::debug(rocket::custom(config)).await.unwrap();
|
||||||
|
|
||||||
client.rocket().shutdown().notify();
|
client.rocket().shutdown().notify();
|
||||||
client.rocket().shutdown().await;
|
client.rocket().shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use std::net::{SocketAddr, Ipv4Addr};
|
||||||
|
|
||||||
use rocket::config::Config;
|
use rocket::config::Config;
|
||||||
use rocket::fairing::AdHoc;
|
use rocket::fairing::AdHoc;
|
||||||
use rocket::futures::channel::oneshot;
|
use rocket::futures::channel::oneshot;
|
||||||
|
@ -5,13 +7,13 @@ use rocket::futures::channel::oneshot;
|
||||||
#[rocket::async_test]
|
#[rocket::async_test]
|
||||||
async fn on_ignite_fairing_can_inspect_port() {
|
async fn on_ignite_fairing_can_inspect_port() {
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
let rocket = rocket::custom(Config { port: 0, ..Config::debug_default() })
|
let rocket = rocket::custom(Config::debug_default())
|
||||||
.attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| {
|
.attach(AdHoc::on_liftoff("Send Port -> Channel", move |rocket| {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
tx.send(rocket.config().port).unwrap();
|
tx.send(rocket.endpoint().tcp().unwrap().port()).unwrap();
|
||||||
})
|
})
|
||||||
}));
|
}));
|
||||||
|
|
||||||
rocket::tokio::spawn(rocket.launch());
|
rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))));
|
||||||
assert_ne!(rx.await.unwrap(), 0);
|
assert_ne!(rx.await.unwrap(), 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,7 +155,7 @@ fn inner_sentinels_detected() {
|
||||||
|
|
||||||
impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel {
|
impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel {
|
||||||
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
|
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
|
||||||
todo!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,19 +8,14 @@ macro_rules! relative {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tls_config_from_source() {
|
fn tls_config_from_source() {
|
||||||
use rocket::config::{Config, TlsConfig};
|
use rocket::tls::TlsConfig;
|
||||||
use rocket::figment::Figment;
|
use rocket::figment::{Figment, providers::Serialized};
|
||||||
|
|
||||||
let cert_path = relative!("examples/tls/private/cert.pem");
|
let cert_path = relative!("examples/tls/private/cert.pem");
|
||||||
let key_path = relative!("examples/tls/private/key.pem");
|
let key_path = relative!("examples/tls/private/key.pem");
|
||||||
|
let config = TlsConfig::from_paths(cert_path, key_path);
|
||||||
|
|
||||||
let rocket_config = Config {
|
let tls: TlsConfig = Figment::from(Serialized::globals(config)).extract().unwrap();
|
||||||
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config: Config = Figment::from(rocket_config).extract().unwrap();
|
|
||||||
let tls = config.tls.expect("have TLS config");
|
|
||||||
assert_eq!(tls.certs().unwrap_left(), cert_path);
|
assert_eq!(tls.certs().unwrap_left(), cert_path);
|
||||||
assert_eq!(tls.key().unwrap_left(), key_path);
|
assert_eq!(tls.key().unwrap_left(), key_path);
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,15 +6,11 @@ async fn test_config(profile: &str) {
|
||||||
let config = rocket.config();
|
let config = rocket.config();
|
||||||
match &*profile {
|
match &*profile {
|
||||||
"debug" => {
|
"debug" => {
|
||||||
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
|
|
||||||
assert_eq!(config.port, 8000);
|
|
||||||
assert_eq!(config.workers, 1);
|
assert_eq!(config.workers, 1);
|
||||||
assert_eq!(config.keep_alive, 0);
|
assert_eq!(config.keep_alive, 0);
|
||||||
assert_eq!(config.log_level, LogLevel::Normal);
|
assert_eq!(config.log_level, LogLevel::Normal);
|
||||||
}
|
}
|
||||||
"release" => {
|
"release" => {
|
||||||
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
|
|
||||||
assert_eq!(config.port, 8000);
|
|
||||||
assert_eq!(config.workers, 12);
|
assert_eq!(config.workers, 12);
|
||||||
assert_eq!(config.keep_alive, 5);
|
assert_eq!(config.keep_alive, 5);
|
||||||
assert_eq!(config.log_level, LogLevel::Critical);
|
assert_eq!(config.log_level, LogLevel::Critical);
|
||||||
|
|
|
@ -74,19 +74,8 @@ fn hello(lang: Option<Lang>, opt: Options<'_>) -> String {
|
||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
fn rocket() -> _ {
|
fn rocket() -> _ {
|
||||||
use rocket::fairing::AdHoc;
|
|
||||||
|
|
||||||
rocket::build()
|
rocket::build()
|
||||||
.mount("/", routes![hello])
|
.mount("/", routes![hello])
|
||||||
.mount("/hello", routes![world, mir])
|
.mount("/hello", routes![world, mir])
|
||||||
.mount("/wave", routes![wave])
|
.mount("/wave", routes![wave])
|
||||||
.attach(AdHoc::on_request("Compatibility Normalizer", |req, _| Box::pin(async move {
|
|
||||||
if !req.uri().is_normalized_nontrailing() {
|
|
||||||
let normal = req.uri().clone().into_normalized_nontrailing();
|
|
||||||
warn!("Incoming request URI was normalized for compatibility.");
|
|
||||||
info_!("{} -> {}", req.uri(), normal);
|
|
||||||
req.set_uri(normal);
|
|
||||||
}
|
|
||||||
})))
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,33 +1,38 @@
|
||||||
//! Redirect all HTTP requests to HTTPs.
|
//! Redirect all HTTP requests to HTTPs.
|
||||||
|
|
||||||
use std::sync::OnceLock;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use rocket::http::Status;
|
use rocket::http::Status;
|
||||||
use rocket::log::LogLevel;
|
use rocket::log::LogLevel;
|
||||||
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config};
|
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite};
|
||||||
use rocket::fairing::{Fairing, Info, Kind};
|
use rocket::fairing::{Fairing, Info, Kind};
|
||||||
use rocket::response::Redirect;
|
use rocket::response::Redirect;
|
||||||
|
|
||||||
|
use yansi::Paint;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct Redirector(u16);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Redirector {
|
pub struct Config {
|
||||||
pub listen_port: u16,
|
server: rocket::Config,
|
||||||
pub tls_port: OnceLock<u16>,
|
tls_addr: SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Redirector {
|
impl Redirector {
|
||||||
pub fn on(port: u16) -> Self {
|
pub fn on(port: u16) -> Self {
|
||||||
Redirector { listen_port: port, tls_port: OnceLock::new() }
|
Redirector(port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route function that gets called on every single request.
|
// Route function that gets called on every single request.
|
||||||
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
|
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
|
||||||
// FIXME: Check the host against a whitelist!
|
// FIXME: Check the host against a whitelist!
|
||||||
let redirector = req.rocket().state::<Self>().expect("managed Self");
|
let config = req.rocket().state::<Config>().expect("managed Self");
|
||||||
if let Some(host) = req.host() {
|
if let Some(host) = req.host() {
|
||||||
let domain = host.domain();
|
let domain = host.domain();
|
||||||
let https_uri = match redirector.tls_port.get() {
|
let https_uri = match config.tls_addr.port() {
|
||||||
Some(443) | None => format!("https://{domain}{}", req.uri()),
|
443 => format!("https://{domain}{}", req.uri()),
|
||||||
Some(port) => format!("https://{domain}:{port}{}", req.uri()),
|
port => format!("https://{domain}:{port}{}", req.uri()),
|
||||||
};
|
};
|
||||||
|
|
||||||
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
|
route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
|
||||||
|
@ -37,21 +42,12 @@ impl Redirector {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch an instance of Rocket than handles redirection on `self.port`.
|
// Launch an instance of Rocket than handles redirection on `self.port`.
|
||||||
pub async fn try_launch(self, mut config: Config) -> Result<Rocket<Ignite>, Error> {
|
pub async fn try_launch(self, config: Config) -> Result<Rocket<Ignite>, Error> {
|
||||||
use yansi::Paint;
|
|
||||||
use rocket::http::Method::*;
|
use rocket::http::Method::*;
|
||||||
|
|
||||||
// Determine the port TLS is being served on.
|
|
||||||
let tls_port = self.tls_port.get_or_init(|| config.port);
|
|
||||||
|
|
||||||
// Adjust config for redirector: disable TLS, set port, disable logging.
|
|
||||||
config.tls = None;
|
|
||||||
config.port = self.listen_port;
|
|
||||||
config.log_level = LogLevel::Critical;
|
|
||||||
|
|
||||||
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
||||||
info_!("redirecting on insecure port {} to TLS port {}",
|
info_!("redirecting insecure port {} to TLS port {}",
|
||||||
self.listen_port.yellow(), tls_port.green());
|
self.0.yellow(), config.tls_addr.port().green());
|
||||||
|
|
||||||
// Build a vector of routes to `redirect` on `<path..>` for each method.
|
// Build a vector of routes to `redirect` on `<path..>` for each method.
|
||||||
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
|
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
|
||||||
|
@ -59,10 +55,11 @@ impl Redirector {
|
||||||
.map(|m| Route::new(m, "/<path..>", Self::redirect))
|
.map(|m| Route::new(m, "/<path..>", Self::redirect))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
rocket::custom(config)
|
let addr = SocketAddr::new(config.tls_addr.ip(), self.0);
|
||||||
.manage(self)
|
rocket::custom(&config.server)
|
||||||
|
.manage(config)
|
||||||
.mount("/", redirects)
|
.mount("/", redirects)
|
||||||
.launch()
|
.launch_on(addr)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,8 +73,24 @@ impl Fairing for Redirector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn on_liftoff(&self, rkt: &Rocket<Orbit>) {
|
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||||
let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone());
|
let Some(tls_addr) = rocket.endpoint().tls().and_then(|tls| tls.tcp()) else {
|
||||||
|
info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
|
||||||
|
warn_!("Main instance is not being served over TLS/TCP.");
|
||||||
|
warn_!("Redirector refusing to start.");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = Config {
|
||||||
|
tls_addr,
|
||||||
|
server: rocket::Config {
|
||||||
|
log_level: LogLevel::Critical,
|
||||||
|
..rocket.config().clone()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let this = *self;
|
||||||
|
let shutdown = rocket.shutdown();
|
||||||
let _ = rocket::tokio::spawn(async move {
|
let _ = rocket::tokio::spawn(async move {
|
||||||
if let Err(e) = this.try_launch(config).await {
|
if let Err(e) = this.try_launch(config).await {
|
||||||
error!("Failed to start HTTP -> HTTPS redirector.");
|
error!("Failed to start HTTP -> HTTPS redirector.");
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
use std::fs::{self, File};
|
use std::fs::{self, File};
|
||||||
|
|
||||||
|
use rocket::http::{CookieJar, Cookie};
|
||||||
use rocket::local::blocking::Client;
|
use rocket::local::blocking::Client;
|
||||||
use rocket::fs::relative;
|
use rocket::fs::relative;
|
||||||
|
|
||||||
|
#[get("/cookie")]
|
||||||
|
fn cookie(jar: &CookieJar<'_>) {
|
||||||
|
jar.add(("k1", "v1"));
|
||||||
|
jar.add_private(("k2", "v2"));
|
||||||
|
|
||||||
|
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
|
||||||
|
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn hello_mutual() {
|
fn hello_mutual() {
|
||||||
let client = Client::tracked(super::rocket()).unwrap();
|
let client = Client::tracked_secure(super::rocket()).unwrap();
|
||||||
let cert_paths = fs::read_dir(relative!("private")).unwrap()
|
let cert_paths = fs::read_dir(relative!("private")).unwrap()
|
||||||
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
|
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
|
||||||
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
|
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
|
||||||
|
@ -23,35 +33,43 @@ fn hello_mutual() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn secure_cookies() {
|
fn secure_cookies() {
|
||||||
use rocket::http::{CookieJar, Cookie};
|
let rocket = super::rocket().mount("/", routes![cookie]);
|
||||||
|
let client = Client::tracked_secure(rocket).unwrap();
|
||||||
|
|
||||||
#[get("/cookie")]
|
|
||||||
fn cookie(jar: &CookieJar<'_>) {
|
|
||||||
jar.add(("k1", "v1"));
|
|
||||||
jar.add_private(("k2", "v2"));
|
|
||||||
|
|
||||||
jar.add(Cookie::build(("k1u", "v1u")).secure(false));
|
|
||||||
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
|
|
||||||
}
|
|
||||||
|
|
||||||
let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap();
|
|
||||||
let response = client.get("/cookie").dispatch();
|
let response = client.get("/cookie").dispatch();
|
||||||
|
|
||||||
let c1 = response.cookies().get("k1").unwrap();
|
let c1 = response.cookies().get("k1").unwrap();
|
||||||
assert_eq!(c1.secure(), Some(true));
|
|
||||||
|
|
||||||
let c2 = response.cookies().get_private("k2").unwrap();
|
let c2 = response.cookies().get_private("k2").unwrap();
|
||||||
|
let c3 = response.cookies().get("k1u").unwrap();
|
||||||
|
let c4 = response.cookies().get_private("k2u").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(c1.secure(), Some(true));
|
||||||
assert_eq!(c2.secure(), Some(true));
|
assert_eq!(c2.secure(), Some(true));
|
||||||
|
assert_ne!(c3.secure(), Some(true));
|
||||||
|
assert_ne!(c4.secure(), Some(true));
|
||||||
|
}
|
||||||
|
|
||||||
let c1 = response.cookies().get("k1u").unwrap();
|
#[test]
|
||||||
assert_ne!(c1.secure(), Some(true));
|
fn insecure_cookies() {
|
||||||
|
let rocket = super::rocket().mount("/", routes![cookie]);
|
||||||
|
let client = Client::tracked(rocket).unwrap();
|
||||||
|
|
||||||
let c2 = response.cookies().get_private("k2u").unwrap();
|
let response = client.get("/cookie").dispatch();
|
||||||
assert_ne!(c2.secure(), Some(true));
|
let c1 = response.cookies().get("k1").unwrap();
|
||||||
|
let c2 = response.cookies().get_private("k2").unwrap();
|
||||||
|
let c3 = response.cookies().get("k1u").unwrap();
|
||||||
|
let c4 = response.cookies().get_private("k2u").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(c1.secure(), None);
|
||||||
|
assert_eq!(c2.secure(), None);
|
||||||
|
assert_eq!(c3.secure(), None);
|
||||||
|
assert_eq!(c4.secure(), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn hello_world() {
|
fn hello_world() {
|
||||||
|
use rocket::listener::DefaultListener;
|
||||||
|
use rocket::config::{Config, SecretKey};
|
||||||
|
|
||||||
let profiles = [
|
let profiles = [
|
||||||
"rsa_sha256",
|
"rsa_sha256",
|
||||||
"ecdsa_nistp256_sha256_pkcs8",
|
"ecdsa_nistp256_sha256_pkcs8",
|
||||||
|
@ -61,11 +79,20 @@ fn hello_world() {
|
||||||
"ed25519",
|
"ed25519",
|
||||||
];
|
];
|
||||||
|
|
||||||
// TODO: Testing doesn't actually read keys since we don't do TLS locally.
|
|
||||||
for profile in profiles {
|
for profile in profiles {
|
||||||
let config = rocket::Config::figment().select(profile);
|
let config = Config {
|
||||||
let client = Client::tracked(super::rocket().configure(config)).unwrap();
|
secret_key: SecretKey::generate().unwrap(),
|
||||||
|
..Config::debug_default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let figment = Config::figment().merge(config).select(profile);
|
||||||
|
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
|
||||||
let response = client.get("/").dispatch();
|
let response = client.get("/").dispatch();
|
||||||
assert_eq!(response.into_string().unwrap(), "Hello, world!");
|
assert_eq!(response.into_string().unwrap(), "Hello, world!");
|
||||||
|
|
||||||
|
let figment = client.rocket().figment();
|
||||||
|
let listener: DefaultListener = figment.extract().unwrap();
|
||||||
|
assert_eq!(figment.profile(), profile);
|
||||||
|
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
<div id="log"></div>
|
<div id="log"></div>
|
||||||
</body>
|
</body>
|
||||||
<script language="javascript" type="text/javascript">
|
<script language="javascript" type="text/javascript">
|
||||||
var wsUri = "ws://127.0.0.1:8000/echo";
|
var wsUri = "ws://127.0.0.1:8000/echo?raw";
|
||||||
var log;
|
var log;
|
||||||
|
|
||||||
function init() {
|
function init() {
|
||||||
|
|
|
@ -20,8 +20,10 @@ fi
|
||||||
echo ":::: Generating the docs..."
|
echo ":::: Generating the docs..."
|
||||||
pushd "${PROJECT_ROOT}" > /dev/null 2>&1
|
pushd "${PROJECT_ROOT}" > /dev/null 2>&1
|
||||||
# Set the crate version and fill in missing doc URLs with docs.rs links.
|
# Set the crate version and fill in missing doc URLs with docs.rs links.
|
||||||
RUSTDOCFLAGS="-Zunstable-options --crate-version ${DOC_VERSION} --extern-html-root-url rocket=https://api.rocket.rs/rocket/" \
|
RUSTDOCFLAGS="-Z unstable-options \
|
||||||
cargo doc -Zrustdoc-map --no-deps --all-features \
|
--crate-version ${DOC_VERSION} \
|
||||||
|
--enable-index-page" \
|
||||||
|
cargo doc -Zrustdoc-map --no-deps --all-features \
|
||||||
-p rocket \
|
-p rocket \
|
||||||
-p rocket_db_pools \
|
-p rocket_db_pools \
|
||||||
-p rocket_sync_db_pools \
|
-p rocket_sync_db_pools \
|
||||||
|
|
|
@ -126,10 +126,11 @@ function test_contrib() {
|
||||||
|
|
||||||
function test_core() {
|
function test_core() {
|
||||||
FEATURES=(
|
FEATURES=(
|
||||||
|
tokio-macros
|
||||||
|
http2
|
||||||
secrets
|
secrets
|
||||||
tls
|
tls
|
||||||
mtls
|
mtls
|
||||||
http2
|
|
||||||
json
|
json
|
||||||
msgpack
|
msgpack
|
||||||
uuid
|
uuid
|
||||||
|
|
Loading…
Reference in New Issue